1<?php 2 3/** @noinspection SqlResolve */ 4 5namespace dokuwiki\plugin\aichat\Storage; 6 7use dokuwiki\plugin\aichat\AIChat; 8use dokuwiki\plugin\aichat\Chunk; 9use dokuwiki\plugin\sqlite\SQLiteDB; 10use KMeans\Cluster; 11use KMeans\Space; 12 13/** 14 * Implements the storage backend using a SQLite database 15 * 16 * Note: all embeddings are stored and returned as normalized vectors 17 */ 18class SQLiteStorage extends AbstractStorage 19{ 20 21 /** @var int Number of documents to randomly sample to create the clusters */ 22 final public const SAMPLE_SIZE = 2000; 23 /** @var int The average size of each cluster */ 24 final public const CLUSTER_SIZE = 400; 25 26 /** @var SQLiteDB */ 27 protected $db; 28 29 protected $useLanguageClusters = false; 30 31 /** @var float minimum similarity to consider a chunk a match */ 32 protected $similarityThreshold = 0; 33 34 /** @inheritdoc */ 35 public function __construct(array $config) 36 { 37 $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/'); 38 $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2); 39 40 $helper = plugin_load('helper', 'aichat'); 41 $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED; 42 43 $this->similarityThreshold = $config['similarityThreshold']/100; 44 } 45 46 /** @inheritdoc */ 47 public function getChunk($chunkID) 48 { 49 $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]); 50 if (!$record) return null; 51 52 return new Chunk( 53 $record['page'], 54 $record['id'], 55 $record['chunk'], 56 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 57 $record['lang'], 58 $record['created'] 59 ); 60 } 61 62 /** @inheritdoc */ 63 public function startCreation($clear = false) 64 { 65 if ($clear) { 66 /** @noinspection SqlWithoutWhere */ 67 $this->db->exec('DELETE FROM embeddings'); 68 /** @noinspection SqlWithoutWhere */ 69 $this->db->exec('DELETE FROM clusters'); 70 } 71 } 72 73 /** @inheritdoc */ 74 public function reusePageChunks($page, $firstChunkID) 75 { 76 // no-op 77 } 78 79 /** @inheritdoc */ 80 public function deletePageChunks($page, $firstChunkID) 81 { 82 $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]); 83 } 84 85 /** @inheritdoc */ 86 public function addPageChunks($chunks) 87 { 88 foreach ($chunks as $chunk) { 89 $this->db->saveRecord('embeddings', [ 90 'page' => $chunk->getPage(), 91 'id' => $chunk->getId(), 92 'chunk' => $chunk->getText(), 93 'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR), 94 'created' => $chunk->getCreated(), 95 'lang' => $chunk->getLanguage(), 96 ]); 97 } 98 } 99 100 /** @inheritdoc */ 101 public function finalizeCreation() 102 { 103 if (!$this->hasClusters()) { 104 $this->createClusters(); 105 } 106 $this->setChunkClusters(); 107 108 $this->db->exec('VACUUM'); 109 } 110 111 /** @inheritdoc */ 112 public function runMaintenance() 113 { 114 $this->createClusters(); 115 $this->setChunkClusters(); 116 } 117 118 /** @inheritdoc */ 119 public function getPageChunks($page, $firstChunkID) 120 { 121 $result = $this->db->queryAll( 122 'SELECT * FROM embeddings WHERE page = ?', 123 [$page] 124 ); 125 $chunks = []; 126 foreach ($result as $record) { 127 $chunks[] = new Chunk( 128 $record['page'], 129 $record['id'], 130 $record['chunk'], 131 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 132 $record['lang'], 133 $record['created'] 134 ); 135 } 136 return $chunks; 137 } 138 139 /** @inheritdoc */ 140 public function getSimilarChunks($vector, $lang = '', $limit = 4) 141 { 142 $cluster = $this->getCluster($vector, $lang); 143 if ($this->logger) $this->logger->info( 144 'Using cluster {cluster} for similarity search', 145 ['cluster' => $cluster] 146 ); 147 148 $result = $this->db->queryAll( 149 'SELECT *, COSIM(?, embedding) AS similarity 150 FROM embeddings 151 WHERE cluster = ? 152 AND GETACCESSLEVEL(page) > 0 153 AND similarity > CAST(? AS FLOAT) 154 ORDER BY similarity DESC 155 LIMIT ?', 156 [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, $this->similarityThreshold, $limit] 157 ); 158 $chunks = []; 159 foreach ($result as $record) { 160 $chunks[] = new Chunk( 161 $record['page'], 162 $record['id'], 163 $record['chunk'], 164 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 165 $record['lang'], 166 $record['created'], 167 $record['similarity'] 168 ); 169 } 170 return $chunks; 171 } 172 173 /** @inheritdoc */ 174 public function statistics() 175 { 176 $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings'); 177 $size = $this->db->queryValue( 178 'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()' 179 ); 180 $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt 181 FROM embeddings 182 GROUP BY cluster 183 ORDER BY cluster"; 184 $clusters = $this->db->queryKeyValueList($query); 185 186 return [ 187 'storage type' => 'SQLite', 188 'chunks' => $items, 189 'db size' => filesize_h($size), 190 'clusters' => $clusters, 191 ]; 192 } 193 194 /** 195 * Method registered as SQLite callback to calculate the cosine similarity 196 * 197 * @param string $query JSON encoded vector array 198 * @param string $embedding JSON encoded vector array 199 * @return float 200 */ 201 public function sqliteCosineSimilarityCallback($query, $embedding) 202 { 203 return (float)$this->cosineSimilarity( 204 json_decode($query, true, 512, JSON_THROW_ON_ERROR), 205 json_decode($embedding, true, 512, JSON_THROW_ON_ERROR) 206 ); 207 } 208 209 /** 210 * Calculate the cosine similarity between two vectors 211 * 212 * Actually just calculating the dot product of the two vectors, since they are normalized 213 * 214 * @param float[] $queryVector The normalized vector of the search phrase 215 * @param float[] $embedding The normalized vector of the chunk 216 * @return float 217 */ 218 protected function cosineSimilarity($queryVector, $embedding) 219 { 220 $dotProduct = 0; 221 foreach ($queryVector as $key => $value) { 222 $dotProduct += $value * $embedding[$key]; 223 } 224 return $dotProduct; 225 } 226 227 /** 228 * Create new clusters based on random chunks 229 * 230 * @return void 231 */ 232 protected function createClusters() 233 { 234 if ($this->useLanguageClusters) { 235 $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings'); 236 $langs = array_column($result, 'lang'); 237 foreach ($langs as $lang) { 238 $this->createLanguageClusters($lang); 239 } 240 } else { 241 $this->createLanguageClusters(''); 242 } 243 } 244 245 /** 246 * Create new clusters based on random chunks for the given Language 247 * 248 * @param string $lang The language to cluster, empty when all languages go into the same cluster 249 * @noinspection SqlWithoutWhere 250 */ 251 protected function createLanguageClusters($lang) 252 { 253 if ($lang != '') { 254 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 255 } else { 256 $where = ''; 257 } 258 259 if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]); 260 $this->db->getPdo()->beginTransaction(); 261 try { 262 // clean up old cluster data 263 $query = "DELETE FROM clusters $where"; 264 $this->db->exec($query); 265 $query = "UPDATE embeddings SET cluster = NULL $where"; 266 $this->db->exec($query); 267 268 // get a random selection of chunks 269 $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?"; 270 $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]); 271 if (!$result) return; // no data to cluster 272 $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR)); 273 274 // how many clusters? 275 if (count($result) < self::CLUSTER_SIZE * 3) { 276 // there would be less than 3 clusters, so just use one 277 $clustercount = 1; 278 } else { 279 // get the number of all chunks, to calculate the number of clusters 280 $query = "SELECT COUNT(*) FROM embeddings $where"; 281 $total = $this->db->queryValue($query); 282 $clustercount = ceil($total / self::CLUSTER_SIZE); 283 } 284 if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]); 285 286 // cluster them using kmeans 287 $space = new Space($dimensions); 288 foreach ($result as $record) { 289 $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR)); 290 } 291 $clusters = $space->solve($clustercount, function ($space, $clusters) { 292 static $iterations = 0; 293 ++$iterations; 294 if ($this->logger) { 295 $clustercounts = implode(',', array_map('count', $clusters)); 296 $this->logger->info('Iteration {iteration}: [{clusters}]', [ 297 'iteration' => $iterations, 'clusters' => $clustercounts 298 ]); 299 } 300 }, Cluster::INIT_KMEANS_PLUS_PLUS); 301 302 // store the clusters 303 foreach ($clusters as $cluster) { 304 /** @var Cluster $cluster */ 305 $centroid = $cluster->getCoordinates(); 306 $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)'; 307 $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]); 308 } 309 310 $this->db->getPdo()->commit(); 311 if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]); 312 } catch (\Exception $e) { 313 $this->db->getPdo()->rollBack(); 314 throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e); 315 } 316 } 317 318 /** 319 * Assign the nearest cluster for all chunks that don't have one 320 * 321 * @return void 322 */ 323 protected function setChunkClusters() 324 { 325 if ($this->logger) $this->logger->info('Assigning clusters to chunks...'); 326 $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL'; 327 $handle = $this->db->query($query); 328 329 while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) { 330 $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR); 331 $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : ''); 332 $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?'; 333 $this->db->exec($query, [$cluster, $record['id']]); 334 if ($this->logger) $this->logger->success( 335 'Chunk {id} assigned to cluster {cluster}', 336 ['id' => $record['id'], 'cluster' => $cluster] 337 ); 338 } 339 $handle->closeCursor(); 340 } 341 342 /** 343 * Get the nearest cluster for the given vector 344 * 345 * @param float[] $vector 346 * @return int|null 347 */ 348 protected function getCluster($vector, $lang) 349 { 350 if ($lang != '') { 351 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 352 } else { 353 $where = ''; 354 } 355 356 $query = "SELECT cluster, centroid 357 FROM clusters 358 $where 359 ORDER BY COSIM(centroid, ?) DESC 360 LIMIT 1"; 361 362 $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]); 363 if (!$result) return null; 364 return $result['cluster']; 365 } 366 367 /** 368 * Check if clustering has been done before 369 * @return bool 370 */ 371 protected function hasClusters() 372 { 373 $query = 'SELECT COUNT(*) FROM clusters'; 374 return $this->db->queryValue($query) > 0; 375 } 376 377 /** 378 * Writes TSV files for visualizing with http://projector.tensorflow.org/ 379 * 380 * @param string $vectorfile path to the file with the vectors 381 * @param string $metafile path to the file with the metadata 382 * @return void 383 */ 384 public function dumpTSV($vectorfile, $metafile) 385 { 386 $query = 'SELECT * FROM embeddings'; 387 $handle = $this->db->query($query); 388 389 $header = implode("\t", ['id', 'page', 'created']); 390 file_put_contents($metafile, $header . "\n", FILE_APPEND); 391 392 while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) { 393 $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR); 394 $vector = implode("\t", $vector); 395 396 $meta = implode("\t", [$row['id'], $row['page'], $row['created']]); 397 398 file_put_contents($vectorfile, $vector . "\n", FILE_APPEND); 399 file_put_contents($metafile, $meta . "\n", FILE_APPEND); 400 } 401 } 402} 403