1<?php 2 3/** @noinspection SqlResolve */ 4 5namespace dokuwiki\plugin\aichat\Storage; 6 7use dokuwiki\plugin\aichat\AIChat; 8use dokuwiki\plugin\aichat\Chunk; 9use dokuwiki\plugin\sqlite\SQLiteDB; 10use KMeans\Cluster; 11use KMeans\Space; 12 13/** 14 * Implements the storage backend using a SQLite database 15 * 16 * Note: all embeddings are stored and returned as normalized vectors 17 */ 18class SQLiteStorage extends AbstractStorage 19{ 20 21 /** @var int Number of documents to randomly sample to create the clusters */ 22 final public const SAMPLE_SIZE = 2000; 23 /** @var int The average size of each cluster */ 24 final public const CLUSTER_SIZE = 400; 25 26 /** @var SQLiteDB */ 27 protected $db; 28 29 protected $useLanguageClusters = false; 30 31 /** @var float minimum similarity to consider a chunk a match */ 32 protected $similarityThreshold = 0; 33 34 /** @inheritdoc */ 35 public function __construct(array $config) 36 { 37 $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/'); 38 $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2); 39 40 $helper = plugin_load('helper', 'aichat'); 41 $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED; 42 43 $this->similarityThreshold = $config['similarityThreshold']/100; 44 } 45 46 /** @inheritdoc */ 47 public function getChunk($chunkID) 48 { 49 $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]); 50 if (!$record) return null; 51 52 return new Chunk( 53 $record['page'], 54 $record['id'], 55 $record['chunk'], 56 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 57 $record['lang'], 58 $record['created'] 59 ); 60 } 61 62 /** @inheritdoc */ 63 public function startCreation($clear = false) 64 { 65 if ($clear) { 66 /** @noinspection SqlWithoutWhere */ 67 $this->db->exec('DELETE FROM embeddings'); 68 } 69 } 70 71 /** @inheritdoc */ 72 public function reusePageChunks($page, $firstChunkID) 73 { 74 // no-op 75 } 76 77 /** @inheritdoc */ 78 public function deletePageChunks($page, $firstChunkID) 79 { 80 $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]); 81 } 82 83 /** @inheritdoc */ 84 public function addPageChunks($chunks) 85 { 86 foreach ($chunks as $chunk) { 87 $this->db->saveRecord('embeddings', [ 88 'page' => $chunk->getPage(), 89 'id' => $chunk->getId(), 90 'chunk' => $chunk->getText(), 91 'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR), 92 'created' => $chunk->getCreated(), 93 'lang' => $chunk->getLanguage(), 94 ]); 95 } 96 } 97 98 /** @inheritdoc */ 99 public function finalizeCreation() 100 { 101 if (!$this->hasClusters()) { 102 $this->createClusters(); 103 } 104 $this->setChunkClusters(); 105 106 $this->db->exec('VACUUM'); 107 } 108 109 /** @inheritdoc */ 110 public function runMaintenance() 111 { 112 $this->createClusters(); 113 $this->setChunkClusters(); 114 } 115 116 /** @inheritdoc */ 117 public function getPageChunks($page, $firstChunkID) 118 { 119 $result = $this->db->queryAll( 120 'SELECT * FROM embeddings WHERE page = ?', 121 [$page] 122 ); 123 $chunks = []; 124 foreach ($result as $record) { 125 $chunks[] = new Chunk( 126 $record['page'], 127 $record['id'], 128 $record['chunk'], 129 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 130 $record['lang'], 131 $record['created'] 132 ); 133 } 134 return $chunks; 135 } 136 137 /** @inheritdoc */ 138 public function getSimilarChunks($vector, $lang = '', $limit = 4) 139 { 140 $cluster = $this->getCluster($vector, $lang); 141 if ($this->logger) $this->logger->info( 142 'Using cluster {cluster} for similarity search', 143 ['cluster' => $cluster] 144 ); 145 146 $result = $this->db->queryAll( 147 'SELECT *, COSIM(?, embedding) AS similarity 148 FROM embeddings 149 WHERE cluster = ? 150 AND GETACCESSLEVEL(page) > 0 151 AND similarity > CAST(? AS FLOAT) 152 ORDER BY similarity DESC 153 LIMIT ?', 154 [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, $this->similarityThreshold, $limit] 155 ); 156 $chunks = []; 157 foreach ($result as $record) { 158 $chunks[] = new Chunk( 159 $record['page'], 160 $record['id'], 161 $record['chunk'], 162 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 163 $record['lang'], 164 $record['created'], 165 $record['similarity'] 166 ); 167 } 168 return $chunks; 169 } 170 171 /** @inheritdoc */ 172 public function statistics() 173 { 174 $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings'); 175 $size = $this->db->queryValue( 176 'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()' 177 ); 178 $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt 179 FROM embeddings 180 GROUP BY cluster 181 ORDER BY cluster"; 182 $clusters = $this->db->queryKeyValueList($query); 183 184 return [ 185 'storage type' => 'SQLite', 186 'chunks' => $items, 187 'db size' => filesize_h($size), 188 'clusters' => $clusters, 189 ]; 190 } 191 192 /** 193 * Method registered as SQLite callback to calculate the cosine similarity 194 * 195 * @param string $query JSON encoded vector array 196 * @param string $embedding JSON encoded vector array 197 * @return float 198 */ 199 public function sqliteCosineSimilarityCallback($query, $embedding) 200 { 201 return (float)$this->cosineSimilarity( 202 json_decode($query, true, 512, JSON_THROW_ON_ERROR), 203 json_decode($embedding, true, 512, JSON_THROW_ON_ERROR) 204 ); 205 } 206 207 /** 208 * Calculate the cosine similarity between two vectors 209 * 210 * Actually just calculating the dot product of the two vectors, since they are normalized 211 * 212 * @param float[] $queryVector The normalized vector of the search phrase 213 * @param float[] $embedding The normalized vector of the chunk 214 * @return float 215 */ 216 protected function cosineSimilarity($queryVector, $embedding) 217 { 218 $dotProduct = 0; 219 foreach ($queryVector as $key => $value) { 220 $dotProduct += $value * $embedding[$key]; 221 } 222 return $dotProduct; 223 } 224 225 /** 226 * Create new clusters based on random chunks 227 * 228 * @return void 229 */ 230 protected function createClusters() 231 { 232 if ($this->useLanguageClusters) { 233 $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings'); 234 $langs = array_column($result, 'lang'); 235 foreach ($langs as $lang) { 236 $this->createLanguageClusters($lang); 237 } 238 } else { 239 $this->createLanguageClusters(''); 240 } 241 } 242 243 /** 244 * Create new clusters based on random chunks for the given Language 245 * 246 * @param string $lang The language to cluster, empty when all languages go into the same cluster 247 * @noinspection SqlWithoutWhere 248 */ 249 protected function createLanguageClusters($lang) 250 { 251 if ($lang != '') { 252 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 253 } else { 254 $where = ''; 255 } 256 257 if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]); 258 $this->db->getPdo()->beginTransaction(); 259 try { 260 // clean up old cluster data 261 $query = "DELETE FROM clusters $where"; 262 $this->db->exec($query); 263 $query = "UPDATE embeddings SET cluster = NULL $where"; 264 $this->db->exec($query); 265 266 // get a random selection of chunks 267 $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?"; 268 $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]); 269 if (!$result) return; // no data to cluster 270 $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR)); 271 272 // how many clusters? 273 if (count($result) < self::CLUSTER_SIZE * 3) { 274 // there would be less than 3 clusters, so just use one 275 $clustercount = 1; 276 } else { 277 // get the number of all chunks, to calculate the number of clusters 278 $query = "SELECT COUNT(*) FROM embeddings $where"; 279 $total = $this->db->queryValue($query); 280 $clustercount = ceil($total / self::CLUSTER_SIZE); 281 } 282 if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]); 283 284 // cluster them using kmeans 285 $space = new Space($dimensions); 286 foreach ($result as $record) { 287 $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR)); 288 } 289 $clusters = $space->solve($clustercount, function ($space, $clusters) { 290 static $iterations = 0; 291 ++$iterations; 292 if ($this->logger) { 293 $clustercounts = implode(',', array_map('count', $clusters)); 294 $this->logger->info('Iteration {iteration}: [{clusters}]', [ 295 'iteration' => $iterations, 'clusters' => $clustercounts 296 ]); 297 } 298 }, Cluster::INIT_KMEANS_PLUS_PLUS); 299 300 // store the clusters 301 foreach ($clusters as $cluster) { 302 /** @var Cluster $cluster */ 303 $centroid = $cluster->getCoordinates(); 304 $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)'; 305 $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]); 306 } 307 308 $this->db->getPdo()->commit(); 309 if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]); 310 } catch (\Exception $e) { 311 $this->db->getPdo()->rollBack(); 312 throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e); 313 } 314 } 315 316 /** 317 * Assign the nearest cluster for all chunks that don't have one 318 * 319 * @return void 320 */ 321 protected function setChunkClusters() 322 { 323 if ($this->logger) $this->logger->info('Assigning clusters to chunks...'); 324 $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL'; 325 $handle = $this->db->query($query); 326 327 while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) { 328 $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR); 329 $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : ''); 330 $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?'; 331 $this->db->exec($query, [$cluster, $record['id']]); 332 if ($this->logger) $this->logger->success( 333 'Chunk {id} assigned to cluster {cluster}', 334 ['id' => $record['id'], 'cluster' => $cluster] 335 ); 336 } 337 $handle->closeCursor(); 338 } 339 340 /** 341 * Get the nearest cluster for the given vector 342 * 343 * @param float[] $vector 344 * @return int|null 345 */ 346 protected function getCluster($vector, $lang) 347 { 348 if ($lang != '') { 349 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 350 } else { 351 $where = ''; 352 } 353 354 $query = "SELECT cluster, centroid 355 FROM clusters 356 $where 357 ORDER BY COSIM(centroid, ?) DESC 358 LIMIT 1"; 359 360 $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]); 361 if (!$result) return null; 362 return $result['cluster']; 363 } 364 365 /** 366 * Check if clustering has been done before 367 * @return bool 368 */ 369 protected function hasClusters() 370 { 371 $query = 'SELECT COUNT(*) FROM clusters'; 372 return $this->db->queryValue($query) > 0; 373 } 374 375 /** 376 * Writes TSV files for visualizing with http://projector.tensorflow.org/ 377 * 378 * @param string $vectorfile path to the file with the vectors 379 * @param string $metafile path to the file with the metadata 380 * @return void 381 */ 382 public function dumpTSV($vectorfile, $metafile) 383 { 384 $query = 'SELECT * FROM embeddings'; 385 $handle = $this->db->query($query); 386 387 $header = implode("\t", ['id', 'page', 'created']); 388 file_put_contents($metafile, $header . "\n", FILE_APPEND); 389 390 while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) { 391 $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR); 392 $vector = implode("\t", $vector); 393 394 $meta = implode("\t", [$row['id'], $row['page'], $row['created']]); 395 396 file_put_contents($vectorfile, $vector . "\n", FILE_APPEND); 397 file_put_contents($metafile, $meta . "\n", FILE_APPEND); 398 } 399 } 400} 401