1<?php 2 3/** @noinspection SqlResolve */ 4 5namespace dokuwiki\plugin\aichat\Storage; 6 7use dokuwiki\plugin\aichat\AIChat; 8use dokuwiki\plugin\aichat\Chunk; 9use dokuwiki\plugin\sqlite\SQLiteDB; 10use KMeans\Cluster; 11use KMeans\Space; 12 13/** 14 * Implements the storage backend using a SQLite database 15 * 16 * Note: all embeddings are stored and returned as normalized vectors 17 */ 18class SQLiteStorage extends AbstractStorage 19{ 20 /** @var float minimum similarity to consider a chunk a match */ 21 final public const SIMILARITY_THRESHOLD = 0; 22 23 /** @var int Number of documents to randomly sample to create the clusters */ 24 final public const SAMPLE_SIZE = 2000; 25 /** @var int The average size of each cluster */ 26 final public const CLUSTER_SIZE = 400; 27 28 /** @var SQLiteDB */ 29 protected $db; 30 31 protected $useLanguageClusters = false; 32 33 /** @inheritdoc */ 34 public function __construct(array $config) 35 { 36 $this->db = new SQLiteDB('aichat', DOKU_PLUGIN . 'aichat/db/'); 37 $this->db->getPdo()->sqliteCreateFunction('COSIM', $this->sqliteCosineSimilarityCallback(...), 2); 38 39 $helper = plugin_load('helper', 'aichat'); 40 $this->useLanguageClusters = $helper->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED; 41 } 42 43 /** @inheritdoc */ 44 public function getChunk($chunkID) 45 { 46 $record = $this->db->queryRecord('SELECT * FROM embeddings WHERE id = ?', [$chunkID]); 47 if (!$record) return null; 48 49 return new Chunk( 50 $record['page'], 51 $record['id'], 52 $record['chunk'], 53 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 54 $record['lang'], 55 $record['created'] 56 ); 57 } 58 59 /** @inheritdoc */ 60 public function startCreation($clear = false) 61 { 62 if ($clear) { 63 /** @noinspection SqlWithoutWhere */ 64 $this->db->exec('DELETE FROM embeddings'); 65 } 66 } 67 68 /** @inheritdoc */ 69 public function reusePageChunks($page, $firstChunkID) 70 { 71 // no-op 72 } 73 74 /** @inheritdoc */ 75 public function deletePageChunks($page, $firstChunkID) 76 { 77 $this->db->exec('DELETE FROM embeddings WHERE page = ?', [$page]); 78 } 79 80 /** @inheritdoc */ 81 public function addPageChunks($chunks) 82 { 83 foreach ($chunks as $chunk) { 84 $this->db->saveRecord('embeddings', [ 85 'page' => $chunk->getPage(), 86 'id' => $chunk->getId(), 87 'chunk' => $chunk->getText(), 88 'embedding' => json_encode($chunk->getEmbedding(), JSON_THROW_ON_ERROR), 89 'created' => $chunk->getCreated(), 90 'lang' => $chunk->getLanguage(), 91 ]); 92 } 93 } 94 95 /** @inheritdoc */ 96 public function finalizeCreation() 97 { 98 if (!$this->hasClusters()) { 99 $this->createClusters(); 100 } 101 $this->setChunkClusters(); 102 103 $this->db->exec('VACUUM'); 104 } 105 106 /** @inheritdoc */ 107 public function runMaintenance() 108 { 109 $this->createClusters(); 110 $this->setChunkClusters(); 111 } 112 113 /** @inheritdoc */ 114 public function getPageChunks($page, $firstChunkID) 115 { 116 $result = $this->db->queryAll( 117 'SELECT * FROM embeddings WHERE page = ?', 118 [$page] 119 ); 120 $chunks = []; 121 foreach ($result as $record) { 122 $chunks[] = new Chunk( 123 $record['page'], 124 $record['id'], 125 $record['chunk'], 126 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 127 $record['lang'], 128 $record['created'] 129 ); 130 } 131 return $chunks; 132 } 133 134 /** @inheritdoc */ 135 public function getSimilarChunks($vector, $lang = '', $limit = 4) 136 { 137 $cluster = $this->getCluster($vector, $lang); 138 if ($this->logger) $this->logger->info( 139 'Using cluster {cluster} for similarity search', 140 ['cluster' => $cluster] 141 ); 142 143 $result = $this->db->queryAll( 144 'SELECT *, COSIM(?, embedding) AS similarity 145 FROM embeddings 146 WHERE cluster = ? 147 AND GETACCESSLEVEL(page) > 0 148 AND similarity > CAST(? AS FLOAT) 149 ORDER BY similarity DESC 150 LIMIT ?', 151 [json_encode($vector, JSON_THROW_ON_ERROR), $cluster, self::SIMILARITY_THRESHOLD, $limit] 152 ); 153 $chunks = []; 154 foreach ($result as $record) { 155 $chunks[] = new Chunk( 156 $record['page'], 157 $record['id'], 158 $record['chunk'], 159 json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR), 160 $record['lang'], 161 $record['created'], 162 $record['similarity'] 163 ); 164 } 165 return $chunks; 166 } 167 168 /** @inheritdoc */ 169 public function statistics() 170 { 171 $items = $this->db->queryValue('SELECT COUNT(*) FROM embeddings'); 172 $size = $this->db->queryValue( 173 'SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()' 174 ); 175 $query = "SELECT cluster || ' ' || lang, COUNT(*) || ' chunks' as cnt 176 FROM embeddings 177 GROUP BY cluster 178 ORDER BY cluster"; 179 $clusters = $this->db->queryKeyValueList($query); 180 181 return [ 182 'storage type' => 'SQLite', 183 'chunks' => $items, 184 'db size' => filesize_h($size), 185 'clusters' => $clusters, 186 ]; 187 } 188 189 /** 190 * Method registered as SQLite callback to calculate the cosine similarity 191 * 192 * @param string $query JSON encoded vector array 193 * @param string $embedding JSON encoded vector array 194 * @return float 195 */ 196 public function sqliteCosineSimilarityCallback($query, $embedding) 197 { 198 return (float)$this->cosineSimilarity( 199 json_decode($query, true, 512, JSON_THROW_ON_ERROR), 200 json_decode($embedding, true, 512, JSON_THROW_ON_ERROR) 201 ); 202 } 203 204 /** 205 * Calculate the cosine similarity between two vectors 206 * 207 * Actually just calculating the dot product of the two vectors, since they are normalized 208 * 209 * @param float[] $queryVector The normalized vector of the search phrase 210 * @param float[] $embedding The normalized vector of the chunk 211 * @return float 212 */ 213 protected function cosineSimilarity($queryVector, $embedding) 214 { 215 $dotProduct = 0; 216 foreach ($queryVector as $key => $value) { 217 $dotProduct += $value * $embedding[$key]; 218 } 219 return $dotProduct; 220 } 221 222 /** 223 * Create new clusters based on random chunks 224 * 225 * @return void 226 */ 227 protected function createClusters() 228 { 229 if ($this->useLanguageClusters) { 230 $result = $this->db->queryAll('SELECT DISTINCT lang FROM embeddings'); 231 $langs = array_column($result, 'lang'); 232 foreach ($langs as $lang) { 233 $this->createLanguageClusters($lang); 234 } 235 } else { 236 $this->createLanguageClusters(''); 237 } 238 } 239 240 /** 241 * Create new clusters based on random chunks for the given Language 242 * 243 * @param string $lang The language to cluster, empty when all languages go into the same cluster 244 * @noinspection SqlWithoutWhere 245 */ 246 protected function createLanguageClusters($lang) 247 { 248 if ($lang != '') { 249 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 250 } else { 251 $where = ''; 252 } 253 254 if ($this->logger) $this->logger->info('Creating new {lang} clusters...', ['lang' => $lang]); 255 $this->db->getPdo()->beginTransaction(); 256 try { 257 // clean up old cluster data 258 $query = "DELETE FROM clusters $where"; 259 $this->db->exec($query); 260 $query = "UPDATE embeddings SET cluster = NULL $where"; 261 $this->db->exec($query); 262 263 // get a random selection of chunks 264 $query = "SELECT id, embedding FROM embeddings $where ORDER BY RANDOM() LIMIT ?"; 265 $result = $this->db->queryAll($query, [self::SAMPLE_SIZE]); 266 if (!$result) return; // no data to cluster 267 $dimensions = count(json_decode((string) $result[0]['embedding'], true, 512, JSON_THROW_ON_ERROR)); 268 269 // how many clusters? 270 if (count($result) < self::CLUSTER_SIZE * 3) { 271 // there would be less than 3 clusters, so just use one 272 $clustercount = 1; 273 } else { 274 // get the number of all chunks, to calculate the number of clusters 275 $query = "SELECT COUNT(*) FROM embeddings $where"; 276 $total = $this->db->queryValue($query); 277 $clustercount = ceil($total / self::CLUSTER_SIZE); 278 } 279 if ($this->logger) $this->logger->info('Creating {clusters} clusters', ['clusters' => $clustercount]); 280 281 // cluster them using kmeans 282 $space = new Space($dimensions); 283 foreach ($result as $record) { 284 $space->addPoint(json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR)); 285 } 286 $clusters = $space->solve($clustercount, function ($space, $clusters) { 287 static $iterations = 0; 288 ++$iterations; 289 if ($this->logger) { 290 $clustercounts = implode(',', array_map('count', $clusters)); 291 $this->logger->info('Iteration {iteration}: [{clusters}]', [ 292 'iteration' => $iterations, 'clusters' => $clustercounts 293 ]); 294 } 295 }, Cluster::INIT_KMEANS_PLUS_PLUS); 296 297 // store the clusters 298 foreach ($clusters as $cluster) { 299 /** @var Cluster $cluster */ 300 $centroid = $cluster->getCoordinates(); 301 $query = 'INSERT INTO clusters (lang, centroid) VALUES (?, ?)'; 302 $this->db->exec($query, [$lang, json_encode($centroid, JSON_THROW_ON_ERROR)]); 303 } 304 305 $this->db->getPdo()->commit(); 306 if ($this->logger) $this->logger->success('Created {clusters} clusters', ['clusters' => count($clusters)]); 307 } catch (\Exception $e) { 308 $this->db->getPdo()->rollBack(); 309 throw new \RuntimeException('Clustering failed: ' . $e->getMessage(), 0, $e); 310 } 311 } 312 313 /** 314 * Assign the nearest cluster for all chunks that don't have one 315 * 316 * @return void 317 */ 318 protected function setChunkClusters() 319 { 320 if ($this->logger) $this->logger->info('Assigning clusters to chunks...'); 321 $query = 'SELECT id, embedding, lang FROM embeddings WHERE cluster IS NULL'; 322 $handle = $this->db->query($query); 323 324 while ($record = $handle->fetch(\PDO::FETCH_ASSOC)) { 325 $vector = json_decode((string) $record['embedding'], true, 512, JSON_THROW_ON_ERROR); 326 $cluster = $this->getCluster($vector, $this->useLanguageClusters ? $record['lang'] : ''); 327 $query = 'UPDATE embeddings SET cluster = ? WHERE id = ?'; 328 $this->db->exec($query, [$cluster, $record['id']]); 329 if ($this->logger) $this->logger->success( 330 'Chunk {id} assigned to cluster {cluster}', 331 ['id' => $record['id'], 'cluster' => $cluster] 332 ); 333 } 334 $handle->closeCursor(); 335 } 336 337 /** 338 * Get the nearest cluster for the given vector 339 * 340 * @param float[] $vector 341 * @return int|null 342 */ 343 protected function getCluster($vector, $lang) 344 { 345 if ($lang != '') { 346 $where = 'WHERE lang = ' . $this->db->getPdo()->quote($lang); 347 } else { 348 $where = ''; 349 } 350 351 $query = "SELECT cluster, centroid 352 FROM clusters 353 $where 354 ORDER BY COSIM(centroid, ?) DESC 355 LIMIT 1"; 356 357 $result = $this->db->queryRecord($query, [json_encode($vector, JSON_THROW_ON_ERROR)]); 358 if (!$result) return null; 359 return $result['cluster']; 360 } 361 362 /** 363 * Check if clustering has been done before 364 * @return bool 365 */ 366 protected function hasClusters() 367 { 368 $query = 'SELECT COUNT(*) FROM clusters'; 369 return $this->db->queryValue($query) > 0; 370 } 371 372 /** 373 * Writes TSV files for visualizing with http://projector.tensorflow.org/ 374 * 375 * @param string $vectorfile path to the file with the vectors 376 * @param string $metafile path to the file with the metadata 377 * @return void 378 */ 379 public function dumpTSV($vectorfile, $metafile) 380 { 381 $query = 'SELECT * FROM embeddings'; 382 $handle = $this->db->query($query); 383 384 $header = implode("\t", ['id', 'page', 'created']); 385 file_put_contents($metafile, $header . "\n", FILE_APPEND); 386 387 while ($row = $handle->fetch(\PDO::FETCH_ASSOC)) { 388 $vector = json_decode((string) $row['embedding'], true, 512, JSON_THROW_ON_ERROR); 389 $vector = implode("\t", $vector); 390 391 $meta = implode("\t", [$row['id'], $row['page'], $row['created']]); 392 393 file_put_contents($vectorfile, $vector . "\n", FILE_APPEND); 394 file_put_contents($metafile, $meta . "\n", FILE_APPEND); 395 } 396 } 397} 398