# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cluster Manager Backend for GHPCFE"""
import json
import logging
import uuid

from google.api_core.exceptions import AlreadyExists
from google.cloud import pubsub

from . import utils

# Note: We can't import Models here, because this module gets run as part of
# startup, and the Models haven't yet been created.
# Instead, import Models in the Callback Functions as appropriate

# pylint: disable=import-outside-toplevel

logger = logging.getLogger(__name__)

# Current design:
#  1 topic for our overall system
#  Subscription for FE is set with a filter of messages WITHOUT a "target"
#  attribute.  Subscriptions for Clusters each have a filter for a "target"
#  attribute that matches the cluster's ID.  (ideally, should rather be
#  something less guessable, like a hash or a unique key.)

# Message data should be json-encoded.
# Message attributes are used by the C2-infrastructure - data is for programmer
# use
# Messages should have the following attributes:
#   * target={subscription_name}  (or no target, to come to FE)
#   * command=('ping', 'sub_job', etc...)
# If no 'target' (aka, coming from the clusters:
#   * source={cluster_id} - Who sent it?

# Command with response callback
#
# Commands that require a response should encode a unique key as a message
# field ('ackid').
# When receiver finishes the command, they should then send an ACK with that
# same 'ackid', and any associated data.

_c2_callbackMap = {}


def c2_ping(message, source_id):
    # Expect source_id in the form of 'cluster_{id}'
    if "id" in message:
        pid = message["id"]
        logger.info(
            "Received PING id %s from cluster %s. Sending PONG", pid, source_id
        )
        _C2STATE.send_message("PONG", {"id": pid}, target=source_id)
    else:
        logger.info("Received anonymous PING from cluster %s", source_id)
    return True


def c2_pong(message, source_id):
    # Expect source_id in the form of 'cluster_{id}'
    if "id" in message:
        logger.info(
            "Received PONG id %s from cluster %s.", message["id"], source_id
        )
    else:
        logger.info("Received PONG from cluster %s", source_id)
    return True


# Difference between UPDATE and ACK:  ACK removes the callback, UPDATE leaves it
# in place
def cb_ack(message, source_id):
    from ..models import C2Callback

    ackid = message.get("ackid", None)
    logger.info("Received ACK to message %s from %s", ackid, source_id)
    if not ackid:
        logger.error("No ackid in ACK.  Ignoring")
        return True
    try:
        entry = C2Callback.objects.get(ackid=uuid.UUID(ackid))
        logger.info("Calling Callback registered for this ACK")
        cb = entry.callback
        entry.delete()
        cb(message)
    except C2Callback.DoesNotExist:
        logger.warning("No Callback registered for the ACK")
        pass

    return True


# Difference between UPDATE and ACK:  ACK removes the callback, UPDATE leaves it
# in place
def cb_update(message, source_id):
    from ..models import C2Callback

    ackid = message.get("ackid", None)
    if not ackid:
        logger.error("No ackid in UPDATE.  Ignoring")
        return True
    logger.info("Received UPDATE to message %s from %s", ackid, source_id)
    try:
        entry = C2Callback.objects.get(ackid=uuid.UUID(ackid))
        logger.info("Calling Callback registered for this UPDATE")
        cb = entry.callback
        cb(message)
    except C2Callback.DoesNotExist:
        logger.warning("No Callback registered for the UPDATE")
        pass

    return True


def cb_cluster_status(message, source_id):
    from ..models import Cluster

    try:
        cid = message["cluster_id"]
        if f"cluster_{cid}" != source_id:
            raise ValueError(
                "Message comes from {source_id}, but claims cluster {cid}. "
                "Ignoring."
            )

        cluster = Cluster.objects.get(pk=cid)
        logger.info(
            "Cluster Status message for %s: %s", cluster.id, message["message"]
        )
        new_status = message.get("status", None)
        if new_status:
            cluster.status = new_status
            cluster.save()
    # This logs the fall-through errors
    except Exception as ex:  # pylint: disable=broad-except
        logger.error("Cluster status callback error!", exc_info=ex)
    return True


def _c2_response_callback(message):
    logger.debug("Received message %s ", message)

    cmd = message.attributes.get("command", None)
    try:
        source = message.attributes.get("source", None)
        if not source:
            logger.error("Message had no Source ID")

        callback = _c2_callbackMap[cmd]
        if callback(json.loads(message.data), source_id=source):
            message.ack()
        else:
            message.nack()
        return
    except KeyError:
        if cmd:
            logger.error(
                'Message requests unknown command "%s".  Discarding', cmd
            )
        else:
            logger.error(
                "Message has no command associated with it. Discarding"
            )
    message.ack()


