1<?php
2
3/** @noinspection SqlResolve */
4
5namespace dokuwiki\plugin\aichat\Storage;
6
7use dokuwiki\plugin\aichat\AIChat;
8use dokuwiki\plugin\aichat\Chunk;
9use dokuwiki\plugin\sqlite\SQLiteDB;
10use KMeans\Cluster;
11use KMeans\Space;
12
13/**
14 * Implements the storage backend using a SQLite database
15 *
16 * Note: all embeddings are stored and returned as normalized vectors
17 */
18class SQLiteStorage extends AbstractStorage
19{
20    /** @var int Number of documents to randomly sample to create the clusters */
21    final public const SAMPLE_SIZE = 2000;
22    /** @var int The average size of each cluster */
23    final public const CLUSTER_SIZE = 400;
24
25    /** @var SQLiteDB */
26    protected $db;
27
28    protected $useLanguageClusters = false;
29
30    /** @var float minimum similarity to consider a chunk a match */
31    protected $similarityThreshold = 0;
32
33    /** @inheritdoc */
34    public function __construct(array $config)
35    {
36        $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
37        $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2);
38
39        $helper = plugin_load('helper', 'aichat');
40        $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED;
41
42        $this->similarityThreshold = $config['similarityThreshold'] / 100;
43    }
44
45    /** @inheritdoc */
46    public function getChunk($chunkID)
47    {
48        $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]);
49        if (!$record) return null;
50
51        return new Chunk(
52            $record['page'],
53            $record['id'],
54            $record['chunk'],
55            json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
56            $record['lang'],
57            $record['created']
58        );
59    }
60
61    /** @inheritdoc */
62    public function startCreation($clear = false)
63    {
64        if ($clear) {
65            /** @noinspection SqlWithoutWhere */
66            $this->db->exec('DELETE FROM embeddings');
67            /** @noinspection SqlWithoutWhere */
68            $this->db->exec('DELETE FROM clusters');
69        }
70    }
71
72    /** @inheritdoc */
73    public function reusePageChunks($page, $firstChunkID)
74    {
75        // no-op
76    }
77
78    /** @inheritdoc */
79    public function deletePageChunks($page, $firstChunkID)
80    {
81        $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]);
82    }
83
84    /** @inheritdoc */
85    public function addPageChunks($chunks)
86    {
87        foreach ($chunks as $chunk) {
88            $this->db->saveRecord('embeddings', [
89                'page' => $chunk->getPage(),
90                'id' => $chunk->getId(),
91                'chunk' => $chunk->getText(),
92                'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR),
93                'created' => $chunk->getCreated(),
94                'lang' => $chunk->getLanguage(),
95            ]);
96        }
97    }
98
99    /** @inheritdoc */
100    public function finalizeCreation()
101    {
102        if (!$this->hasClusters()) {
103            $this->createClusters();
104        }
105        $this->setChunkClusters();
106
107        $this->db->exec('VACUUM');
108    }
109
110    /** @inheritdoc */
111    public function runMaintenance()
112    {
113        $this->createClusters();
114        $this->setChunkClusters();
115    }
116
117    /** @inheritdoc */
118    public function getPageChunks($page, $firstChunkID)
119    {
120        $result = $this->db->queryAll(
121            'SELECT * FROM embeddings WHERE page = ?',
122            [$page]
123        );
124        $chunks = [];
125        foreach ($result as $record) {
126            $chunks[] = new Chunk(
127                $record['page'],
128                $record['id'],
129                $record['chunk'],
130                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
131                $record['lang'],
132                $record['created']
133            );
134        }
135        return $chunks;
136    }
137
138    /** @inheritdoc */
139    public function getSimilarChunks($vector, $lang = '', $limit = 4)
140    {
141        $cluster = $this->getCluster($vector, $lang);
142        if ($this->logger) $this->logger->info(
143            'Using cluster {cluster} for similarity search',
144            ['cluster' => $cluster]
145        );
146
147        $result = $this->db->queryAll(
148            'SELECT *, COSIM(?, embedding) AS similarity
149               FROM embeddings
150              WHERE cluster = ?
151                AND GETACCESSLEVEL(page) > 0
152                AND similarity > CAST(? AS FLOAT)
153           ORDER BY similarity DESC
154              LIMIT ?',
155            [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, $this->similarityThreshold, $limit]
156        );
157        $chunks = [];
158        foreach ($result as $record) {
159            $chunks[] = new Chunk(
160                $record['page'],
161                $record['id'],
162                $record['chunk'],
163                json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR),
164                $record['lang'],
165                $record['created'],
166                $record['similarity']
167            );
168        }
169        return $chunks;
170    }
171
172    /** @inheritdoc */
173    public function statistics()
174    {
175        $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings');
176        $size = $this->db->queryValue(
177            'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
178        );
179        $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt
180                    FROM embeddings
181                GROUP BY cluster
182                ORDER BY cluster";
183        $clusters = $this->db->queryKeyValueList($query);
184
185        return [
186            'storage type' => 'SQLite',
187            'chunks' => $items,
188            'db size' => filesize_h($size),
189            'clusters' => $clusters,
190        ];
191    }
192
193    /**
194     * Method registered as SQLite callback to calculate the cosine similarity
195     *
196     * @param string $query JSON encoded vector array
197     * @param string $embedding JSON encoded vector array
198     * @return float
199     */
200    public function sqliteCosineSimilarityCallback($query, $embedding)
201    {
202        return (float)$this->cosineSimilarity(
203            json_decode($query, true, 512, JSON_THROW_ON_ERROR),
204            json_decode($embedding, true, 512, JSON_THROW_ON_ERROR)
205        );
206    }
207
208    /**
209     * Calculate the cosine similarity between two vectors
210     *
211     * Actually just calculating the dot product of the two vectors, since they are normalized
212     *
213     * @param float[] $queryVector The normalized vector of the search phrase
214     * @param float[] $embedding The normalized vector of the chunk
215     * @return float
216     */
217    protected function cosineSimilarity($queryVector, $embedding)
218    {
219        $dotProduct = 0;
220        foreach ($queryVector as $key => $value) {
221            if(!isset($embedding[$key])) break; // if the vector is shorter than the query, stop.
222            $dotProduct += $value * $embedding[$key];
223        }
224        return $dotProduct;
225    }
226
227    /**
228     * Create new clusters based on random chunks
229     *
230     * @return void
231     */
232    protected function createClusters()
233    {
234        if ($this->useLanguageClusters) {
235            $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings');
236            $langs = array_column($result, 'lang');
237            foreach ($langs as $lang) {
238                $this->createLanguageClusters($lang);
239            }
240        } else {
241            $this->createLanguageClusters('');
242        }
243    }
244
245    /**
246     * Create new clusters based on random chunks for the given Language
247     *
248     * @param string $lang The language to cluster, empty when all languages go into the same cluster
249     * @noinspection SqlWithoutWhere
250     */
251    protected function createLanguageClusters($lang)
252    {
253        if ($lang != '') {
254            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
255        } else {
256            $where = '';
257        }
258
259        if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]);
260        $this->db->getPdo()->beginTransaction();
261        try {
262            // clean up old cluster data
263            $query = "DELETE FROM clusters $where";
264            $this->db->exec($query);
265            $query = "UPDATE embeddings SET cluster = NULL $where";
266            $this->db->exec($query);
267
268            // get a random selection of chunks
269            $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?";
270            $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
271            if (!$result) return; // no data to cluster
272            $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR));
273
274            // how many clusters?
275            if (count($result) < self::CLUSTER_SIZE * 3) {
276                // there would be less than 3 clusters, so just use one
277                $clustercount = 1;
278            } else {
279                // get the number of all chunks, to calculate the number of clusters
280                $query = "SELECT COUNT(*) FROM embeddings $where";
281                $total = $this->db->queryValue($query);
282                $clustercount = ceil($total / self::CLUSTER_SIZE);
283            }
284            if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
285
286            // cluster them using kmeans
287            $space = new Space($dimensions);
288            foreach ($result as $record) {
289                $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR));
290            }
291            $clusters = $space->solve($clustercount, function ($space, $clusters) {
292                static $iterations = 0;
293                ++$iterations;
294                if ($this->logger) {
295                    $clustercounts = implode(',', array_map('count', $clusters));
296                    $this->logger->info('Iteration {iteration}: [{clusters}]', [
297                        'iteration' => $iterations, 'clusters' => $clustercounts
298                    ]);
299                }
300            }, Cluster::INIT_KMEANS_PLUS_PLUS);
301
302            // store the clusters
303            foreach ($clusters as $cluster) {
304                /** @var Cluster $cluster */
305                $centroid = $cluster->getCoordinates();
306                $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)';
307                $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]);
308            }
309
310            $this->db->getPdo()->commit();
311            if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
312        } catch (\Exception $e) {
313            $this->db->getPdo()->rollBack();
314            throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 4005, $e);
315        }
316    }
317
318    /**
319     * Assign the nearest cluster for all chunks that don't have one
320     *
321     * @return void
322     */
323    protected function setChunkClusters()
324    {
325        if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
326        $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL';
327        $handle = $this->db->query($query);
328
329        while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
330            $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR);
331            $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : '');
332            $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
333            $this->db->exec($query, [$cluster, $record['id']]);
334            if ($this->logger) $this->logger->success(
335                'Chunk {id} assigned to cluster {cluster}',
336                ['id' => $record['id'], 'cluster' => $cluster]
337            );
338        }
339        $handle->closeCursor();
340    }
341
342    /**
343     * Get the nearest cluster for the given vector
344     *
345     * @param float[] $vector
346     * @return int|null
347     */
348    protected function getCluster($vector, $lang)
349    {
350        if ($lang != '') {
351            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
352        } else {
353            $where = '';
354        }
355
356        $query = "SELECT cluster, centroid
357                    FROM clusters
358                   $where
359                ORDER BY COSIM(centroid, ?) DESC
360                   LIMIT 1";
361
362        $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]);
363        if (!$result) return null;
364        return $result['cluster'];
365    }
366
367    /**
368     * Check if clustering has been done before
369     * @return bool
370     */
371    protected function hasClusters()
372    {
373        $query = 'SELECT COUNT(*) FROM clusters';
374        return $this->db->queryValue($query) > 0;
375    }
376
377    /**
378     * Writes TSV files for visualizing with http://projector.tensorflow.org/
379     *
380     * @param string $vectorfile path to the file with the vectors
381     * @param string $metafile path to the file with the metadata
382     * @return void
383     */
384    public function dumpTSV($vectorfile, $metafile)
385    {
386        $query = 'SELECT * FROM embeddings';
387        $handle = $this->db->query($query);
388
389        $header = implode("\t", ['id', 'page', 'created']);
390        file_put_contents($metafile, $header . "\n", FILE_APPEND);
391
392        while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) {
393            $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR);
394            $vector = implode("\t", $vector);
395
396            $meta = implode("\t", [$row['id'], $row['page'], $row['created']]);
397
398            file_put_contents($vectorfile, $vector . "\n", FILE_APPEND);
399            file_put_contents($metafile, $meta . "\n", FILE_APPEND);
400        }
401    }
402}
403