1<?php
2
3namespace Elastica\Transport;
4
5use Aws\Credentials\CredentialProvider;
6use Aws\Credentials\Credentials;
7use Aws\Signature\SignatureV4;
8use Elastica\Connection;
9use GuzzleHttp;
10use GuzzleHttp\Client;
11use GuzzleHttp\HandlerStack;
12use GuzzleHttp\Middleware;
13use GuzzleHttp\Psr7;
14use Psr\Http\Message\RequestInterface;
15
16class AwsAuthV4 extends Guzzle
17{
18    protected function _getGuzzleClient(bool $persistent = true): Client
19    {
20        if (!$persistent || !self::$_guzzleClientConnection) {
21            $stack = HandlerStack::create(GuzzleHttp\choose_handler());
22            $stack->push($this->getSigningMiddleware(), 'sign');
23
24            self::$_guzzleClientConnection = new Client([
25                'handler' => $stack,
26            ]);
27        }
28
29        return self::$_guzzleClientConnection;
30    }
31
32    protected function _getBaseUrl(Connection $connection): string
33    {
34        $this->initializePortAndScheme();
35
36        return parent::_getBaseUrl($connection);
37    }
38
39    private function getSigningMiddleware(): callable
40    {
41        $region = $this->getConnection()->hasParam('aws_region')
42            ? $this->getConnection()->getParam('aws_region')
43            : \getenv('AWS_REGION');
44        $signer = new SignatureV4('es', $region);
45        $credProvider = $this->getCredentialProvider();
46        $transport = $this;
47
48        return Middleware::mapRequest(static function (RequestInterface $req) use (
49            $signer,
50            $credProvider,
51            $transport
52        ) {
53            return $signer->signRequest($transport->sanitizeRequest($req), $credProvider()->wait());
54        });
55    }
56
57    private function sanitizeRequest(RequestInterface $request): RequestInterface
58    {
59        // Trailing dots are valid parts of DNS host names (see RFC 1034),
60        // but interferes with header signing where AWS expects a stripped host name.
61        if ('.' === \substr($request->getHeader('host')[0], -1)) {
62            $changes = ['set_headers' => ['host' => \rtrim($request->getHeader('host')[0], '.')]];
63            if (\class_exists(Psr7\Utils::class)) {
64                $request = Psr7\Utils::modifyRequest($request, $changes);
65            } else {
66                $request = Psr7\modify_request($request, $changes);
67            }
68        }
69
70        return $request;
71    }
72
73    private function getCredentialProvider(): callable
74    {
75        $connection = $this->getConnection();
76
77        if ($connection->hasParam('aws_credential_provider')) {
78            return $connection->getParam('aws_credential_provider');
79        }
80
81        if ($connection->hasParam('aws_secret_access_key')) {
82            return CredentialProvider::fromCredentials(new Credentials(
83                $connection->getParam('aws_access_key_id'),
84                $connection->getParam('aws_secret_access_key'),
85                $connection->hasParam('aws_session_token')
86                    ? $connection->getParam('aws_session_token')
87                    : null
88            ));
89        }
90
91        return CredentialProvider::defaultProvider();
92    }
93
94    private function initializePortAndScheme(): void
95    {
96        $connection = $this->getConnection();
97        if (true === $this->isSslRequired($connection)) {
98            $this->_scheme = 'https';
99            $connection->setPort(443);
100        } else {
101            $this->_scheme = 'http';
102            $connection->setPort(80);
103        }
104    }
105
106    private function isSslRequired(Connection $conn, bool $default = false): bool
107    {
108        return $conn->hasParam('ssl')
109            ? (bool) $conn->getParam('ssl')
110            : $default;
111    }
112}
113