Source code for libcloud.common.aws

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hmac
import time
import base64
import hashlib
from typing import Dict, Type, Optional
from hashlib import sha256
from datetime import datetime

from libcloud.utils.py3 import ET, b, httplib, urlquote, basestring, _real_unicode
from libcloud.utils.xml import findall_ignore_namespace, findtext_ignore_namespace
from libcloud.common.base import BaseDriver, XmlResponse, JsonResponse, ConnectionUserAndKey
from libcloud.common.types import InvalidCredsError, MalformedResponseError

try:
    import simplejson as json
except ImportError:
    import json  # type: ignore


__all__ = [
    "AWSBaseResponse",
    "AWSGenericResponse",
    "AWSTokenConnection",
    "SignedAWSConnection",
    "AWSRequestSignerAlgorithmV2",
    "AWSRequestSignerAlgorithmV4",
    "AWSDriver",
]

DEFAULT_SIGNATURE_VERSION = "2"
UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD"

PARAMS_NOT_STRING_ERROR_MSG = """
"params" dictionary contains an attribute "%s" which value (%s, %s) is not a
string.

Parameters are sent via query parameters and not via request body and as such,
all the values need to be of a simple type (string, int, bool).

For arrays and other complex types, you should use notation similar to this
one:

params['TagSpecification.1.Tag.Value'] = 'foo'
params['TagSpecification.2.Tag.Value'] = 'bar'

See https://docs.aws.amazon.com/AWSEC2/latest/APIReference/Query-Requests.html
for details.
""".strip()


