xref: /plugin/aichat/Storage/SQLiteStorage.php (revision 04afb84f6cb8a0c9b1d4d807e18f90fe739ec371)
1<?php
2
3/** @noinspection SqlResolve */
4
5namespace dokuwiki\plugin\aichat\Storage;
6
7use dokuwiki\plugin\aichat\AIChat;
8use dokuwiki\plugin\aichat\Chunk;
9use dokuwiki\plugin\sqlite\SQLiteDB;
10use KMeans\Cluster;
11use KMeans\Space;
12
13/**
14 * Implements the storage backend using a SQLite database
15 *
16 * Note: all embeddings are stored and returned as normalized vectors
17 */
18class SQLiteStorage extends AbstractStorage
19{
20    /** @var float minimum similarity to consider a chunk a match */
21    final public const SIMILARITY_THRESHOLD = 0;
22
23    /** @var int Number of documents to randomly sample to create the clusters */
24    final public const SAMPLE_SIZE = 2000;
25    /** @var int The average size of each cluster */
26    final public const CLUSTER_SIZE = 400;
27
28    /** @var SQLiteDB */
29    protected $db;
30
31    protected $useLanguageClusters = false;
32
33    /** @inheritdoc */
34    public function __construct(array $config)
35    {
36        $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
37        $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2);
38
39        $helper = plugin_load('helper', 'aichat');
40        $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED;
41    }
42
43    /** @inheritdoc */
44    public function getChunk($chunkID)
45    {
46        $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]);
47        if (!$record) return null;
48
49        return new Chunk(
50            $record['page'],
51            $record['id'],
52            $record['chunk'],
53            json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
54            $record['lang'],
55            $record['created']
56        );
57    }
58
59    /** @inheritdoc */
60    public function startCreation($clear = false)
61    {
62        if ($clear) {
63            /** @noinspection SqlWithoutWhere */
64            $this->db->exec('DELETE FROM embeddings');
65        }
66    }
67
68    /** @inheritdoc */
69    public function reusePageChunks($page, $firstChunkID)
70    {
71        // no-op
72    }
73
74    /** @inheritdoc */
75    public function deletePageChunks($page, $firstChunkID)
76    {
77        $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]);
78    }
79
80    /** @inheritdoc */
81    public function addPageChunks($chunks)
82    {
83        foreach ($chunks as $chunk) {
84            $this->db->saveRecord('embeddings', [
85                'page' => $chunk->getPage(),
86                'id' => $chunk->getId(),
87                'chunk' => $chunk->getText(),
88                'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR),
89                'created' => $chunk->getCreated(),
90                'lang' => $chunk->getLanguage(),
91            ]);
92        }
93    }
94
95    /** @inheritdoc */
96    public function finalizeCreation()
97    {
98        if (!$this->hasClusters()) {
99            $this->createClusters();
100        }
101        $this->setChunkClusters();
102
103        $this->db->exec('VACUUM');
104    }
105
106    /** @inheritdoc */
107    public function runMaintenance()
108    {
109        $this->createClusters();
110        $this->setChunkClusters();
111    }
112
113    /** @inheritdoc */
114    public function getPageChunks($page, $firstChunkID)
115    {
116        $result = $this->db->queryAll(
117            'SELECT * FROM embeddings WHERE page = ?',
118            [$page]
119        );
120        $chunks = [];
121        foreach ($result as $record) {
122            $chunks[] = new Chunk(
123                $record['page'],
124                $record['id'],
125                $record['chunk'],
126                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
127                $record['lang'],
128                $record['created']
129            );
130        }
131        return $chunks;
132    }
133
134    /** @inheritdoc */
135    public function getSimilarChunks($vector, $lang = '', $limit = 4)
136    {
137        $cluster = $this->getCluster($vector, $lang);
138        if ($this->logger) $this->logger->info(
139            'Using cluster {cluster} for similarity search',
140            ['cluster' => $cluster]
141        );
142
143        $result = $this->db->queryAll(
144            'SELECT *, COSIM(?, embedding) AS similarity
145               FROM embeddings
146              WHERE cluster = ?
147                AND GETACCESSLEVEL(page) > 0
148                AND similarity > CAST(? AS FLOAT)
149           ORDER BY similarity DESC
150              LIMIT ?',
151            [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, self::SIMILARITY_THRESHOLD, $limit]
152        );
153        $chunks = [];
154        foreach ($result as $record) {
155            $chunks[] = new Chunk(
156                $record['page'],
157                $record['id'],
158                $record['chunk'],
159                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
160                $record['lang'],
161                $record['created'],
162                $record['similarity']
163            );
164        }
165        return $chunks;
166    }
167
168    /** @inheritdoc */
169    public function statistics()
170    {
171        $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings');
172        $size = $this->db->queryValue(
173            'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
174        );
175        $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt
176                    FROM embeddings
177                GROUP BY cluster
178                ORDER BY cluster";
179        $clusters = $this->db->queryKeyValueList($query);
180
181        return [
182            'storage type' => 'SQLite',
183            'chunks' => $items,
184            'db size' => filesize_h($size),
185            'clusters' => $clusters,
186        ];
187    }
188
189    /**
190     * Method registered as SQLite callback to calculate the cosine similarity
191     *
192     * @param string $query JSON encoded vector array
193     * @param string $embedding JSON encoded vector array
194     * @return float
195     */
196    public function sqliteCosineSimilarityCallback($query, $embedding)
197    {
198        return (float)$this->cosineSimilarity(
199            json_decode($query, true, 512, JSON_THROW_ON_ERROR),
200            json_decode($embedding, true, 512, JSON_THROW_ON_ERROR)
201        );
202    }
203
204    /**
205     * Calculate the cosine similarity between two vectors
206     *
207     * Actually just calculating the dot product of the two vectors, since they are normalized
208     *
209     * @param float[] $queryVector The normalized vector of the search phrase
210     * @param float[] $embedding The normalized vector of the chunk
211     * @return float
212     */
213    protected function cosineSimilarity($queryVector, $embedding)
214    {
215        $dotProduct = 0;
216        foreach ($queryVector as $key => $value) {
217            $dotProduct += $value * $embedding[$key];
218        }
219        return $dotProduct;
220    }
221
222    /**
223     * Create new clusters based on random chunks
224     *
225     * @return void
226     */
227    protected function createClusters()
228    {
229        if ($this->useLanguageClusters) {
230            $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings');
231            $langs = array_column($result, 'lang');
232            foreach ($langs as $lang) {
233                $this->createLanguageClusters($lang);
234            }
235        } else {
236            $this->createLanguageClusters('');
237        }
238    }
239
240    /**
241     * Create new clusters based on random chunks for the given Language
242     *
243     * @param string $lang The language to cluster, empty when all languages go into the same cluster
244     * @noinspection SqlWithoutWhere
245     */
246    protected function createLanguageClusters($lang)
247    {
248        if ($lang != '') {
249            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
250        } else {
251            $where = '';
252        }
253
254        if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]);
255        $this->db->getPdo()->beginTransaction();
256        try {
257            // clean up old cluster data
258            $query = "DELETE FROM clusters $where";
259            $this->db->exec($query);
260            $query = "UPDATE embeddings SET cluster = NULL $where";
261            $this->db->exec($query);
262
263            // get a random selection of chunks
264            $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?";
265            $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
266            if (!$result) return; // no data to cluster
267            $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR));
268
269            // how many clusters?
270            if (count($result) < self::CLUSTER_SIZE * 3) {
271                // there would be less than 3 clusters, so just use one
272                $clustercount = 1;
273            } else {
274                // get the number of all chunks, to calculate the number of clusters
275                $query = "SELECT COUNT(*) FROM embeddings $where";
276                $total = $this->db->queryValue($query);
277                $clustercount = ceil($total / self::CLUSTER_SIZE);
278            }
279            if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
280
281            // cluster them using kmeans
282            $space = new Space($dimensions);
283            foreach ($result as $record) {
284                $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR));
285            }
286            $clusters = $space->solve($clustercount, function ($space, $clusters) {
287                static $iterations = 0;
288                ++$iterations;
289                if ($this->logger) {
290                    $clustercounts = implode(',', array_map('count', $clusters));
291                    $this->logger->info('Iteration {iteration}: [{clusters}]', [
292                        'iteration' => $iterations, 'clusters' => $clustercounts
293                    ]);
294                }
295            }, Cluster::INIT_KMEANS_PLUS_PLUS);
296
297            // store the clusters
298            foreach ($clusters as $cluster) {
299                /** @var Cluster $cluster */
300                $centroid = $cluster->getCoordinates();
301                $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)';
302                $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]);
303            }
304
305            $this->db->getPdo()->commit();
306            if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
307        } catch (\Exception $e) {
308            $this->db->getPdo()->rollBack();
309            throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e);
310        }
311    }
312
313    /**
314     * Assign the nearest cluster for all chunks that don't have one
315     *
316     * @return void
317     */
318    protected function setChunkClusters()
319    {
320        if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
321        $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL';
322        $handle = $this->db->query($query);
323
324        while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
325            $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR);
326            $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : '');
327            $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
328            $this->db->exec($query, [$cluster, $record['id']]);
329            if ($this->logger) $this->logger->success(
330                'Chunk {id} assigned to cluster {cluster}',
331                ['id' => $record['id'], 'cluster' => $cluster]
332            );
333        }
334        $handle->closeCursor();
335    }
336
337    /**
338     * Get the nearest cluster for the given vector
339     *
340     * @param float[] $vector
341     * @return int|null
342     */
343    protected function getCluster($vector, $lang)
344    {
345        if ($lang != '') {
346            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
347        } else {
348            $where = '';
349        }
350
351        $query = "SELECT cluster, centroid
352                    FROM clusters
353                   $where
354                ORDER BY COSIM(centroid, ?) DESC
355                   LIMIT 1";
356
357        $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]);
358        if (!$result) return null;
359        return $result['cluster'];
360    }
361
362    /**
363     * Check if clustering has been done before
364     * @return bool
365     */
366    protected function hasClusters()
367    {
368        $query = 'SELECT COUNT(*) FROM clusters';
369        return $this->db->queryValue($query) > 0;
370    }
371
372    /**
373     * Writes TSV files for visualizing with http://projector.tensorflow.org/
374     *
375     * @param string $vectorfile path to the file with the vectors
376     * @param string $metafile path to the file with the metadata
377     * @return void
378     */
379    public function dumpTSV($vectorfile, $metafile)
380    {
381        $query = 'SELECT * FROM embeddings';
382        $handle = $this->db->query($query);
383
384        $header = implode("\t", ['id', 'page', 'created']);
385        file_put_contents($metafile, $header . "\n", FILE_APPEND);
386
387        while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) {
388            $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR);
389            $vector = implode("\t", $vector);
390
391            $meta = implode("\t", [$row['id'], $row['page'], $row['created']]);
392
393            file_put_contents($vectorfile, $vector . "\n", FILE_APPEND);
394            file_put_contents($metafile, $meta . "\n", FILE_APPEND);
395        }
396    }
397}
398