1<?php 2 3namespace TikToken; 4 5class Encoder 6{ 7 private bool $initialized = false; 8 9 /** @var array<string> */ 10 private array $bpeCache = []; 11 12 /** @var array<string> */ 13 private array $rawCharacters = []; 14 15 /** @var array<string> */ 16 private array $encoder = []; 17 18 /** @var array<array<int>> */ 19 private array $bpeRanks = []; 20 21 private function initialize(): void 22 { 23 if ($this->initialized) { 24 return; 25 } 26 $rawCharacters = file_get_contents(__DIR__.'/../data/characters.json'); 27 if (false === $rawCharacters) { 28 throw new \RuntimeException('Unable to load characters.json'); 29 } 30 $this->rawCharacters = json_decode($rawCharacters, true, 512, JSON_THROW_ON_ERROR); 31 32 $encoder = file_get_contents(__DIR__.'/../data/encoder.json'); 33 if (false === $encoder) { 34 throw new \RuntimeException('Unable to load encoder.json'); 35 } 36 $this->encoder = json_decode($encoder, true, 512, JSON_THROW_ON_ERROR); 37 38 $bpeDictionary = file_get_contents(__DIR__.'/../data/vocab.bpe'); 39 if (false === $bpeDictionary) { 40 throw new \RuntimeException('Unable to load vocab.bpe'); 41 } 42 43 $lines = preg_split('#\r\n|\r|\n#', $bpeDictionary); 44 if (false === $lines) { 45 throw new \RuntimeException('Unable to split vocab.bpe'); 46 } 47 48 $bpeMerges = []; 49 $rawDictionaryLines = array_slice($lines, 1, count($lines), true); 50 foreach ($rawDictionaryLines as $rawDictionaryLine) { 51 $splitLine = preg_split('#(\s+)#', (string) $rawDictionaryLine); 52 if (false === $splitLine) { 53 continue; 54 } 55 $splitLine = array_filter($splitLine, $this->filterEmpty(...)); 56 if ([] !== $splitLine) { 57 $bpeMerges[] = $splitLine; 58 } 59 } 60 61 $this->bpeRanks = $this->buildBpeRanks($bpeMerges); 62 $this->initialized = true; 63 } 64 65 /** @return array<string> */ 66 public function encode(string $text): array 67 { 68 if (empty($text)) { 69 return []; 70 } 71 72 $this->initialize(); 73 74 preg_match_all("#'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+#u", $text, $matches); 75 if (!isset($matches[0]) || 0 == (is_countable($matches[0]) ? count($matches[0]) : 0)) { 76 return []; 77 } 78 79 $bpeTokens = []; 80 foreach ($matches[0] as $token) { 81 $token = mb_convert_encoding((string) $token, "UTF-8", "ISO-8859-1"); 82 $characters = mb_str_split($token, 1, 'UTF-8'); 83 84 $resultWord = ''; 85 foreach ($characters as $char) { 86 if (!isset($this->rawCharacters[$this->characterToUnicode($char)])) { 87 continue; 88 } 89 $resultWord .= $this->rawCharacters[$this->characterToUnicode($char)]; 90 } 91 92 $newTokensBpe = $this->bpe($resultWord); 93 $newTokensBpe = explode(' ', $newTokensBpe); 94 foreach ($newTokensBpe as $newBpeToken) { 95 $encoded = $this->encoder[$newBpeToken] ?? $newBpeToken; 96 if (isset($bpeTokens[$newBpeToken])) { 97 $bpeTokens[] = $encoded; 98 } else { 99 $bpeTokens[$newBpeToken] = $encoded; 100 } 101 } 102 } 103 104 return array_values($bpeTokens); 105 } 106 107 private function filterEmpty(mixed $var): bool 108 { 109 return null !== $var && false !== $var && '' !== $var; 110 } 111 112 private function characterToUnicode(string $characters): int 113 { 114 $firstCharacterCode = ord($characters[0]); 115 116 if ($firstCharacterCode <= 127) { 117 return $firstCharacterCode; 118 } 119 120 if ($firstCharacterCode >= 192 && $firstCharacterCode <= 223) { 121 return ($firstCharacterCode - 192) * 64 + (ord($characters[1]) - 128); 122 } 123 124 if ($firstCharacterCode >= 224 && $firstCharacterCode <= 239) { 125 return ($firstCharacterCode - 224) * 4096 + (ord($characters[1]) - 128) * 64 + (ord($characters[2]) - 128); 126 } 127 128 if ($firstCharacterCode >= 240 && $firstCharacterCode <= 247) { 129 return ($firstCharacterCode - 240) * 262144 + (ord($characters[1]) - 128) * 4096 + (ord($characters[2]) - 128) * 64 + (ord($characters[3]) - 128); 130 } 131 132 if ($firstCharacterCode >= 248 && $firstCharacterCode <= 251) { 133 return ($firstCharacterCode - 248) * 16_777_216 + (ord($characters[1]) - 128) * 262144 + (ord($characters[2]) - 128) * 4096 + (ord($characters[3]) - 128) * 64 + (ord($characters[4]) - 128); 134 } 135 136 if ($firstCharacterCode >= 252 && $firstCharacterCode <= 253) { 137 return ($firstCharacterCode - 252) * 1_073_741_824 + (ord($characters[1]) - 128) * 16_777_216 + (ord($characters[2]) - 128) * 262144 + (ord($characters[3]) - 128) * 4096 + (ord($characters[4]) - 128) * 64 + (ord($characters[5]) - 128); 138 } 139 140 if ($firstCharacterCode >= 254) { 141 return 0; 142 } 143 144 return 0; 145 } 146 147 /** 148 * @param array<array<mixed>> $bpes 149 * 150 * @return array<array<int>> 151 */ 152 private function buildBpeRanks(array $bpes): array 153 { 154 $result = []; 155 $rank = 0; 156 foreach ($bpes as $bpe) { 157 if (!isset($bpe[1], $bpe[0])) { 158 continue; 159 } 160 161 $result[$bpe[0]][$bpe[1]] = $rank; 162 ++$rank; 163 } 164 165 return $result; 166 } 167 168 /** 169 * Return set of symbol pairs in a word. 170 * Word is represented as tuple of symbols (symbols being variable-length strings). 171 * 172 * @param array<int, string> $word 173 * 174 * @return mixed[] 175 */ 176 private function buildSymbolPairs(array $word): array 177 { 178 $pairs = []; 179 $previousPart = null; 180 foreach ($word as $i => $part) { 181 if ($i > 0) { 182 $pairs[] = [$previousPart, $part]; 183 } 184 185 $previousPart = $part; 186 } 187 188 return $pairs; 189 } 190 191 private function bpe(string $token): string 192 { 193 if (isset($this->bpeCache[$token])) { 194 return $this->bpeCache[$token]; 195 } 196 197 $word = mb_str_split($token, 1, 'UTF-8'); 198 $initialLength = count($word); 199 $pairs = $this->buildSymbolPairs($word); 200 if ([] === $pairs) { 201 return $token; 202 } 203 204 while (true) { 205 $minPairs = []; 206 foreach ($pairs as $pair) { 207 if (isset($this->bpeRanks[$pair[0]][$pair[1]])) { 208 $rank = $this->bpeRanks[$pair[0]][$pair[1]]; 209 $minPairs[$rank] = $pair; 210 } else { 211 $minPairs[10e10] = $pair; 212 } 213 } 214 215 $minPairsKeys = array_keys($minPairs); 216 sort($minPairsKeys, SORT_NUMERIC); 217 $minimumKey = $minPairsKeys[0] ?? null; 218 219 $bigram = $minPairs[$minimumKey]; 220 if (!isset($this->bpeRanks[$bigram[0]][$bigram[1]])) { 221 break; 222 } 223 224 $first = $bigram[0]; 225 $second = $bigram[1]; 226 $newWord = []; 227 $i = 0; 228 while ($i < count($word)) { 229 $j = $this->indexOf($word, $first, $i); 230 if (-1 === $j) { 231 $newWord = [ 232 ...$newWord, 233 ...array_slice($word, $i, null, true), 234 ]; 235 break; 236 } 237 238 $slicer = $i > $j || 0 === $j ? [] : array_slice($word, $i, $j - $i, true); 239 240 $newWord = [ 241 ...$newWord, 242 ...$slicer, 243 ]; 244 if (count($newWord) > $initialLength) { 245 break; 246 } 247 248 $i = $j; 249 if ($word[$i] === $first && $i < count($word) - 1 && $word[$i + 1] === $second) { 250 $newWord[] = $first.$second; 251 $i += 2; 252 } else { 253 $newWord[] = $word[$i]; 254 ++$i; 255 } 256 } 257 258 if ($word === $newWord) { 259 break; 260 } 261 262 $word = $newWord; 263 if (1 === count($word)) { 264 break; 265 } 266 267 $pairs = $this->buildSymbolPairs($word); 268 } 269 270 $word = implode(' ', $word); 271 $this->bpeCache[$token] = $word; 272 273 return $word; 274 } 275 276 /** 277 * @param array<int, string> $array 278 */ 279 private function indexOf(array $array, string $searchElement, int $fromIndex): int 280 { 281 $slicedArray = array_slice($array, $fromIndex, preserve_keys: true); 282 283 $indexed = array_search($searchElement, $slicedArray); 284 285 return false === $indexed ? -1 : $indexed; 286 } 287} 288