xref: /plugin/aichat/cli/simulate.php (revision 0de7e020fcc340c97acd36e48cdb20a9d43528b6)
1<?php
2
3use dokuwiki\Extension\CLIPlugin;
4use dokuwiki\plugin\aichat\AbstractCLI;
5use dokuwiki\plugin\aichat\ModelFactory;
6use splitbrain\phpcli\Colors;
7use splitbrain\phpcli\Options;
8
9/**
10 * DokuWiki Plugin aichat (CLI Component)
11 *
12 * @license GPL 2 http://www.gnu.org/licenses/gpl-2.0.html
13 * @author  Andreas Gohr <gohr@cosmocode.de>
14 */
15class cli_plugin_aichat_simulate extends AbstractCLI
16{
17
18
19    /** @inheritDoc */
20    protected function setup(Options $options)
21    {
22        parent::setup($options);
23
24        $options->setHelp('Run a prepared chat session against multiple models');
25        $options->registerArgument('input', 'A file with the chat questions. Each question separated by two newlines');
26        $options->registerArgument('output', 'Where to write the result CSV to');
27
28        $options->registerOption(
29            'filter',
30            'Use only models matching this case-insensitive regex (no delimiters)',
31            'f',
32            'regex'
33        );
34    }
35
36    /** @inheritDoc */
37    protected function main(Options $options)
38    {
39        parent::main($options);
40
41        [$input, $output] = $options->getArgs();
42        $questions = $this->readInputFile($input);
43        $outFH = @fopen($output, 'w');
44        if (!$outFH) throw new \Exception("Could not open $output for writing");
45
46        $models = $this->helper->factory->getModels(true, 'chat');
47
48        $results = [];
49        foreach ($models as $name => $info) {
50            if ($options->getOpt('filter') && !preg_match('/' . $options->getOpt('filter') . '/i', $name)) {
51                continue;
52            }
53            $this->success("Running on $name...");
54            $results[$name] = $this->simulate($questions, $info);
55        }
56
57        foreach ($this->records2rows($results) as $row) {
58            fputcsv($outFH, $row);
59        }
60        fclose($outFH);
61        $this->success("Results written to $output");
62    }
63
64    protected function simulate($questions, $model)
65    {
66        // override models
67        $this->helper->factory->chatModel = $model['instance'];
68        $this->helper->factory->rephraseModel = clone $model['instance'];
69
70        $records = [];
71
72        $history = [];
73        foreach ($questions as $q) {
74            $this->helper->getChatModel()->resetUsageStats();
75            $this->helper->getRephraseModel()->resetUsageStats();
76            $this->helper->getEmbeddingModel()->resetUsageStats();
77
78            $this->colors->ptln($q, Colors::C_LIGHTPURPLE);
79            $result = $this->helper->askChatQuestion($q, $history);
80            $history[] = [$result['question'], $result['answer']];
81            $this->colors->ptln($result['question'], Colors::C_LIGHTBLUE);
82
83            $record = [
84                'question' => $q,
85                'rephrased' => $result['question'],
86                'answer' => $result['answer'],
87                'source.list' => implode("\n", $result['sources']),
88                'source.time' => $this->helper->getEmbeddings()->timeSpent,
89                ...$this->flattenStats('stats.embedding', $this->helper->getEmbeddingModel()->getUsageStats()),
90                ...$this->flattenStats('stats.rephrase', $this->helper->getRephraseModel()->getUsageStats()),
91                ...$this->flattenStats('stats.chat', $this->helper->getChatModel()->getUsageStats()),
92            ];
93            $records[] = $record;
94            $this->colors->ptln($result['answer'], Colors::C_LIGHTCYAN);
95        }
96
97        return $records;
98    }
99
100    /**
101     * Reformat the result array into a CSV friendly array
102     */
103    protected function records2rows(array $result): array
104    {
105        $rowkeys = [
106            'question' => ['question', 'stats.embedding.cost', 'stats.embedding.time'],
107            'rephrased' => ['rephrased', 'stats.rephrase.cost', 'stats.rephrase.time'],
108            'sources' => ['source.list', '', 'source.time'],
109            'answer' => ['answer', 'stats.chat.cost', 'stats.chat.time'],
110        ];
111
112        $models = array_keys($result);
113        $numberOfRecords = count($result[$models[0]]);
114        $rows = [];
115
116        // write headers
117        $row = [];
118        $row[] = 'type';
119        foreach ($models as $model) {
120            $row[] = $model;
121            $row[] = 'Cost USD';
122            $row[] = 'Time s';
123        }
124        $rows[] = $row;
125
126        // write rows
127        for ($i = 0; $i < $numberOfRecords; $i++) {
128            foreach ($rowkeys as $type => $keys) {
129                $row = [];
130                $row[] = $type;
131                foreach ($models as $model) {
132                    foreach ($keys as $key) {
133                        if ($key) {
134                            $row[] = $result[$model][$i][$key];
135                        } else {
136                            $row[] = '';
137                        }
138                    }
139                }
140                $rows[] = $row;
141            }
142        }
143
144
145        return $rows;
146    }
147
148
149    /**
150     * Prefix each key in the given stats array to be merged with a larger array
151     *
152     * @param string $prefix
153     * @param array $stats
154     * @return array
155     */
156    protected function flattenStats(string $prefix, array $stats)
157    {
158        $result = [];
159        foreach ($stats as $key => $value) {
160            $result["$prefix.$key"] = $value;
161        }
162        return $result;
163    }
164
165    /**
166     * @param string $file
167     * @return array
168     * @throws Exception
169     */
170    protected function readInputFile(string $file): array
171    {
172        if (!file_exists($file)) throw new \Exception("File not found: $file");
173        $lines = file_get_contents($file);
174        $questions = explode("\n\n", $lines);
175        $questions = array_map('trim', $questions);
176        return $questions;
177    }
178}
179