class _C2State:
    """Internal pubsub state management"""

    def __init__(self):
        self._pub_client = None
        self._sub_client = None
        self._streaming_pull_future = None
        self._project_id = None
        self._topic = None
        self._topic_path = None

    @property
    def sub_client(self):
        if not self._sub_client:
            self._sub_client = pubsub.SubscriberClient()
        return self._sub_client

    @property
    def pub_client(self):
        if not self._pub_client:
            self._pub_client = pubsub.PublisherClient()
        return self._pub_client

    def startup(self):
        conf = utils.load_config()
        self._project_id = conf["server"]["gcp_project"]
        self._topic = conf["server"]["c2_topic"]
        self._topic_path = self.pub_client.topic_path(
            self._project_id, self._topic
        )

        sub_path = self.get_or_create_subscription(
            "c2resp", filter_target=False
        )

        self._streaming_pull_future = self.sub_client.subscribe(
            sub_path, callback=_c2_response_callback
        )
        # TODO: Currently no clean shutdown method

    def get_subscription_path(self, sub_id):
        sub_id = f"{self._topic}-{sub_id}"
        return self.sub_client.subscription_path(self._project_id, sub_id)

    def get_or_create_subscription(
        self, sub_id, filter_target=True, service_account=None
    ):
        sub_path = self.get_subscription_path(sub_id)

        request = {"name": sub_path, "topic": self._topic_path}
        if filter_target:
            request["filter"] = f'attributes.target="{sub_id}"'
        else:
            request["filter"] = "NOT attributes:target"

        try:
            # Create subscription if it doesn't already exist
            self.sub_client.create_subscription(request=request)
            logger.info("PubSub Subscription %s created", sub_path)

            if service_account:
                self.setup_service_account(sub_id, service_account)

        except AlreadyExists:
            logger.info("PubSub Subscription %s already exists", sub_path)

        return sub_path

    def setup_service_account(self, sub_id, service_account):
        sub_path = self.get_subscription_path(sub_id)
        # Need to set 2 policies.  One on the subscription, to allow
        # access to subscribe one on the main topic, to allow
        # publication (c2 response)

        policy = self.sub_client.get_iam_policy(request={"resource": sub_path})
        policy.bindings.add(
            role="roles/pubsub.subscriber",
            members=[f"serviceAccount:{service_account}"],
        )
        policy = self.sub_client.set_iam_policy(
            request={"resource": sub_path, "policy": policy}
        )

        policy = self.pub_client.get_iam_policy(
            request={"resource": self._topic_path}
        )
        policy.bindings.add(
            role="roles/pubsub.publisher",
            members=[f"serviceAccount:{service_account}"],
        )
        policy = self.pub_client.set_iam_policy(
            request={"resource": self._topic_path, "policy": policy}
        )

    def delete_subscription(self, sub_id, service_account):
        sub_path = self.get_subscription_path(sub_id)
        self.sub_client.delete_subscription(request={"subscription": sub_path})
        if service_account:
            # TODO:  Remove IAM permission from topic
            # policy = self.pub_client.get_iam_policy(request={"resource":
            # sub_path}) policy.bindings.remove(role='roles/pubsub.publisher',
            # members=[f"serviceAccount:{service_account}"]) policy =
            # self.pub_client.set_iam_policy(request={"resource": sub_path,
            # "policy": policy})
            pass

    def send_message(self, command, message, target, extra_attrs=None):

        extra_attrs = extra_attrs if extra_attrs else {}
        # TODO: If we want loopback, need to make 'target' optional,
        # or change up our filters
        # TODO: Consider if we want to keep the futures or not
        self.pub_client.publish(
            self._topic_path,
            bytes(json.dumps(message), "utf-8"),
            target=target,
            command=command,
            **extra_attrs,
        )


_C2STATE = None


def get_cluster_sub_id(cluster_id):
    return f"cluster_{cluster_id}"


def get_cluster_subscription_path(cluster_id):
    return _C2STATE.get_subscription_path(get_cluster_sub_id(cluster_id))


def create_cluster_subscription(cluster_id):
    return _C2STATE.get_or_create_subscription(
        get_cluster_sub_id(cluster_id), filter_target=True
    )


def add_cluster_subscription_service_account(cluster_id, service_account):
    return _C2STATE.setup_service_account(
        get_cluster_sub_id(cluster_id), service_account
    )


def delete_cluster_subscription(cluster_id, service_account=None):
    return _C2STATE.delete_subscription(
        get_cluster_sub_id(cluster_id), service_account=service_account
    )


def get_topic_path():
    return _C2STATE._topic_path #pylint: disable=protected-access


def startup():
    global _C2STATE
    if _C2STATE:
        logger.error("ERROR:  C&C PubSub already started!")
        return

    _C2STATE = _C2State()
    _C2STATE.startup()
    # Difference between UPDATE and ACK:  ACK removes the callback, UPDATE
    # leaves it in place
    register_command("ACK", cb_ack)
    register_command("UPDATE", cb_update)
    register_command("PING", c2_ping)
    register_command("PONG", c2_pong)
    register_command("CLUSTER_STATUS", cb_cluster_status)


def send_command(cluster_id, cmd, data, on_response=None):
    if on_response:
        from ..models import C2Callback

        callback_entry = C2Callback(callback=on_response)
        callback_entry.save()
        data["ackid"] = str(callback_entry.ackid)
    _C2STATE.send_message(
        command=cmd, message=data, target=get_cluster_sub_id(cluster_id)
    )
    return data["ackid"]


def send_update(cluster_id, comm_id, data):
    # comm_id is result from `send_command()`
    data["ackid"] = comm_id
    _C2STATE.send_message(
        command="UPDATE", message=data, target=get_cluster_sub_id(cluster_id)
    )


def register_command(command_id, callback):
    _c2_callbackMap[command_id] = callback
