xref: /plugin/aichat/cli/simulate.php (revision c2b7a1f7fd0f6c6579c9ee46f0437ff89c2fc4b3)
1<?php
2
3use dokuwiki\plugin\aichat\ModelFactory;
4use splitbrain\phpcli\Colors;
5use splitbrain\phpcli\Options;
6
7/**
8 * DokuWiki Plugin aichat (CLI Component)
9 *
10 * @license GPL 2 http://www.gnu.org/licenses/gpl-2.0.html
11 * @author  Andreas Gohr <gohr@cosmocode.de>
12 */
13class cli_plugin_aichat_simulate extends \dokuwiki\Extension\CLIPlugin
14{
15    /** @var helper_plugin_aichat */
16    protected $helper;
17
18    /** @inheritdoc */
19    public function __construct($autocatch = true)
20    {
21        parent::__construct($autocatch);
22        $this->helper = plugin_load('helper', 'aichat');
23        $this->helper->setLogger($this);
24        $this->loadConfig();
25    }
26
27
28    /** @inheritDoc */
29    protected function setup(Options $options)
30    {
31        $options->setHelp('Run a prpared chat session against multiple models');
32        $options->registerArgument('input', 'A file with the chat questions. Each question separated by two newlines');
33        $options->registerArgument('output', 'Where to write the result CSV to');
34
35        $options->registerOption(
36            'filter',
37            'Use only models matching this case-insensitive regex (no delimiters)',
38            'f',
39            'regex'
40        );
41    }
42
43    /** @inheritDoc */
44    protected function main(Options $options)
45    {
46        if ($this->loglevel['debug']['enabled']) {
47            $this->helper->factory->setDebug(true);
48        }
49
50        [$input, $output] = $options->getArgs();
51        $questions = $this->readInputFile($input);
52        $outfh = @fopen($output, 'w');
53        if(!$outfh) throw new \Exception("Could not open $output for writing");
54
55        $models = $this->helper->factory->getModels(true, 'chat');
56
57        $results = [];
58        foreach ($models as $name => $info) {
59            if ($options->getOpt('filter') && !preg_match('/' . $options->getOpt('filter') . '/i', $name)) {
60                continue;
61            }
62            $this->success("Running on $name...");
63            $results[$name] = $this->simulate($questions, $info);
64        }
65
66        foreach ($this->records2rows($results) as $row) {
67            fputcsv($outfh, $row);
68        }
69        fclose($outfh);
70        $this->success("Results written to $output");
71    }
72
73    protected function simulate($questions, $model)
74    {
75        // override models
76        $this->helper->factory->chatModel = $model['instance'];
77        $this->helper->factory->rephraseModel = clone $model['instance'];
78
79        $records = [];
80
81        $history = [];
82        foreach ($questions as $q) {
83            $this->helper->getChatModel()->resetUsageStats();
84            $this->helper->getRephraseModel()->resetUsageStats();
85            $this->helper->getEmbeddingModel()->resetUsageStats();
86
87            $this->colors->ptln($q, Colors::C_LIGHTPURPLE);
88            $result = $this->helper->askChatQuestion($q, $history);
89            $history[] = [$result['question'], $result['answer']];
90
91            $record = [
92                'question' => $q,
93                'rephrased' => $result['question'],
94                'answer' => $result['answer'],
95                'source.list' => join("\n", $result['sources']),
96                'source.time' => $this->helper->getEmbeddings()->timeSpent,
97                ...$this->flattenStats('stats.embedding', $this->helper->getEmbeddingModel()->getUsageStats()),
98                ...$this->flattenStats('stats.rephrase', $this->helper->getRephraseModel()->getUsageStats()),
99                ...$this->flattenStats('stats.chat', $this->helper->getChatModel()->getUsageStats()),
100            ];
101            $records[] = $record;
102            $this->colors->ptln($result['answer'], Colors::C_LIGHTCYAN);
103        }
104
105        return $records;
106    }
107
108    /**
109     * Reformat the result array into a CSV friendly array
110     */
111    protected function records2rows(array $result): array
112    {
113        $rowkeys = [
114            'question' => ['question', 'stats.embedding.cost', 'stats.embedding.time'],
115            'rephrased' => ['rephrased', 'stats.rephrase.cost', 'stats.rephrase.time'],
116            'sources' => ['source.list', '', 'source.time'],
117            'answer' => ['answer', 'stats.chat.cost', 'stats.chat.time'],
118        ];
119
120        $models = array_keys($result);
121        $numberOfRecords = count($result[$models[0]]);
122        $rows = [];
123
124        // write headers
125        $row = [];
126        $row[] = 'type';
127        foreach ($models as $model) {
128            $row[] = $model;
129            $row[] = 'Cost USD';
130            $row[] = 'Time s';
131        }
132        $rows[] = $row;
133
134        // write rows
135        for($i=0; $i<$numberOfRecords; $i++) {
136            foreach($rowkeys as $type => $keys) {
137                $row = [];
138                $row[] = $type;
139                foreach($models as $model) {
140                    foreach ($keys as $key) {
141                        if($key) {
142                            $row[] = $result[$model][$i][$key];
143                        } else {
144                            $row[] = '';
145                        }
146                    }
147                }
148                $rows[] = $row;
149            }
150        }
151
152
153        return $rows;
154    }
155
156
157    /**
158     * Prefix each key in the given stats array to be merged with a larger array
159     *
160     * @param string $prefix
161     * @param array $stats
162     * @return array
163     */
164    protected function flattenStats(string $prefix, array $stats) {
165        $result = [];
166        foreach($stats as $key => $value) {
167            $result["$prefix.$key"] = $value;
168        }
169        return $result;
170    }
171
172    /**
173     * @param string $file
174     * @return array
175     * @throws Exception
176     */
177    protected function readInputFile(string $file): array
178    {
179        if (!file_exists($file)) throw new \Exception("File not found: $file");
180        $lines = file_get_contents($file);
181        $questions = explode("\n\n", $lines);
182        $questions = array_map('trim', $questions);
183        return $questions;
184    }
185}
186
187