1<?php /** @noinspection SqlResolve */ 2 3namespace dokuwiki\plugin\aichat\Storage; 4 5use dokuwiki\plugin\aichat\AIChat; 6use dokuwiki\plugin\aichat\Chunk; 7use dokuwiki\plugin\sqlite\SQLiteDB; 8use KMeans\Cluster; 9use KMeans\Space; 10 11/** 12 * Implements the storage backend using a SQLite database 13 * 14 * Note: all embeddings are stored and returned as normalized vectors 15 */ 16class SQLiteStorage extends AbstractStorage 17{ 18 /** @var float minimum similarity to consider a chunk a match */ 19 const SIMILARITY_THRESHOLD = 0.75; 20 21 /** @var int Number of documents to randomly sample to create the clusters */ 22 const SAMPLE_SIZE = 2000; 23 /** @var int The average size of each cluster */ 24 const CLUSTER_SIZE = 400; 25 26 /** @var SQLiteDB */ 27 protected $db; 28 29 protected $useLanguageClusters = false; 30 31 /** 32 * Initializes the database connection and registers our custom function 33 * 34 * @throws \Exception 35 */ 36 public function __construct() 37 { 38 $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/'); 39 $this->db->getPdo()->sqliteCreateFunction('COSIM', [$this, 'sqliteCosineSimilarityCallback'], 2); 40 41 $helper = plugin_load('helper', 'aichat'); 42 $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED; 43 } 44 45 /** @inheritdoc */ 46 public function getChunk($chunkID) 47 { 48 $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]); 49 if (!$record) return null; 50 51 return new Chunk( 52 $record['page'], 53 $record['id'], 54 $record['chunk'], 55 json_decode($record['embedding'], true), 56 $record['lang'], 57 $record['created'] 58 ); 59 } 60 61 /** @inheritdoc */ 62 public function startCreation($clear = false) 63 { 64 if ($clear) { 65 /** @noinspection SqlWithoutWhere */ 66 $this->db->exec('DELETE FROM embeddings'); 67 } 68 } 69 70 /** @inheritdoc */ 71 public function reusePageChunks($page, $firstChunkID) 72 { 73 // no-op 74 } 75 76 /** @inheritdoc */ 77 public function deletePageChunks($page, $firstChunkID) 78 { 79 $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]); 80 } 81 82 /** @inheritdoc */ 83 public function addPageChunks($chunks) 84 { 85 foreach ($chunks as $chunk) { 86 $this->db->saveRecord('embeddings', [ 87 'page' => $chunk->getPage(), 88 'id' => $chunk->getId(), 89 'chunk' => $chunk->getText(), 90 'embedding' => json_encode($chunk->getEmbedding()), 91 'created' => $chunk->getCreated(), 92 'lang' => $chunk->getLanguage(), 93 ]); 94 } 95 } 96 97 /** @inheritdoc */ 98 public function finalizeCreation() 99 { 100 if (!$this->hasClusters()) { 101 $this->createClusters(); 102 } 103 $this->setChunkClusters(); 104 105 $this->db->exec('VACUUM'); 106 } 107 108 /** @inheritdoc */ 109 public function runMaintenance() 110 { 111 $this->createClusters(); 112 $this->setChunkClusters(); 113 } 114 115 /** @inheritdoc */ 116 public function getPageChunks($page, $firstChunkID) 117 { 118 $result = $this->db->queryAll( 119 'SELECT * FROM embeddings WHERE page = ?', 120 [$page] 121 ); 122 $chunks = []; 123 foreach ($result as $record) { 124 $chunks[] = new Chunk( 125 $record['page'], 126 $record['id'], 127 $record['chunk'], 128 json_decode($record['embedding'], true), 129 $record['lang'], 130 $record['created'] 131 ); 132 } 133 return $chunks; 134 } 135 136 /** @inheritdoc */ 137 public function getSimilarChunks($vector, $lang = '', $limit = 4) 138 { 139 $cluster = $this->getCluster($vector, $lang); 140 if ($this->logger) $this->logger->info( 141 'Using cluster {cluster} for similarity search', ['cluster' => $cluster] 142 ); 143 144 $result = $this->db->queryAll( 145 'SELECT *, COSIM(?, embedding) AS similarity 146 FROM embeddings 147 WHERE cluster = ? 148 AND GETACCESSLEVEL(page) > 0 149 AND similarity > CAST(? AS FLOAT) 150 ORDER BY similarity DESC 151 LIMIT ?', 152 [json_encode($vector), $cluster, self::SIMILARITY_THRESHOLD, $limit] 153 ); 154 $chunks = []; 155 foreach ($result as $record) { 156 $chunks[] = new Chunk( 157 $record['page'], 158 $record['id'], 159 $record['chunk'], 160 json_decode($record['embedding'], true), 161 $record['lang'], 162 $record['created'], 163 $record['similarity'] 164 ); 165 } 166 return $chunks; 167 } 168 169 /** @inheritdoc */ 170 public function statistics() 171 { 172 $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings'); 173 $size = $this->db->queryValue( 174 'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()' 175 ); 176 $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt FROM embeddings GROUP BY cluster ORDER BY cluster"; 177 $clusters = $this->db->queryKeyValueList($query); 178 179 return [ 180 'storage type' => 'SQLite', 181 'chunks' => $items, 182 'db size' => filesize_h($size), 183 'clusters' => $clusters, 184 ]; 185 } 186 187 /** 188 * Method registered as SQLite callback to calculate the cosine similarity 189 * 190 * @param string $query JSON encoded vector array 191 * @param string $embedding JSON encoded vector array 192 * @return float 193 */ 194 public function sqliteCosineSimilarityCallback($query, $embedding) 195 { 196 return (float)$this->cosineSimilarity(json_decode($query), json_decode($embedding)); 197 } 198 199 /** 200 * Calculate the cosine similarity between two vectors 201 * 202 * Actually just calculating the dot product of the two vectors, since they are normalized 203 * 204 * @param float[] $queryVector The normalized vector of the search phrase 205 * @param float[] $embedding The normalized vector of the chunk 206 * @return float 207 */ 208 protected function cosineSimilarity($queryVector, $embedding) 209 { 210 $dotProduct = 0; 211 foreach ($queryVector as $key => $value) { 212 $dotProduct += $value * $embedding[$key]; 213 } 214 return $dotProduct; 215 } 216 217 /** 218 * Create new clusters based on random chunks 219 * 220 * @return void 221 */ 222 protected function createClusters() 223 { 224 if($this->useLanguageClusters) { 225 $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings'); 226 $langs = array_column($result, 'lang'); 227 foreach ($langs as $lang) { 228 $this->createLanguageClusters($lang); 229 } 230 } else { 231 $this->createLanguageClusters(''); 232 } 233 } 234 235 /** 236 * Create new clusters based on random chunks for the given Language 237 * 238 * @param string $lang The language to cluster, empty when all languages go into the same cluster 239 * @noinspection SqlWithoutWhere 240 */ 241 protected function createLanguageClusters($lang) 242 { 243 if($lang != '') { 244 $where = 'WHERE lang = '. $this->db->getPdo()->quote($lang); 245 } else { 246 $where = ''; 247 } 248 249 if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]); 250 $this->db->getPdo()->beginTransaction(); 251 try { 252 // clean up old cluster data 253 $query = "DELETE FROM clusters $where"; 254 $this->db->exec($query); 255 $query = "UPDATE embeddings SET cluster = NULL $where"; 256 $this->db->exec($query); 257 258 // get a random selection of chunks 259 $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?"; 260 $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]); 261 if (!$result) return; // no data to cluster 262 $dimensions = count(json_decode($result[0]['embedding'], true)); 263 264 // get the number of all chunks, to calculate the number of clusters 265 $query = "SELECT COUNT(*) FROM embeddings $where"; 266 $total = $this->db->queryValue($query); 267 $clustercount = ceil($total / self::CLUSTER_SIZE); 268 if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]); 269 270 // cluster them using kmeans 271 $space = new Space($dimensions); 272 foreach ($result as $record) { 273 $space->addPoint(json_decode($record['embedding'], true)); 274 } 275 $clusters = $space->solve($clustercount, function ($space, $clusters) { 276 static $iterations = 0; 277 ++$iterations; 278 if ($this->logger) { 279 $clustercounts = join(',', array_map('count', $clusters)); 280 $this->logger->info('Iteration {iteration}: [{clusters}]', [ 281 'iteration' => $iterations, 'clusters' => $clustercounts 282 ]); 283 } 284 }, Cluster::INIT_KMEANS_PLUS_PLUS); 285 286 // store the clusters 287 foreach ($clusters as $clusterID => $cluster) { 288 /** @var Cluster $cluster */ 289 $centroid = $cluster->getCoordinates(); 290 $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)'; 291 $this->db->exec($query, [$lang, json_encode($centroid)]); 292 } 293 294 $this->db->getPdo()->commit(); 295 if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]); 296 } catch (\Exception $e) { 297 $this->db->getPdo()->rollBack(); 298 throw new \RuntimeException('Clustering failed: '.$e->getMessage(), 0, $e); 299 } 300 } 301 302 /** 303 * Assign the nearest cluster for all chunks that don't have one 304 * 305 * @return void 306 */ 307 protected function setChunkClusters() 308 { 309 if ($this->logger) $this->logger->info('Assigning clusters to chunks...'); 310 $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL'; 311 $handle = $this->db->query($query); 312 313 while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) { 314 $vector = json_decode($record['embedding'], true); 315 $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : ''); 316 $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?'; 317 $this->db->exec($query, [$cluster, $record['id']]); 318 if ($this->logger) $this->logger->success( 319 'Chunk {id} assigned to cluster {cluster}', ['id' => $record['id'], 'cluster' => $cluster] 320 ); 321 } 322 $handle->closeCursor(); 323 } 324 325 /** 326 * Get the nearest cluster for the given vector 327 * 328 * @param float[] $vector 329 * @return int|null 330 */ 331 protected function getCluster($vector, $lang) 332 { 333 if($lang != '') { 334 $where = 'WHERE lang = '. $this->db->getPdo()->quote($lang); 335 } else { 336 $where = ''; 337 } 338 339 $query = "SELECT cluster, centroid 340 FROM clusters 341 $where 342 ORDER BY COSIM(centroid, ?) DESC 343 LIMIT 1"; 344 345 $result = $this->db->queryRecord($query, [json_encode($vector)]); 346 if (!$result) return null; 347 return $result['cluster']; 348 } 349 350 /** 351 * Check if clustering has been done before 352 * @return bool 353 */ 354 protected function hasClusters() 355 { 356 $query = 'SELECT COUNT(*) FROM clusters'; 357 return $this->db->queryValue($query) > 0; 358 } 359 360 /** 361 * Writes TSV files for visualizing with http://projector.tensorflow.org/ 362 * 363 * @param string $vectorfile path to the file with the vectors 364 * @param string $metafile path to the file with the metadata 365 * @return void 366 */ 367 public function dumpTSV($vectorfile, $metafile) 368 { 369 $query = 'SELECT * FROM embeddings'; 370 $handle = $this->db->query($query); 371 372 $header = implode("\t", ['id', 'page', 'created']); 373 file_put_contents($metafile, $header . "\n", FILE_APPEND); 374 375 while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) { 376 $vector = json_decode($row['embedding'], true); 377 $vector = implode("\t", $vector); 378 379 $meta = implode("\t", [$row['id'], $row['page'], $row['created']]); 380 381 file_put_contents($vectorfile, $vector . "\n", FILE_APPEND); 382 file_put_contents($metafile, $meta . "\n", FILE_APPEND); 383 } 384 } 385} 386