1<?php
2
3namespace TikToken;
4
5class Encoder
6{
7    private bool $initialized = false;
8
9    /** @var array<string> */
10    private array $bpeCache = [];
11
12    /** @var array<string> */
13    private array $rawCharacters = [];
14
15    /** @var array<string> */
16    private array $encoder = [];
17
18    /** @var array<array<int>> */
19    private array $bpeRanks = [];
20
21    private function initialize(): void
22    {
23        if ($this->initialized) {
24            return;
25        }
26        $rawCharacters = file_get_contents(__DIR__.'/../data/characters.json');
27        if (false === $rawCharacters) {
28            throw new \RuntimeException('Unable to load characters.json');
29        }
30        $this->rawCharacters = json_decode($rawCharacters, true, 512, JSON_THROW_ON_ERROR);
31
32        $encoder = file_get_contents(__DIR__.'/../data/encoder.json');
33        if (false === $encoder) {
34            throw new \RuntimeException('Unable to load encoder.json');
35        }
36        $this->encoder = json_decode($encoder, true, 512, JSON_THROW_ON_ERROR);
37
38        $bpeDictionary = file_get_contents(__DIR__.'/../data/vocab.bpe');
39        if (false === $bpeDictionary) {
40            throw new \RuntimeException('Unable to load vocab.bpe');
41        }
42
43        $lines = preg_split('#\r\n|\r|\n#', $bpeDictionary);
44        if (false === $lines) {
45            throw new \RuntimeException('Unable to split vocab.bpe');
46        }
47
48        $bpeMerges = [];
49        $rawDictionaryLines = array_slice($lines, 1, count($lines), true);
50        foreach ($rawDictionaryLines as $rawDictionaryLine) {
51            $splitLine = preg_split('#(\s+)#', (string) $rawDictionaryLine);
52            if (false === $splitLine) {
53                continue;
54            }
55            $splitLine = array_filter($splitLine, $this->filterEmpty(...));
56            if ([] !== $splitLine) {
57                $bpeMerges[] = $splitLine;
58            }
59        }
60
61        $this->bpeRanks = $this->buildBpeRanks($bpeMerges);
62        $this->initialized = true;
63    }
64
65    /** @return array<string> */
66    public function encode(string $text): array
67    {
68        if (empty($text)) {
69            return [];
70        }
71
72        $this->initialize();
73
74        preg_match_all("#'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+#u", $text, $matches);
75        if (!isset($matches[0]) || 0 == (is_countable($matches[0]) ? count($matches[0]) : 0)) {
76            return [];
77        }
78
79        $bpeTokens = [];
80        foreach ($matches[0] as $token) {
81            $token = mb_convert_encoding((string) $token, "UTF-8", "ISO-8859-1");
82            $characters = mb_str_split($token, 1, 'UTF-8');
83
84            $resultWord = '';
85            foreach ($characters as $char) {
86                if (!isset($this->rawCharacters[$this->characterToUnicode($char)])) {
87                    continue;
88                }
89                $resultWord .= $this->rawCharacters[$this->characterToUnicode($char)];
90            }
91
92            $newTokensBpe = $this->bpe($resultWord);
93            $newTokensBpe = explode(' ', $newTokensBpe);
94            foreach ($newTokensBpe as $newBpeToken) {
95                $encoded = $this->encoder[$newBpeToken] ?? $newBpeToken;
96                if (isset($bpeTokens[$newBpeToken])) {
97                    $bpeTokens[] = $encoded;
98                } else {
99                    $bpeTokens[$newBpeToken] = $encoded;
100                }
101            }
102        }
103
104        return array_values($bpeTokens);
105    }
106
107    private function filterEmpty(mixed $var): bool
108    {
109        return null !== $var && false !== $var && '' !== $var;
110    }
111
112    private function characterToUnicode(string $characters): int
113    {
114        $firstCharacterCode = ord($characters[0]);
115
116        if ($firstCharacterCode <= 127) {
117            return $firstCharacterCode;
118        }
119
120        if ($firstCharacterCode >= 192 && $firstCharacterCode <= 223) {
121            return ($firstCharacterCode - 192) * 64 + (ord($characters[1]) - 128);
122        }
123
124        if ($firstCharacterCode >= 224 && $firstCharacterCode <= 239) {
125            return ($firstCharacterCode - 224) * 4096 + (ord($characters[1]) - 128) * 64 + (ord($characters[2]) - 128);
126        }
127
128        if ($firstCharacterCode >= 240 && $firstCharacterCode <= 247) {
129            return ($firstCharacterCode - 240) * 262144 + (ord($characters[1]) - 128) * 4096 + (ord($characters[2]) - 128) * 64 + (ord($characters[3]) - 128);
130        }
131
132        if ($firstCharacterCode >= 248 && $firstCharacterCode <= 251) {
133            return ($firstCharacterCode - 248) * 16_777_216 + (ord($characters[1]) - 128) * 262144 + (ord($characters[2]) - 128) * 4096 + (ord($characters[3]) - 128) * 64 + (ord($characters[4]) - 128);
134        }
135
136        if ($firstCharacterCode >= 252 && $firstCharacterCode <= 253) {
137            return ($firstCharacterCode - 252) * 1_073_741_824 + (ord($characters[1]) - 128) * 16_777_216 + (ord($characters[2]) - 128) * 262144 + (ord($characters[3]) - 128) * 4096 + (ord($characters[4]) - 128) * 64 + (ord($characters[5]) - 128);
138        }
139
140        if ($firstCharacterCode >= 254) {
141            return 0;
142        }
143
144        return 0;
145    }
146
147    /**
148     * @param array<array<mixed>> $bpes
149     *
150     * @return array<array<int>>
151     */
152    private function buildBpeRanks(array $bpes): array
153    {
154        $result = [];
155        $rank = 0;
156        foreach ($bpes as $bpe) {
157            if (!isset($bpe[1], $bpe[0])) {
158                continue;
159            }
160
161            $result[$bpe[0]][$bpe[1]] = $rank;
162            ++$rank;
163        }
164
165        return $result;
166    }
167
168    /**
169     * Return set of symbol pairs in a word.
170     * Word is represented as tuple of symbols (symbols being variable-length strings).
171     *
172     * @param array<int, string> $word
173     *
174     * @return mixed[]
175     */
176    private function buildSymbolPairs(array $word): array
177    {
178        $pairs = [];
179        $previousPart = null;
180        foreach ($word as $i => $part) {
181            if ($i > 0) {
182                $pairs[] = [$previousPart, $part];
183            }
184
185            $previousPart = $part;
186        }
187
188        return $pairs;
189    }
190
191    private function bpe(string $token): string
192    {
193        if (isset($this->bpeCache[$token])) {
194            return $this->bpeCache[$token];
195        }
196
197        $word = mb_str_split($token, 1, 'UTF-8');
198        $initialLength = count($word);
199        $pairs = $this->buildSymbolPairs($word);
200        if ([] === $pairs) {
201            return $token;
202        }
203
204        while (true) {
205            $minPairs = [];
206            foreach ($pairs as $pair) {
207                if (isset($this->bpeRanks[$pair[0]][$pair[1]])) {
208                    $rank = $this->bpeRanks[$pair[0]][$pair[1]];
209                    $minPairs[$rank] = $pair;
210                } else {
211                    $minPairs[10e10] = $pair;
212                }
213            }
214
215            $minPairsKeys = array_keys($minPairs);
216            sort($minPairsKeys, SORT_NUMERIC);
217            $minimumKey = $minPairsKeys[0] ?? null;
218
219            $bigram = $minPairs[$minimumKey];
220            if (!isset($this->bpeRanks[$bigram[0]][$bigram[1]])) {
221                break;
222            }
223
224            $first = $bigram[0];
225            $second = $bigram[1];
226            $newWord = [];
227            $i = 0;
228            while ($i < count($word)) {
229                $j = $this->indexOf($word, $first, $i);
230                if (-1 === $j) {
231                    $newWord = [
232                        ...$newWord,
233                        ...array_slice($word, $i, null, true),
234                    ];
235                    break;
236                }
237
238                $slicer = $i > $j || 0 === $j ? [] : array_slice($word, $i, $j - $i, true);
239
240                $newWord = [
241                    ...$newWord,
242                    ...$slicer,
243                ];
244                if (count($newWord) > $initialLength) {
245                    break;
246                }
247
248                $i = $j;
249                if ($word[$i] === $first && $i < count($word) - 1 && $word[$i + 1] === $second) {
250                    $newWord[] = $first.$second;
251                    $i += 2;
252                } else {
253                    $newWord[] = $word[$i];
254                    ++$i;
255                }
256            }
257
258            if ($word === $newWord) {
259                break;
260            }
261
262            $word = $newWord;
263            if (1 === count($word)) {
264                break;
265            }
266
267            $pairs = $this->buildSymbolPairs($word);
268        }
269
270        $word = implode(' ', $word);
271        $this->bpeCache[$token] = $word;
272
273        return $word;
274    }
275
276    /**
277     * @param array<int, string> $array
278     */
279    private function indexOf(array $array, string $searchElement, int $fromIndex): int
280    {
281        $slicedArray = array_slice($array, $fromIndex, preserve_keys: true);
282
283        $indexed = array_search($searchElement, $slicedArray);
284
285        return false === $indexed ? -1 : $indexed;
286    }
287}
288