xref: /plugin/aichat/Model/AbstractModel.php (revision 7c3b69cbd638c7d1775b012bfb9b48d8136f76b9)
1<?php
2
3namespace dokuwiki\plugin\aichat\Model;
4
5use dokuwiki\HTTP\DokuHTTPClient;
6
7/**
8 * Base class for all models
9 *
10 * Model classes also need to implement one of the following interfaces:
11 * - ChatInterface
12 * - EmbeddingInterface
13 *
14 * This class already implements most of the requirements for these interfaces.
15 *
16 * In addition to any missing interface methods, model implementations will need to
17 * extend the constructor to handle the plugin configuration and implement the
18 * parseAPIResponse() method to handle the specific API response.
19 */
20abstract class AbstractModel implements ModelInterface
21{
22    /** @var string The model name */
23    protected $modelName;
24    /** @var string The full model name */
25    protected $modelFullName;
26    /** @var array The model info from the model.json file */
27    protected $modelInfo;
28    /** @var string The provider name */
29    protected $selfIdent;
30
31    /** @var int input tokens used since last reset */
32    protected $inputTokensUsed = 0;
33    /** @var int output tokens used since last reset */
34    protected $outputTokensUsed = 0;
35    /** @var int total time spent in requests since last reset */
36    protected $timeUsed = 0;
37    /** @var int total number of requests made since last reset */
38    protected $requestsMade = 0;
39    /** @var int start time of the current request chain (may be multiple when retries needed) */
40    protected $requestStart = 0;
41
42    /** @var int How often to retry a request if it fails */
43    public const MAX_RETRIES = 3;
44
45    /** @var DokuHTTPClient */
46    protected $http;
47    /** @var bool debug API communication */
48    protected $debug = false;
49
50    /** @var array The plugin configuration */
51    protected $config;
52
53    // region ModelInterface
54
55    /** @inheritdoc */
56    public function __construct(string $name, array $config)
57    {
58        $this->modelName = $name;
59        $this->config = $config;
60
61        $reflect = new \ReflectionClass($this);
62        $json = dirname($reflect->getFileName()) . '/models.json';
63        if (!file_exists($json)) {
64            throw new \Exception('Model info file not found at ' . $json, 2001);
65        }
66        try {
67            $modelinfos = json_decode(file_get_contents($json), true, 512, JSON_THROW_ON_ERROR);
68        } catch (\JsonException $e) {
69            throw new \Exception('Failed to parse model info file: ' . $e->getMessage(), 2002, $e);
70        }
71
72        $this->selfIdent = basename(dirname($reflect->getFileName()));
73        $this->modelFullName = basename(dirname($reflect->getFileName())) . ' ' . $name;
74
75        if ($this instanceof ChatInterface) {
76            if (isset($modelinfos['chat'][$name])) {
77                $this->modelInfo = $modelinfos['chat'][$name];
78            } else {
79                $this->modelInfo = $this->loadUnknownModelInfo();
80            }
81
82        }
83
84        if ($this instanceof EmbeddingInterface) {
85            if (isset($modelinfos['embedding'][$name])) {
86                $this->modelInfo = $modelinfos['embedding'][$name];
87            } else {
88                $this->modelInfo = $this->loadUnknownModelInfo();
89            }
90        }
91    }
92
93    /** @inheritdoc */
94    public function __toString(): string
95    {
96        return $this->modelFullName;
97    }
98
99    /** @inheritdoc */
100    public function getModelName()
101    {
102        return $this->modelName;
103    }
104
105    /**
106     * Reset the usage statistics
107     *
108     * Usually not needed when only handling one operation per request, but useful in CLI
109     */
110    public function resetUsageStats()
111    {
112        $this->inputTokensUsed = 0;
113        $this->outputTokensUsed = 0;
114        $this->timeUsed = 0;
115        $this->requestsMade = 0;
116    }
117
118    /**
119     * Get the usage statistics for this instance
120     *
121     * @return string[]
122     */
123    public function getUsageStats()
124    {
125
126        $cost = 0;
127        $cost += $this->inputTokensUsed * $this->getInputTokenPrice();
128        if ($this instanceof ChatInterface) {
129            $cost += $this->outputTokensUsed * $this->getOutputTokenPrice();
130        }
131
132        return [
133            'tokens' => $this->inputTokensUsed + $this->outputTokensUsed,
134            'cost' => sprintf("%.6f", $cost / 1_000_000),
135            'time' => round($this->timeUsed, 2),
136            'requests' => $this->requestsMade,
137        ];
138    }
139
140    /** @inheritdoc */
141    public function getMaxInputTokenLength(): int
142    {
143        return $this->modelInfo['inputTokens'] ?? 0;
144    }
145
146    /** @inheritdoc */
147    public function getInputTokenPrice(): float
148    {
149        return $this->modelInfo['inputTokenPrice'] ?? 0;
150    }
151
152    /** @inheritdoc */
153    function loadUnknownModelInfo(): array
154    {
155        $info = [
156            'description' => $this->modelFullName,
157            'inputTokens' => 0,
158            'inputTokenPrice' => 0,
159        ];
160
161        if ($this instanceof ChatInterface) {
162            $info['outputTokens'] = 0;
163            $info['outputTokenPrice'] = 0;
164        } elseif ($this instanceof EmbeddingInterface) {
165            $info['dimensions'] = 512;
166        }
167
168        return $info;
169    }
170
171    // endregion
172
173    // region EmbeddingInterface
174
175    /** @inheritdoc */
176    public function getDimensions(): int
177    {
178        return $this->modelInfo['dimensions'];
179    }
180
181    // endregion
182
183    // region ChatInterface
184
185    public function getMaxOutputTokenLength(): int
186    {
187        return $this->modelInfo['outputTokens'];
188    }
189
190    public function getOutputTokenPrice(): float
191    {
192        return $this->modelInfo['outputTokenPrice'];
193    }
194
195    // endregion
196
197    // region API communication
198
199    /**
200     * When enabled, the input/output of the API will be printed to STDOUT
201     *
202     * @param bool $debug
203     */
204    public function setDebug($debug = true)
205    {
206        $this->debug = $debug;
207    }
208
209    /**
210     * Get the HTTP client used for API requests
211     *
212     * This method will create a new DokuHTTPClient instance if it does not exist yet.
213     * The client will be configured with a timeout and the appropriate headers for JSON communication.
214     * Inheriting models should override this method if they need to add additional headers or configuration
215     * to the HTTP client.
216     *
217     * @return DokuHTTPClient
218     */
219    protected function getHttpClient()
220    {
221        if($this->http === null) {
222            $this->http = new DokuHTTPClient();
223            $this->http->timeout = 60;
224            $this->http->headers['Content-Type'] = 'application/json';
225            $this->http->headers['Accept'] = 'application/json';
226        }
227
228        return $this->http;
229    }
230
231    /**
232     * This method should check the response for any errors. If the API singalled an error,
233     * this method should throw an Exception with a meaningful error message.
234     *
235     * If the response returned any info on used tokens, they should be added to $this->tokensUsed
236     *
237     * The method should return the parsed response, which will be passed to the calling method.
238     *
239     * @param mixed $response the parsed JSON response from the API
240     * @return mixed
241     * @throws \Exception when the response indicates an error
242     */
243    abstract protected function parseAPIResponse($response);
244
245    /**
246     * Send a request to the API
247     *
248     * Model classes should use this method to send requests to the API.
249     *
250     * This method will take care of retrying and logging basic statistics.
251     *
252     * It is assumed that all APIs speak JSON.
253     *
254     * @param string $method The HTTP method to use (GET, POST, PUT, DELETE, etc.)
255     * @param string $url The full URL to send the request to
256     * @param array|string $data Payload to send, will be encoded to JSON
257     * @param int $retry How often this request has been retried, do not set externally
258     * @return array API response as returned by parseAPIResponse
259     * @throws \Exception when anything goes wrong
260     */
261    protected function sendAPIRequest($method, $url, $data, $retry = 0)
262    {
263        // init statistics
264        if ($retry === 0) {
265            $this->requestStart = microtime(true);
266        } else {
267            sleep($retry); // wait a bit between retries
268        }
269        $this->requestsMade++;
270
271        // encode payload data
272        try {
273            $json = json_encode($data, JSON_THROW_ON_ERROR | JSON_PRETTY_PRINT);
274        } catch (\JsonException $e) {
275            $this->timeUsed += $this->requestStart - microtime(true);
276            throw new \Exception('Failed to encode JSON for API:' . $e->getMessage(), 2003, $e);
277        }
278
279        if ($this->debug) {
280            echo 'Sending ' . $method . ' request to ' . $url . ' with payload:' . "\n";
281            print_r($json);
282            echo "\n";
283        }
284
285        // send request and handle retries
286        $http = $this->getHttpClient();
287        $http->sendRequest($url, $json, $method);
288        $response = $http->resp_body;
289        if ($response === false || $http->error) {
290            if ($retry < self::MAX_RETRIES) {
291                return $this->sendAPIRequest($method, $url, $data, $retry + 1);
292            }
293            $this->timeUsed += microtime(true) - $this->requestStart;
294            throw new \Exception('API returned no response. ' . $http->error, 2004);
295        }
296
297        if ($this->debug) {
298            echo 'Received response:' . "\n";
299            print_r($response);
300            echo "\n";
301        }
302
303        // decode the response
304        try {
305            $result = json_decode((string)$response, true, 512, JSON_THROW_ON_ERROR);
306        } catch (\JsonException $e) {
307            $this->timeUsed += microtime(true) - $this->requestStart;
308            throw new \Exception('API returned invalid JSON: ' . $response, 2005, $e);
309        }
310
311        // parse the response, retry on error
312        try {
313            $result = $this->parseAPIResponse($result);
314        } catch (\Exception $e) {
315            if ($retry < self::MAX_RETRIES) {
316                return $this->sendAPIRequest($method, $url, $data, $retry + 1);
317            }
318            $this->timeUsed += microtime(true) - $this->requestStart;
319            throw $e;
320        }
321
322        $this->timeUsed += microtime(true) - $this->requestStart;
323        return $result;
324    }
325
326    // endregion
327
328    // region Tools
329
330    /**
331     * Get a configuration value
332     *
333     * The given key is prefixed by the model namespace
334     *
335     * @param string $key
336     * @param mixed $default The default to return if the key is not found. When set to null an Exception is thrown.
337     * @return mixed
338     * @throws ModelException when the key is not found and no default is given
339     */
340    public function getFromConf(string $key, $default = null)
341    {
342        $config = $this->config;
343
344        $key = strtolower($this->selfIdent) . '_' . $key;
345        if (isset($config[$key])) {
346            return $config[$key];
347        }
348        if ($default !== null) {
349            return $default;
350        }
351        throw new ModelException('Key ' . $key . ' not found in configuration', 3001);
352    }
353
354    // endregion
355}
356