xref: /plugin/aichat/Storage/SQLiteStorage.php (revision 3379af09b7ec10f96a8d4f23b1563bd7f9ae79ac)
1*3379af09SAndreas Gohr<?php /** @noinspection SqlResolve */
2f6ef2e50SAndreas Gohr
3f6ef2e50SAndreas Gohrnamespace dokuwiki\plugin\aichat\Storage;
4f6ef2e50SAndreas Gohr
5f6ef2e50SAndreas Gohruse dokuwiki\plugin\aichat\Chunk;
6f6ef2e50SAndreas Gohruse dokuwiki\plugin\sqlite\SQLiteDB;
7*3379af09SAndreas Gohruse KMeans\Cluster;
8*3379af09SAndreas Gohruse KMeans\Point;
9*3379af09SAndreas Gohruse KMeans\Space;
10f6ef2e50SAndreas Gohr
11f6ef2e50SAndreas Gohr/**
12f6ef2e50SAndreas Gohr * Implements the storage backend using a SQLite database
1335555bacSAndreas Gohr *
1435555bacSAndreas Gohr * Note: all embeddings are stored and returned as normalized vectors
15f6ef2e50SAndreas Gohr */
16f6ef2e50SAndreas Gohrclass SQLiteStorage extends AbstractStorage
17f6ef2e50SAndreas Gohr{
1881b450c8SAndreas Gohr    /** @var float minimum similarity to consider a chunk a match */
1981b450c8SAndreas Gohr    const SIMILARITY_THRESHOLD = 0.75;
2081b450c8SAndreas Gohr
21*3379af09SAndreas Gohr    /** @var int Number of documents to randomly sample to create the clusters */
22*3379af09SAndreas Gohr    const SAMPLE_SIZE = 2000;
23*3379af09SAndreas Gohr    /** @var int The average size of each cluster */
24*3379af09SAndreas Gohr    const CLUSTER_SIZE = 400;
25*3379af09SAndreas Gohr
26f6ef2e50SAndreas Gohr    /** @var SQLiteDB */
27f6ef2e50SAndreas Gohr    protected $db;
28f6ef2e50SAndreas Gohr
29f6ef2e50SAndreas Gohr    /**
30f6ef2e50SAndreas Gohr     * Initializes the database connection and registers our custom function
31f6ef2e50SAndreas Gohr     *
32f6ef2e50SAndreas Gohr     * @throws \Exception
33f6ef2e50SAndreas Gohr     */
34f6ef2e50SAndreas Gohr    public function __construct()
35f6ef2e50SAndreas Gohr    {
36f6ef2e50SAndreas Gohr        $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/');
37f6ef2e50SAndreas Gohr        $this->db->getPdo()->sqliteCreateFunction('COSIM', [$this, 'sqliteCosineSimilarityCallback'], 2);
38f6ef2e50SAndreas Gohr    }
39f6ef2e50SAndreas Gohr
40f6ef2e50SAndreas Gohr    /** @inheritdoc */
41f6ef2e50SAndreas Gohr    public function getChunk($chunkID)
42f6ef2e50SAndreas Gohr    {
43f6ef2e50SAndreas Gohr        $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]);
44f6ef2e50SAndreas Gohr        if (!$record) return null;
45f6ef2e50SAndreas Gohr
46f6ef2e50SAndreas Gohr        return new Chunk(
47f6ef2e50SAndreas Gohr            $record['page'],
48f6ef2e50SAndreas Gohr            $record['id'],
49f6ef2e50SAndreas Gohr            $record['chunk'],
50f6ef2e50SAndreas Gohr            json_decode($record['embedding'], true),
51f6ef2e50SAndreas Gohr            $record['created']
52f6ef2e50SAndreas Gohr        );
53f6ef2e50SAndreas Gohr    }
54f6ef2e50SAndreas Gohr
55f6ef2e50SAndreas Gohr    /** @inheritdoc */
56f6ef2e50SAndreas Gohr    public function startCreation($clear = false)
57f6ef2e50SAndreas Gohr    {
58f6ef2e50SAndreas Gohr        if ($clear) {
59f6ef2e50SAndreas Gohr            /** @noinspection SqlWithoutWhere */
60f6ef2e50SAndreas Gohr            $this->db->exec('DELETE FROM embeddings');
61f6ef2e50SAndreas Gohr        }
62f6ef2e50SAndreas Gohr    }
63f6ef2e50SAndreas Gohr
64f6ef2e50SAndreas Gohr    /** @inheritdoc */
65f6ef2e50SAndreas Gohr    public function reusePageChunks($page, $firstChunkID)
66f6ef2e50SAndreas Gohr    {
67f6ef2e50SAndreas Gohr        // no-op
68f6ef2e50SAndreas Gohr    }
69f6ef2e50SAndreas Gohr
70f6ef2e50SAndreas Gohr    /** @inheritdoc */
71f6ef2e50SAndreas Gohr    public function deletePageChunks($page, $firstChunkID)
72f6ef2e50SAndreas Gohr    {
73f6ef2e50SAndreas Gohr        $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]);
74f6ef2e50SAndreas Gohr    }
75f6ef2e50SAndreas Gohr
76f6ef2e50SAndreas Gohr    /** @inheritdoc */
77f6ef2e50SAndreas Gohr    public function addPageChunks($chunks)
78f6ef2e50SAndreas Gohr    {
79f6ef2e50SAndreas Gohr        foreach ($chunks as $chunk) {
80f6ef2e50SAndreas Gohr            $this->db->saveRecord('embeddings', [
81f6ef2e50SAndreas Gohr                'page' => $chunk->getPage(),
82f6ef2e50SAndreas Gohr                'id' => $chunk->getId(),
83f6ef2e50SAndreas Gohr                'chunk' => $chunk->getText(),
84f6ef2e50SAndreas Gohr                'embedding' => json_encode($chunk->getEmbedding()),
85f6ef2e50SAndreas Gohr                'created' => $chunk->getCreated()
86f6ef2e50SAndreas Gohr            ]);
87f6ef2e50SAndreas Gohr        }
88f6ef2e50SAndreas Gohr    }
89f6ef2e50SAndreas Gohr
90f6ef2e50SAndreas Gohr    /** @inheritdoc */
91f6ef2e50SAndreas Gohr    public function finalizeCreation()
92f6ef2e50SAndreas Gohr    {
93*3379af09SAndreas Gohr        if (!$this->hasClusters()) {
94*3379af09SAndreas Gohr            $this->createClusters();
95*3379af09SAndreas Gohr        }
96*3379af09SAndreas Gohr        $this->setChunkClusters();
97*3379af09SAndreas Gohr
98f6ef2e50SAndreas Gohr        $this->db->exec('VACUUM');
99f6ef2e50SAndreas Gohr    }
100f6ef2e50SAndreas Gohr
101f6ef2e50SAndreas Gohr    /** @inheritdoc */
102*3379af09SAndreas Gohr    public function runMaintenance()
103*3379af09SAndreas Gohr    {
104*3379af09SAndreas Gohr        $this->createClusters();
105*3379af09SAndreas Gohr        $this->setChunkClusters();
106*3379af09SAndreas Gohr    }
107*3379af09SAndreas Gohr
108*3379af09SAndreas Gohr
109*3379af09SAndreas Gohr    /** @inheritdoc */
11001f06932SAndreas Gohr    public function getPageChunks($page, $firstChunkID)
11101f06932SAndreas Gohr    {
11201f06932SAndreas Gohr        $result = $this->db->queryAll(
11301f06932SAndreas Gohr            'SELECT * FROM embeddings WHERE page = ?',
11401f06932SAndreas Gohr            [$page]
11501f06932SAndreas Gohr        );
11601f06932SAndreas Gohr        $chunks = [];
11701f06932SAndreas Gohr        foreach ($result as $record) {
11801f06932SAndreas Gohr            $chunks[] = new Chunk(
11901f06932SAndreas Gohr                $record['page'],
12001f06932SAndreas Gohr                $record['id'],
12101f06932SAndreas Gohr                $record['chunk'],
12201f06932SAndreas Gohr                json_decode($record['embedding'], true),
12301f06932SAndreas Gohr                $record['created']
12401f06932SAndreas Gohr            );
12501f06932SAndreas Gohr        }
12601f06932SAndreas Gohr        return $chunks;
12701f06932SAndreas Gohr    }
12801f06932SAndreas Gohr
12901f06932SAndreas Gohr
13001f06932SAndreas Gohr    /** @inheritdoc */
131f6ef2e50SAndreas Gohr    public function getSimilarChunks($vector, $limit = 4)
132f6ef2e50SAndreas Gohr    {
133*3379af09SAndreas Gohr        $cluster = $this->getCluster($vector);
134*3379af09SAndreas Gohr        if ($this->logger) $this->logger->info('Using cluster {cluster} for similarity search', ['cluster' => $cluster]);
135*3379af09SAndreas Gohr
136f6ef2e50SAndreas Gohr        $result = $this->db->queryAll(
137f6ef2e50SAndreas Gohr            'SELECT *, COSIM(?, embedding) AS similarity
138f6ef2e50SAndreas Gohr               FROM embeddings
139*3379af09SAndreas Gohr              WHERE cluster = ?
140*3379af09SAndreas Gohr                AND GETACCESSLEVEL(page) > 0
14181b450c8SAndreas Gohr                AND similarity > CAST(? AS FLOAT)
142f6ef2e50SAndreas Gohr           ORDER BY similarity DESC
143f6ef2e50SAndreas Gohr              LIMIT ?',
144*3379af09SAndreas Gohr            [json_encode($vector), $cluster, self::SIMILARITY_THRESHOLD, $limit]
145f6ef2e50SAndreas Gohr        );
146f6ef2e50SAndreas Gohr        $chunks = [];
147f6ef2e50SAndreas Gohr        foreach ($result as $record) {
148f6ef2e50SAndreas Gohr            $chunks[] = new Chunk(
149f6ef2e50SAndreas Gohr                $record['page'],
150f6ef2e50SAndreas Gohr                $record['id'],
151f6ef2e50SAndreas Gohr                $record['chunk'],
152f6ef2e50SAndreas Gohr                json_decode($record['embedding'], true),
1539b3d1b36SAndreas Gohr                $record['created'],
1549b3d1b36SAndreas Gohr                $record['similarity']
155f6ef2e50SAndreas Gohr            );
156f6ef2e50SAndreas Gohr        }
157f6ef2e50SAndreas Gohr        return $chunks;
158f6ef2e50SAndreas Gohr    }
159f6ef2e50SAndreas Gohr
160f6ef2e50SAndreas Gohr    /** @inheritdoc */
161f6ef2e50SAndreas Gohr    public function statistics()
162f6ef2e50SAndreas Gohr    {
163f6ef2e50SAndreas Gohr        $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings');
164f6ef2e50SAndreas Gohr        $size = $this->db->queryValue(
165f6ef2e50SAndreas Gohr            'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()'
166f6ef2e50SAndreas Gohr        );
167*3379af09SAndreas Gohr        $query = "SELECT cluster, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster";
168*3379af09SAndreas Gohr        $clusters = $this->db->queryKeyValueList($query);
169*3379af09SAndreas Gohr
170f6ef2e50SAndreas Gohr        return [
171f6ef2e50SAndreas Gohr            'storage type' => 'SQLite',
172f6ef2e50SAndreas Gohr            'chunks' => $items,
173*3379af09SAndreas Gohr            'db size' => filesize_h($size),
174*3379af09SAndreas Gohr            'clusters' => $clusters,
175f6ef2e50SAndreas Gohr        ];
176f6ef2e50SAndreas Gohr    }
177f6ef2e50SAndreas Gohr
178f6ef2e50SAndreas Gohr    /**
179f6ef2e50SAndreas Gohr     * Method registered as SQLite callback to calculate the cosine similarity
180f6ef2e50SAndreas Gohr     *
181f6ef2e50SAndreas Gohr     * @param string $query JSON encoded vector array
182f6ef2e50SAndreas Gohr     * @param string $embedding JSON encoded vector array
183f6ef2e50SAndreas Gohr     * @return float
184f6ef2e50SAndreas Gohr     */
185f6ef2e50SAndreas Gohr    public function sqliteCosineSimilarityCallback($query, $embedding)
186f6ef2e50SAndreas Gohr    {
187f6ef2e50SAndreas Gohr        return (float)$this->cosineSimilarity(json_decode($query), json_decode($embedding));
188f6ef2e50SAndreas Gohr    }
189f6ef2e50SAndreas Gohr
190f6ef2e50SAndreas Gohr    /**
191f6ef2e50SAndreas Gohr     * Calculate the cosine similarity between two vectors
192f6ef2e50SAndreas Gohr     *
19335555bacSAndreas Gohr     * Actually just calculating the dot product of the two vectors, since they are normalized
19435555bacSAndreas Gohr     *
19535555bacSAndreas Gohr     * @param float[] $queryVector The normalized vector of the search phrase
19635555bacSAndreas Gohr     * @param float[] $embedding The normalized vector of the chunk
197f6ef2e50SAndreas Gohr     * @return float
198f6ef2e50SAndreas Gohr     */
199f6ef2e50SAndreas Gohr    protected function cosineSimilarity($queryVector, $embedding)
200f6ef2e50SAndreas Gohr    {
201f6ef2e50SAndreas Gohr        $dotProduct = 0;
202f6ef2e50SAndreas Gohr        foreach ($queryVector as $key => $value) {
203f6ef2e50SAndreas Gohr            $dotProduct += $value * $embedding[$key];
204f6ef2e50SAndreas Gohr        }
20535555bacSAndreas Gohr        return $dotProduct;
206f6ef2e50SAndreas Gohr    }
207*3379af09SAndreas Gohr
208*3379af09SAndreas Gohr    /**
209*3379af09SAndreas Gohr     * Create new clusters based on random chunks
210*3379af09SAndreas Gohr     *
211*3379af09SAndreas Gohr     * @noinspection SqlWithoutWhere
212*3379af09SAndreas Gohr     */
213*3379af09SAndreas Gohr    protected function createClusters()
214*3379af09SAndreas Gohr    {
215*3379af09SAndreas Gohr        if ($this->logger) $this->logger->info('Creating new clusters...');
216*3379af09SAndreas Gohr        $this->db->getPdo()->beginTransaction();
217*3379af09SAndreas Gohr        try {
218*3379af09SAndreas Gohr            // clean up old cluster data
219*3379af09SAndreas Gohr            $query = 'DELETE FROM clusters';
220*3379af09SAndreas Gohr            $this->db->exec($query);
221*3379af09SAndreas Gohr            $query = 'UPDATE embeddings SET cluster = NULL';
222*3379af09SAndreas Gohr            $this->db->exec($query);
223*3379af09SAndreas Gohr
224*3379af09SAndreas Gohr            // get a random selection of chunks
225*3379af09SAndreas Gohr            $query = 'SELECT id, embedding FROM embeddings ORDER BY RANDOM() LIMIT ?';
226*3379af09SAndreas Gohr            $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]);
227*3379af09SAndreas Gohr            if (!$result) return; // no data to cluster
228*3379af09SAndreas Gohr            $dimensions = count(json_decode($result[0]['embedding'], true));
229*3379af09SAndreas Gohr
230*3379af09SAndreas Gohr            // get the number of all chunks, to calculate the number of clusters
231*3379af09SAndreas Gohr            $query = 'SELECT COUNT(*) FROM embeddings';
232*3379af09SAndreas Gohr            $total = $this->db->queryValue($query);
233*3379af09SAndreas Gohr            $clustercount = ceil($total / self::CLUSTER_SIZE);
234*3379af09SAndreas Gohr            if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]);
235*3379af09SAndreas Gohr
236*3379af09SAndreas Gohr            // cluster them using kmeans
237*3379af09SAndreas Gohr            $space = new Space($dimensions);
238*3379af09SAndreas Gohr            foreach ($result as $record) {
239*3379af09SAndreas Gohr                $space->addPoint(json_decode($record['embedding'], true));
240*3379af09SAndreas Gohr            }
241*3379af09SAndreas Gohr            $clusters = $space->solve($clustercount, function ($space, $clusters) {
242*3379af09SAndreas Gohr                static $iterations = 0;
243*3379af09SAndreas Gohr                ++$iterations;
244*3379af09SAndreas Gohr                if ($this->logger) {
245*3379af09SAndreas Gohr                    $clustercounts = join(',', array_map('count', $clusters));
246*3379af09SAndreas Gohr                    $this->logger->info('Iteration {iteration}: [{clusters}]', [
247*3379af09SAndreas Gohr                        'iteration' => $iterations, 'clusters' => $clustercounts
248*3379af09SAndreas Gohr                    ]);
249*3379af09SAndreas Gohr                }
250*3379af09SAndreas Gohr            }, Cluster::INIT_KMEANS_PLUS_PLUS);
251*3379af09SAndreas Gohr
252*3379af09SAndreas Gohr            // store the clusters
253*3379af09SAndreas Gohr            foreach ($clusters as $clusterID => $cluster) {
254*3379af09SAndreas Gohr                /** @var Cluster $cluster */
255*3379af09SAndreas Gohr                $centroid = $cluster->getCoordinates();
256*3379af09SAndreas Gohr                $query = 'INSERT INTO clusters (cluster, centroid) VALUES (?, ?)';
257*3379af09SAndreas Gohr                $this->db->exec($query, [$clusterID, json_encode($centroid)]);
258*3379af09SAndreas Gohr            }
259*3379af09SAndreas Gohr
260*3379af09SAndreas Gohr            $this->db->getPdo()->commit();
261*3379af09SAndreas Gohr            if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]);
262*3379af09SAndreas Gohr        } catch (\Exception $e) {
263*3379af09SAndreas Gohr            $this->db->getPdo()->rollBack();
264*3379af09SAndreas Gohr            throw new \RuntimeException('Clustering failed', 0, $e);
265*3379af09SAndreas Gohr        }
266*3379af09SAndreas Gohr    }
267*3379af09SAndreas Gohr
268*3379af09SAndreas Gohr    /**
269*3379af09SAndreas Gohr     * Assign the nearest cluster for all chunks that don't have one
270*3379af09SAndreas Gohr     *
271*3379af09SAndreas Gohr     * @return void
272*3379af09SAndreas Gohr     */
273*3379af09SAndreas Gohr    protected function setChunkClusters()
274*3379af09SAndreas Gohr    {
275*3379af09SAndreas Gohr        if ($this->logger) $this->logger->info('Assigning clusters to chunks...');
276*3379af09SAndreas Gohr        $query = 'SELECT id, embedding FROM embeddings WHERE cluster IS NULL';
277*3379af09SAndreas Gohr        $handle = $this->db->query($query);
278*3379af09SAndreas Gohr
279*3379af09SAndreas Gohr        while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) {
280*3379af09SAndreas Gohr            $vector = json_decode($record['embedding'], true);
281*3379af09SAndreas Gohr            $cluster = $this->getCluster($vector);
282*3379af09SAndreas Gohr            $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?';
283*3379af09SAndreas Gohr            $this->db->exec($query, [$cluster, $record['id']]);
284*3379af09SAndreas Gohr            if ($this->logger) $this->logger->success(
285*3379af09SAndreas Gohr                'Chunk {id} assigned to cluster {cluster}', ['id' => $record['id'], 'cluster' => $cluster]
286*3379af09SAndreas Gohr            );
287*3379af09SAndreas Gohr        }
288*3379af09SAndreas Gohr        $handle->closeCursor();
289*3379af09SAndreas Gohr    }
290*3379af09SAndreas Gohr
291*3379af09SAndreas Gohr    /**
292*3379af09SAndreas Gohr     * Get the nearest cluster for the given vector
293*3379af09SAndreas Gohr     *
294*3379af09SAndreas Gohr     * @param float[] $vector
295*3379af09SAndreas Gohr     * @return int|null
296*3379af09SAndreas Gohr     */
297*3379af09SAndreas Gohr    protected function getCluster($vector)
298*3379af09SAndreas Gohr    {
299*3379af09SAndreas Gohr        $query = 'SELECT cluster, centroid FROM clusters ORDER BY COSIM(centroid, ?) DESC LIMIT 1';
300*3379af09SAndreas Gohr        $result = $this->db->queryRecord($query, [json_encode($vector)]);
301*3379af09SAndreas Gohr        if (!$result) return null;
302*3379af09SAndreas Gohr        return $result['cluster'];
303*3379af09SAndreas Gohr    }
304*3379af09SAndreas Gohr
305*3379af09SAndreas Gohr    /**
306*3379af09SAndreas Gohr     * Check if clustering has been done before
307*3379af09SAndreas Gohr     * @return bool
308*3379af09SAndreas Gohr     */
309*3379af09SAndreas Gohr    protected function hasClusters()
310*3379af09SAndreas Gohr    {
311*3379af09SAndreas Gohr        $query = 'SELECT COUNT(*) FROM clusters';
312*3379af09SAndreas Gohr        return $this->db->queryValue($query) > 0;
313*3379af09SAndreas Gohr    }
314f6ef2e50SAndreas Gohr}
315