# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import functools
import json
import os
import boto3
from mypy_boto3_dynamodb.service_resource import DynamoDBServiceResource, Table
from mypy_boto3_sns import SNSClient

from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.signal_type.pdq import PdqSignal

from hmalib import metrics
from hmalib.common.logging import get_logger
from hmalib.common.models.bank import BanksTable
from hmalib.matchers.matchers_base import Matcher
from hmalib.common.config import HMAConfig
from hmalib.common.models.pipeline import PipelineHashRecord

INDEXES_BUCKET_NAME = os.environ["INDEXES_BUCKET_NAME"]
BANKS_TABLE = os.environ["BANKS_TABLE"]
DYNAMODB_TABLE = os.environ["DYNAMODB_TABLE"]
HMA_CONFIG_TABLE = os.environ["HMA_CONFIG_TABLE"]
MATCHES_TOPIC_ARN = os.environ["MATCHES_TOPIC_ARN"]

HMAConfig.initialize(HMA_CONFIG_TABLE)


@functools.lru_cache(maxsize=None)
def get_dynamodb() -> DynamoDBServiceResource:
    return boto3.resource("dynamodb")


@functools.lru_cache(maxsize=None)
def get_sns_client() -> SNSClient:
    return boto3.client("sns")


_matcher = None


def get_matcher(banks_table: BanksTable):
    global _matcher
    if _matcher is None:
        _matcher = Matcher(
            index_bucket_name=INDEXES_BUCKET_NAME,
            supported_signal_types=[PdqSignal, VideoMD5Signal],
            banks_table=banks_table,
        )
    return _matcher


logger = get_logger(__name__)


def lambda_handler(event, context):
    """
    Listens to SQS events fired when new hash is generated. Loads the index
    stored in an S3 bucket and looks for a match.

    When matched, publishes a notification to an SNS endpoint. Note this is in
    contrast with hasher and indexer. They publish to SQS directly. Publishing
    to SQS implies there can be only one consumer.

    Because, here, in the matcher, we publish to SNS, we can plug multiple
    queues behind it and profit!
    """
    table = get_dynamodb().Table(DYNAMODB_TABLE)
    banks_table = BanksTable(get_dynamodb().Table(BANKS_TABLE))

    for sqs_record in event["Records"]:
        message = json.loads(sqs_record["body"])

        if message.get("Event") == "TestEvent":
            logger.debug("Disregarding Test Event")
            continue

        if not PipelineHashRecord.could_be(message):
            logger.warn(
                "Could not de-serialize message in matcher lambda. Message was %s",
                message,
            )
            continue

        hash_record = PipelineHashRecord.from_sqs_message(message)
        logger.info(
            "HashRecord for contentId: %s with contentHash: %s",
            hash_record.content_id,
            hash_record.content_hash,
        )

        matches = get_matcher(banks_table).match(
            hash_record.signal_type, hash_record.content_hash
        )
        logger.info("Found %d matches.", len(matches))

        for match in matches:
            get_matcher(banks_table).write_match_record_for_result(
                table=table,
                signal_type=hash_record.signal_type,
                content_hash=hash_record.content_hash,
                content_id=hash_record.content_id,
                match=match,
            )

        for match in matches:
            get_matcher(banks_table).write_signal_if_not_found(
                table=table, signal_type=hash_record.signal_type, match=match
            )

        if len(matches) != 0:
            # Publish all messages together
            get_matcher(banks_table).publish_match_message(
                content_id=hash_record.content_id,
                content_hash=hash_record.content_hash,
                matches=matches,
                sns_client=get_sns_client(),
                topic_arn=MATCHES_TOPIC_ARN,
            )

        metrics.flush()
