1<?php
2
3namespace dokuwiki\plugin\aichat\Model;
4
5use dokuwiki\HTTP\DokuHTTPClient;
6
7/**
8 * Base class for all models
9 *
10 * Model classes also need to implement one of the following interfaces:
11 * - ChatInterface
12 * - EmbeddingInterface
13 *
14 * This class already implements most of the requirements for these interfaces.
15 *
16 * In addition to any missing interface methods, model implementations will need to
17 * extend the constructor to handle the plugin configuration and implement the
18 * parseAPIResponse() method to handle the specific API response.
19 */
20abstract class AbstractModel implements ModelInterface
21{
22    /** @var string The model name */
23    protected $modelName;
24    /** @var string The full model name */
25    protected $modelFullName;
26    /** @var array The model info from the model.json file */
27    protected $modelInfo;
28    /** @var string The provider name */
29    protected $selfIdent;
30
31    /** @var int input tokens used since last reset */
32    protected $inputTokensUsed = 0;
33    /** @var int output tokens used since last reset */
34    protected $outputTokensUsed = 0;
35    /** @var int total time spent in requests since last reset */
36    protected $timeUsed = 0;
37    /** @var int total number of requests made since last reset */
38    protected $requestsMade = 0;
39    /** @var int start time of the current request chain (may be multiple when retries needed) */
40    protected $requestStart = 0;
41
42    /** @var int How often to retry a request if it fails */
43    public const MAX_RETRIES = 3;
44
45    /** @var DokuHTTPClient */
46    protected $http;
47    /** @var bool debug API communication */
48    protected $debug = false;
49    /** @var string The base API URL. Configurable for some models */
50    protected $apiurl = '';
51
52    /** @var array The plugin configuration */
53    protected $config;
54
55    // region ModelInterface
56
57    /** @inheritdoc */
58    public function __construct(string $name, array $config)
59    {
60        $this->modelName = $name;
61        $this->config = $config;
62
63        $reflect = new \ReflectionClass($this);
64        $json = dirname($reflect->getFileName()) . '/models.json';
65        if (!file_exists($json)) {
66            throw new \Exception('Model info file not found at ' . $json, 2001);
67        }
68        try {
69            $modelinfos = json_decode(file_get_contents($json), true, 512, JSON_THROW_ON_ERROR);
70        } catch (\JsonException $e) {
71            throw new \Exception('Failed to parse model info file: ' . $e->getMessage(), 2002, $e);
72        }
73
74        $this->selfIdent = basename(dirname($reflect->getFileName()));
75        $this->modelFullName = basename(dirname($reflect->getFileName())) . ' ' . $name;
76
77        if ($this->apiurl === '') {
78            // we use an empty default here, since some models may not use this property
79            $this->apiurl = $this->getFromConf('apiurl', '');
80        }
81        $this->apiurl = rtrim($this->apiurl, '/');
82
83        if ($this instanceof ChatInterface) {
84            if (isset($modelinfos['chat'][$name])) {
85                $this->modelInfo = $modelinfos['chat'][$name];
86            } else {
87                $this->modelInfo = $this->loadUnknownModelInfo();
88            }
89
90        }
91
92        if ($this instanceof EmbeddingInterface) {
93            if (isset($modelinfos['embedding'][$name])) {
94                $this->modelInfo = $modelinfos['embedding'][$name];
95            } else {
96                $this->modelInfo = $this->loadUnknownModelInfo();
97            }
98        }
99    }
100
101    /** @inheritdoc */
102    public function __toString(): string
103    {
104        return $this->modelFullName;
105    }
106
107    /** @inheritdoc */
108    public function getModelName()
109    {
110        return $this->modelName;
111    }
112
113    /**
114     * Reset the usage statistics
115     *
116     * Usually not needed when only handling one operation per request, but useful in CLI
117     */
118    public function resetUsageStats()
119    {
120        $this->inputTokensUsed = 0;
121        $this->outputTokensUsed = 0;
122        $this->timeUsed = 0;
123        $this->requestsMade = 0;
124    }
125
126    /**
127     * Get the usage statistics for this instance
128     *
129     * @return string[]
130     */
131    public function getUsageStats()
132    {
133
134        $cost = 0;
135        $cost += $this->inputTokensUsed * $this->getInputTokenPrice();
136        if ($this instanceof ChatInterface) {
137            $cost += $this->outputTokensUsed * $this->getOutputTokenPrice();
138        }
139
140        return [
141            'tokens' => $this->inputTokensUsed + $this->outputTokensUsed,
142            'cost' => sprintf("%.6f", $cost / 1_000_000),
143            'time' => round($this->timeUsed, 2),
144            'requests' => $this->requestsMade,
145        ];
146    }
147
148    /** @inheritdoc */
149    public function getMaxInputTokenLength(): int
150    {
151        return $this->modelInfo['inputTokens'] ?? 0;
152    }
153
154    /** @inheritdoc */
155    public function getInputTokenPrice(): float
156    {
157        return $this->modelInfo['inputTokenPrice'] ?? 0;
158    }
159
160    /** @inheritdoc */
161    function loadUnknownModelInfo(): array
162    {
163        $info = [
164            'description' => $this->modelFullName,
165            'inputTokens' => 0,
166            'inputTokenPrice' => 0,
167        ];
168
169        if ($this instanceof ChatInterface) {
170            $info['outputTokens'] = 0;
171            $info['outputTokenPrice'] = 0;
172        } elseif ($this instanceof EmbeddingInterface) {
173            $info['dimensions'] = 512;
174        }
175
176        return $info;
177    }
178
179    // endregion
180
181    // region EmbeddingInterface
182
183    /** @inheritdoc */
184    public function getDimensions(): int
185    {
186        return $this->modelInfo['dimensions'];
187    }
188
189    // endregion
190
191    // region ChatInterface
192
193    public function getMaxOutputTokenLength(): int
194    {
195        return $this->modelInfo['outputTokens'];
196    }
197
198    public function getOutputTokenPrice(): float
199    {
200        return $this->modelInfo['outputTokenPrice'];
201    }
202
203    // endregion
204
205    // region API communication
206
207    /**
208     * When enabled, the input/output of the API will be printed to STDOUT
209     *
210     * @param bool $debug
211     */
212    public function setDebug($debug = true)
213    {
214        $this->debug = $debug;
215    }
216
217    /**
218     * Get the HTTP client used for API requests
219     *
220     * This method will create a new DokuHTTPClient instance if it does not exist yet.
221     * The client will be configured with a timeout and the appropriate headers for JSON communication.
222     * Inheriting models should override this method if they need to add additional headers or configuration
223     * to the HTTP client.
224     *
225     * @return DokuHTTPClient
226     */
227    protected function getHttpClient()
228    {
229        if ($this->http === null) {
230            $this->http = new DokuHTTPClient();
231            $this->http->timeout = 60;
232            $this->http->headers['Content-Type'] = 'application/json';
233            $this->http->headers['Accept'] = 'application/json';
234        }
235
236        return $this->http;
237    }
238
239    /**
240     * This method should check the response for any errors. If the API singalled an error,
241     * this method should throw an Exception with a meaningful error message.
242     *
243     * If the response returned any info on used tokens, they should be added to $this->tokensUsed
244     *
245     * The method should return the parsed response, which will be passed to the calling method.
246     *
247     * @param mixed $response the parsed JSON response from the API
248     * @return mixed
249     * @throws \Exception when the response indicates an error
250     */
251    abstract protected function parseAPIResponse($response);
252
253    /**
254     * Send a request to the API
255     *
256     * Model classes should use this method to send requests to the API.
257     *
258     * This method will take care of retrying and logging basic statistics.
259     *
260     * It is assumed that all APIs speak JSON.
261     *
262     * @param string $method The HTTP method to use (GET, POST, PUT, DELETE, etc.)
263     * @param string $url The full URL to send the request to
264     * @param array|string $data Payload to send, will be encoded to JSON
265     * @param int $retry How often this request has been retried, do not set externally
266     * @return array API response as returned by parseAPIResponse
267     * @throws \Exception when anything goes wrong
268     */
269    protected function sendAPIRequest($method, $url, $data, $retry = 0)
270    {
271        // init statistics
272        if ($retry === 0) {
273            $this->requestStart = microtime(true);
274        } else {
275            sleep($retry); // wait a bit between retries
276        }
277        $this->requestsMade++;
278
279        // encode payload data
280        try {
281            $json = json_encode($data, JSON_THROW_ON_ERROR | JSON_PRETTY_PRINT);
282        } catch (\JsonException $e) {
283            $this->timeUsed += $this->requestStart - microtime(true);
284            throw new \Exception('Failed to encode JSON for API:' . $e->getMessage(), 2003, $e);
285        }
286
287        if ($this->debug) {
288            echo 'Sending ' . $method . ' request to ' . $url . ' with payload:' . "\n";
289            print_r($json);
290            echo "\n";
291        }
292
293        // send request and handle retries
294        $http = $this->getHttpClient();
295        $http->sendRequest($url, $json, $method);
296        $response = $http->resp_body;
297        if ($response === false || $http->error) {
298            if ($retry < self::MAX_RETRIES) {
299                return $this->sendAPIRequest($method, $url, $data, $retry + 1);
300            }
301            $this->timeUsed += microtime(true) - $this->requestStart;
302            throw new \Exception('API returned no response. ' . $http->error, 2004);
303        }
304
305        if ($this->debug) {
306            echo 'Received response:' . "\n";
307            print_r($response);
308            echo "\n";
309        }
310
311        // decode the response
312        try {
313            $result = json_decode((string)$response, true, 512, JSON_THROW_ON_ERROR);
314        } catch (\JsonException $e) {
315            $this->timeUsed += microtime(true) - $this->requestStart;
316            throw new \Exception('API returned invalid JSON: ' . $response, 2005, $e);
317        }
318
319        // parse the response, retry on error
320        try {
321            $result = $this->parseAPIResponse($result);
322        } catch (\Exception $e) {
323            if ($retry < self::MAX_RETRIES) {
324                return $this->sendAPIRequest($method, $url, $data, $retry + 1);
325            }
326            $this->timeUsed += microtime(true) - $this->requestStart;
327            throw $e;
328        }
329
330        $this->timeUsed += microtime(true) - $this->requestStart;
331        return $result;
332    }
333
334    // endregion
335
336    // region Tools
337
338    /**
339     * Get a configuration value
340     *
341     * The given key is prefixed by the model namespace
342     *
343     * @param string $key
344     * @param mixed $default The default to return if the key is not found. When set to null an Exception is thrown.
345     * @return mixed
346     * @throws ModelException when the key is not found and no default is given
347     */
348    public function getFromConf(string $key, $default = null)
349    {
350        $config = $this->config;
351
352        $key = strtolower($this->selfIdent) . '_' . $key;
353        if (isset($config[$key])) {
354            return $config[$key];
355        }
356        if ($default !== null) {
357            return $default;
358        }
359        throw new ModelException('Key ' . $key . ' not found in configuration', 3001);
360    }
361
362    // endregion
363}
364