xref: /plugin/aichat/Storage/SQLiteStorage.php (revision 8285fff93bda4602771c807cbe9218d1f3b88d32)
1<?php /** @noinspection SqlResolve */
2
3namespace dokuwiki\plugin\aichat\Storage;
4
5use dokuwiki\plugin\aichat\Chunk;
6use dokuwiki\plugin\sqlite\SQLiteDB;
7use KMeans\Cluster;
8use KMeans\Point;
9use KMeans\Space;
10
11/**
12 * Implements the storage backend using a SQLite database
13 *
14 * Note: all embeddings are stored and returned as normalized vectors
15 */
16class SQLiteStorage extends AbstractStorage
17{
18    /** @var float minimum similarity to consider a chunk a match */
19    const SIMILARITY_THRESHOLD = 0.75;
20
21    /** @var int Number of documents to randomly sample to create the clusters */
22    const SAMPLE_SIZE = 2000;
23    /** @var int The average size of each cluster */
24    const CLUSTER_SIZE = 400;
25
26    /** @var SQLiteDB */
27    protected $db;
28
29    /**
30     * Initializes the database connection and registers our custom function
31     *
32     * @throws \Exception
33     */
34    public function __construct()
35    {
36        $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
37        $this->db->getPdo()->sqliteCreateFunction('COSIM', [$this, 'sqliteCosineSimilarityCallback'], 2);
38    }
39
40    /** @inheritdoc */
41    public function getChunk($chunkID)
42    {
43        $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]);
44        if (!$record) return null;
45
46        return new Chunk(
47            $record['page'],
48            $record['id'],
49            $record['chunk'],
50            json_decode($record['embedding'], true),
51            $record['created']
52        );
53    }
54
55    /** @inheritdoc */
56    public function startCreation($clear = false)
57    {
58        if ($clear) {
59            /** @noinspection SqlWithoutWhere */
60            $this->db->exec('DELETE FROM embeddings');
61        }
62    }
63
64    /** @inheritdoc */
65    public function reusePageChunks($page, $firstChunkID)
66    {
67        // no-op
68    }
69
70    /** @inheritdoc */
71    public function deletePageChunks($page, $firstChunkID)
72    {
73        $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]);
74    }
75
76    /** @inheritdoc */
77    public function addPageChunks($chunks)
78    {
79        foreach ($chunks as $chunk) {
80            $this->db->saveRecord('embeddings', [
81                'page' => $chunk->getPage(),
82                'id' => $chunk->getId(),
83                'chunk' => $chunk->getText(),
84                'embedding' => json_encode($chunk->getEmbedding()),
85                'created' => $chunk->getCreated()
86            ]);
87        }
88    }
89
90    /** @inheritdoc */
91    public function finalizeCreation()
92    {
93        if (!$this->hasClusters()) {
94            $this->createClusters();
95        }
96        $this->setChunkClusters();
97
98        $this->db->exec('VACUUM');
99    }
100
101    /** @inheritdoc */
102    public function runMaintenance()
103    {
104        $this->createClusters();
105        $this->setChunkClusters();
106    }
107
108    /** @inheritdoc */
109    public function getPageChunks($page, $firstChunkID)
110    {
111        $result = $this->db->queryAll(
112            'SELECT * FROM embeddings WHERE page = ?',
113            [$page]
114        );
115        $chunks = [];
116        foreach ($result as $record) {
117            $chunks[] = new Chunk(
118                $record['page'],
119                $record['id'],
120                $record['chunk'],
121                json_decode($record['embedding'], true),
122                $record['created']
123            );
124        }
125        return $chunks;
126    }
127
128    /** @inheritdoc */
129    public function getSimilarChunks($vector, $limit = 4)
130    {
131        $cluster = $this->getCluster($vector);
132        if ($this->logger) $this->logger->info(
133            'Using cluster {cluster} for similarity search', ['cluster' => $cluster]
134        );
135
136        $result = $this->db->queryAll(
137            'SELECT *, COSIM(?, embedding) AS similarity
138               FROM embeddings
139              WHERE cluster = ?
140                AND GETACCESSLEVEL(page) > 0
141                AND similarity > CAST(? AS FLOAT)
142           ORDER BY similarity DESC
143              LIMIT ?',
144            [json_encode($vector), $cluster, self::SIMILARITY_THRESHOLD, $limit]
145        );
146        $chunks = [];
147        foreach ($result as $record) {
148            $chunks[] = new Chunk(
149                $record['page'],
150                $record['id'],
151                $record['chunk'],
152                json_decode($record['embedding'], true),
153                $record['created'],
154                $record['similarity']
155            );
156        }
157        return $chunks;
158    }
159
160    /** @inheritdoc */
161    public function statistics()
162    {
163        $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings');
164        $size = $this->db->queryValue(
165            'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
166        );
167        $query = "SELECT cluster, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster";
168        $clusters = $this->db->queryKeyValueList($query);
169
170        return [
171            'storage type' => 'SQLite',
172            'chunks' => $items,
173            'db size' => filesize_h($size),
174            'clusters' => $clusters,
175        ];
176    }
177
178    /**
179     * Method registered as SQLite callback to calculate the cosine similarity
180     *
181     * @param string $query JSON encoded vector array
182     * @param string $embedding JSON encoded vector array
183     * @return float
184     */
185    public function sqliteCosineSimilarityCallback($query, $embedding)
186    {
187        return (float)$this->cosineSimilarity(json_decode($query), json_decode($embedding));
188    }
189
190    /**
191     * Calculate the cosine similarity between two vectors
192     *
193     * Actually just calculating the dot product of the two vectors, since they are normalized
194     *
195     * @param float[] $queryVector The normalized vector of the search phrase
196     * @param float[] $embedding The normalized vector of the chunk
197     * @return float
198     */
199    protected function cosineSimilarity($queryVector, $embedding)
200    {
201        $dotProduct = 0;
202        foreach ($queryVector as $key => $value) {
203            $dotProduct += $value * $embedding[$key];
204        }
205        return $dotProduct;
206    }
207
208    /**
209     * Create new clusters based on random chunks
210     *
211     * @noinspection SqlWithoutWhere
212     */
213    protected function createClusters()
214    {
215        if ($this->logger) $this->logger->info('Creating new clusters...');
216        $this->db->getPdo()->beginTransaction();
217        try {
218            // clean up old cluster data
219            $query = 'DELETE FROM clusters';
220            $this->db->exec($query);
221            $query = 'UPDATE embeddings SET cluster = NULL';
222            $this->db->exec($query);
223
224            // get a random selection of chunks
225            $query = 'SELECT id, embedding FROM embeddings ORDER BY RANDOM() LIMIT ?';
226            $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
227            if (!$result) return; // no data to cluster
228            $dimensions = count(json_decode($result[0]['embedding'], true));
229
230            // get the number of all chunks, to calculate the number of clusters
231            $query = 'SELECT COUNT(*) FROM embeddings';
232            $total = $this->db->queryValue($query);
233            $clustercount = ceil($total / self::CLUSTER_SIZE);
234            if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
235
236            // cluster them using kmeans
237            $space = new Space($dimensions);
238            foreach ($result as $record) {
239                $space->addPoint(json_decode($record['embedding'], true));
240            }
241            $clusters = $space->solve($clustercount, function ($space, $clusters) {
242                static $iterations = 0;
243                ++$iterations;
244                if ($this->logger) {
245                    $clustercounts = join(',', array_map('count', $clusters));
246                    $this->logger->info('Iteration {iteration}: [{clusters}]', [
247                        'iteration' => $iterations, 'clusters' => $clustercounts
248                    ]);
249                }
250            }, Cluster::INIT_KMEANS_PLUS_PLUS);
251
252            // store the clusters
253            foreach ($clusters as $clusterID => $cluster) {
254                /** @var Cluster $cluster */
255                $centroid = $cluster->getCoordinates();
256                $query = 'INSERT INTO clusters (cluster, centroid) VALUES (?, ?)';
257                $this->db->exec($query, [$clusterID, json_encode($centroid)]);
258            }
259
260            $this->db->getPdo()->commit();
261            if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
262        } catch (\Exception $e) {
263            $this->db->getPdo()->rollBack();
264            throw new \RuntimeException('Clustering failed', 0, $e);
265        }
266    }
267
268    /**
269     * Assign the nearest cluster for all chunks that don't have one
270     *
271     * @return void
272     */
273    protected function setChunkClusters()
274    {
275        if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
276        $query = 'SELECT id, embedding FROM embeddings WHERE cluster IS NULL';
277        $handle = $this->db->query($query);
278
279        while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
280            $vector = json_decode($record['embedding'], true);
281            $cluster = $this->getCluster($vector);
282            $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
283            $this->db->exec($query, [$cluster, $record['id']]);
284            if ($this->logger) $this->logger->success(
285                'Chunk {id} assigned to cluster {cluster}', ['id' => $record['id'], 'cluster' => $cluster]
286            );
287        }
288        $handle->closeCursor();
289    }
290
291    /**
292     * Get the nearest cluster for the given vector
293     *
294     * @param float[] $vector
295     * @return int|null
296     */
297    protected function getCluster($vector)
298    {
299        $query = 'SELECT cluster, centroid FROM clusters ORDER BY COSIM(centroid, ?) DESC LIMIT 1';
300        $result = $this->db->queryRecord($query, [json_encode($vector)]);
301        if (!$result) return null;
302        return $result['cluster'];
303    }
304
305    /**
306     * Check if clustering has been done before
307     * @return bool
308     */
309    protected function hasClusters()
310    {
311        $query = 'SELECT COUNT(*) FROM clusters';
312        return $this->db->queryValue($query) > 0;
313    }
314}
315