1<?php /** @noinspection SqlResolve */ 2 3namespace dokuwiki\plugin\aichat\Storage; 4 5use dokuwiki\plugin\aichat\Chunk; 6use dokuwiki\plugin\sqlite\SQLiteDB; 7use KMeans\Cluster; 8use KMeans\Point; 9use KMeans\Space; 10 11/** 12 * Implements the storage backend using a SQLite database 13 * 14 * Note: all embeddings are stored and returned as normalized vectors 15 */ 16class SQLiteStorage extends AbstractStorage 17{ 18 /** @var float minimum similarity to consider a chunk a match */ 19 const SIMILARITY_THRESHOLD = 0.75; 20 21 /** @var int Number of documents to randomly sample to create the clusters */ 22 const SAMPLE_SIZE = 2000; 23 /** @var int The average size of each cluster */ 24 const CLUSTER_SIZE = 400; 25 26 /** @var SQLiteDB */ 27 protected $db; 28 29 /** 30 * Initializes the database connection and registers our custom function 31 * 32 * @throws \Exception 33 */ 34 public function __construct() 35 { 36 $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/'); 37 $this->db->getPdo()->sqliteCreateFunction('COSIM', [$this, 'sqliteCosineSimilarityCallback'], 2); 38 } 39 40 /** @inheritdoc */ 41 public function getChunk($chunkID) 42 { 43 $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]); 44 if (!$record) return null; 45 46 return new Chunk( 47 $record['page'], 48 $record['id'], 49 $record['chunk'], 50 json_decode($record['embedding'], true), 51 $record['created'] 52 ); 53 } 54 55 /** @inheritdoc */ 56 public function startCreation($clear = false) 57 { 58 if ($clear) { 59 /** @noinspection SqlWithoutWhere */ 60 $this->db->exec('DELETE FROM embeddings'); 61 } 62 } 63 64 /** @inheritdoc */ 65 public function reusePageChunks($page, $firstChunkID) 66 { 67 // no-op 68 } 69 70 /** @inheritdoc */ 71 public function deletePageChunks($page, $firstChunkID) 72 { 73 $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]); 74 } 75 76 /** @inheritdoc */ 77 public function addPageChunks($chunks) 78 { 79 foreach ($chunks as $chunk) { 80 $this->db->saveRecord('embeddings', [ 81 'page' => $chunk->getPage(), 82 'id' => $chunk->getId(), 83 'chunk' => $chunk->getText(), 84 'embedding' => json_encode($chunk->getEmbedding()), 85 'created' => $chunk->getCreated() 86 ]); 87 } 88 } 89 90 /** @inheritdoc */ 91 public function finalizeCreation() 92 { 93 if (!$this->hasClusters()) { 94 $this->createClusters(); 95 } 96 $this->setChunkClusters(); 97 98 $this->db->exec('VACUUM'); 99 } 100 101 /** @inheritdoc */ 102 public function runMaintenance() 103 { 104 $this->createClusters(); 105 $this->setChunkClusters(); 106 } 107 108 109 /** @inheritdoc */ 110 public function getPageChunks($page, $firstChunkID) 111 { 112 $result = $this->db->queryAll( 113 'SELECT * FROM embeddings WHERE page = ?', 114 [$page] 115 ); 116 $chunks = []; 117 foreach ($result as $record) { 118 $chunks[] = new Chunk( 119 $record['page'], 120 $record['id'], 121 $record['chunk'], 122 json_decode($record['embedding'], true), 123 $record['created'] 124 ); 125 } 126 return $chunks; 127 } 128 129 130 /** @inheritdoc */ 131 public function getSimilarChunks($vector, $limit = 4) 132 { 133 $cluster = $this->getCluster($vector); 134 if ($this->logger) $this->logger->info('Using cluster {cluster} for similarity search', ['cluster' => $cluster]); 135 136 $result = $this->db->queryAll( 137 'SELECT *, COSIM(?, embedding) AS similarity 138 FROM embeddings 139 WHERE cluster = ? 140 AND GETACCESSLEVEL(page) > 0 141 AND similarity > CAST(? AS FLOAT) 142 ORDER BY similarity DESC 143 LIMIT ?', 144 [json_encode($vector), $cluster, self::SIMILARITY_THRESHOLD, $limit] 145 ); 146 $chunks = []; 147 foreach ($result as $record) { 148 $chunks[] = new Chunk( 149 $record['page'], 150 $record['id'], 151 $record['chunk'], 152 json_decode($record['embedding'], true), 153 $record['created'], 154 $record['similarity'] 155 ); 156 } 157 return $chunks; 158 } 159 160 /** @inheritdoc */ 161 public function statistics() 162 { 163 $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings'); 164 $size = $this->db->queryValue( 165 'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()' 166 ); 167 $query = "SELECT cluster, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster"; 168 $clusters = $this->db->queryKeyValueList($query); 169 170 return [ 171 'storage type' => 'SQLite', 172 'chunks' => $items, 173 'db size' => filesize_h($size), 174 'clusters' => $clusters, 175 ]; 176 } 177 178 /** 179 * Method registered as SQLite callback to calculate the cosine similarity 180 * 181 * @param string $query JSON encoded vector array 182 * @param string $embedding JSON encoded vector array 183 * @return float 184 */ 185 public function sqliteCosineSimilarityCallback($query, $embedding) 186 { 187 return (float)$this->cosineSimilarity(json_decode($query), json_decode($embedding)); 188 } 189 190 /** 191 * Calculate the cosine similarity between two vectors 192 * 193 * Actually just calculating the dot product of the two vectors, since they are normalized 194 * 195 * @param float[] $queryVector The normalized vector of the search phrase 196 * @param float[] $embedding The normalized vector of the chunk 197 * @return float 198 */ 199 protected function cosineSimilarity($queryVector, $embedding) 200 { 201 $dotProduct = 0; 202 foreach ($queryVector as $key => $value) { 203 $dotProduct += $value * $embedding[$key]; 204 } 205 return $dotProduct; 206 } 207 208 /** 209 * Create new clusters based on random chunks 210 * 211 * @noinspection SqlWithoutWhere 212 */ 213 protected function createClusters() 214 { 215 if ($this->logger) $this->logger->info('Creating new clusters...'); 216 $this->db->getPdo()->beginTransaction(); 217 try { 218 // clean up old cluster data 219 $query = 'DELETE FROM clusters'; 220 $this->db->exec($query); 221 $query = 'UPDATE embeddings SET cluster = NULL'; 222 $this->db->exec($query); 223 224 // get a random selection of chunks 225 $query = 'SELECT id, embedding FROM embeddings ORDER BY RANDOM() LIMIT ?'; 226 $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]); 227 if (!$result) return; // no data to cluster 228 $dimensions = count(json_decode($result[0]['embedding'], true)); 229 230 // get the number of all chunks, to calculate the number of clusters 231 $query = 'SELECT COUNT(*) FROM embeddings'; 232 $total = $this->db->queryValue($query); 233 $clustercount = ceil($total / self::CLUSTER_SIZE); 234 if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]); 235 236 // cluster them using kmeans 237 $space = new Space($dimensions); 238 foreach ($result as $record) { 239 $space->addPoint(json_decode($record['embedding'], true)); 240 } 241 $clusters = $space->solve($clustercount, function ($space, $clusters) { 242 static $iterations = 0; 243 ++$iterations; 244 if ($this->logger) { 245 $clustercounts = join(',', array_map('count', $clusters)); 246 $this->logger->info('Iteration {iteration}: [{clusters}]', [ 247 'iteration' => $iterations, 'clusters' => $clustercounts 248 ]); 249 } 250 }, Cluster::INIT_KMEANS_PLUS_PLUS); 251 252 // store the clusters 253 foreach ($clusters as $clusterID => $cluster) { 254 /** @var Cluster $cluster */ 255 $centroid = $cluster->getCoordinates(); 256 $query = 'INSERT INTO clusters (cluster, centroid) VALUES (?, ?)'; 257 $this->db->exec($query, [$clusterID, json_encode($centroid)]); 258 } 259 260 $this->db->getPdo()->commit(); 261 if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]); 262 } catch (\Exception $e) { 263 $this->db->getPdo()->rollBack(); 264 throw new \RuntimeException('Clustering failed', 0, $e); 265 } 266 } 267 268 /** 269 * Assign the nearest cluster for all chunks that don't have one 270 * 271 * @return void 272 */ 273 protected function setChunkClusters() 274 { 275 if ($this->logger) $this->logger->info('Assigning clusters to chunks...'); 276 $query = 'SELECT id, embedding FROM embeddings WHERE cluster IS NULL'; 277 $handle = $this->db->query($query); 278 279 while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) { 280 $vector = json_decode($record['embedding'], true); 281 $cluster = $this->getCluster($vector); 282 $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?'; 283 $this->db->exec($query, [$cluster, $record['id']]); 284 if ($this->logger) $this->logger->success( 285 'Chunk {id} assigned to cluster {cluster}', ['id' => $record['id'], 'cluster' => $cluster] 286 ); 287 } 288 $handle->closeCursor(); 289 } 290 291 /** 292 * Get the nearest cluster for the given vector 293 * 294 * @param float[] $vector 295 * @return int|null 296 */ 297 protected function getCluster($vector) 298 { 299 $query = 'SELECT cluster, centroid FROM clusters ORDER BY COSIM(centroid, ?) DESC LIMIT 1'; 300 $result = $this->db->queryRecord($query, [json_encode($vector)]); 301 if (!$result) return null; 302 return $result['cluster']; 303 } 304 305 /** 306 * Check if clustering has been done before 307 * @return bool 308 */ 309 protected function hasClusters() 310 { 311 $query = 'SELECT COUNT(*) FROM clusters'; 312 return $this->db->queryValue($query) > 0; 313 } 314} 315