xref: /plugin/aichat/cli/simulate.php (revision ab1f8dde36106432cc0a6f320220da5fae6971fe)
1<?php
2
3use dokuwiki\plugin\aichat\AbstractCLI;
4use splitbrain\phpcli\Colors;
5use splitbrain\phpcli\Options;
6
7/**
8 * DokuWiki Plugin aichat (CLI Component)
9 *
10 * @license GPL 2 http://www.gnu.org/licenses/gpl-2.0.html
11 * @author  Andreas Gohr <gohr@cosmocode.de>
12 */
13class cli_plugin_aichat_simulate extends AbstractCLI
14{
15
16
17    /** @inheritDoc */
18    protected function setup(Options $options)
19    {
20        parent::setup($options);
21
22        $options->setHelp('Run a prepared chat session against multiple models');
23        $options->registerArgument('input', 'A file with the chat questions. Each question separated by two newlines');
24        $options->registerArgument('output', 'Where to write the result CSV to');
25
26        $options->registerOption(
27            'filter',
28            'Use only models matching this case-insensitive regex (no delimiters)',
29            'f',
30            'regex'
31        );
32    }
33
34    /** @inheritDoc */
35    protected function main(Options $options)
36    {
37        parent::main($options);
38
39        [$input, $output] = $options->getArgs();
40        $questions = $this->readInputFile($input);
41        $outFH = @fopen($output, 'w');
42        if (!$outFH) throw new \Exception("Could not open $output for writing");
43
44        $models = $this->helper->factory->getModels(true, 'chat');
45
46        $results = [];
47        foreach ($models as $name => $info) {
48            if ($options->getOpt('filter') && !preg_match('/' . $options->getOpt('filter') . '/i', $name)) {
49                continue;
50            }
51            $this->success("Running on $name...");
52            $results[$name] = $this->simulate($questions, $info);
53        }
54
55        foreach ($this->records2rows($results) as $row) {
56            fputcsv($outFH, $row);
57        }
58        fclose($outFH);
59        $this->success("Results written to $output");
60    }
61
62    protected function simulate($questions, $model)
63    {
64        // override models
65        $this->helper->factory->chatModel = $model['instance'];
66        $this->helper->factory->rephraseModel = clone $model['instance'];
67
68        $records = [];
69
70        $history = [];
71        foreach ($questions as $q) {
72            $this->helper->getChatModel()->resetUsageStats();
73            $this->helper->getRephraseModel()->resetUsageStats();
74            $this->helper->getEmbeddingModel()->resetUsageStats();
75
76            $this->colors->ptln($q, Colors::C_LIGHTPURPLE);
77            try {
78                $result = $this->helper->askChatQuestion($q, $history);
79                $history[] = [$result['question'], $result['answer']];
80                $this->colors->ptln($result['question'], Colors::C_LIGHTBLUE);
81            } catch (Exception $e) {
82                $this->error($e->getMessage());
83                $this->debug($e->getTraceAsString());
84                $result = ['question' => $q, 'answer' => "ERROR\n" . $e->getMessage(), 'sources' => []];
85            }
86
87            $record = [
88                'question' => $q,
89                'rephrased' => $result['question'],
90                'answer' => $result['answer'],
91                'source.list' => implode("\n", $result['sources']),
92                'source.time' => $this->helper->getEmbeddings()->timeSpent,
93                ...$this->flattenStats('stats.embedding', $this->helper->getEmbeddingModel()->getUsageStats()),
94                ...$this->flattenStats('stats.rephrase', $this->helper->getRephraseModel()->getUsageStats()),
95                ...$this->flattenStats('stats.chat', $this->helper->getChatModel()->getUsageStats()),
96            ];
97            $records[] = $record;
98            $this->colors->ptln($result['answer'], Colors::C_LIGHTCYAN);
99        }
100
101        return $records;
102    }
103
104    /**
105     * Reformat the result array into a CSV friendly array
106     */
107    protected function records2rows(array $result): array
108    {
109        $rowkeys = [
110            'question' => ['question', 'stats.embedding.cost', 'stats.embedding.time'],
111            'rephrased' => ['rephrased', 'stats.rephrase.cost', 'stats.rephrase.time'],
112            'sources' => ['source.list', '', 'source.time'],
113            'answer' => ['answer', 'stats.chat.cost', 'stats.chat.time'],
114        ];
115
116        $models = array_keys($result);
117        $numberOfRecords = count($result[$models[0]]);
118        $rows = [];
119
120        // write headers
121        $row = [];
122        $row[] = 'type';
123        foreach ($models as $model) {
124            $row[] = $model;
125            $row[] = 'Cost USD';
126            $row[] = 'Time s';
127        }
128        $rows[] = $row;
129
130        // write rows
131        for ($i = 0; $i < $numberOfRecords; $i++) {
132            foreach ($rowkeys as $type => $keys) {
133                $row = [];
134                $row[] = $type;
135                foreach ($models as $model) {
136                    foreach ($keys as $key) {
137                        if ($key) {
138                            $row[] = $result[$model][$i][$key];
139                        } else {
140                            $row[] = '';
141                        }
142                    }
143                }
144                $rows[] = $row;
145            }
146        }
147
148
149        return $rows;
150    }
151
152
153    /**
154     * Prefix each key in the given stats array to be merged with a larger array
155     *
156     * @param string $prefix
157     * @param array $stats
158     * @return array
159     */
160    protected function flattenStats(string $prefix, array $stats)
161    {
162        $result = [];
163        foreach ($stats as $key => $value) {
164            $result["$prefix.$key"] = $value;
165        }
166        return $result;
167    }
168
169    /**
170     * @param string $file
171     * @return array
172     * @throws Exception
173     */
174    protected function readInputFile(string $file): array
175    {
176        if (!file_exists($file)) throw new \Exception("File not found: $file");
177        $lines = file_get_contents($file);
178        $questions = explode("\n\n", $lines);
179        $questions = array_map('trim', $questions);
180        return $questions;
181    }
182}
183