"""Client for talking to OpenStack APIs using twisted

This is not a complete implemention of all interfaces, just what juju needs.

There is a fair bit of code cleanup and feature implementation to do here
still.

* Must check https certificates, can use code in txaws to do this.
* Must support user/password authentication with keystone as well as keypair.
* Want a ProviderInteractionError subclass that can include the extra details
  returned in json form when something goes wrong and is raised by clients.
* Request flow and json handling in general needs polish.
* Need to prevent concurrent authentication attempts.
* Need to limit concurrent http api requests to 4 or something reasonable,
  can use DeferredSemaphore for this.
* Should really have authentication retry logic in case the token expires.
* Would be nice to use Agent keep alive support that twisted 12.1.0 added.
"""

import base64
import json
import logging
import operator
import urllib

import twisted
from twisted.internet.defer import (
    Deferred, inlineCallbacks, returnValue, succeed)
from twisted.internet.protocol import Protocol
from twisted.internet.interfaces import IProducer
from twisted.internet import reactor

from twisted.web.client import Agent
# Older twisted versions don't expose _newclient exceptions via client module
try:
    from twisted.web.client import ResponseDone, ResponseFailed
except ImportError:
    from twisted.web._newclient import ResponseDone, ResponseFailed
from twisted.web.http_headers import Headers
from zope.interface import implements

import juju
from juju import errors

try:
    from ._ssl import SSLError, WebVerifyingContextFactory
except ImportError:
    WebVerifyingContextFactory = None


log = logging.getLogger("juju.openstack")


_USER_AGENT = "juju/%s twisted/%s" % (juju.__version__, twisted.__version__)


class BytestringProducer(object):
    """Wrap basic bytestring as a needlessly fancy twisted producer."""

    implements(IProducer)

    def __init__(self, bytestring):
        self.content = bytestring
        self.length = len(bytestring)

    def pauseProducing(self):
        """Nothing to do if production is paused"""

    def startProducing(self, consumer):
        """Write entire contents when production starts"""
        consumer.write(self.content)
        return succeed(None)

    def stopProducing(self):
        """Nothing to do when production halts"""


class ResponseReader(Protocol):
    """Protocol object suitable for use with Response.deliverBody

    The 'onConnectionLost' deferred will be called back once the connection
    is shut down with all the bytes from the body collected at that point.
    """

    def __init__(self):
        self.onConnectionLost = Deferred()

    def connectionMade(self):
        self.data = []

    def dataReceived(self, data):
        self.data.append(data)

    def connectionLost(self, reason):
        """Called on connection shut down

        Here 'reason' can be one of ResponseDone, PotentialDataLost, or
        ResponseFailed, but currently there is no fancy handling of these.
        """
        self.onConnectionLost.callback("".join(self.data))


def _translate_response_failed(failure):
    """Turn internal twisted client failures into juju exceptions"""
    txerr = failure.value
    if isinstance(txerr, ResponseFailed):
        for reason in txerr.reasons:
            err = reason.value
            if isinstance(err, SSLError):
                raise errors.SSLVerificationError(err)
    return failure


@inlineCallbacks
def request(method, url, extra_headers=(), body=None, check_certs=False):
    headers = Headers({
        # GZ 2012-07-03: Previously passed Accept: application/json header
        #                here, but not always the right thing. Bad for swift?
        "User-Agent": [_USER_AGENT],
        })
    for header, value in extra_headers:
        headers.setRawHeaders(header, [value])
    if body is not None:
        if isinstance(body, dict):
            content_type = "application/json"
            body = json.dumps(body)
        elif isinstance(body, str):
            content_type = "application/octet-stream"
        headers.setRawHeaders("Content-Type", [content_type])
        body = BytestringProducer(body)
    kwargs = {}
    if check_certs:
        kwargs['contextFactory'] = WebVerifyingContextFactory()
    agent = Agent(reactor, **kwargs)
    response = yield agent.request(method, url, headers, body).addErrback(
        _translate_response_failed)
    if response.length == 0:
        returnValue((response, ""))
    reader = ResponseReader()
    response.deliverBody(reader)
    body = yield reader.onConnectionLost
    returnValue((response, body))


