xref: /plugin/aichat/Storage/SQLiteStorage.php (revision 720bb43f9ac252f6e0b09e7b06804dec7c547a47)
1<?php
2
3/** @noinspection SqlResolve */
4
5namespace dokuwiki\plugin\aichat\Storage;
6
7use dokuwiki\plugin\aichat\AIChat;
8use dokuwiki\plugin\aichat\Chunk;
9use dokuwiki\plugin\sqlite\SQLiteDB;
10use KMeans\Cluster;
11use KMeans\Space;
12
13/**
14 * Implements the storage backend using a SQLite database
15 *
16 * Note: all embeddings are stored and returned as normalized vectors
17 */
18class SQLiteStorage extends AbstractStorage
19{
20
21    /** @var int Number of documents to randomly sample to create the clusters */
22    final public const SAMPLE_SIZE = 2000;
23    /** @var int The average size of each cluster */
24    final public const CLUSTER_SIZE = 400;
25
26    /** @var SQLiteDB */
27    protected $db;
28
29    protected $useLanguageClusters = false;
30
31    /** @var float minimum similarity to consider a chunk a match */
32    protected $similarityThreshold = 0;
33
34    /** @inheritdoc */
35    public function __construct(array $config)
36    {
37        $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
38        $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2);
39
40        $helper = plugin_load('helper', 'aichat');
41        $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED;
42
43        $this->similarityThreshold = $config['similarityThreshold']/100;
44    }
45
46    /** @inheritdoc */
47    public function getChunk($chunkID)
48    {
49        $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]);
50        if (!$record) return null;
51
52        return new Chunk(
53            $record['page'],
54            $record['id'],
55            $record['chunk'],
56            json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
57            $record['lang'],
58            $record['created']
59        );
60    }
61
62    /** @inheritdoc */
63    public function startCreation($clear = false)
64    {
65        if ($clear) {
66            /** @noinspection SqlWithoutWhere */
67            $this->db->exec('DELETE FROM embeddings');
68        }
69    }
70
71    /** @inheritdoc */
72    public function reusePageChunks($page, $firstChunkID)
73    {
74        // no-op
75    }
76
77    /** @inheritdoc */
78    public function deletePageChunks($page, $firstChunkID)
79    {
80        $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]);
81    }
82
83    /** @inheritdoc */
84    public function addPageChunks($chunks)
85    {
86        foreach ($chunks as $chunk) {
87            $this->db->saveRecord('embeddings', [
88                'page' => $chunk->getPage(),
89                'id' => $chunk->getId(),
90                'chunk' => $chunk->getText(),
91                'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR),
92                'created' => $chunk->getCreated(),
93                'lang' => $chunk->getLanguage(),
94            ]);
95        }
96    }
97
98    /** @inheritdoc */
99    public function finalizeCreation()
100    {
101        if (!$this->hasClusters()) {
102            $this->createClusters();
103        }
104        $this->setChunkClusters();
105
106        $this->db->exec('VACUUM');
107    }
108
109    /** @inheritdoc */
110    public function runMaintenance()
111    {
112        $this->createClusters();
113        $this->setChunkClusters();
114    }
115
116    /** @inheritdoc */
117    public function getPageChunks($page, $firstChunkID)
118    {
119        $result = $this->db->queryAll(
120            'SELECT * FROM embeddings WHERE page = ?',
121            [$page]
122        );
123        $chunks = [];
124        foreach ($result as $record) {
125            $chunks[] = new Chunk(
126                $record['page'],
127                $record['id'],
128                $record['chunk'],
129                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
130                $record['lang'],
131                $record['created']
132            );
133        }
134        return $chunks;
135    }
136
137    /** @inheritdoc */
138    public function getSimilarChunks($vector, $lang = '', $limit = 4)
139    {
140        $cluster = $this->getCluster($vector, $lang);
141        if ($this->logger) $this->logger->info(
142            'Using cluster {cluster} for similarity search',
143            ['cluster' => $cluster]
144        );
145
146        $result = $this->db->queryAll(
147            'SELECT *, COSIM(?, embedding) AS similarity
148               FROM embeddings
149              WHERE cluster = ?
150                AND GETACCESSLEVEL(page) > 0
151                AND similarity > CAST(? AS FLOAT)
152           ORDER BY similarity DESC
153              LIMIT ?',
154            [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, $this->similarityThreshold, $limit]
155        );
156        $chunks = [];
157        foreach ($result as $record) {
158            $chunks[] = new Chunk(
159                $record['page'],
160                $record['id'],
161                $record['chunk'],
162                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
163                $record['lang'],
164                $record['created'],
165                $record['similarity']
166            );
167        }
168        return $chunks;
169    }
170
171    /** @inheritdoc */
172    public function statistics()
173    {
174        $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings');
175        $size = $this->db->queryValue(
176            'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
177        );
178        $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt
179                    FROM embeddings
180                GROUP BY cluster
181                ORDER BY cluster";
182        $clusters = $this->db->queryKeyValueList($query);
183
184        return [
185            'storage type' => 'SQLite',
186            'chunks' => $items,
187            'db size' => filesize_h($size),
188            'clusters' => $clusters,
189        ];
190    }
191
192    /**
193     * Method registered as SQLite callback to calculate the cosine similarity
194     *
195     * @param string $query JSON encoded vector array
196     * @param string $embedding JSON encoded vector array
197     * @return float
198     */
199    public function sqliteCosineSimilarityCallback($query, $embedding)
200    {
201        return (float)$this->cosineSimilarity(
202            json_decode($query, true, 512, JSON_THROW_ON_ERROR),
203            json_decode($embedding, true, 512, JSON_THROW_ON_ERROR)
204        );
205    }
206
207    /**
208     * Calculate the cosine similarity between two vectors
209     *
210     * Actually just calculating the dot product of the two vectors, since they are normalized
211     *
212     * @param float[] $queryVector The normalized vector of the search phrase
213     * @param float[] $embedding The normalized vector of the chunk
214     * @return float
215     */
216    protected function cosineSimilarity($queryVector, $embedding)
217    {
218        $dotProduct = 0;
219        foreach ($queryVector as $key => $value) {
220            $dotProduct += $value * $embedding[$key];
221        }
222        return $dotProduct;
223    }
224
225    /**
226     * Create new clusters based on random chunks
227     *
228     * @return void
229     */
230    protected function createClusters()
231    {
232        if ($this->useLanguageClusters) {
233            $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings');
234            $langs = array_column($result, 'lang');
235            foreach ($langs as $lang) {
236                $this->createLanguageClusters($lang);
237            }
238        } else {
239            $this->createLanguageClusters('');
240        }
241    }
242
243    /**
244     * Create new clusters based on random chunks for the given Language
245     *
246     * @param string $lang The language to cluster, empty when all languages go into the same cluster
247     * @noinspection SqlWithoutWhere
248     */
249    protected function createLanguageClusters($lang)
250    {
251        if ($lang != '') {
252            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
253        } else {
254            $where = '';
255        }
256
257        if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]);
258        $this->db->getPdo()->beginTransaction();
259        try {
260            // clean up old cluster data
261            $query = "DELETE FROM clusters $where";
262            $this->db->exec($query);
263            $query = "UPDATE embeddings SET cluster = NULL $where";
264            $this->db->exec($query);
265
266            // get a random selection of chunks
267            $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?";
268            $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
269            if (!$result) return; // no data to cluster
270            $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR));
271
272            // how many clusters?
273            if (count($result) < self::CLUSTER_SIZE * 3) {
274                // there would be less than 3 clusters, so just use one
275                $clustercount = 1;
276            } else {
277                // get the number of all chunks, to calculate the number of clusters
278                $query = "SELECT COUNT(*) FROM embeddings $where";
279                $total = $this->db->queryValue($query);
280                $clustercount = ceil($total / self::CLUSTER_SIZE);
281            }
282            if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
283
284            // cluster them using kmeans
285            $space = new Space($dimensions);
286            foreach ($result as $record) {
287                $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR));
288            }
289            $clusters = $space->solve($clustercount, function ($space, $clusters) {
290                static $iterations = 0;
291                ++$iterations;
292                if ($this->logger) {
293                    $clustercounts = implode(',', array_map('count', $clusters));
294                    $this->logger->info('Iteration {iteration}: [{clusters}]', [
295                        'iteration' => $iterations, 'clusters' => $clustercounts
296                    ]);
297                }
298            }, Cluster::INIT_KMEANS_PLUS_PLUS);
299
300            // store the clusters
301            foreach ($clusters as $cluster) {
302                /** @var Cluster $cluster */
303                $centroid = $cluster->getCoordinates();
304                $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)';
305                $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]);
306            }
307
308            $this->db->getPdo()->commit();
309            if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
310        } catch (\Exception $e) {
311            $this->db->getPdo()->rollBack();
312            throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e);
313        }
314    }
315
316    /**
317     * Assign the nearest cluster for all chunks that don't have one
318     *
319     * @return void
320     */
321    protected function setChunkClusters()
322    {
323        if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
324        $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL';
325        $handle = $this->db->query($query);
326
327        while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
328            $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR);
329            $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : '');
330            $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
331            $this->db->exec($query, [$cluster, $record['id']]);
332            if ($this->logger) $this->logger->success(
333                'Chunk {id} assigned to cluster {cluster}',
334                ['id' => $record['id'], 'cluster' => $cluster]
335            );
336        }
337        $handle->closeCursor();
338    }
339
340    /**
341     * Get the nearest cluster for the given vector
342     *
343     * @param float[] $vector
344     * @return int|null
345     */
346    protected function getCluster($vector, $lang)
347    {
348        if ($lang != '') {
349            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
350        } else {
351            $where = '';
352        }
353
354        $query = "SELECT cluster, centroid
355                    FROM clusters
356                   $where
357                ORDER BY COSIM(centroid, ?) DESC
358                   LIMIT 1";
359
360        $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]);
361        if (!$result) return null;
362        return $result['cluster'];
363    }
364
365    /**
366     * Check if clustering has been done before
367     * @return bool
368     */
369    protected function hasClusters()
370    {
371        $query = 'SELECT COUNT(*) FROM clusters';
372        return $this->db->queryValue($query) > 0;
373    }
374
375    /**
376     * Writes TSV files for visualizing with http://projector.tensorflow.org/
377     *
378     * @param string $vectorfile path to the file with the vectors
379     * @param string $metafile path to the file with the metadata
380     * @return void
381     */
382    public function dumpTSV($vectorfile, $metafile)
383    {
384        $query = 'SELECT * FROM embeddings';
385        $handle = $this->db->query($query);
386
387        $header = implode("\t", ['id', 'page', 'created']);
388        file_put_contents($metafile, $header . "\n", FILE_APPEND);
389
390        while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) {
391            $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR);
392            $vector = implode("\t", $vector);
393
394            $meta = implode("\t", [$row['id'], $row['page'], $row['created']]);
395
396            file_put_contents($vectorfile, $vector . "\n", FILE_APPEND);
397            file_put_contents($metafile, $meta . "\n", FILE_APPEND);
398        }
399    }
400}
401