1<?php 2 3/** @noinspection SqlResolve */ 4 5namespace dokuwiki\plugin\aichat\Storage; 6 7use dokuwiki\plugin\aichat\AIChat; 8use dokuwiki\plugin\aichat\Chunk; 9use dokuwiki\plugin\sqlite\SQLiteDB; 10use KMeans\Cluster; 11use KMeans\Space; 12 13/** 14 * Implements the storage backend using a SQLite database 15 * 16 * Note: all embeddings are stored and returned as normalized vectors 17 */ 18class SQLiteStorage extends AbstractStorage 19{ 20 /** @var int Number of documents to randomly sample to create the clusters */ 21 final public const SAMPLE_SIZE = 2000; 22 /** @var int The average size of each cluster */ 23 final public const CLUSTER_SIZE = 400; 24 25 /** @var SQLiteDB */ 26 protected $db; 27 28 protected $useLanguageClusters = false; 29 30 /** @var float minimum similarity to consider a chunk a match */ 31 protected $similarityThreshold = 0; 32 33 /** @inheritdoc */ 34 public function __construct(array $config) 35 { 36 $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/'); 37 $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2); 38 39 $helper = plugin_load('helper', 'aichat'); 40 $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED; 41 42 $this->similarityThreshold = $config['similarityThreshold'] / 100; 43 } 44 45 /** @inheritdoc */ 46 public function getChunk($chunkID) 47 { 48 $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]); 49 if (!$record) return null; 50 51 return new Chunk( 52 $record['page'], 53 $record['id'], 54 $record['chunk'], 55 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 56 $record['lang'], 57 $record['created'] 58 ); 59 } 60 61 /** @inheritdoc */ 62 public function startCreation($clear = false) 63 { 64 if ($clear) { 65 /** @noinspection SqlWithoutWhere */ 66 $this->db->exec('DELETE FROM embeddings'); 67 /** @noinspection SqlWithoutWhere */ 68 $this->db->exec('DELETE FROM clusters'); 69 } 70 } 71 72 /** @inheritdoc */ 73 public function reusePageChunks($page, $firstChunkID) 74 { 75 // no-op 76 } 77 78 /** @inheritdoc */ 79 public function deletePageChunks($page, $firstChunkID) 80 { 81 $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]); 82 } 83 84 /** @inheritdoc */ 85 public function addPageChunks($chunks) 86 { 87 foreach ($chunks as $chunk) { 88 $this->db->saveRecord('embeddings', [ 89 'page' => $chunk->getPage(), 90 'id' => $chunk->getId(), 91 'chunk' => $chunk->getText(), 92 'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR), 93 'created' => $chunk->getCreated(), 94 'lang' => $chunk->getLanguage(), 95 ]); 96 } 97 } 98 99 /** @inheritdoc */ 100 public function finalizeCreation() 101 { 102 if (!$this->hasClusters()) { 103 $this->createClusters(); 104 } 105 $this->setChunkClusters(); 106 107 $this->db->exec('VACUUM'); 108 } 109 110 /** @inheritdoc */ 111 public function runMaintenance() 112 { 113 $this->createClusters(); 114 $this->setChunkClusters(); 115 } 116 117 /** @inheritdoc */ 118 public function getPageChunks($page, $firstChunkID) 119 { 120 $result = $this->db->queryAll( 121 'SELECT * FROM embeddings WHERE page = ?', 122 [$page] 123 ); 124 $chunks = []; 125 foreach ($result as $record) { 126 $chunks[] = new Chunk( 127 $record['page'], 128 $record['id'], 129 $record['chunk'], 130 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 131 $record['lang'], 132 $record['created'] 133 ); 134 } 135 return $chunks; 136 } 137 138 /** @inheritdoc */ 139 public function getSimilarChunks($vector, $lang = '', $limit = 4) 140 { 141 $cluster = $this->getCluster($vector, $lang); 142 if ($this->logger) $this->logger->info( 143 'Using cluster {cluster} for similarity search', 144 ['cluster' => $cluster] 145 ); 146 147 $result = $this->db->queryAll( 148 'SELECT *, COSIM(?, embedding) AS similarity 149 FROM embeddings 150 WHERE cluster = ? 151 AND GETACCESSLEVEL(page) > 0 152 AND similarity > CAST(? AS FLOAT) 153 ORDER BY similarity DESC 154 LIMIT ?', 155 [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, $this->similarityThreshold, $limit] 156 ); 157 $chunks = []; 158 foreach ($result as $record) { 159 $chunks[] = new Chunk( 160 $record['page'], 161 $record['id'], 162 $record['chunk'], 163 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 164 $record['lang'], 165 $record['created'], 166 $record['similarity'] 167 ); 168 } 169 return $chunks; 170 } 171 172 /** @inheritdoc */ 173 public function statistics() 174 { 175 $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings'); 176 $size = $this->db->queryValue( 177 'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()' 178 ); 179 $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt 180 FROM embeddings 181 GROUP BY cluster 182 ORDER BY cluster"; 183 $clusters = $this->db->queryKeyValueList($query); 184 185 return [ 186 'storage type' => 'SQLite', 187 'chunks' => $items, 188 'db size' => filesize_h($size), 189 'clusters' => $clusters, 190 ]; 191 } 192 193 /** 194 * Method registered as SQLite callback to calculate the cosine similarity 195 * 196 * @param string $query JSON encoded vector array 197 * @param string $embedding JSON encoded vector array 198 * @return float 199 */ 200 public function sqliteCosineSimilarityCallback($query, $embedding) 201 { 202 return (float)$this->cosineSimilarity( 203 json_decode($query, true, 512, JSON_THROW_ON_ERROR), 204 json_decode($embedding, true, 512, JSON_THROW_ON_ERROR) 205 ); 206 } 207 208 /** 209 * Calculate the cosine similarity between two vectors 210 * 211 * Actually just calculating the dot product of the two vectors, since they are normalized 212 * 213 * @param float[] $queryVector The normalized vector of the search phrase 214 * @param float[] $embedding The normalized vector of the chunk 215 * @return float 216 */ 217 protected function cosineSimilarity($queryVector, $embedding) 218 { 219 $dotProduct = 0; 220 foreach ($queryVector as $key => $value) { 221 $dotProduct += $value * $embedding[$key]; 222 } 223 return $dotProduct; 224 } 225 226 /** 227 * Create new clusters based on random chunks 228 * 229 * @return void 230 */ 231 protected function createClusters() 232 { 233 if ($this->useLanguageClusters) { 234 $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings'); 235 $langs = array_column($result, 'lang'); 236 foreach ($langs as $lang) { 237 $this->createLanguageClusters($lang); 238 } 239 } else { 240 $this->createLanguageClusters(''); 241 } 242 } 243 244 /** 245 * Create new clusters based on random chunks for the given Language 246 * 247 * @param string $lang The language to cluster, empty when all languages go into the same cluster 248 * @noinspection SqlWithoutWhere 249 */ 250 protected function createLanguageClusters($lang) 251 { 252 if ($lang != '') { 253 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 254 } else { 255 $where = ''; 256 } 257 258 if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]); 259 $this->db->getPdo()->beginTransaction(); 260 try { 261 // clean up old cluster data 262 $query = "DELETE FROM clusters $where"; 263 $this->db->exec($query); 264 $query = "UPDATE embeddings SET cluster = NULL $where"; 265 $this->db->exec($query); 266 267 // get a random selection of chunks 268 $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?"; 269 $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]); 270 if (!$result) return; // no data to cluster 271 $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR)); 272 273 // how many clusters? 274 if (count($result) < self::CLUSTER_SIZE * 3) { 275 // there would be less than 3 clusters, so just use one 276 $clustercount = 1; 277 } else { 278 // get the number of all chunks, to calculate the number of clusters 279 $query = "SELECT COUNT(*) FROM embeddings $where"; 280 $total = $this->db->queryValue($query); 281 $clustercount = ceil($total / self::CLUSTER_SIZE); 282 } 283 if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]); 284 285 // cluster them using kmeans 286 $space = new Space($dimensions); 287 foreach ($result as $record) { 288 $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR)); 289 } 290 $clusters = $space->solve($clustercount, function ($space, $clusters) { 291 static $iterations = 0; 292 ++$iterations; 293 if ($this->logger) { 294 $clustercounts = implode(',', array_map('count', $clusters)); 295 $this->logger->info('Iteration {iteration}: [{clusters}]', [ 296 'iteration' => $iterations, 'clusters' => $clustercounts 297 ]); 298 } 299 }, Cluster::INIT_KMEANS_PLUS_PLUS); 300 301 // store the clusters 302 foreach ($clusters as $cluster) { 303 /** @var Cluster $cluster */ 304 $centroid = $cluster->getCoordinates(); 305 $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)'; 306 $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]); 307 } 308 309 $this->db->getPdo()->commit(); 310 if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]); 311 } catch (\Exception $e) { 312 $this->db->getPdo()->rollBack(); 313 throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e); 314 } 315 } 316 317 /** 318 * Assign the nearest cluster for all chunks that don't have one 319 * 320 * @return void 321 */ 322 protected function setChunkClusters() 323 { 324 if ($this->logger) $this->logger->info('Assigning clusters to chunks...'); 325 $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL'; 326 $handle = $this->db->query($query); 327 328 while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) { 329 $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR); 330 $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : ''); 331 $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?'; 332 $this->db->exec($query, [$cluster, $record['id']]); 333 if ($this->logger) $this->logger->success( 334 'Chunk {id} assigned to cluster {cluster}', 335 ['id' => $record['id'], 'cluster' => $cluster] 336 ); 337 } 338 $handle->closeCursor(); 339 } 340 341 /** 342 * Get the nearest cluster for the given vector 343 * 344 * @param float[] $vector 345 * @return int|null 346 */ 347 protected function getCluster($vector, $lang) 348 { 349 if ($lang != '') { 350 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 351 } else { 352 $where = ''; 353 } 354 355 $query = "SELECT cluster, centroid 356 FROM clusters 357 $where 358 ORDER BY COSIM(centroid, ?) DESC 359 LIMIT 1"; 360 361 $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]); 362 if (!$result) return null; 363 return $result['cluster']; 364 } 365 366 /** 367 * Check if clustering has been done before 368 * @return bool 369 */ 370 protected function hasClusters() 371 { 372 $query = 'SELECT COUNT(*) FROM clusters'; 373 return $this->db->queryValue($query) > 0; 374 } 375 376 /** 377 * Writes TSV files for visualizing with http://projector.tensorflow.org/ 378 * 379 * @param string $vectorfile path to the file with the vectors 380 * @param string $metafile path to the file with the metadata 381 * @return void 382 */ 383 public function dumpTSV($vectorfile, $metafile) 384 { 385 $query = 'SELECT * FROM embeddings'; 386 $handle = $this->db->query($query); 387 388 $header = implode("\t", ['id', 'page', 'created']); 389 file_put_contents($metafile, $header . "\n", FILE_APPEND); 390 391 while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) { 392 $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR); 393 $vector = implode("\t", $vector); 394 395 $meta = implode("\t", [$row['id'], $row['page'], $row['created']]); 396 397 file_put_contents($vectorfile, $vector . "\n", FILE_APPEND); 398 file_put_contents($metafile, $meta . "\n", FILE_APPEND); 399 } 400 } 401} 402