xref: /plugin/aichat/cli/simulate.php (revision 2071dced6f96936ea7b9bf5dbe8a117eef598448)
1<?php
2
3use dokuwiki\Extension\CLIPlugin;
4use dokuwiki\plugin\aichat\ModelFactory;
5use splitbrain\phpcli\Colors;
6use splitbrain\phpcli\Options;
7
8/**
9 * DokuWiki Plugin aichat (CLI Component)
10 *
11 * @license GPL 2 http://www.gnu.org/licenses/gpl-2.0.html
12 * @author  Andreas Gohr <gohr@cosmocode.de>
13 */
14class cli_plugin_aichat_simulate extends CLIPlugin
15{
16    /** @var helper_plugin_aichat */
17    protected $helper;
18
19    /** @inheritdoc */
20    public function __construct($autocatch = true)
21    {
22        parent::__construct($autocatch);
23        $this->helper = plugin_load('helper', 'aichat');
24        $this->helper->setLogger($this);
25        $this->loadConfig();
26    }
27
28
29    /** @inheritDoc */
30    protected function setup(Options $options)
31    {
32        $options->setHelp('Run a prpared chat session against multiple models');
33        $options->registerArgument('input', 'A file with the chat questions. Each question separated by two newlines');
34        $options->registerArgument('output', 'Where to write the result CSV to');
35
36        $options->registerOption(
37            'filter',
38            'Use only models matching this case-insensitive regex (no delimiters)',
39            'f',
40            'regex'
41        );
42    }
43
44    /** @inheritDoc */
45    protected function main(Options $options)
46    {
47        if ($this->loglevel['debug']['enabled']) {
48            $this->helper->factory->setDebug(true);
49        }
50
51        [$input, $output] = $options->getArgs();
52        $questions = $this->readInputFile($input);
53        $outfh = @fopen($output, 'w');
54        if (!$outfh) throw new \Exception("Could not open $output for writing");
55
56        $models = $this->helper->factory->getModels(true, 'chat');
57
58        $results = [];
59        foreach ($models as $name => $info) {
60            if ($options->getOpt('filter') && !preg_match('/' . $options->getOpt('filter') . '/i', $name)) {
61                continue;
62            }
63            $this->success("Running on $name...");
64            $results[$name] = $this->simulate($questions, $info);
65        }
66
67        foreach ($this->records2rows($results) as $row) {
68            fputcsv($outfh, $row);
69        }
70        fclose($outfh);
71        $this->success("Results written to $output");
72    }
73
74    protected function simulate($questions, $model)
75    {
76        // override models
77        $this->helper->factory->chatModel = $model['instance'];
78        $this->helper->factory->rephraseModel = clone $model['instance'];
79
80        $records = [];
81
82        $history = [];
83        foreach ($questions as $q) {
84            $this->helper->getChatModel()->resetUsageStats();
85            $this->helper->getRephraseModel()->resetUsageStats();
86            $this->helper->getEmbeddingModel()->resetUsageStats();
87
88            $this->colors->ptln($q, Colors::C_LIGHTPURPLE);
89            $result = $this->helper->askChatQuestion($q, $history);
90            $history[] = [$result['question'], $result['answer']];
91
92            $record = [
93                'question' => $q,
94                'rephrased' => $result['question'],
95                'answer' => $result['answer'],
96                'source.list' => implode("\n", $result['sources']),
97                'source.time' => $this->helper->getEmbeddings()->timeSpent,
98                ...$this->flattenStats('stats.embedding', $this->helper->getEmbeddingModel()->getUsageStats()),
99                ...$this->flattenStats('stats.rephrase', $this->helper->getRephraseModel()->getUsageStats()),
100                ...$this->flattenStats('stats.chat', $this->helper->getChatModel()->getUsageStats()),
101            ];
102            $records[] = $record;
103            $this->colors->ptln($result['answer'], Colors::C_LIGHTCYAN);
104        }
105
106        return $records;
107    }
108
109    /**
110     * Reformat the result array into a CSV friendly array
111     */
112    protected function records2rows(array $result): array
113    {
114        $rowkeys = [
115            'question' => ['question', 'stats.embedding.cost', 'stats.embedding.time'],
116            'rephrased' => ['rephrased', 'stats.rephrase.cost', 'stats.rephrase.time'],
117            'sources' => ['source.list', '', 'source.time'],
118            'answer' => ['answer', 'stats.chat.cost', 'stats.chat.time'],
119        ];
120
121        $models = array_keys($result);
122        $numberOfRecords = count($result[$models[0]]);
123        $rows = [];
124
125        // write headers
126        $row = [];
127        $row[] = 'type';
128        foreach ($models as $model) {
129            $row[] = $model;
130            $row[] = 'Cost USD';
131            $row[] = 'Time s';
132        }
133        $rows[] = $row;
134
135        // write rows
136        for ($i = 0; $i < $numberOfRecords; $i++) {
137            foreach ($rowkeys as $type => $keys) {
138                $row = [];
139                $row[] = $type;
140                foreach ($models as $model) {
141                    foreach ($keys as $key) {
142                        if ($key) {
143                            $row[] = $result[$model][$i][$key];
144                        } else {
145                            $row[] = '';
146                        }
147                    }
148                }
149                $rows[] = $row;
150            }
151        }
152
153
154        return $rows;
155    }
156
157
158    /**
159     * Prefix each key in the given stats array to be merged with a larger array
160     *
161     * @param string $prefix
162     * @param array $stats
163     * @return array
164     */
165    protected function flattenStats(string $prefix, array $stats)
166    {
167        $result = [];
168        foreach ($stats as $key => $value) {
169            $result["$prefix.$key"] = $value;
170        }
171        return $result;
172    }
173
174    /**
175     * @param string $file
176     * @return array
177     * @throws Exception
178     */
179    protected function readInputFile(string $file): array
180    {
181        if (!file_exists($file)) throw new \Exception("File not found: $file");
182        $lines = file_get_contents($file);
183        $questions = explode("\n\n", $lines);
184        $questions = array_map('trim', $questions);
185        return $questions;
186    }
187}
188