class _OpenStackClient(object):

    def __init__(self, credentials, check_certs):
        self.credentials = credentials
        if check_certs and WebVerifyingContextFactory is None:
            raise errors.SSLVerificationUnsupported()
        self.check_certs = check_certs
        log.debug("openstack: using auth-mode %r with %s", credentials.mode,
            credentials.url)
        if credentials.mode == "keypair":
            self.authenticate = self.authenticate_v2_keypair
        elif credentials.mode == "legacy":
            self.authenticate = self.authenticate_v1
        elif credentials.mode == "rax":
            self.authenticate = self.authenticate_rax_auth
        else:
            self.authenticate = self.authenticate_v2_userpass
        self.token = None

    def _make_url(self, service, parts):
        """Form full url from path components to service endpoint url"""
        # GZ 2012-07-03: Need to ensure either services is populated or catch
        #                error here and propogate as one useful for users.
        endpoint = self.services[service]
        if not endpoint[-1] == "/":
            endpoint += "/"
        if isinstance(parts, str):
            return endpoint + parts
        quoted_parts = []
        for part in parts:
            if not isinstance(part, str):
                part = urllib.quote(unicode(part).encode("utf-8"), "/~")
            quoted_parts.append(part)
        url = endpoint + "/".join(quoted_parts)
        log.debug('access %s @ %s', service, url)
        return url

    @inlineCallbacks
    def authenticate_v1(self):
        deferred = request(
            "GET",
            self.credentials.url,
            extra_headers=[
                ("X-Auth-User", self.credentials.username),
                ("X-Auth-Key", self.credentials.access_key),
                ],
            check_certs=self.check_certs,
            )
        response, body = yield deferred
        if response.code != 204:
            raise errors.ProviderInteractionError("Failed to authenticate")
        # TODO: check response has right headers
        [nova_url] = response.headers.getRawHeaders("X-Server-Management-Url")
        if self.check_certs:
            self._warn_if_endpoint_insecure("compute", nova_url)
        self.nova_url = nova_url
        self.services = {"compute": nova_url}
        # No swift_url set as that is not supported
        [self.token] = response.headers.getRawHeaders("X-Auth-Token")

    def authenticate_v2_keypair(self):
        deferred = request(
            "POST",
            self.credentials.url + "tokens",
            body={"auth": {
                  "apiAccessKeyCredentials": {
                    "accessKey": self.credentials.access_key,
                    "secretKey": self.credentials.secret_key,
                  },
                  "tenantName": self.credentials.project_name,
                }},
            check_certs=self.check_certs,
            )
        return deferred.addCallback(self._handle_v2_auth)

    def authenticate_v2_userpass(self):
        deferred = request(
            "POST",
            self.credentials.url + "tokens",
            body={"auth": {
                  "passwordCredentials": {
                    "username": self.credentials.username,
                    "password": self.credentials.password,
                  },
                  "tenantName": self.credentials.project_name,
                }},
            check_certs=self.check_certs,
            )
        return deferred.addCallback(self._handle_v2_auth)

    def authenticate_rax_auth(self):
        # openstack is not a product, but a kit for making snowflakes.
        deferred = request(
            "POST",
            self.credentials.url + "tokens",
            body={"auth": {
                  "RAX-KSKEY:apiKeyCredentials": {
                    "username": self.credentials.username,
                    "apiKey": self.credentials.password,
                    "tenantName": self.credentials.project_name}}},
            check_certs=self.check_certs,
            )
        return deferred.addCallback(self._handle_v2_auth)

    def _handle_v2_auth(self, result):
        access_details = self._json(result, 200, 'access')
        token_details = access_details["token"]
        # Decoded json uses unicode for all string values, but that can upset
        # twisted when serialising headers later. Really should encode at that
        # point, but as keystone should only give ascii tokens a cast will do.
        self.token = token_details["id"].encode("ascii")

        # TODO: care about token_details["expires"]
        # Don't need to we're not preserving tokens.
        services = []
        log.debug("openstack: authenticated til %r", token_details['expires'])
        region = self.credentials.region
        # HP cloud uses both az-1.region-a.geo-1 and region-a.geo-1 forms, not
        # clear what should be in config or what the correct logic is.
        if region is not None:
            base_region = region.split('.', 1)[-1]
        # GZ: 2012-07-03: Should split extraction of endpoints, add logging,
        #                 and make more robust.
        for catalog in access_details["serviceCatalog"]:
            for endpoint in catalog["endpoints"]:
                if region is not None and region != endpoint["region"]:
                    if base_region != endpoint["region"]:
                        continue
                services.append((catalog["type"], str(endpoint["publicURL"])))
                break

        if not services:
            raise errors.ProviderInteractionError("No suitable endpoints")

        self.services = dict(services)
        if self.check_certs:
            for service in ("compute", "object-store"):
                if service in self.services:
                    self._warn_if_endpoint_insecure(service,
                        self.services[service])

    def _warn_if_endpoint_insecure(self, service_type, url):
        # XXX: Should only warn per host ideally, otherwise is just annoying
        if not url.startswith("https:"):
            log.warn("OpenStack %s service not using secure transport" %
                service_type)

    def is_authenticated(self):
        return self.token is not None

    @inlineCallbacks
    def authed_request(self, method, url, headers=None, body=None):
        log.debug("openstack: %s %r", method, url)
        request_headers = [("X-Auth-Token", self.token)]
        if headers:
            request_headers += headers
        response, body = yield request(method, url, request_headers, body,
            self.check_certs)
        log.debug("openstack: %d %r", response.code, body)

        # OpenStack returns 401 when using an expired token; simply
        # retry after reauthenticating
        if response.code == 401:
            self.token = None
            raise errors.ProviderInteractionError(
                "Need to reauthenticate by retrying")

        returnValue((response, body))

    def _empty(self, result, code):
        response, body = result
        if response.code != code:
            # XXX: This is a deeply unhelpful error, need context from request
            raise errors.ProviderInteractionError("Unexpected %d: %r" % (
                response.code, body))

    def _json(self, result, code, root=None):
        response, body = result
        if response.code != code:
            raise errors.ProviderInteractionError("Unexpected %d: %r" % (
                response.code, body))
        type_headers = response.headers.getRawHeaders("Content-Type")

        found = False
        for h in type_headers:
            if 'application/json' in h:
                found = True
        if not found:
            raise errors.ProviderInteractionError(
                "Expected json response got %s" % type_headers)

        data = json.loads(body)
        if root is not None:
            return data[root]
        return data


