# AttestationClient.py
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import time
from enum import Enum
from base64 import urlsafe_b64decode
from src.OsInfo import OsInfo
from src.Isolation import IsolationType, Isolation, TdxEvidence, SnpEvidence
from src.Logger import Logger
from src.ReportParser import ReportParser
from src.ImdsClient import ImdsClient
from src.AttestationProvider import MAAProvider, ITAProvider
from AttestationTypes import TpmInfo
from src.measurements import get_measurements
from src.Encoder import Encoder
from tpm_wrapper import TssWrapper
from requests.exceptions import RequestException

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers.aead import AESGCM

from snp import AttestationReport

# The version number of the attestation protocol between the client and the service.
PROTOCOL_VERSION = "2.0"


class GuestAttestationParameters:
  def __init__(self, os_info=None, tcg_logs=None, tpm_info=None, isolation=None):
    self.os_info = os_info
    self.tcg_logs = tcg_logs
    self.tpm_info = tpm_info
    self.isolation = isolation
  
  def toJson(self):
    return json.dumps({
      'AttestationProtocolVersion': PROTOCOL_VERSION,
      'OSType': Encoder.base64_encode_string(str(self.os_info.type)),
      'OSDistro': Encoder.base64_encode_string(self.os_info.distro_name),
      'OSVersionMajor': str(self.os_info.major_version),
      'OSVersionMinor': str(self.os_info.minor_version),
      'OSBuild': Encoder.base64_encode_string(self.os_info.build),
      'TcgLogs': Encoder.base64_encode(self.tcg_logs),
      'ClientPayload': Encoder.base64_encode_string(""),
      'TpmInfo': self.tpm_info.get_values(),
      'IsolationInfo': self.isolation.get_values()
    })


class HardwareEvidence:
  """
  A class to represent hardware evidence.

  Attributes
  ----------
  hardware_report : bytes
    The hardware report.
  runtime_data : bytes
    The runtime data.
  """

  def __init__(self, report_type: str, hardware_report: bytes, runtime_data: bytes):
    if not isinstance(report_type, str):
      raise TypeError(f"Expected bytes for report_type, got {type(report_type).__name__}")
    if not isinstance(hardware_report, bytes):
      raise TypeError(f"Expected bytes for hardware_report, got {type(hardware_report).__name__}")
    if not isinstance(runtime_data, bytes):
      raise TypeError(f"Expected bytes for runtime_data, got {type(runtime_data).__name__}")

    self.type = report_type
    self.hardware_report = hardware_report
    self.runtime_data = runtime_data


class Verifier(Enum):
  UNDEFINED = 0 # Undefined type
  MAA = 1       # Microsoft Attestation Service
  ITA = 2       # Intel Trusted Authority


class AttestationClientParameters:
  def __init__(self, endpoint: str, verifier: Verifier, isolation_type: IsolationType, claims = None, api_key = None):
    # Validate the isolation type
    if not isinstance(isolation_type, IsolationType):
      raise ValueError(f"Unsupported isolation type: {isolation_type}. Supported types: {list(IsolationType)}")
    
     # Validate the verifier
    if not isinstance(verifier, Verifier):
      raise ValueError(f"Unsupported isolation type: {verifier}. Supported types: {list(Verifier)}")

    self.endpoint = endpoint
    self.verifier = verifier
    self.api_key = api_key
    self.isolation_type = isolation_type
    self.user_claims = claims


class UnsupportedReportTypeException(Exception):
  pass


