xref: /plugin/aichat/cli/simulate.php (revision c2b7a1f7fd0f6c6579c9ee46f0437ff89c2fc4b3)
1*c2b7a1f7SAndreas Gohr<?php
2*c2b7a1f7SAndreas Gohr
3*c2b7a1f7SAndreas Gohruse dokuwiki\plugin\aichat\ModelFactory;
4*c2b7a1f7SAndreas Gohruse splitbrain\phpcli\Colors;
5*c2b7a1f7SAndreas Gohruse splitbrain\phpcli\Options;
6*c2b7a1f7SAndreas Gohr
7*c2b7a1f7SAndreas Gohr/**
8*c2b7a1f7SAndreas Gohr * DokuWiki Plugin aichat (CLI Component)
9*c2b7a1f7SAndreas Gohr *
10*c2b7a1f7SAndreas Gohr * @license GPL 2 http://www.gnu.org/licenses/gpl-2.0.html
11*c2b7a1f7SAndreas Gohr * @author  Andreas Gohr <gohr@cosmocode.de>
12*c2b7a1f7SAndreas Gohr */
13*c2b7a1f7SAndreas Gohrclass cli_plugin_aichat_simulate extends \dokuwiki\Extension\CLIPlugin
14*c2b7a1f7SAndreas Gohr{
15*c2b7a1f7SAndreas Gohr    /** @var helper_plugin_aichat */
16*c2b7a1f7SAndreas Gohr    protected $helper;
17*c2b7a1f7SAndreas Gohr
18*c2b7a1f7SAndreas Gohr    /** @inheritdoc */
19*c2b7a1f7SAndreas Gohr    public function __construct($autocatch = true)
20*c2b7a1f7SAndreas Gohr    {
21*c2b7a1f7SAndreas Gohr        parent::__construct($autocatch);
22*c2b7a1f7SAndreas Gohr        $this->helper = plugin_load('helper', 'aichat');
23*c2b7a1f7SAndreas Gohr        $this->helper->setLogger($this);
24*c2b7a1f7SAndreas Gohr        $this->loadConfig();
25*c2b7a1f7SAndreas Gohr    }
26*c2b7a1f7SAndreas Gohr
27*c2b7a1f7SAndreas Gohr
28*c2b7a1f7SAndreas Gohr    /** @inheritDoc */
29*c2b7a1f7SAndreas Gohr    protected function setup(Options $options)
30*c2b7a1f7SAndreas Gohr    {
31*c2b7a1f7SAndreas Gohr        $options->setHelp('Run a prpared chat session against multiple models');
32*c2b7a1f7SAndreas Gohr        $options->registerArgument('input', 'A file with the chat questions. Each question separated by two newlines');
33*c2b7a1f7SAndreas Gohr        $options->registerArgument('output', 'Where to write the result CSV to');
34*c2b7a1f7SAndreas Gohr
35*c2b7a1f7SAndreas Gohr        $options->registerOption(
36*c2b7a1f7SAndreas Gohr            'filter',
37*c2b7a1f7SAndreas Gohr            'Use only models matching this case-insensitive regex (no delimiters)',
38*c2b7a1f7SAndreas Gohr            'f',
39*c2b7a1f7SAndreas Gohr            'regex'
40*c2b7a1f7SAndreas Gohr        );
41*c2b7a1f7SAndreas Gohr    }
42*c2b7a1f7SAndreas Gohr
43*c2b7a1f7SAndreas Gohr    /** @inheritDoc */
44*c2b7a1f7SAndreas Gohr    protected function main(Options $options)
45*c2b7a1f7SAndreas Gohr    {
46*c2b7a1f7SAndreas Gohr        if ($this->loglevel['debug']['enabled']) {
47*c2b7a1f7SAndreas Gohr            $this->helper->factory->setDebug(true);
48*c2b7a1f7SAndreas Gohr        }
49*c2b7a1f7SAndreas Gohr
50*c2b7a1f7SAndreas Gohr        [$input, $output] = $options->getArgs();
51*c2b7a1f7SAndreas Gohr        $questions = $this->readInputFile($input);
52*c2b7a1f7SAndreas Gohr        $outfh = @fopen($output, 'w');
53*c2b7a1f7SAndreas Gohr        if(!$outfh) throw new \Exception("Could not open $output for writing");
54*c2b7a1f7SAndreas Gohr
55*c2b7a1f7SAndreas Gohr        $models = $this->helper->factory->getModels(true, 'chat');
56*c2b7a1f7SAndreas Gohr
57*c2b7a1f7SAndreas Gohr        $results = [];
58*c2b7a1f7SAndreas Gohr        foreach ($models as $name => $info) {
59*c2b7a1f7SAndreas Gohr            if ($options->getOpt('filter') && !preg_match('/' . $options->getOpt('filter') . '/i', $name)) {
60*c2b7a1f7SAndreas Gohr                continue;
61*c2b7a1f7SAndreas Gohr            }
62*c2b7a1f7SAndreas Gohr            $this->success("Running on $name...");
63*c2b7a1f7SAndreas Gohr            $results[$name] = $this->simulate($questions, $info);
64*c2b7a1f7SAndreas Gohr        }
65*c2b7a1f7SAndreas Gohr
66*c2b7a1f7SAndreas Gohr        foreach ($this->records2rows($results) as $row) {
67*c2b7a1f7SAndreas Gohr            fputcsv($outfh, $row);
68*c2b7a1f7SAndreas Gohr        }
69*c2b7a1f7SAndreas Gohr        fclose($outfh);
70*c2b7a1f7SAndreas Gohr        $this->success("Results written to $output");
71*c2b7a1f7SAndreas Gohr    }
72*c2b7a1f7SAndreas Gohr
73*c2b7a1f7SAndreas Gohr    protected function simulate($questions, $model)
74*c2b7a1f7SAndreas Gohr    {
75*c2b7a1f7SAndreas Gohr        // override models
76*c2b7a1f7SAndreas Gohr        $this->helper->factory->chatModel = $model['instance'];
77*c2b7a1f7SAndreas Gohr        $this->helper->factory->rephraseModel = clone $model['instance'];
78*c2b7a1f7SAndreas Gohr
79*c2b7a1f7SAndreas Gohr        $records = [];
80*c2b7a1f7SAndreas Gohr
81*c2b7a1f7SAndreas Gohr        $history = [];
82*c2b7a1f7SAndreas Gohr        foreach ($questions as $q) {
83*c2b7a1f7SAndreas Gohr            $this->helper->getChatModel()->resetUsageStats();
84*c2b7a1f7SAndreas Gohr            $this->helper->getRephraseModel()->resetUsageStats();
85*c2b7a1f7SAndreas Gohr            $this->helper->getEmbeddingModel()->resetUsageStats();
86*c2b7a1f7SAndreas Gohr
87*c2b7a1f7SAndreas Gohr            $this->colors->ptln($q, Colors::C_LIGHTPURPLE);
88*c2b7a1f7SAndreas Gohr            $result = $this->helper->askChatQuestion($q, $history);
89*c2b7a1f7SAndreas Gohr            $history[] = [$result['question'], $result['answer']];
90*c2b7a1f7SAndreas Gohr
91*c2b7a1f7SAndreas Gohr            $record = [
92*c2b7a1f7SAndreas Gohr                'question' => $q,
93*c2b7a1f7SAndreas Gohr                'rephrased' => $result['question'],
94*c2b7a1f7SAndreas Gohr                'answer' => $result['answer'],
95*c2b7a1f7SAndreas Gohr                'source.list' => join("\n", $result['sources']),
96*c2b7a1f7SAndreas Gohr                'source.time' => $this->helper->getEmbeddings()->timeSpent,
97*c2b7a1f7SAndreas Gohr                ...$this->flattenStats('stats.embedding', $this->helper->getEmbeddingModel()->getUsageStats()),
98*c2b7a1f7SAndreas Gohr                ...$this->flattenStats('stats.rephrase', $this->helper->getRephraseModel()->getUsageStats()),
99*c2b7a1f7SAndreas Gohr                ...$this->flattenStats('stats.chat', $this->helper->getChatModel()->getUsageStats()),
100*c2b7a1f7SAndreas Gohr            ];
101*c2b7a1f7SAndreas Gohr            $records[] = $record;
102*c2b7a1f7SAndreas Gohr            $this->colors->ptln($result['answer'], Colors::C_LIGHTCYAN);
103*c2b7a1f7SAndreas Gohr        }
104*c2b7a1f7SAndreas Gohr
105*c2b7a1f7SAndreas Gohr        return $records;
106*c2b7a1f7SAndreas Gohr    }
107*c2b7a1f7SAndreas Gohr
108*c2b7a1f7SAndreas Gohr    /**
109*c2b7a1f7SAndreas Gohr     * Reformat the result array into a CSV friendly array
110*c2b7a1f7SAndreas Gohr     */
111*c2b7a1f7SAndreas Gohr    protected function records2rows(array $result): array
112*c2b7a1f7SAndreas Gohr    {
113*c2b7a1f7SAndreas Gohr        $rowkeys = [
114*c2b7a1f7SAndreas Gohr            'question' => ['question', 'stats.embedding.cost', 'stats.embedding.time'],
115*c2b7a1f7SAndreas Gohr            'rephrased' => ['rephrased', 'stats.rephrase.cost', 'stats.rephrase.time'],
116*c2b7a1f7SAndreas Gohr            'sources' => ['source.list', '', 'source.time'],
117*c2b7a1f7SAndreas Gohr            'answer' => ['answer', 'stats.chat.cost', 'stats.chat.time'],
118*c2b7a1f7SAndreas Gohr        ];
119*c2b7a1f7SAndreas Gohr
120*c2b7a1f7SAndreas Gohr        $models = array_keys($result);
121*c2b7a1f7SAndreas Gohr        $numberOfRecords = count($result[$models[0]]);
122*c2b7a1f7SAndreas Gohr        $rows = [];
123*c2b7a1f7SAndreas Gohr
124*c2b7a1f7SAndreas Gohr        // write headers
125*c2b7a1f7SAndreas Gohr        $row = [];
126*c2b7a1f7SAndreas Gohr        $row[] = 'type';
127*c2b7a1f7SAndreas Gohr        foreach ($models as $model) {
128*c2b7a1f7SAndreas Gohr            $row[] = $model;
129*c2b7a1f7SAndreas Gohr            $row[] = 'Cost USD';
130*c2b7a1f7SAndreas Gohr            $row[] = 'Time s';
131*c2b7a1f7SAndreas Gohr        }
132*c2b7a1f7SAndreas Gohr        $rows[] = $row;
133*c2b7a1f7SAndreas Gohr
134*c2b7a1f7SAndreas Gohr        // write rows
135*c2b7a1f7SAndreas Gohr        for($i=0; $i<$numberOfRecords; $i++) {
136*c2b7a1f7SAndreas Gohr            foreach($rowkeys as $type => $keys) {
137*c2b7a1f7SAndreas Gohr                $row = [];
138*c2b7a1f7SAndreas Gohr                $row[] = $type;
139*c2b7a1f7SAndreas Gohr                foreach($models as $model) {
140*c2b7a1f7SAndreas Gohr                    foreach ($keys as $key) {
141*c2b7a1f7SAndreas Gohr                        if($key) {
142*c2b7a1f7SAndreas Gohr                            $row[] = $result[$model][$i][$key];
143*c2b7a1f7SAndreas Gohr                        } else {
144*c2b7a1f7SAndreas Gohr                            $row[] = '';
145*c2b7a1f7SAndreas Gohr                        }
146*c2b7a1f7SAndreas Gohr                    }
147*c2b7a1f7SAndreas Gohr                }
148*c2b7a1f7SAndreas Gohr                $rows[] = $row;
149*c2b7a1f7SAndreas Gohr            }
150*c2b7a1f7SAndreas Gohr        }
151*c2b7a1f7SAndreas Gohr
152*c2b7a1f7SAndreas Gohr
153*c2b7a1f7SAndreas Gohr        return $rows;
154*c2b7a1f7SAndreas Gohr    }
155*c2b7a1f7SAndreas Gohr
156*c2b7a1f7SAndreas Gohr
157*c2b7a1f7SAndreas Gohr    /**
158*c2b7a1f7SAndreas Gohr     * Prefix each key in the given stats array to be merged with a larger array
159*c2b7a1f7SAndreas Gohr     *
160*c2b7a1f7SAndreas Gohr     * @param string $prefix
161*c2b7a1f7SAndreas Gohr     * @param array $stats
162*c2b7a1f7SAndreas Gohr     * @return array
163*c2b7a1f7SAndreas Gohr     */
164*c2b7a1f7SAndreas Gohr    protected function flattenStats(string $prefix, array $stats) {
165*c2b7a1f7SAndreas Gohr        $result = [];
166*c2b7a1f7SAndreas Gohr        foreach($stats as $key => $value) {
167*c2b7a1f7SAndreas Gohr            $result["$prefix.$key"] = $value;
168*c2b7a1f7SAndreas Gohr        }
169*c2b7a1f7SAndreas Gohr        return $result;
170*c2b7a1f7SAndreas Gohr    }
171*c2b7a1f7SAndreas Gohr
172*c2b7a1f7SAndreas Gohr    /**
173*c2b7a1f7SAndreas Gohr     * @param string $file
174*c2b7a1f7SAndreas Gohr     * @return array
175*c2b7a1f7SAndreas Gohr     * @throws Exception
176*c2b7a1f7SAndreas Gohr     */
177*c2b7a1f7SAndreas Gohr    protected function readInputFile(string $file): array
178*c2b7a1f7SAndreas Gohr    {
179*c2b7a1f7SAndreas Gohr        if (!file_exists($file)) throw new \Exception("File not found: $file");
180*c2b7a1f7SAndreas Gohr        $lines = file_get_contents($file);
181*c2b7a1f7SAndreas Gohr        $questions = explode("\n\n", $lines);
182*c2b7a1f7SAndreas Gohr        $questions = array_map('trim', $questions);
183*c2b7a1f7SAndreas Gohr        return $questions;
184*c2b7a1f7SAndreas Gohr    }
185*c2b7a1f7SAndreas Gohr}
186*c2b7a1f7SAndreas Gohr
187