xref: /plugin/aichat/Embeddings.php (revision 303d0c59f1b12d14fa2d9b91127b31f070c16911)
1<?php
2
3namespace dokuwiki\plugin\aichat;
4
5use dokuwiki\Extension\Event;
6use dokuwiki\Extension\PluginInterface;
7use dokuwiki\plugin\aichat\Model\ChatInterface;
8use dokuwiki\plugin\aichat\Model\EmbeddingInterface;
9use dokuwiki\plugin\aichat\Storage\AbstractStorage;
10use dokuwiki\Search\Indexer;
11use splitbrain\phpcli\CLI;
12use TikToken\Encoder;
13use Vanderlee\Sentence\Sentence;
14
15/**
16 * Manage the embeddings index
17 *
18 * Pages are split into chunks of 1000 tokens each. For each chunk the embedding vector is fetched from
19 * OpenAI and stored in the Storage backend.
20 */
21class Embeddings
22{
23    /** @var int maximum overlap between chunks in tokens */
24    final public const MAX_OVERLAP_LEN = 200;
25
26    /** @var ChatInterface */
27    protected $chatModel;
28
29    /** @var EmbeddingInterface */
30    protected $embedModel;
31
32    /** @var CLI|null */
33    protected $logger;
34    /** @var Encoder */
35    protected $tokenEncoder;
36
37    /** @var AbstractStorage */
38    protected $storage;
39
40    /** @var array remember sentences when chunking */
41    private $sentenceQueue = [];
42
43    /** @var int the time spent for the last similar chunk retrieval */
44    public $timeSpent = 0;
45
46    protected $configChunkSize;
47    protected $configContextChunks;
48    protected $similarityThreshold;
49
50    /**
51     * Embeddings constructor.
52     *
53     * @param ChatInterface $chatModel
54     * @param EmbeddingInterface $embedModel
55     * @param AbstractStorage $storage
56     * @param array $config The plugin configuration
57     */
58    public function __construct(
59        ChatInterface $chatModel,
60        EmbeddingInterface $embedModel,
61        AbstractStorage $storage,
62        $config
63    ) {
64        $this->chatModel = $chatModel;
65        $this->embedModel = $embedModel;
66        $this->storage = $storage;
67        $this->configChunkSize = $config['chunkSize'];
68        $this->configContextChunks = $config['contextChunks'];
69        $this->similarityThreshold = $config['similarityThreshold'] / 100;
70    }
71
72    /**
73     * Access storage
74     *
75     * @return AbstractStorage
76     */
77    public function getStorage()
78    {
79        return $this->storage;
80    }
81
82    /**
83     * Add a logger instance
84     *
85     * @return void
86     */
87    public function setLogger(CLI $logger)
88    {
89        $this->logger = $logger;
90    }
91
92    /**
93     * Get the token encoder instance
94     *
95     * @return Encoder
96     */
97    public function getTokenEncoder()
98    {
99        if (!$this->tokenEncoder instanceof Encoder) {
100            $this->tokenEncoder = new Encoder();
101        }
102        return $this->tokenEncoder;
103    }
104
105    /**
106     * Return the chunk size to use
107     *
108     * @return int
109     */
110    public function getChunkSize()
111    {
112        return min(
113            floor($this->chatModel->getMaxInputTokenLength() / 4), // be able to fit 4 chunks into the max input
114            floor($this->embedModel->getMaxInputTokenLength() * 0.9), // only use 90% of the embedding model to be safe
115            $this->configChunkSize, // this is usually the smallest
116        );
117    }
118
119    /**
120     * Update the embeddings storage
121     *
122     * @param string $skipRE Regular expression to filter out pages (full RE with delimiters)
123     * @param string $matchRE Regular expression pages have to match to be included (full RE with delimiters)
124     * @param bool $clear Should any existing storage be cleared before updating?
125     * @return void
126     * @throws \Exception
127     */
128    public function createNewIndex($skipRE = '', $matchRE = '', $clear = false)
129    {
130        $indexer = new Indexer();
131        $pages = $indexer->getPages();
132
133        $this->storage->startCreation($clear);
134        foreach ($pages as $pid => $page) {
135            $chunkID = $pid * 100; // chunk IDs start at page ID * 100
136
137            if (
138                !page_exists($page) ||
139                isHiddenPage($page) ||
140                filesize(wikiFN($page)) < 150 || // skip very small pages
141                ($skipRE && preg_match($skipRE, (string)$page)) ||
142                ($matchRE && !preg_match($matchRE, ":$page"))
143            ) {
144                // this page should not be in the index (anymore)
145                $this->storage->deletePageChunks($page, $chunkID);
146                continue;
147            }
148
149            $firstChunk = $this->storage->getChunk($chunkID);
150            if ($firstChunk && @filemtime(wikiFN($page)) < $firstChunk->getCreated()) {
151                // page is older than the chunks we have, reuse the existing chunks
152                $this->storage->reusePageChunks($page, $chunkID);
153                if ($this->logger instanceof CLI) $this->logger->info("Reusing chunks for $page");
154            } else {
155                // page is newer than the chunks we have, create new chunks
156                $this->storage->deletePageChunks($page, $chunkID);
157                $chunks = $this->createPageChunks($page, $chunkID);
158                if ($chunks) $this->storage->addPageChunks($chunks);
159            }
160        }
161        $this->storage->finalizeCreation();
162    }
163
164    /**
165     * Split the given page, fetch embedding vectors and return Chunks
166     *
167     * Will use the text renderer plugin if available to get the rendered text.
168     * Otherwise the raw wiki text is used.
169     *
170     * @param string $page Name of the page to split
171     * @param int $firstChunkID The ID of the first chunk of this page
172     * @return Chunk[] A list of chunks created for this page
173     * @emits INDEXER_PAGE_ADD support plugins that add additional data to the page
174     * @throws \Exception
175     */
176    public function createPageChunks($page, $firstChunkID)
177    {
178        $chunkList = [];
179
180        $textRenderer = plugin_load('renderer', 'text');
181        if ($textRenderer instanceof PluginInterface) {
182            global $ID;
183            $ID = $page;
184            try {
185                $text = p_cached_output(wikiFN($page), 'text', $page);
186            } catch (\Throwable $e) {
187                if ($this->logger) $this->logger->error(
188                    'Failed to render page {page} using raw text instead. {msg}',
189                    ['page' => $page, 'msg' => $e->getMessage()]
190                );
191                $text = rawWiki($page);
192            }
193        } else {
194            $text = rawWiki($page);
195        }
196
197        // allow plugins to modify the text before splitting
198        $eventData = [
199            'page' => $page,
200            'body' => '',
201            'metadata' => ['title' => $page, 'relation_references' => []],
202        ];
203        $event = new Event('INDEXER_PAGE_ADD', $eventData);
204        if ($event->advise_before()) {
205            $text = $eventData['body'] . ' ' . $text;
206        } else {
207            $text = $eventData['body'];
208        }
209
210        $parts = $this->splitIntoChunks($text);
211        foreach ($parts as $part) {
212            if (trim((string)$part) == '') continue; // skip empty chunks
213
214            try {
215                $embedding = $this->embedModel->getEmbedding($part);
216            } catch (\Exception $e) {
217                if ($this->logger instanceof CLI) {
218                    $this->logger->error(
219                        'Failed to get embedding for chunk of page {page}: {msg}',
220                        ['page' => $page, 'msg' => $e->getMessage()]
221                    );
222                }
223                continue;
224            }
225            $chunkList[] = new Chunk($page, $firstChunkID, $part, $embedding);
226            $firstChunkID++;
227        }
228        if ($this->logger instanceof CLI) {
229            if ($chunkList !== []) {
230                $this->logger->success(
231                    '{id} split into {count} chunks',
232                    ['id' => $page, 'count' => count($chunkList)]
233                );
234            } else {
235                $this->logger->warning('{id} could not be split into chunks', ['id' => $page]);
236            }
237        }
238        return $chunkList;
239    }
240
241    /**
242     * Do a nearest neighbor search for chunks similar to the given question
243     *
244     * Returns only chunks the current user is allowed to read, may return an empty result.
245     * The number of returned chunks depends on the MAX_CONTEXT_LEN setting.
246     *
247     * @param string $query The question
248     * @param string $lang Limit results to this language
249     * @return Chunk[]
250     * @throws \Exception
251     */
252    public function getSimilarChunks($query, $lang = '')
253    {
254        global $auth;
255        $vector = $this->embedModel->getEmbedding($query);
256
257        $fetch = min(
258            ($this->chatModel->getMaxInputTokenLength() / $this->getChunkSize()),
259            $this->configContextChunks
260        );
261
262        $time = microtime(true);
263        $chunks = $this->storage->getSimilarChunks($vector, $lang, $fetch);
264        $this->timeSpent = round(microtime(true) - $time, 2);
265        if ($this->logger instanceof CLI) {
266            $this->logger->info(
267                'Fetched {count} similar chunks from store in {time} seconds',
268                ['count' => count($chunks), 'time' => $this->timeSpent]
269            );
270        }
271
272        $size = 0;
273        $result = [];
274        foreach ($chunks as $chunk) {
275            // filter out chunks the user is not allowed to read
276            if ($auth && auth_quickaclcheck($chunk->getPage()) < AUTH_READ) continue;
277            if ($chunk->getScore() < $this->similarityThreshold) continue;
278
279            $chunkSize = count($this->getTokenEncoder()->encode($chunk->getText()));
280            if ($size + $chunkSize > $this->chatModel->getMaxInputTokenLength()) break; // we have enough
281
282            $result[] = $chunk;
283            $size += $chunkSize;
284        }
285        return $result;
286    }
287
288
289    /**
290     * @param $text
291     * @return array
292     * @throws \Exception
293     * @todo support splitting too long sentences
294     */
295    protected function splitIntoChunks($text)
296    {
297        $sentenceSplitter = new Sentence();
298        $tiktok = $this->getTokenEncoder();
299
300        $chunks = [];
301        $sentences = $sentenceSplitter->split($text);
302
303        $chunklen = 0;
304        $chunk = '';
305        while ($sentence = array_shift($sentences)) {
306            $slen = count($tiktok->encode($sentence));
307            if ($slen > $this->getChunkSize()) {
308                // sentence is too long, we need to split it further
309                if ($this->logger instanceof CLI) $this->logger->warning(
310                    'Sentence too long, splitting not implemented yet'
311                );
312                continue;
313            }
314
315            if ($chunklen + $slen < $this->getChunkSize()) {
316                // add to current chunk
317                $chunk .= $sentence;
318                $chunklen += $slen;
319                // remember sentence for overlap check
320                $this->rememberSentence($sentence);
321            } else {
322                // add current chunk to result
323                $chunk = trim($chunk);
324                if ($chunk !== '') $chunks[] = $chunk;
325
326                // start new chunk with remembered sentences
327                $chunk = implode(' ', $this->sentenceQueue);
328                $chunk .= $sentence;
329                $chunklen = count($tiktok->encode($chunk));
330            }
331        }
332        $chunks[] = $chunk;
333
334        return $chunks;
335    }
336
337    /**
338     * Add a sentence to the queue of remembered sentences
339     *
340     * @param string $sentence
341     * @return void
342     */
343    protected function rememberSentence($sentence)
344    {
345        // add sentence to queue
346        $this->sentenceQueue[] = $sentence;
347
348        // remove oldest sentences from queue until we are below the max overlap
349        $encoder = $this->getTokenEncoder();
350        while (count($encoder->encode(implode(' ', $this->sentenceQueue))) > self::MAX_OVERLAP_LEN) {
351            array_shift($this->sentenceQueue);
352        }
353    }
354}
355