[docs]class AWSBaseResponse(XmlResponse): namespace = None def _parse_error_details(self, element): """ Parse code and message from the provided error element. :return: ``tuple`` with two elements: (code, message) :rtype: ``tuple`` """ code = findtext_ignore_namespace(element=element, xpath="Code", namespace=self.namespace) message = findtext_ignore_namespace( element=element, xpath="Message", namespace=self.namespace ) return code, message
[docs]class AWSGenericResponse(AWSBaseResponse): # There are multiple error messages in AWS, but they all have an Error node # with Code and Message child nodes. Xpath to select them # None if the root node *is* the Error node xpath = None # This dict maps <Error><Code>CodeName</Code></Error> to a specific # exception class that is raised immediately. # If a custom exception class is not defined, errors are accumulated and # returned from the parse_error method. exceptions = {} # type: Dict[str, Type[Exception]]
[docs] def success(self): return self.status in [httplib.OK, httplib.CREATED, httplib.ACCEPTED]
[docs] def parse_error(self): context = self.connection.context status = int(self.status) # FIXME: Probably ditch this as the forbidden message will have # corresponding XML. if status == httplib.FORBIDDEN: if not self.body: raise InvalidCredsError(str(self.status) + ": " + self.error) else: raise InvalidCredsError(self.body) try: body = ET.XML(self.body) except Exception: raise MalformedResponseError( "Failed to parse XML", body=self.body, driver=self.connection.driver ) if self.xpath: errs = findall_ignore_namespace( element=body, xpath=self.xpath, namespace=self.namespace ) else: errs = [body] msgs = [] for err in errs: code, message = self._parse_error_details(element=err) exceptionCls = self.exceptions.get(code, None) if exceptionCls is None: msgs.append("{}: {}".format(code, message)) continue # Custom exception class is defined, immediately throw an exception params = {} if hasattr(exceptionCls, "kwargs"): for key in exceptionCls.kwargs: if key in context: params[key] = context[key] raise exceptionCls(value=message, driver=self.connection.driver, **params) return "\n".join(msgs)
[docs]class AWSTokenConnection(ConnectionUserAndKey): def __init__( self, user_id, key, secure=True, host=None, port=None, url=None, timeout=None, proxy_url=None, token=None, retry_delay=None, backoff=None, ): self.token = token super().__init__( user_id, key, secure=secure, host=host, port=port, url=url, timeout=timeout, retry_delay=retry_delay, backoff=backoff, proxy_url=proxy_url, )
[docs] def add_default_params(self, params): # Even though we are adding it to the headers, we need it here too # so that the token is added to the signature. if self.token: params["x-amz-security-token"] = self.token return super().add_default_params(params)
[docs] def add_default_headers(self, headers): if self.token: headers["x-amz-security-token"] = self.token return super().add_default_headers(headers)
class AWSRequestSigner: """ Class which handles signing the outgoing AWS requests. """ def __init__(self, access_key, access_secret, version, connection): """ :param access_key: Access key. :type access_key: ``str`` :param access_secret: Access secret. :type access_secret: ``str`` :param version: API version. :type version: ``str`` :param connection: Connection instance. :type connection: :class:`Connection` """ self.access_key = access_key self.access_secret = access_secret self.version = version # TODO: Remove cycling dependency between connection and signer self.connection = connection def get_request_params(self, params, method="GET", path="/"): return params def get_request_headers(self, params, headers, method="GET", path="/", data=None): return params, headers
[docs]class AWSRequestSignerAlgorithmV2(AWSRequestSigner):
[docs] def get_request_params(self, params, method="GET", path="/"): params["SignatureVersion"] = "2" params["SignatureMethod"] = "HmacSHA256" params["AWSAccessKeyId"] = self.access_key params["Version"] = self.version params["Timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) params["Signature"] = self._get_aws_auth_param( params=params, secret_key=self.access_secret, path=path ) return params
def _get_aws_auth_param(self, params, secret_key, path="/"): """ Creates the signature required for AWS, per http://bit.ly/aR7GaQ [docs.amazonwebservices.com]: StringToSign = HTTPVerb + "\n" + ValueOfHostHeaderInLowercase + "\n" + HTTPRequestURI + "\n" + CanonicalizedQueryString <from the preceding step> """ connection = self.connection keys = list(params.keys()) keys.sort() pairs = [] for key in keys: value = str(params[key]) pairs.append(urlquote(key, safe="") + "=" + urlquote(value, safe="-_~")) qs = "&".join(pairs) hostname = connection.host if (connection.secure and connection.port != 443) or ( not connection.secure and connection.port != 80 ): hostname += ":" + str(connection.port) string_to_sign = "\n".join(("GET", hostname, path, qs)) b64_hmac = base64.b64encode( hmac.new(b(secret_key), b(string_to_sign), digestmod=sha256).digest() ) return b64_hmac.decode("utf-8")
[docs]class AWSRequestSignerAlgorithmV4(AWSRequestSigner):
[docs] def get_request_params(self, params, method="GET", path="/"): if method == "GET": params["Version"] = self.version return params
[docs] def get_request_headers(self, params, headers, method="GET", path="/", data=None): now = datetime.utcnow() headers["X-AMZ-Date"] = now.strftime("%Y%m%dT%H%M%SZ") headers["X-AMZ-Content-SHA256"] = self._get_payload_hash(method, data) headers["Authorization"] = self._get_authorization_v4_header( params=params, headers=headers, dt=now, method=method, path=path, data=data ) return params, headers
def _get_authorization_v4_header(self, params, headers, dt, method="GET", path="/", data=None): credentials_scope = self._get_credential_scope(dt=dt) signed_headers = self._get_signed_headers(headers=headers) signature = self._get_signature( params=params, headers=headers, dt=dt, method=method, path=path, data=data ) return ( "AWS4-HMAC-SHA256 Credential=%(u)s/%(c)s, " "SignedHeaders=%(sh)s, Signature=%(s)s" % { "u": self.access_key, "c": credentials_scope, "sh": signed_headers, "s": signature, } ) def _get_signature(self, params, headers, dt, method, path, data): key = self._get_key_to_sign_with(dt) string_to_sign = self._get_string_to_sign( params=params, headers=headers, dt=dt, method=method, path=path, data=data ) return _sign(key=key, msg=string_to_sign, hex=True) def _get_key_to_sign_with(self, dt): return _sign( _sign( _sign( _sign(("AWS4" + self.access_secret), dt.strftime("%Y%m%d")), self.connection.driver.region_name, ), self.connection.service_name, ), "aws4_request", ) def _get_string_to_sign(self, params, headers, dt, method, path, data): canonical_request = self._get_canonical_request( params=params, headers=headers, method=method, path=path, data=data ) return "\n".join( [ "AWS4-HMAC-SHA256", dt.strftime("%Y%m%dT%H%M%SZ"), self._get_credential_scope(dt), _hash(canonical_request), ] ) def _get_credential_scope(self, dt): return "/".join( [ dt.strftime("%Y%m%d"), self.connection.driver.region_name, self.connection.service_name, "aws4_request", ] ) def _get_signed_headers(self, headers): return ";".join([k.lower() for k in sorted(headers.keys(), key=str.lower)]) def _get_canonical_headers(self, headers): return ( "\n".join( [ ":".join([k.lower(), str(v).strip()]) for k, v in sorted(headers.items(), key=lambda k: k[0].lower()) ] ) + "\n" ) def _get_payload_hash(self, method, data=None): if data is UnsignedPayloadSentinel: return UNSIGNED_PAYLOAD if method in ("POST", "PUT"): if data: if hasattr(data, "next") or hasattr(data, "__next__"): # File upload; don't try to read the entire payload return UNSIGNED_PAYLOAD return _hash(data) else: return UNSIGNED_PAYLOAD else: return _hash("") def _get_request_params(self, params): # For self.method == GET return "&".join( [ "{}={}".format(urlquote(k, safe=""), urlquote(str(v), safe="~")) for k, v in sorted(params.items()) ] ) def _get_canonical_request(self, params, headers, method, path, data): return "\n".join( [ method, path, self._get_request_params(params), self._get_canonical_headers(headers), self._get_signed_headers(headers), self._get_payload_hash(method, data), ] )
class UnsignedPayloadSentinel: pass
[docs]class SignedAWSConnection(AWSTokenConnection): version = None # type: Optional[str] def __init__( self, user_id, key, secure=True, host=None, port=None, url=None, timeout=None, proxy_url=None, token=None, retry_delay=None, backoff=None, signature_version=DEFAULT_SIGNATURE_VERSION, ): super().__init__( user_id=user_id, key=key, secure=secure, host=host, port=port, url=url, timeout=timeout, token=token, retry_delay=retry_delay, backoff=backoff, proxy_url=proxy_url, ) self.signature_version = str(signature_version) if self.signature_version == "2": signer_cls = AWSRequestSignerAlgorithmV2 elif self.signature_version == "4": signer_cls = AWSRequestSignerAlgorithmV4 else: raise ValueError("Unsupported signature_version: %s" % (signature_version)) self.signer = signer_cls( access_key=self.user_id, access_secret=self.key, version=self.version, connection=self, )
[docs] def add_default_params(self, params): params = self.signer.get_request_params(params=params, method=self.method, path=self.action) # Verify that params only contain simple types and no nested # dictionaries. # params are sent via query params so only strings are supported for key, value in params.items(): if not isinstance(value, (_real_unicode, basestring, int, bool)): msg = PARAMS_NOT_STRING_ERROR_MSG % (key, value, type(value)) raise ValueError(msg) return params
[docs] def pre_connect_hook(self, params, headers): params, headers = self.signer.get_request_headers( params=params, headers=headers, method=self.method, path=self.action, data=self.data, ) return params, headers
class AWSJsonResponse(JsonResponse): """ Amazon ECS response class. ECS API uses JSON unlike the s3, elb drivers """ def parse_error(self): response = json.loads(self.body) code = response["__type"] message = response.get("Message", response["message"]) return "{}: {}".format(code, message) def _sign(key, msg, hex=False): if hex: return hmac.new(b(key), b(msg), hashlib.sha256).hexdigest() else: return hmac.new(b(key), b(msg), hashlib.sha256).digest() def _hash(msg): return hashlib.sha256(b(msg)).hexdigest()
[docs]class AWSDriver(BaseDriver): def __init__( self, key, secret=None, secure=True, host=None, port=None, api_version=None, region=None, token=None, **kwargs, ): self.token = token super().__init__( key, secret=secret, secure=secure, host=host, port=port, api_version=api_version, region=region, token=token, **kwargs, ) def _ex_connection_class_kwargs(self): kwargs = super()._ex_connection_class_kwargs() kwargs["token"] = self.token return kwargs