xref: /plugin/aichat/helper.php (revision 04afb84f6cb8a0c9b1d4d807e18f90fe739ec371)
1<?php
2
3use dokuwiki\Extension\CLIPlugin;
4use dokuwiki\Extension\Plugin;
5use dokuwiki\plugin\aichat\AIChat;
6use dokuwiki\plugin\aichat\Chunk;
7use dokuwiki\plugin\aichat\Embeddings;
8use dokuwiki\plugin\aichat\Model\ChatInterface;
9use dokuwiki\plugin\aichat\Model\EmbeddingInterface;
10use dokuwiki\plugin\aichat\Model\OpenAI\Embedding3Small;
11use dokuwiki\plugin\aichat\Storage\AbstractStorage;
12
13/**
14 * DokuWiki Plugin aichat (Helper Component)
15 *
16 * @license GPL 2 http://www.gnu.org/licenses/gpl-2.0.html
17 * @author  Andreas Gohr <gohr@cosmocode.de>
18 */
19class helper_plugin_aichat extends Plugin
20{
21    /** @var CLIPlugin $logger */
22    protected $logger;
23    /** @var ChatInterface */
24    protected $chatModel;
25    /** @var EmbeddingInterface */
26    protected $embedModel;
27    /** @var Embeddings */
28    protected $embeddings;
29    /** @var AbstractStorage */
30    protected $storage;
31
32    /** @var array where to store meta data on the last run */
33    protected $runDataFile;
34
35    /**
36     * Constructor. Initializes vendor autoloader
37     */
38    public function __construct()
39    {
40        require_once __DIR__ . '/vendor/autoload.php'; // FIXME obsolete from Kaos onwards
41        global $conf;
42        $this->runDataFile = $conf['metadir'] . '/aichat__run.json';
43        $this->loadConfig();
44    }
45
46    /**
47     * Use the given CLI plugin for logging
48     *
49     * @param CLIPlugin $logger
50     * @return void
51     */
52    public function setLogger($logger)
53    {
54        $this->logger = $logger;
55    }
56
57    /**
58     * Check if the current user is allowed to use the plugin (if it has been restricted)
59     *
60     * @return bool
61     */
62    public function userMayAccess()
63    {
64        global $auth;
65        global $USERINFO;
66        global $INPUT;
67
68        if (!$auth) return true;
69        if (!$this->getConf('restrict')) return true;
70        if (!isset($USERINFO)) return false;
71
72        return auth_isMember($this->getConf('restrict'), $INPUT->server->str('REMOTE_USER'), $USERINFO['grps']);
73    }
74
75    /**
76     * Access the Chat Model
77     *
78     * @return ChatInterface
79     */
80    public function getChatModel()
81    {
82        if ($this->chatModel instanceof ChatInterface) {
83            return $this->chatModel;
84        }
85
86        [$namespace, $name] = sexplode(' ', $this->getConf('chatmodel'), 2);
87        $class = '\\dokuwiki\\plugin\\aichat\\Model\\' . $namespace . '\\ChatModel';
88
89        if (!class_exists($class)) {
90            throw new \RuntimeException('No ChatModel found for ' . $namespace);
91        }
92
93        $this->chatModel = new $class($name, $this->conf);
94        return $this->chatModel;
95    }
96
97    /**
98     * Access the Embedding Model
99     *
100     * @return EmbeddingInterface
101     */
102    public function getEmbedModel()
103    {
104        if ($this->embedModel instanceof EmbeddingInterface) {
105            return $this->embedModel;
106        }
107
108        [$namespace, $name] = sexplode(' ', $this->getConf('embedmodel'), 2);
109        $class = '\\dokuwiki\\plugin\\aichat\\Model\\' . $namespace . '\\EmbeddingModel';
110
111        if (!class_exists($class)) {
112            throw new \RuntimeException('No EmbeddingModel found for ' . $namespace);
113        }
114
115        $this->embedModel = new $class($name, $this->conf);
116        return $this->embedModel;
117    }
118
119
120    /**
121     * Access the Embeddings interface
122     *
123     * @return Embeddings
124     */
125    public function getEmbeddings()
126    {
127        if ($this->embeddings instanceof Embeddings) {
128            return $this->embeddings;
129        }
130
131        $this->embeddings = new Embeddings(
132            $this->getChatModel(),
133            $this->getEmbedModel(),
134            $this->getStorage(),
135            $this->conf
136        );
137        if ($this->logger) {
138            $this->embeddings->setLogger($this->logger);
139        }
140
141        return $this->embeddings;
142    }
143
144    /**
145     * Access the Storage interface
146     *
147     * @return AbstractStorage
148     */
149    public function getStorage()
150    {
151        if ($this->storage instanceof AbstractStorage) {
152            return $this->storage;
153        }
154
155        $class = '\\dokuwiki\\plugin\\aichat\\Storage\\' . $this->getConf('storage') . 'Storage';
156        $this->storage = new $class($this->conf);
157
158        if ($this->logger) {
159            $this->storage->setLogger($this->logger);
160        }
161
162        return $this->storage;
163    }
164
165    /**
166     * Ask a question with a chat history
167     *
168     * @param string $question
169     * @param array[] $history The chat history [[user, ai], [user, ai], ...]
170     * @return array ['question' => $question, 'answer' => $answer, 'sources' => $sources]
171     * @throws Exception
172     */
173    public function askChatQuestion($question, $history = [])
174    {
175        if ($history) {
176            $standaloneQuestion = $this->rephraseChatQuestion($question, $history);
177        } else {
178            $standaloneQuestion = $question;
179        }
180        return $this->askQuestion($standaloneQuestion, $history);
181    }
182
183    /**
184     * Ask a single standalone question
185     *
186     * @param string $question
187     * @param array $history [user, ai] of the previous question
188     * @return array ['question' => $question, 'answer' => $answer, 'sources' => $sources]
189     * @throws Exception
190     */
191    public function askQuestion($question, $history = [])
192    {
193        $similar = $this->getEmbeddings()->getSimilarChunks($question, $this->getLanguageLimit());
194        if ($similar) {
195            $context = implode(
196                "\n",
197                array_map(static fn(Chunk $chunk) => "\n```\n" . $chunk->getText() . "\n```\n", $similar)
198            );
199            $prompt = $this->getPrompt('question', [
200                'context' => $context,
201            ]);
202        } else {
203            $prompt = $this->getPrompt('noanswer');
204            $history = [];
205        }
206
207        $messages = $this->prepareMessages($prompt, $question, $history);
208        $answer = $this->getChatModel()->getAnswer($messages);
209
210        return [
211            'question' => $question,
212            'answer' => $answer,
213            'sources' => $similar,
214        ];
215    }
216
217    /**
218     * Rephrase a question into a standalone question based on the chat history
219     *
220     * @param string $question The original user question
221     * @param array[] $history The chat history [[user, ai], [user, ai], ...]
222     * @return string The rephrased question
223     * @throws Exception
224     */
225    public function rephraseChatQuestion($question, $history)
226    {
227        $prompt = $this->getPrompt('rephrase');
228        $messages = $this->prepareMessages($prompt, $question, $history);
229        return $this->getChatModel()->getAnswer($messages);
230    }
231
232    /**
233     * Prepare the messages for the AI
234     *
235     * @param string $prompt The fully prepared system prompt
236     * @param string $question The user question
237     * @param array[] $history The chat history [[user, ai], [user, ai], ...]
238     * @return array An OpenAI compatible array of messages
239     */
240    protected function prepareMessages($prompt, $question, $history)
241    {
242        // calculate the space for context
243        $remainingContext = $this->getChatModel()->getMaxInputTokenLength();
244        $remainingContext -= $this->countTokens($prompt);
245        $remainingContext -= $this->countTokens($question);
246        $safetyMargin = $remainingContext * 0.05; // 5% safety margin
247        $remainingContext -= $safetyMargin;
248        // FIXME we may want to also have an upper limit for the history and not always use the full context
249
250        $messages = $this->historyMessages($history, $remainingContext);
251        $messages[] = [
252            'role' => 'system',
253            'content' => $prompt
254        ];
255        $messages[] = [
256            'role' => 'user',
257            'content' => $question
258        ];
259        return $messages;
260    }
261
262    /**
263     * Create an array of OpenAI compatible messages from the given history
264     *
265     * Only as many messages are used as fit into the token limit
266     *
267     * @param array[] $history The chat history [[user, ai], [user, ai], ...]
268     * @param int $tokenLimit
269     * @return array
270     */
271    protected function historyMessages($history, $tokenLimit)
272    {
273        $remainingContext = $tokenLimit;
274
275        $messages = [];
276        $history = array_reverse($history);
277        foreach ($history as $row) {
278            $length = $this->countTokens($row[0] . $row[1]);
279            if ($length > $remainingContext) {
280                break;
281            }
282            $remainingContext -= $length;
283
284            $messages[] = [
285                'role' => 'assistant',
286                'content' => $row[1]
287            ];
288            $messages[] = [
289                'role' => 'user',
290                'content' => $row[0]
291            ];
292        }
293        return array_reverse($messages);
294    }
295
296    /**
297     * Get an aproximation of the token count for the given text
298     *
299     * @param $text
300     * @return int
301     */
302    protected function countTokens($text)
303    {
304        return count($this->getEmbeddings()->getTokenEncoder()->encode($text));
305    }
306
307    /**
308     * Load the given prompt template and fill in the variables
309     *
310     * @param string $type
311     * @param string[] $vars
312     * @return string
313     */
314    protected function getPrompt($type, $vars = [])
315    {
316        $template = file_get_contents($this->localFN('prompt_' . $type));
317        $vars['language'] = $this->getLanguagePrompt();
318
319        $replace = [];
320        foreach ($vars as $key => $val) {
321            $replace['{{' . strtoupper($key) . '}}'] = $val;
322        }
323
324        return strtr($template, $replace);
325    }
326
327    /**
328     * Construct the prompt to define the answer language
329     *
330     * @return string
331     */
332    protected function getLanguagePrompt()
333    {
334        global $conf;
335        $isoLangnames = include(__DIR__ . '/lang/languages.php');
336
337        $currentLang = $isoLangnames[$conf['lang']] ?? 'English';
338
339        if ($this->getConf('preferUIlanguage') > AIChat::LANG_AUTO_ALL) {
340            if (isset($isoLangnames[$conf['lang']])) {
341                $languagePrompt = 'Always answer in ' . $isoLangnames[$conf['lang']] . '.';
342                return $languagePrompt;
343            }
344        }
345
346        $languagePrompt = 'Always answer in the user\'s language. ' .
347            "If you are unsure about the language, speak $currentLang.";
348        return $languagePrompt;
349    }
350
351    /**
352     * Should sources be limited to current language?
353     *
354     * @return string The current language code or empty string
355     */
356    public function getLanguageLimit()
357    {
358        if ($this->getConf('preferUIlanguage') >= AIChat::LANG_UI_LIMITED) {
359            global $conf;
360            return $conf['lang'];
361        } else {
362            return '';
363        }
364    }
365
366    /**
367     * Store info about the last run
368     *
369     * @param array $data
370     * @return void
371     */
372    public function setRunData(array $data)
373    {
374        file_put_contents($this->runDataFile, json_encode($data, JSON_PRETTY_PRINT));
375    }
376
377    /**
378     * Get info about the last run
379     *
380     * @return array
381     */
382    public function getRunData()
383    {
384        if (!file_exists($this->runDataFile)) {
385            return [];
386        }
387        return json_decode(file_get_contents($this->runDataFile), true);
388    }
389}
390