xref: /plugin/aichat/Storage/SQLiteStorage.php (revision ab1f8dde36106432cc0a6f320220da5fae6971fe)
1<?php
2
3/** @noinspection SqlResolve */
4
5namespace dokuwiki\plugin\aichat\Storage;
6
7use dokuwiki\plugin\aichat\AIChat;
8use dokuwiki\plugin\aichat\Chunk;
9use dokuwiki\plugin\sqlite\SQLiteDB;
10use KMeans\Cluster;
11use KMeans\Space;
12
13/**
14 * Implements the storage backend using a SQLite database
15 *
16 * Note: all embeddings are stored and returned as normalized vectors
17 */
18class SQLiteStorage extends AbstractStorage
19{
20
21    /** @var int Number of documents to randomly sample to create the clusters */
22    final public const SAMPLE_SIZE = 2000;
23    /** @var int The average size of each cluster */
24    final public const CLUSTER_SIZE = 400;
25
26    /** @var SQLiteDB */
27    protected $db;
28
29    protected $useLanguageClusters = false;
30
31    /** @var float minimum similarity to consider a chunk a match */
32    protected $similarityThreshold = 0;
33
34    /** @inheritdoc */
35    public function __construct(array $config)
36    {
37        $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
38        $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2);
39
40        $helper = plugin_load('helper', 'aichat');
41        $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED;
42
43        $this->similarityThreshold = $config['similarityThreshold']/100;
44    }
45
46    /** @inheritdoc */
47    public function getChunk($chunkID)
48    {
49        $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]);
50        if (!$record) return null;
51
52        return new Chunk(
53            $record['page'],
54            $record['id'],
55            $record['chunk'],
56            json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
57            $record['lang'],
58            $record['created']
59        );
60    }
61
62    /** @inheritdoc */
63    public function startCreation($clear = false)
64    {
65        if ($clear) {
66            /** @noinspection SqlWithoutWhere */
67            $this->db->exec('DELETE FROM embeddings');
68            /** @noinspection SqlWithoutWhere */
69            $this->db->exec('DELETE FROM clusters');
70        }
71    }
72
73    /** @inheritdoc */
74    public function reusePageChunks($page, $firstChunkID)
75    {
76        // no-op
77    }
78
79    /** @inheritdoc */
80    public function deletePageChunks($page, $firstChunkID)
81    {
82        $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]);
83    }
84
85    /** @inheritdoc */
86    public function addPageChunks($chunks)
87    {
88        foreach ($chunks as $chunk) {
89            $this->db->saveRecord('embeddings', [
90                'page' => $chunk->getPage(),
91                'id' => $chunk->getId(),
92                'chunk' => $chunk->getText(),
93                'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR),
94                'created' => $chunk->getCreated(),
95                'lang' => $chunk->getLanguage(),
96            ]);
97        }
98    }
99
100    /** @inheritdoc */
101    public function finalizeCreation()
102    {
103        if (!$this->hasClusters()) {
104            $this->createClusters();
105        }
106        $this->setChunkClusters();
107
108        $this->db->exec('VACUUM');
109    }
110
111    /** @inheritdoc */
112    public function runMaintenance()
113    {
114        $this->createClusters();
115        $this->setChunkClusters();
116    }
117
118    /** @inheritdoc */
119    public function getPageChunks($page, $firstChunkID)
120    {
121        $result = $this->db->queryAll(
122            'SELECT * FROM embeddings WHERE page = ?',
123            [$page]
124        );
125        $chunks = [];
126        foreach ($result as $record) {
127            $chunks[] = new Chunk(
128                $record['page'],
129                $record['id'],
130                $record['chunk'],
131                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
132                $record['lang'],
133                $record['created']
134            );
135        }
136        return $chunks;
137    }
138
139    /** @inheritdoc */
140    public function getSimilarChunks($vector, $lang = '', $limit = 4)
141    {
142        $cluster = $this->getCluster($vector, $lang);
143        if ($this->logger) $this->logger->info(
144            'Using cluster {cluster} for similarity search',
145            ['cluster' => $cluster]
146        );
147
148        $result = $this->db->queryAll(
149            'SELECT *, COSIM(?, embedding) AS similarity
150               FROM embeddings
151              WHERE cluster = ?
152                AND GETACCESSLEVEL(page) > 0
153                AND similarity > CAST(? AS FLOAT)
154           ORDER BY similarity DESC
155              LIMIT ?',
156            [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, $this->similarityThreshold, $limit]
157        );
158        $chunks = [];
159        foreach ($result as $record) {
160            $chunks[] = new Chunk(
161                $record['page'],
162                $record['id'],
163                $record['chunk'],
164                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
165                $record['lang'],
166                $record['created'],
167                $record['similarity']
168            );
169        }
170        return $chunks;
171    }
172
173    /** @inheritdoc */
174    public function statistics()
175    {
176        $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings');
177        $size = $this->db->queryValue(
178            'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
179        );
180        $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt
181                    FROM embeddings
182                GROUP BY cluster
183                ORDER BY cluster";
184        $clusters = $this->db->queryKeyValueList($query);
185
186        return [
187            'storage type' => 'SQLite',
188            'chunks' => $items,
189            'db size' => filesize_h($size),
190            'clusters' => $clusters,
191        ];
192    }
193
194    /**
195     * Method registered as SQLite callback to calculate the cosine similarity
196     *
197     * @param string $query JSON encoded vector array
198     * @param string $embedding JSON encoded vector array
199     * @return float
200     */
201    public function sqliteCosineSimilarityCallback($query, $embedding)
202    {
203        return (float)$this->cosineSimilarity(
204            json_decode($query, true, 512, JSON_THROW_ON_ERROR),
205            json_decode($embedding, true, 512, JSON_THROW_ON_ERROR)
206        );
207    }
208
209    /**
210     * Calculate the cosine similarity between two vectors
211     *
212     * Actually just calculating the dot product of the two vectors, since they are normalized
213     *
214     * @param float[] $queryVector The normalized vector of the search phrase
215     * @param float[] $embedding The normalized vector of the chunk
216     * @return float
217     */
218    protected function cosineSimilarity($queryVector, $embedding)
219    {
220        $dotProduct = 0;
221        foreach ($queryVector as $key => $value) {
222            $dotProduct += $value * $embedding[$key];
223        }
224        return $dotProduct;
225    }
226
227    /**
228     * Create new clusters based on random chunks
229     *
230     * @return void
231     */
232    protected function createClusters()
233    {
234        if ($this->useLanguageClusters) {
235            $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings');
236            $langs = array_column($result, 'lang');
237            foreach ($langs as $lang) {
238                $this->createLanguageClusters($lang);
239            }
240        } else {
241            $this->createLanguageClusters('');
242        }
243    }
244
245    /**
246     * Create new clusters based on random chunks for the given Language
247     *
248     * @param string $lang The language to cluster, empty when all languages go into the same cluster
249     * @noinspection SqlWithoutWhere
250     */
251    protected function createLanguageClusters($lang)
252    {
253        if ($lang != '') {
254            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
255        } else {
256            $where = '';
257        }
258
259        if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]);
260        $this->db->getPdo()->beginTransaction();
261        try {
262            // clean up old cluster data
263            $query = "DELETE FROM clusters $where";
264            $this->db->exec($query);
265            $query = "UPDATE embeddings SET cluster = NULL $where";
266            $this->db->exec($query);
267
268            // get a random selection of chunks
269            $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?";
270            $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
271            if (!$result) return; // no data to cluster
272            $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR));
273
274            // how many clusters?
275            if (count($result) < self::CLUSTER_SIZE * 3) {
276                // there would be less than 3 clusters, so just use one
277                $clustercount = 1;
278            } else {
279                // get the number of all chunks, to calculate the number of clusters
280                $query = "SELECT COUNT(*) FROM embeddings $where";
281                $total = $this->db->queryValue($query);
282                $clustercount = ceil($total / self::CLUSTER_SIZE);
283            }
284            if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
285
286            // cluster them using kmeans
287            $space = new Space($dimensions);
288            foreach ($result as $record) {
289                $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR));
290            }
291            $clusters = $space->solve($clustercount, function ($space, $clusters) {
292                static $iterations = 0;
293                ++$iterations;
294                if ($this->logger) {
295                    $clustercounts = implode(',', array_map('count', $clusters));
296                    $this->logger->info('Iteration {iteration}: [{clusters}]', [
297                        'iteration' => $iterations, 'clusters' => $clustercounts
298                    ]);
299                }
300            }, Cluster::INIT_KMEANS_PLUS_PLUS);
301
302            // store the clusters
303            foreach ($clusters as $cluster) {
304                /** @var Cluster $cluster */
305                $centroid = $cluster->getCoordinates();
306                $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)';
307                $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]);
308            }
309
310            $this->db->getPdo()->commit();
311            if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
312        } catch (\Exception $e) {
313            $this->db->getPdo()->rollBack();
314            throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e);
315        }
316    }
317
318    /**
319     * Assign the nearest cluster for all chunks that don't have one
320     *
321     * @return void
322     */
323    protected function setChunkClusters()
324    {
325        if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
326        $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL';
327        $handle = $this->db->query($query);
328
329        while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
330            $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR);
331            $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : '');
332            $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
333            $this->db->exec($query, [$cluster, $record['id']]);
334            if ($this->logger) $this->logger->success(
335                'Chunk {id} assigned to cluster {cluster}',
336                ['id' => $record['id'], 'cluster' => $cluster]
337            );
338        }
339        $handle->closeCursor();
340    }
341
342    /**
343     * Get the nearest cluster for the given vector
344     *
345     * @param float[] $vector
346     * @return int|null
347     */
348    protected function getCluster($vector, $lang)
349    {
350        if ($lang != '') {
351            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
352        } else {
353            $where = '';
354        }
355
356        $query = "SELECT cluster, centroid
357                    FROM clusters
358                   $where
359                ORDER BY COSIM(centroid, ?) DESC
360                   LIMIT 1";
361
362        $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]);
363        if (!$result) return null;
364        return $result['cluster'];
365    }
366
367    /**
368     * Check if clustering has been done before
369     * @return bool
370     */
371    protected function hasClusters()
372    {
373        $query = 'SELECT COUNT(*) FROM clusters';
374        return $this->db->queryValue($query) > 0;
375    }
376
377    /**
378     * Writes TSV files for visualizing with http://projector.tensorflow.org/
379     *
380     * @param string $vectorfile path to the file with the vectors
381     * @param string $metafile path to the file with the metadata
382     * @return void
383     */
384    public function dumpTSV($vectorfile, $metafile)
385    {
386        $query = 'SELECT * FROM embeddings';
387        $handle = $this->db->query($query);
388
389        $header = implode("\t", ['id', 'page', 'created']);
390        file_put_contents($metafile, $header . "\n", FILE_APPEND);
391
392        while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) {
393            $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR);
394            $vector = implode("\t", $vector);
395
396            $meta = implode("\t", [$row['id'], $row['page'], $row['created']]);
397
398            file_put_contents($vectorfile, $vector . "\n", FILE_APPEND);
399            file_put_contents($metafile, $meta . "\n", FILE_APPEND);
400        }
401    }
402}
403