class AttestationClient():
  def __init__(self, logger: Logger, parameters: AttestationClientParameters):
    verifier = parameters.verifier
    isolation_type = parameters.isolation_type
    endpoint = parameters.endpoint
    api_key = parameters.api_key

    self.parameters = parameters
    self.log = logger

    self.provider = MAAProvider(logger,isolation_type,endpoint) if verifier == Verifier.MAA else ITAProvider(logger,isolation_type,endpoint, api_key) if verifier == Verifier.ITA else None

  def get_hardware_evidence(self) -> HardwareEvidence:
    """
    Returns an instance of the HardwareEvidence class.

    Returns
    -------
    HardwareEvidence
      The current instance of the HardwareEvidence class.
    """

    try:
      self.log.info('Collecting hardware evidence...')

      # Extract Hardware Report and Runtime Data
      tss_wrapper = TssWrapper(self.log)
      hcl_report = tss_wrapper.get_hcl_report(self.parameters.user_claims)
      report_type = ReportParser.extract_report_type(hcl_report)
      hw_report = ReportParser.extract_hw_report(hcl_report)
      runtime_data = ReportParser.extract_runtimes_data(hcl_report)

      isolation_type = self.parameters.isolation_type
      if report_type == 'snp' and isolation_type == IsolationType.SEV_SNP:
        self.log_snp_report(hw_report)
      elif report_type == 'tdx' and isolation_type == IsolationType.TDX:
        self.log.info("Fetching td quote...")

        # Logs important TDX fields from the hardware report
        imds_client = ImdsClient(self.log)
        encoded_report = Encoder.base64url_encode(hw_report)
        encoded_hw_evidence = imds_client.get_td_quote(encoded_report)
        hw_report = Encoder.base64url_decode(encoded_hw_evidence)

        self.log.info("Finished fetching td quote")

        self.log.info("Hardware report parsing for TDX not supported yet")
      else:
        raise UnsupportedReportTypeException(f"Unsupported report type: {report_type}")

      return HardwareEvidence(report_type, hw_report, runtime_data)
    except Exception as e:
      self.log.error(f"Error while reading hardware report. Exception {e}")

  def attest_guest(self):
    """
    Attest the Hardware and Guest
    """

    # Attest the platform using exponential backoff
    max_retries = 5
    retries = 0
    backoff_factor = 1
    while retries < max_retries:
      try:
        self.log.info('Attesting Guest Evidence...')

        hardware_evidence = self.get_hardware_evidence()
        hw_report = hardware_evidence.hardware_report   # td_quote or snp report
        runtime_data = hardware_evidence.runtime_data
        report_type = hardware_evidence.type

        # get the isolation information for the platform
        hw_evidence = ""
        imds_client = ImdsClient(self.log)
        if report_type == 'tdx':
          hw_evidence = TdxEvidence(hw_report, runtime_data)
        elif report_type == 'snp':
          cert_chain = imds_client.get_vcek_certificate()
          hw_evidence = SnpEvidence(hw_report, runtime_data, cert_chain)
        else:
          self.log.info('Invalid Hardware Report Type')


        # Collect guest attestation parameters
        os_info = OsInfo()
        tss_wrapper = TssWrapper(self.log)
        aik_cert = tss_wrapper.get_aik_cert()
        aik_pub = tss_wrapper.get_aik_pub()
        pcr_quote, sig = tss_wrapper.get_pcr_quote(os_info.pcr_list)
        pcr_values = tss_wrapper.get_pcr_values(os_info.pcr_list)
        key, key_handle, tpm = tss_wrapper.get_ephemeral_key(os_info.pcr_list)
        tpm_info = TpmInfo(aik_cert, aik_pub, pcr_quote, sig, pcr_values, key)
        tcg_logs = get_measurements(os_info.type)
        isolation = Isolation(self.parameters.isolation_type, hw_evidence)
        param = GuestAttestationParameters(os_info, tcg_logs, tpm_info, isolation)

        # Calls attestation provider with the guest evidence
        request = {
          "AttestationInfo": Encoder.base64url_encode_string(param.toJson())
        }
        encoded_response = self.provider.attest_guest(request)

        # Check the response from the server if there is an error
        # we retry until all retries have been exhausted
        if encoded_response:
          self.log.info('Parsing encoded token...')

          # decode the response
          response = urlsafe_b64decode(encoded_response + '==').decode('utf-8')
          response = json.loads(response)

          # parse encrypted inner key
          encrypted_inner_key = response['EncryptedInnerKey']
          encrypted_inner_key = json.dumps(encrypted_inner_key)
          encrypted_inner_key_decoded = Encoder.base64decode(encrypted_inner_key)

          # parse Encryption Parameters
          encryption_params_json = response['EncryptionParams']
          iv = json.dumps(encryption_params_json['Iv'])
          iv = Encoder.base64decode(iv)

          auth_data = response['AuthenticationData']
          auth_data = json.dumps(auth_data)
          auth_data = Encoder.base64decode(auth_data)

          decrypted_inner_key = \
            tss_wrapper.decrypt_with_ephemeral_key(
              encrypted_inner_key_decoded,
              os_info.pcr_list,
              key_handle,
              tpm
            )

          # parse the encrypted token
          encrypted_jwt = response['Jwt']
          encrypted_jwt = json.dumps(encrypted_jwt)
          encrypted_jwt = Encoder.base64decode(encrypted_jwt)

          # Your AES key
          key = decrypted_inner_key

          # Create an AESGCM object with the generated key
          aesgcm = AESGCM(key)

          self.log.info('Decrypting JWT...')

          associated_data = bytearray(b'Transport Key')

          # NOTE: authentication data is part of the cipher's last 16 bytes
          cipher_message = encrypted_jwt + auth_data

          # Decrypt the token using the same key, nonce, and associated data
          decrypted_data = aesgcm.decrypt(iv, cipher_message, bytes(associated_data))
          self.log.info("Decrypted JWT Successfully.")
          self.log.info('TOKEN:')
          self.log.info(decrypted_data.decode('utf-8'))

          encoded_token = decrypted_data.decode('utf-8')
          self.provider.print_guest_claims(encoded_token)

          return decrypted_data
        else:
          self.log.error("Token was not received from attestation provider")

          retries += 1
          if retries < max_retries:
            sleep_time = backoff_factor * (2 ** (retries - 1))
            self.log.info(f"Retrying in {sleep_time} seconds...")
            time.sleep(sleep_time)
          else:
            self.log.error("Token was not received from attestation provider")
      except RequestException as e:
        self.log.error(f"Request to attest platform failed with an exception: {e}")

        retries += 1
        if retries < max_retries:
          sleep_time = backoff_factor * (2 ** (retries - 1))
          self.log.info(f"Retrying in {sleep_time} seconds...")
          time.sleep(sleep_time)
        else:
          self.log.error(
            f"Request failed after all retries have been exhausted. Error: {e}"
          )


  def attest_platform(self):
    """
    Attest the Hardware
    """

    # Attest the platform using exponential backoff
    max_retries = 5
    retries = 0
    backoff_factor = 1
    while retries < max_retries:
      try:
        self.log.info('Attesting Platform Evidence...')

        tss_wrapper = TssWrapper(self.log)
        isolation_type = self.parameters.isolation_type
        # Extract Hardware Report and Runtime Data
        hcl_report = tss_wrapper.get_hcl_report(self.parameters.user_claims)
        report_type = ReportParser.extract_report_type(hcl_report)
        runtime_data = ReportParser.extract_runtimes_data(hcl_report)
        hw_report = ReportParser.extract_hw_report(hcl_report)

        # Set request data based on the platform
        encoded_report = Encoder.base64url_encode(hw_report)
        encoded_runtime_data = Encoder.base64url_encode(runtime_data)
        encoded_token = ""
        encoded_hw_evidence = ""

        imds_client = ImdsClient(self.log)
        if report_type == 'tdx' and isolation_type == IsolationType.TDX:
          encoded_hw_evidence = imds_client.get_td_quote(encoded_report)
        elif report_type == 'snp' and isolation_type == IsolationType.SEV_SNP:
          # Logs important SNP fields from the hardware report
          self.log_snp_report(hw_report)

          cert_chain = imds_client.get_vcek_certificate()
          snp_report = {
            'SnpReport': encoded_report,
            'VcekCertChain': Encoder.base64url_encode(cert_chain)
          }
          snp_report = json.dumps(snp_report)
          snp_report = bytearray(snp_report.encode('utf-8'))
          encoded_hw_evidence = Encoder.base64url_encode(snp_report)
        else:
          self.log.info('Invalid Hardware Report Type')

        # verify hardware evidence
        encoded_token = self.provider.attest_platform(encoded_hw_evidence, encoded_runtime_data)

        # Check the response from the server if there is an error
        # we retry until all retries have been exhausted
        if encoded_token:
          self.log.info('TOKEN:')
          self.log.info(encoded_token)
          self.provider.print_platform_claims(encoded_token)

          return encoded_token
        else:
          self.log.error("Token was not received from attestation provider")

          retries += 1
          if retries < max_retries:
            sleep_time = backoff_factor * (2 ** (retries - 1))
            self.log.info(f"Retrying in {sleep_time} seconds...")
            time.sleep(sleep_time)
          else:
            self.log.error("Token was not received from attestation provider")
      except RequestException as e:
        self.log.error(f"Request to attest platform failed with an exception: {e}")

        retries += 1
        if retries < max_retries:
          sleep_time = backoff_factor * (2 ** (retries - 1))
          self.log.info(f"Retrying in {sleep_time} seconds...")
          time.sleep(sleep_time)
        else:
          self.log.error(
            f"Request failed after all retries have been exhausted. Error: {e}"
          )


  def log_snp_report(self, hw_report):
    """
    Logs snp snp attestation report fields.
    """
    report_instance = AttestationReport.deserialize(hw_report)
    self.log.info(f"Attestation report size: {len(hw_report)} bytes")
    self.log.info(f"Report version: {report_instance.version}")
    self.log.info(f"Report guest svn: {report_instance.guest_svn}")

    formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.current_tcb.serialize()[::-1])
    self.log.info(f"Current TCB version: {formatted_tcb}")

    formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.reported_tcb.serialize()[::-1])
    self.log.info(f"Reported TCB version: {formatted_tcb}")

    formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.committed_tcb.serialize()[::-1])
    self.log.info(f"Commited TCB version: {formatted_tcb}")

    formatted_tcb = "".join(f"{byte:02X}" for byte in report_instance.launch_tcb.serialize()[::-1])
    self.log.info(f"Launched TCB version: {formatted_tcb}")

