1<?php /** @noinspection SqlResolve */ 2 3namespace dokuwiki\plugin\aichat\Storage; 4 5use dokuwiki\plugin\aichat\AIChat; 6use dokuwiki\plugin\aichat\Chunk; 7use dokuwiki\plugin\sqlite\SQLiteDB; 8use KMeans\Cluster; 9use KMeans\Space; 10 11/** 12 * Implements the storage backend using a SQLite database 13 * 14 * Note: all embeddings are stored and returned as normalized vectors 15 */ 16class SQLiteStorage extends AbstractStorage 17{ 18 /** @var float minimum similarity to consider a chunk a match */ 19 const SIMILARITY_THRESHOLD = 0.75; 20 21 /** @var int Number of documents to randomly sample to create the clusters */ 22 const SAMPLE_SIZE = 2000; 23 /** @var int The average size of each cluster */ 24 const CLUSTER_SIZE = 400; 25 26 /** @var SQLiteDB */ 27 protected $db; 28 29 protected $useLanguageClusters = false; 30 31 /** 32 * Initializes the database connection and registers our custom function 33 * 34 * @throws \Exception 35 */ 36 public function __construct() 37 { 38 $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/'); 39 $this->db->getPdo()->sqliteCreateFunction('COSIM', [$this, 'sqliteCosineSimilarityCallback'], 2); 40 41 $helper = plugin_load('helper', 'aichat'); 42 $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED; 43 } 44 45 /** @inheritdoc */ 46 public function getChunk($chunkID) 47 { 48 $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]); 49 if (!$record) return null; 50 51 return new Chunk( 52 $record['page'], 53 $record['id'], 54 $record['chunk'], 55 json_decode($record['embedding'], true), 56 $record['lang'], 57 $record['created'] 58 ); 59 } 60 61 /** @inheritdoc */ 62 public function startCreation($clear = false) 63 { 64 if ($clear) { 65 /** @noinspection SqlWithoutWhere */ 66 $this->db->exec('DELETE FROM embeddings'); 67 } 68 } 69 70 /** @inheritdoc */ 71 public function reusePageChunks($page, $firstChunkID) 72 { 73 // no-op 74 } 75 76 /** @inheritdoc */ 77 public function deletePageChunks($page, $firstChunkID) 78 { 79 $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]); 80 } 81 82 /** @inheritdoc */ 83 public function addPageChunks($chunks) 84 { 85 foreach ($chunks as $chunk) { 86 $this->db->saveRecord('embeddings', [ 87 'page' => $chunk->getPage(), 88 'id' => $chunk->getId(), 89 'chunk' => $chunk->getText(), 90 'embedding' => json_encode($chunk->getEmbedding()), 91 'created' => $chunk->getCreated(), 92 'lang' => $chunk->getLanguage(), 93 ]); 94 } 95 } 96 97 /** @inheritdoc */ 98 public function finalizeCreation() 99 { 100 if (!$this->hasClusters()) { 101 $this->createClusters(); 102 } 103 $this->setChunkClusters(); 104 105 $this->db->exec('VACUUM'); 106 } 107 108 /** @inheritdoc */ 109 public function runMaintenance() 110 { 111 $this->createClusters(); 112 $this->setChunkClusters(); 113 } 114 115 /** @inheritdoc */ 116 public function getPageChunks($page, $firstChunkID) 117 { 118 $result = $this->db->queryAll( 119 'SELECT * FROM embeddings WHERE page = ?', 120 [$page] 121 ); 122 $chunks = []; 123 foreach ($result as $record) { 124 $chunks[] = new Chunk( 125 $record['page'], 126 $record['id'], 127 $record['chunk'], 128 json_decode($record['embedding'], true), 129 $record['lang'], 130 $record['created'] 131 ); 132 } 133 return $chunks; 134 } 135 136 /** @inheritdoc */ 137 public function getSimilarChunks($vector, $lang = '', $limit = 4) 138 { 139 $cluster = $this->getCluster($vector, $lang); 140 if ($this->logger) $this->logger->info( 141 'Using cluster {cluster} for similarity search', ['cluster' => $cluster] 142 ); 143 144 $result = $this->db->queryAll( 145 'SELECT *, COSIM(?, embedding) AS similarity 146 FROM embeddings 147 WHERE cluster = ? 148 AND GETACCESSLEVEL(page) > 0 149 AND similarity > CAST(? AS FLOAT) 150 ORDER BY similarity DESC 151 LIMIT ?', 152 [json_encode($vector), $cluster, self::SIMILARITY_THRESHOLD, $limit] 153 ); 154 $chunks = []; 155 foreach ($result as $record) { 156 $chunks[] = new Chunk( 157 $record['page'], 158 $record['id'], 159 $record['chunk'], 160 json_decode($record['embedding'], true), 161 $record['lang'], 162 $record['created'], 163 $record['similarity'] 164 ); 165 } 166 return $chunks; 167 } 168 169 /** @inheritdoc */ 170 public function statistics() 171 { 172 $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings'); 173 $size = $this->db->queryValue( 174 'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()' 175 ); 176 $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster"; 177 $clusters = $this->db->queryKeyValueList($query); 178 179 return [ 180 'storage type' => 'SQLite', 181 'chunks' => $items, 182 'db size' => filesize_h($size), 183 'clusters' => $clusters, 184 ]; 185 } 186 187 /** 188 * Method registered as SQLite callback to calculate the cosine similarity 189 * 190 * @param string $query JSON encoded vector array 191 * @param string $embedding JSON encoded vector array 192 * @return float 193 */ 194 public function sqliteCosineSimilarityCallback($query, $embedding) 195 { 196 return (float)$this->cosineSimilarity(json_decode($query), json_decode($embedding)); 197 } 198 199 /** 200 * Calculate the cosine similarity between two vectors 201 * 202 * Actually just calculating the dot product of the two vectors, since they are normalized 203 * 204 * @param float[] $queryVector The normalized vector of the search phrase 205 * @param float[] $embedding The normalized vector of the chunk 206 * @return float 207 */ 208 protected function cosineSimilarity($queryVector, $embedding) 209 { 210 $dotProduct = 0; 211 foreach ($queryVector as $key => $value) { 212 $dotProduct += $value * $embedding[$key]; 213 } 214 return $dotProduct; 215 } 216 217 /** 218 * Create new clusters based on random chunks 219 * 220 * @return void 221 */ 222 protected function createClusters() 223 { 224 if ($this->useLanguageClusters) { 225 $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings'); 226 $langs = array_column($result, 'lang'); 227 foreach ($langs as $lang) { 228 $this->createLanguageClusters($lang); 229 } 230 } else { 231 $this->createLanguageClusters(''); 232 } 233 } 234 235 /** 236 * Create new clusters based on random chunks for the given Language 237 * 238 * @param string $lang The language to cluster, empty when all languages go into the same cluster 239 * @noinspection SqlWithoutWhere 240 */ 241 protected function createLanguageClusters($lang) 242 { 243 if ($lang != '') { 244 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 245 } else { 246 $where = ''; 247 } 248 249 if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]); 250 $this->db->getPdo()->beginTransaction(); 251 try { 252 // clean up old cluster data 253 $query = "DELETE FROM clusters $where"; 254 $this->db->exec($query); 255 $query = "UPDATE embeddings SET cluster = NULL $where"; 256 $this->db->exec($query); 257 258 // get a random selection of chunks 259 $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?"; 260 $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]); 261 if (!$result) return; // no data to cluster 262 $dimensions = count(json_decode($result[0]['embedding'], true)); 263 264 // how many clusters? 265 if (count($result) < self::CLUSTER_SIZE * 3) { 266 // there would be less than 3 clusters, so just use one 267 $clustercount = 1; 268 } else { 269 // get the number of all chunks, to calculate the number of clusters 270 $query = "SELECT COUNT(*) FROM embeddings $where"; 271 $total = $this->db->queryValue($query); 272 $clustercount = ceil($total / self::CLUSTER_SIZE); 273 } 274 if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]); 275 276 // cluster them using kmeans 277 $space = new Space($dimensions); 278 foreach ($result as $record) { 279 $space->addPoint(json_decode($record['embedding'], true)); 280 } 281 $clusters = $space->solve($clustercount, function ($space, $clusters) { 282 static $iterations = 0; 283 ++$iterations; 284 if ($this->logger) { 285 $clustercounts = join(',', array_map('count', $clusters)); 286 $this->logger->info('Iteration {iteration}: [{clusters}]', [ 287 'iteration' => $iterations, 'clusters' => $clustercounts 288 ]); 289 } 290 }, Cluster::INIT_KMEANS_PLUS_PLUS); 291 292 // store the clusters 293 foreach ($clusters as $clusterID => $cluster) { 294 /** @var Cluster $cluster */ 295 $centroid = $cluster->getCoordinates(); 296 $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)'; 297 $this->db->exec($query, [$lang, json_encode($centroid)]); 298 } 299 300 $this->db->getPdo()->commit(); 301 if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]); 302 } catch (\Exception $e) { 303 $this->db->getPdo()->rollBack(); 304 throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e); 305 } 306 } 307 308 /** 309 * Assign the nearest cluster for all chunks that don't have one 310 * 311 * @return void 312 */ 313 protected function setChunkClusters() 314 { 315 if ($this->logger) $this->logger->info('Assigning clusters to chunks...'); 316 $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL'; 317 $handle = $this->db->query($query); 318 319 while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) { 320 $vector = json_decode($record['embedding'], true); 321 $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : ''); 322 $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?'; 323 $this->db->exec($query, [$cluster, $record['id']]); 324 if ($this->logger) $this->logger->success( 325 'Chunk {id} assigned to cluster {cluster}', ['id' => $record['id'], 'cluster' => $cluster] 326 ); 327 } 328 $handle->closeCursor(); 329 } 330 331 /** 332 * Get the nearest cluster for the given vector 333 * 334 * @param float[] $vector 335 * @return int|null 336 */ 337 protected function getCluster($vector, $lang) 338 { 339 if ($lang != '') { 340 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 341 } else { 342 $where = ''; 343 } 344 345 $query = "SELECT cluster, centroid 346 FROM clusters 347 $where 348 ORDER BY COSIM(centroid, ?) DESC 349 LIMIT 1"; 350 351 $result = $this->db->queryRecord($query, [json_encode($vector)]); 352 if (!$result) return null; 353 return $result['cluster']; 354 } 355 356 /** 357 * Check if clustering has been done before 358 * @return bool 359 */ 360 protected function hasClusters() 361 { 362 $query = 'SELECT COUNT(*) FROM clusters'; 363 return $this->db->queryValue($query) > 0; 364 } 365 366 /** 367 * Writes TSV files for visualizing with http://projector.tensorflow.org/ 368 * 369 * @param string $vectorfile path to the file with the vectors 370 * @param string $metafile path to the file with the metadata 371 * @return void 372 */ 373 public function dumpTSV($vectorfile, $metafile) 374 { 375 $query = 'SELECT * FROM embeddings'; 376 $handle = $this->db->query($query); 377 378 $header = implode("\t", ['id', 'page', 'created']); 379 file_put_contents($metafile, $header . "\n", FILE_APPEND); 380 381 while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) { 382 $vector = json_decode($row['embedding'], true); 383 $vector = implode("\t", $vector); 384 385 $meta = implode("\t", [$row['id'], $row['page'], $row['created']]); 386 387 file_put_contents($vectorfile, $vector . "\n", FILE_APPEND); 388 file_put_contents($metafile, $meta . "\n", FILE_APPEND); 389 } 390 } 391} 392