class _NovaClient(object):

    def __init__(self, client):
        self._client = client

    @inlineCallbacks
    def request(self, method, parts, headers=None, body=None):
        if not self._client.is_authenticated():
            yield self._client.authenticate()
        url = self._client._make_url("compute", parts)
        result = yield self._client.authed_request(method, url, headers, body)
        returnValue(result)

    def delete(self, parts, code=202):
        deferred = self.request("DELETE", parts)
        return deferred.addCallback(self._client._empty, code)

    def get(self, parts, root, code=200):
        deferred = self.request("GET", parts)
        return deferred.addCallback(self._client._json, code, root)

    def post(self, parts, jsonobj, root, code=200):
        deferred = self.request("POST", parts, None, jsonobj)  # XXX
        return deferred.addCallback(self._client._json, code, root)

    def post_no_data(self, parts, root, code=200):
        deferred = self.request("POST", parts, None, "")  # XXX
        return deferred.addCallback(self._client._json, code, root)

    def post_no_result(self, parts, jsonobj, code=202):
        deferred = self.request("POST", parts, None, jsonobj)  # XXX
        return deferred.addCallback(self._client._empty, code)

    def list_flavors(self):
        return self.get("flavors", "flavors")

    def list_flavor_details(self):
        return self.get(["flavors", "detail"], "flavors")

    def get_server(self, server_id):
        return self.get(["servers", server_id], "server")

    def list_servers(self):
        return self.get(["servers"], "servers")

    def list_servers_detail(self):
        return self.get(["servers", "detail"], "servers")

    def delete_server(self, server_id):
        return self.delete(["servers", server_id], code=204)

    def run_server(self, image_id, flavor_id, name, security_group_names=None,
                   user_data=None, scheduler_hints=None):
        server = {
            'name': name,
            'flavorRef': flavor_id,
            'imageRef': image_id,
        }
        post_dict = {"server": server}
        if user_data is not None:
            server["user_data"] = base64.b64encode(user_data)
        if security_group_names is not None:
            server["security_groups"] = [{'name': n}
                for n in security_group_names]
        if scheduler_hints is not None:
            post_dict["OS-SCH-HNT:scheduler_hints"] = scheduler_hints
        return self.post(["servers"], post_dict,
            root="server", code=202)

    def get_server_security_groups(self, server_id):
        d = self.get(
            ["servers", server_id, "os-security-groups"],
            root="security_groups")

        # 2012-07-12: kt Workaround lack of this api in HP cloud
        def _get_group_fallback(f):
            log.debug("Falling back to older/diablo sec groups api")
            return self.get_server(server_id).addCallback(
                operator.itemgetter("security_groups"))
        d.addErrback(_get_group_fallback)
        return d

    def get_security_group_details(self, group_id):
        return self.get(["os-security-groups", group_id], "security_group")

    def list_security_groups(self):
        return self.get(["os-security-groups"], "security_groups")

    def create_security_group(self, name, description):
        return self.post("os-security-groups", {
            'security_group': {
                'name': name,
                'description': description,
                }
            },
            root="security_group")

    def delete_security_group(self, group_id):
        return self.delete(["os-security-groups", group_id])

    def add_security_group_rule(self, parent_group_id, **kwargs):
        rule = {'parent_group_id': parent_group_id}
        using_group = "group_id" in kwargs
        if using_group:
            rule['group_id'] = kwargs['group_id']
        elif "cidr" in kwargs:
            rule['cidr'] = kwargs['cidr']
        if not using_group or "ip_protocol" in kwargs:
            rule['ip_protocol'] = kwargs['ip_protocol']
            rule['from_port'] = kwargs['from_port']
            rule['to_port'] = kwargs['to_port']
        return self.post("os-security-group-rules",
            {'security_group_rule': rule},
            root="security_group_rule")

    def delete_security_group_rule(self, rule_id):
        return self.delete(["os-security-group-rules", rule_id])

    def add_server_security_group(self, server_id, group_name):
        return self.post_no_result(["servers", server_id, "action"], {
            "addSecurityGroup": {
                "name": group_name,
            }})

    def remove_server_security_group(self, server_id, group_name):
        return self.post_no_result(["servers", server_id, "action"], {
            "removeSecurityGroup": {
                "name": group_name,
            }})

    def list_floating_ips(self):
        return self.get(["os-floating-ips"], "floating_ips")

    def get_floating_ip(self, ip_id):
        return self.get(["os-floating-ips", ip_id], "floating_ip")

    def allocate_floating_ip(self):
        return self.post_no_data(["os-floating-ips"], "floating_ip")

    def delete_floating_ip(self, ip_id):
        return self.delete(["os-floating-ips", ip_id])

    def add_floating_ip(self, server_id, addr):
        return self.post_no_result(["servers", server_id, "action"], {
            'addFloatingIp': {
                'address': addr,
            }})

    def remove_floating_ip(self, server_id, addr):
        return self.post_no_result(["servers", server_id, "action"], {
            'removeFloatingIp': {
                'address': addr,
            }})


