1<?php
2
3namespace dokuwiki\plugin\aichat\Model;
4
5use dokuwiki\HTTP\DokuHTTPClient;
6
7/**
8 * Base class for all models
9 *
10 * Model classes also need to implement one of the following interfaces:
11 * - ChatInterface
12 * - EmbeddingInterface
13 *
14 * This class already implements most of the requirements for these interfaces.
15 *
16 * In addition to any missing interface methods, model implementations will need to
17 * extend the constructor to handle the plugin configuration and implement the
18 * parseAPIResponse() method to handle the specific API response.
19 */
20abstract class AbstractModel implements ModelInterface
21{
22    /** @var string The model name */
23    protected $modelName;
24    /** @var string The full model name */
25    protected $modelFullName;
26    /** @var array The model info from the model.json file */
27    protected $modelInfo;
28
29    /** @var int input tokens used since last reset */
30    protected $inputTokensUsed = 0;
31    /** @var int output tokens used since last reset */
32    protected $outputTokensUsed = 0;
33    /** @var int total time spent in requests since last reset */
34    protected $timeUsed = 0;
35    /** @var int total number of requests made since last reset */
36    protected $requestsMade = 0;
37    /** @var int start time of the current request chain (may be multiple when retries needed) */
38    protected $requestStart = 0;
39
40    /** @var int How often to retry a request if it fails */
41    public const MAX_RETRIES = 3;
42
43    /** @var DokuHTTPClient */
44    protected $http;
45    /** @var bool debug API communication */
46    protected $debug = false;
47
48    // region ModelInterface
49
50    /** @inheritdoc */
51    public function __construct(string $name, array $config)
52    {
53        $this->modelName = $name;
54        $this->http = new DokuHTTPClient();
55        $this->http->timeout = 60;
56        $this->http->headers['Content-Type'] = 'application/json';
57        $this->http->headers['Accept'] = 'application/json';
58
59        $reflect = new \ReflectionClass($this);
60        $json = dirname($reflect->getFileName()) . '/models.json';
61        if (!file_exists($json)) {
62            throw new \Exception('Model info file not found at ' . $json);
63        }
64        try {
65            $modelinfos = json_decode(file_get_contents($json), true, 512, JSON_THROW_ON_ERROR);
66        } catch (\JsonException $e) {
67            throw new \Exception('Failed to parse model info file: ' . $e->getMessage(), $e->getCode(), $e);
68        }
69
70        $this->modelFullName = basename(dirname($reflect->getFileName()) . ' ' . $name);
71
72        if ($this instanceof ChatInterface) {
73            if (isset($modelinfos['chat'][$name])) {
74                $this->modelInfo = $modelinfos['chat'][$name];
75            } else {
76                $this->modelInfo = $this->loadUnknownModelInfo();
77            }
78
79        }
80
81        if ($this instanceof EmbeddingInterface) {
82            if (isset($modelinfos['embedding'][$name])) {
83                $this->modelInfo = $modelinfos['embedding'][$name];
84            } else {
85                $this->modelInfo = $this->loadUnknownModelInfo();
86            }
87        }
88    }
89
90    /** @inheritdoc */
91    public function __toString(): string
92    {
93        return $this->modelFullName;
94    }
95
96
97    /** @inheritdoc */
98    public function getModelName()
99    {
100        return $this->modelName;
101    }
102
103    /**
104     * Reset the usage statistics
105     *
106     * Usually not needed when only handling one operation per request, but useful in CLI
107     */
108    public function resetUsageStats()
109    {
110        $this->inputTokensUsed = 0;
111        $this->outputTokensUsed = 0;
112        $this->timeUsed = 0;
113        $this->requestsMade = 0;
114    }
115
116    /**
117     * Get the usage statistics for this instance
118     *
119     * @return string[]
120     */
121    public function getUsageStats()
122    {
123
124        $cost = 0;
125        $cost += $this->inputTokensUsed * $this->getInputTokenPrice();
126        if ($this instanceof ChatInterface) {
127            $cost += $this->outputTokensUsed * $this->getOutputTokenPrice();
128        }
129
130        return [
131            'tokens' => $this->inputTokensUsed + $this->outputTokensUsed,
132            'cost' => sprintf("%.6f", $cost / 1_000_000),
133            'time' => round($this->timeUsed, 2),
134            'requests' => $this->requestsMade,
135        ];
136    }
137
138    /** @inheritdoc */
139    public function getMaxInputTokenLength(): int
140    {
141        return $this->modelInfo['inputTokens'];
142    }
143
144    /** @inheritdoc */
145    public function getInputTokenPrice(): float
146    {
147        return $this->modelInfo['inputTokenPrice'];
148    }
149
150    /** @inheritdoc */
151    function loadUnknownModelInfo(): array
152    {
153        $info = [
154            'description' => $this->modelFullName,
155            'inputTokens' => 1024,
156            'inputTokenPrice' => 0,
157        ];
158
159        if ($this instanceof ChatInterface) {
160            $info['outputTokens'] = 1024;
161            $info['outputTokenPrice'] = 0;
162        } elseif ($this instanceof EmbeddingInterface) {
163            $info['dimensions'] = 512;
164        }
165
166        return $info;
167    }
168
169    // endregion
170
171    // region EmbeddingInterface
172
173    /** @inheritdoc */
174    public function getDimensions(): int
175    {
176        return $this->modelInfo['dimensions'];
177    }
178
179    // endregion
180
181    // region ChatInterface
182
183    public function getMaxOutputTokenLength(): int
184    {
185        return $this->modelInfo['outputTokens'];
186    }
187
188    public function getOutputTokenPrice(): float
189    {
190        return $this->modelInfo['outputTokenPrice'];
191    }
192
193    // endregion
194
195    // region API communication
196
197    /**
198     * When enabled, the input/output of the API will be printed to STDOUT
199     *
200     * @param bool $debug
201     */
202    public function setDebug($debug = true)
203    {
204        $this->debug = $debug;
205    }
206
207    /**
208     * This method should check the response for any errors. If the API singalled an error,
209     * this method should throw an Exception with a meaningful error message.
210     *
211     * If the response returned any info on used tokens, they should be added to $this->tokensUsed
212     *
213     * The method should return the parsed response, which will be passed to the calling method.
214     *
215     * @param mixed $response the parsed JSON response from the API
216     * @return mixed
217     * @throws \Exception when the response indicates an error
218     */
219    abstract protected function parseAPIResponse($response);
220
221    /**
222     * Send a request to the API
223     *
224     * Model classes should use this method to send requests to the API.
225     *
226     * This method will take care of retrying and logging basic statistics.
227     *
228     * It is assumed that all APIs speak JSON.
229     *
230     * @param string $method The HTTP method to use (GET, POST, PUT, DELETE, etc.)
231     * @param string $url The full URL to send the request to
232     * @param array|string $data Payload to send, will be encoded to JSON
233     * @param int $retry How often this request has been retried, do not set externally
234     * @return array API response as returned by parseAPIResponse
235     * @throws \Exception when anything goes wrong
236     */
237    protected function sendAPIRequest($method, $url, $data, $retry = 0)
238    {
239        // init statistics
240        if ($retry === 0) {
241            $this->requestStart = microtime(true);
242        } else {
243            sleep($retry); // wait a bit between retries
244        }
245        $this->requestsMade++;
246
247        // encode payload data
248        try {
249            $json = json_encode($data, JSON_THROW_ON_ERROR | JSON_PRETTY_PRINT);
250        } catch (\JsonException $e) {
251            $this->timeUsed += $this->requestStart - microtime(true);
252            throw new \Exception('Failed to encode JSON for API:' . $e->getMessage(), $e->getCode(), $e);
253        }
254
255        if ($this->debug) {
256            echo 'Sending ' . $method . ' request to ' . $url . ' with payload:' . "\n";
257            print_r($json);
258            echo "\n";
259        }
260
261        // send request and handle retries
262        $this->http->sendRequest($url, $json, $method);
263        $response = $this->http->resp_body;
264        if ($response === false || $this->http->error) {
265            if ($retry < self::MAX_RETRIES) {
266                return $this->sendAPIRequest($method, $url, $data, $retry + 1);
267            }
268            $this->timeUsed += microtime(true) - $this->requestStart;
269            throw new \Exception('API returned no response. ' . $this->http->error);
270        }
271
272        if ($this->debug) {
273            echo 'Received response:' . "\n";
274            print_r($response);
275            echo "\n";
276        }
277
278        // decode the response
279        try {
280            $result = json_decode((string)$response, true, 512, JSON_THROW_ON_ERROR);
281        } catch (\JsonException $e) {
282            $this->timeUsed += microtime(true) - $this->requestStart;
283            throw new \Exception('API returned invalid JSON: ' . $response, 0, $e);
284        }
285
286        // parse the response, retry on error
287        try {
288            $result = $this->parseAPIResponse($result);
289        } catch (\Exception $e) {
290            if ($retry < self::MAX_RETRIES) {
291                return $this->sendAPIRequest($method, $url, $data, $retry + 1);
292            }
293            $this->timeUsed += microtime(true) - $this->requestStart;
294            throw $e;
295        }
296
297        $this->timeUsed += microtime(true) - $this->requestStart;
298        return $result;
299    }
300
301    // endregion
302}
303