xref: /plugin/aichat/Storage/SQLiteStorage.php (revision adfc54298aeea1d4c6f3a22467bfe96f73dfc1c5)
1<?php /** @noinspection SqlResolve */
2
3namespace dokuwiki\plugin\aichat\Storage;
4
5use dokuwiki\plugin\aichat\AIChat;
6use dokuwiki\plugin\aichat\Chunk;
7use dokuwiki\plugin\sqlite\SQLiteDB;
8use KMeans\Cluster;
9use KMeans\Space;
10
11/**
12 * Implements the storage backend using a SQLite database
13 *
14 * Note: all embeddings are stored and returned as normalized vectors
15 */
16class SQLiteStorage extends AbstractStorage
17{
18    /** @var float minimum similarity to consider a chunk a match */
19    const SIMILARITY_THRESHOLD = 0.75;
20
21    /** @var int Number of documents to randomly sample to create the clusters */
22    const SAMPLE_SIZE = 2000;
23    /** @var int The average size of each cluster */
24    const CLUSTER_SIZE = 400;
25
26    /** @var SQLiteDB */
27    protected $db;
28
29    protected $useLanguageClusters = false;
30
31    /**
32     * Initializes the database connection and registers our custom function
33     *
34     * @throws \Exception
35     */
36    public function __construct()
37    {
38        $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
39        $this->db->getPdo()->sqliteCreateFunction('COSIM', [$this, 'sqliteCosineSimilarityCallback'], 2);
40
41        $helper = plugin_load('helper', 'aichat');
42        $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED;
43    }
44
45    /** @inheritdoc */
46    public function getChunk($chunkID)
47    {
48        $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]);
49        if (!$record) return null;
50
51        return new Chunk(
52            $record['page'],
53            $record['id'],
54            $record['chunk'],
55            json_decode($record['embedding'], true),
56            $record['lang'],
57            $record['created']
58        );
59    }
60
61    /** @inheritdoc */
62    public function startCreation($clear = false)
63    {
64        if ($clear) {
65            /** @noinspection SqlWithoutWhere */
66            $this->db->exec('DELETE FROM embeddings');
67        }
68    }
69
70    /** @inheritdoc */
71    public function reusePageChunks($page, $firstChunkID)
72    {
73        // no-op
74    }
75
76    /** @inheritdoc */
77    public function deletePageChunks($page, $firstChunkID)
78    {
79        $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]);
80    }
81
82    /** @inheritdoc */
83    public function addPageChunks($chunks)
84    {
85        foreach ($chunks as $chunk) {
86            $this->db->saveRecord('embeddings', [
87                'page' => $chunk->getPage(),
88                'id' => $chunk->getId(),
89                'chunk' => $chunk->getText(),
90                'embedding' => json_encode($chunk->getEmbedding()),
91                'created' => $chunk->getCreated(),
92                'lang' => $chunk->getLanguage(),
93            ]);
94        }
95    }
96
97    /** @inheritdoc */
98    public function finalizeCreation()
99    {
100        if (!$this->hasClusters()) {
101            $this->createClusters();
102        }
103        $this->setChunkClusters();
104
105        $this->db->exec('VACUUM');
106    }
107
108    /** @inheritdoc */
109    public function runMaintenance()
110    {
111        $this->createClusters();
112        $this->setChunkClusters();
113    }
114
115    /** @inheritdoc */
116    public function getPageChunks($page, $firstChunkID)
117    {
118        $result = $this->db->queryAll(
119            'SELECT * FROM embeddings WHERE page = ?',
120            [$page]
121        );
122        $chunks = [];
123        foreach ($result as $record) {
124            $chunks[] = new Chunk(
125                $record['page'],
126                $record['id'],
127                $record['chunk'],
128                json_decode($record['embedding'], true),
129                $record['lang'],
130                $record['created']
131            );
132        }
133        return $chunks;
134    }
135
136    /** @inheritdoc */
137    public function getSimilarChunks($vector, $lang = '', $limit = 4)
138    {
139        $cluster = $this->getCluster($vector, $lang);
140        if ($this->logger) $this->logger->info(
141            'Using cluster {cluster} for similarity search', ['cluster' => $cluster]
142        );
143
144        $result = $this->db->queryAll(
145            'SELECT *, COSIM(?, embedding) AS similarity
146               FROM embeddings
147              WHERE cluster = ?
148                AND GETACCESSLEVEL(page) > 0
149                AND similarity > CAST(? AS FLOAT)
150           ORDER BY similarity DESC
151              LIMIT ?',
152            [json_encode($vector), $cluster, self::SIMILARITY_THRESHOLD, $limit]
153        );
154        $chunks = [];
155        foreach ($result as $record) {
156            $chunks[] = new Chunk(
157                $record['page'],
158                $record['id'],
159                $record['chunk'],
160                json_decode($record['embedding'], true),
161                $record['lang'],
162                $record['created'],
163                $record['similarity']
164            );
165        }
166        return $chunks;
167    }
168
169    /** @inheritdoc */
170    public function statistics()
171    {
172        $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings');
173        $size = $this->db->queryValue(
174            'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
175        );
176        $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster";
177        $clusters = $this->db->queryKeyValueList($query);
178
179        return [
180            'storage type' => 'SQLite',
181            'chunks' => $items,
182            'db size' => filesize_h($size),
183            'clusters' => $clusters,
184        ];
185    }
186
187    /**
188     * Method registered as SQLite callback to calculate the cosine similarity
189     *
190     * @param string $query JSON encoded vector array
191     * @param string $embedding JSON encoded vector array
192     * @return float
193     */
194    public function sqliteCosineSimilarityCallback($query, $embedding)
195    {
196        return (float)$this->cosineSimilarity(json_decode($query), json_decode($embedding));
197    }
198
199    /**
200     * Calculate the cosine similarity between two vectors
201     *
202     * Actually just calculating the dot product of the two vectors, since they are normalized
203     *
204     * @param float[] $queryVector The normalized vector of the search phrase
205     * @param float[] $embedding The normalized vector of the chunk
206     * @return float
207     */
208    protected function cosineSimilarity($queryVector, $embedding)
209    {
210        $dotProduct = 0;
211        foreach ($queryVector as $key => $value) {
212            $dotProduct += $value * $embedding[$key];
213        }
214        return $dotProduct;
215    }
216
217    /**
218     * Create new clusters based on random chunks
219     *
220     * @return void
221     */
222    protected function createClusters()
223    {
224        if ($this->useLanguageClusters) {
225            $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings');
226            $langs = array_column($result, 'lang');
227            foreach ($langs as $lang) {
228                $this->createLanguageClusters($lang);
229            }
230        } else {
231            $this->createLanguageClusters('');
232        }
233    }
234
235    /**
236     * Create new clusters based on random chunks for the given Language
237     *
238     * @param string $lang The language to cluster, empty when all languages go into the same cluster
239     * @noinspection SqlWithoutWhere
240     */
241    protected function createLanguageClusters($lang)
242    {
243        if ($lang != '') {
244            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
245        } else {
246            $where = '';
247        }
248
249        if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]);
250        $this->db->getPdo()->beginTransaction();
251        try {
252            // clean up old cluster data
253            $query = "DELETE FROM clusters $where";
254            $this->db->exec($query);
255            $query = "UPDATE embeddings SET cluster = NULL $where";
256            $this->db->exec($query);
257
258            // get a random selection of chunks
259            $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?";
260            $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
261            if (!$result) return; // no data to cluster
262            $dimensions = count(json_decode($result[0]['embedding'], true));
263
264            // how many clusters?
265            if (count($result) < self::CLUSTER_SIZE * 3) {
266                // there would be less than 3 clusters, so just use one
267                $clustercount = 1;
268            } else {
269                // get the number of all chunks, to calculate the number of clusters
270                $query = "SELECT COUNT(*) FROM embeddings $where";
271                $total = $this->db->queryValue($query);
272                $clustercount = ceil($total / self::CLUSTER_SIZE);
273            }
274            if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
275
276            // cluster them using kmeans
277            $space = new Space($dimensions);
278            foreach ($result as $record) {
279                $space->addPoint(json_decode($record['embedding'], true));
280            }
281            $clusters = $space->solve($clustercount, function ($space, $clusters) {
282                static $iterations = 0;
283                ++$iterations;
284                if ($this->logger) {
285                    $clustercounts = join(',', array_map('count', $clusters));
286                    $this->logger->info('Iteration {iteration}: [{clusters}]', [
287                        'iteration' => $iterations, 'clusters' => $clustercounts
288                    ]);
289                }
290            }, Cluster::INIT_KMEANS_PLUS_PLUS);
291
292            // store the clusters
293            foreach ($clusters as $clusterID => $cluster) {
294                /** @var Cluster $cluster */
295                $centroid = $cluster->getCoordinates();
296                $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)';
297                $this->db->exec($query, [$lang, json_encode($centroid)]);
298            }
299
300            $this->db->getPdo()->commit();
301            if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
302        } catch (\Exception $e) {
303            $this->db->getPdo()->rollBack();
304            throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e);
305        }
306    }
307
308    /**
309     * Assign the nearest cluster for all chunks that don't have one
310     *
311     * @return void
312     */
313    protected function setChunkClusters()
314    {
315        if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
316        $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL';
317        $handle = $this->db->query($query);
318
319        while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
320            $vector = json_decode($record['embedding'], true);
321            $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : '');
322            $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
323            $this->db->exec($query, [$cluster, $record['id']]);
324            if ($this->logger) $this->logger->success(
325                'Chunk {id} assigned to cluster {cluster}', ['id' => $record['id'], 'cluster' => $cluster]
326            );
327        }
328        $handle->closeCursor();
329    }
330
331    /**
332     * Get the nearest cluster for the given vector
333     *
334     * @param float[] $vector
335     * @return int|null
336     */
337    protected function getCluster($vector, $lang)
338    {
339        if ($lang != '') {
340            $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang);
341        } else {
342            $where = '';
343        }
344
345        $query = "SELECT cluster, centroid
346                    FROM clusters
347                   $where
348                ORDER BY COSIM(centroid, ?) DESC
349                   LIMIT 1";
350
351        $result = $this->db->queryRecord($query, [json_encode($vector)]);
352        if (!$result) return null;
353        return $result['cluster'];
354    }
355
356    /**
357     * Check if clustering has been done before
358     * @return bool
359     */
360    protected function hasClusters()
361    {
362        $query = 'SELECT COUNT(*) FROM clusters';
363        return $this->db->queryValue($query) > 0;
364    }
365
366    /**
367     * Writes TSV files for visualizing with http://projector.tensorflow.org/
368     *
369     * @param string $vectorfile path to the file with the vectors
370     * @param string $metafile path to the file with the metadata
371     * @return void
372     */
373    public function dumpTSV($vectorfile, $metafile)
374    {
375        $query = 'SELECT * FROM embeddings';
376        $handle = $this->db->query($query);
377
378        $header = implode("\t", ['id', 'page', 'created']);
379        file_put_contents($metafile, $header . "\n", FILE_APPEND);
380
381        while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) {
382            $vector = json_decode($row['embedding'], true);
383            $vector = implode("\t", $vector);
384
385            $meta = implode("\t", [$row['id'], $row['page'], $row['created']]);
386
387            file_put_contents($vectorfile, $vector . "\n", FILE_APPEND);
388            file_put_contents($metafile, $meta . "\n", FILE_APPEND);
389        }
390    }
391}
392