class _SwiftClient(object):

    def __init__(self, client):
        self._client = client

    @inlineCallbacks
    def request(self, method, parts, headers=None, body=None):
        if not self._client.is_authenticated():
            yield self._client.authenticate()
        url = self._client._make_url("object-store", parts)
        result = yield self._client.authed_request(method, url, headers, body)
        returnValue(result)

    def public_object_url(self, container, object_name):
        if not self._client.is_authenticated():
            raise ValueError("Need to have authenticated to get object url")
        return self._client._make_url("object-store", [container, object_name])

    def put_container(self, container_name):
        # Juju expects there to be a (semi) public url for some objects. This
        # could probably be more restrictive or placed in a seperate container
        # with some refactoring, but for now just make everything public.
        read_acl_header = ("X-Container-Read", ".r:*")
        return self.request("PUT", [container_name], [read_acl_header], "")

    def delete_container(self, container_name):
        return self.request("DELETE", [container_name])

    def head_object(self, container, object_name):
        return self.request("HEAD", [container, object_name])

    def get_object(self, container, object_name):
        return self.request("GET", [container, object_name])

    def delete_object(self, container, object_name):
        return self.request("DELETE", [container, object_name])

    def put_object(self, container, object_name, bytestring):
        return self.request("PUT", [container, object_name], None, bytestring)
