# Copyright 2020 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Magic command wrapper for whistle mapping language."""

import json
import os

from google.cloud import storage
from googleapiclient import discovery
import grpc
from IPython.core import magic_arguments
from IPython.core.magic import cell_magic
from IPython.core.magic import line_magic
from IPython.core.magic import Magics
from IPython.core.magic import magics_class
from IPython.display import JSON

from google.protobuf import json_format
from wstl.magics import _constants
from wstl.magics import _location
from wstl.proto import wstlservice_pb2
from wstl.proto import wstlservice_pb2_grpc

_GRPC_TIMEOUT = os.environ.get("NOTEBOOK_GRPC_TIMEOUT", 10.0)
_DEFAULT_HOST = os.environ.get("NOTEBOOK_GRPC_HOST", "localhost")
_DEFAULT_PORT = os.environ.get("NOTEBOOK_GRPC_PORT", "50051")


@magics_class
class WSTLMagics(Magics):
  """Evaluates whistle mapping language (.wstl) within a cell."""

  def __init__(self, shell):
    super(WSTLMagics, self).__init__(shell)
    self.grpc_target = "{}:{}".format(_DEFAULT_HOST, _DEFAULT_PORT)

  @line_magic("wstl-reset")
  def wstl_reset(self, line):
    """Cell magic to clear all variables and functions from incremental transformation."""
    with grpc.insecure_channel(self.grpc_target) as channel:
      stub = wstlservice_pb2_grpc.WhistleServiceStub(channel)
      session_id = str(self.shell.history_manager.session_number)
      req = wstlservice_pb2.DeleteIncrementalSessionRequest(
          session_id=session_id)
      try:
        resp = stub.DeleteIncrementalSessionRequest(req)
      except grpc.RpcError as rpc_error:
        return rpc_error
      else:
        return JSON(json_format.MessageToDict(resp))

  @magic_arguments.magic_arguments()
  @magic_arguments.argument(
      "--input",
      type=str,
      required=False,
      help="""The input. Supports the following prefix notations:
      py://<name_of_python_variable>
      json://<inline_json_object_or_array> : python inline dict and list
      expressions are supported. e.g. json://{"field":"value"} or
      json://[{"first":"value"},{"second":"value"}]
      file://<path_to_local_file_system> , supports glob wildcard expressions
      and will only load .json or .ndjson file extensions. Each json object/list
      defined within an ndjson will be a separate input to the mapping.
      """)
  @magic_arguments.argument(
      "--library_config",
      type=str,
      required=False,
      help="""Path to the directory where the library mapping files are located."""
  )
  @magic_arguments.argument(
      "--code_config",
      type=str,
      required=False,
      help="""Path to the directory of FHIR ConceptMaps used for code harmonization."""
  )
  @magic_arguments.argument(
      "--unit_config",
      type=str,
      required=False,
      help="""Path to a unit harmonization file (textproto).""")
  @magic_arguments.argument(
      "--output",
      type=str,
      required=False,
      help="""Name of python variable to store result. e.g. --output temp_var"""
  )
  @cell_magic("wstl")
  def wstl(self, line, cell):
    """Cell magic to evaluate whistle mapping language from iPython kernel."""
    args = magic_arguments.parse_argstring(self.wstl, line)

    # TODO (): migrate to secure channel.
    with grpc.insecure_channel(self.grpc_target) as channel:
      stub = wstlservice_pb2_grpc.WhistleServiceStub(channel)

      (incremental_session, err) = _get_or_create_session(stub, self.shell)
      if err:
        return err

      (transform,
       err) = _get_incremental_transform(stub, self.shell,
                                         incremental_session.session_id, args,
                                         cell)
      if err:
        return err

      result = _response_to_json(transform)
      if args.output:
        self.shell.push({args.output: result})
      return JSON(result)

  @magic_arguments.magic_arguments()
  @magic_arguments.argument(
      "--version",
      choices=["stu3", "r4"],
      default="r4",
      type=str,
      help="""The fhir version to apply to the validation.
      The default is r4.""")
  @magic_arguments.argument(
      "--input",
      type=str,
      help="""The input. Supports the following prefix notations:
      py://<name_of_python_variable>
      json://<inline_json_object_or_array> : python inline dict and list
      expressions are supported. e.g. json://{"field":"value"} or
      json://[{"first":"value"},{"second":"value"}]
      file://<path_to_local_file_system> , supports glob wildcard expressions
      and will only load .json or .ndjson file extensions. Each json object/list
      defined within an ndjson will be a separate input to the validation.""")
  @line_magic("fhir_validate")
  def fhir_validate(self, line):
    """Line magic to validate json FHIR resource(s) from iPython kernel."""
    args = magic_arguments.parse_argstring(self.fhir_validate, line)

    with grpc.insecure_channel(self.grpc_target) as channel:
      stub = wstlservice_pb2_grpc.WhistleServiceStub(channel)

      (resp, err) = _get_validation(stub, self.shell, args.version, args.input)
      if err:
        return err
      return JSON(json_format.MessageToDict(resp))


@magics_class
class LoadHL7Magics(Magics):
  """Loads parsed HL7v2 message from GCS or HL7v2 Store."""

  @magic_arguments.magic_arguments()
  @magic_arguments.argument(
      "--project_id",
      type=str,
      help="""ID of the GCP project that the HL7v2 Store belongs to.""",
      required=True)
  @magic_arguments.argument(
      "--region",
      type=str,
      help="""Region of the HL7v2 Store.""",
      required=True)
  @magic_arguments.argument(
      "--dataset_id",
      type=str,
      required=True,
      help="""ID of the dataset that the HL7v2 store belongs to.""")
  @magic_arguments.argument(
      "--hl7v2_store_id",
      type=str,
      required=True,
      help="""ID of the HL7v2 store to load data from.""")
  @magic_arguments.argument(
      "--api_version",
      type=str,
      required=False,
      default="v1beta1",
      choices=["v1", "v1beta1"],
      help="""The version of healthcare api to call. Default to v1.""")
  @magic_arguments.argument(
      "--filter",
      type=str,
      required=False,
      help="""
      filter: string, Restricts messages returned to those matching a filter. Syntax:
    https://cloud.google.com/appengine/docs/standard/python/search/query_strings

    If the filter string contains white space, it must be surrounded by single quotes.

    Fields/functions available for filtering are:

    *  `message_type`, from the MSH-9.1 field. For example,
    `NOT message_type = "ADT"`.
    *  `send_date` or `sendDate`, the YYYY-MM-DD date the message was sent in
    the dataset's time_zone, from the MSH-7 segment. For example,
    `send_date < "2017-01-02"`.
    *  `send_time`, the timestamp when the message was sent, using the
    RFC3339 time format for comparisons, from the MSH-7 segment. For example,
    `send_time < "2017-01-02T00:00:00-05:00"`.
    *  `send_facility`, the care center that the message came from, from the
    MSH-4 segment. For example, `send_facility = "ABC"`.
    *  `PatientId(value, type)`, which matches if the message lists a patient
    having an ID of the given value and type in the PID-2, PID-3, or PID-4
    segments. For example, `PatientId("123456", "MRN")`.
    *  `labels.x`, a string value of the label with key `x` as set using the
    Message.labels
    map. For example, `labels."priority"="high"`. The operator `:*` can be used
    to assert the existence of a label. For example, `labels."priority":*`.""")
  @magic_arguments.argument(
      "--dest_file_name",
      type=str,
      required=False,
      help="""
      The destination file path to store the loaded data.
      If not provided, the result will be directly returned to the IPython kernel.
      """,
      default="")
  @line_magic("load_hl7v2_datastore")
  def load_hl7v2_datastore(self, line):
    """Load parsed HL7v2 massage from the HL7v2 Store specified."""
    args = magic_arguments.parse_argstring(self.load_hl7v2_datastore, line)

    hl7v2_messages = _get_message_from_hl7v2_store(args.api_version,
                                                   args.project_id, args.region,
                                                   args.dataset_id,
                                                   args.hl7v2_store_id,
                                                   args.filter)
    if args.dest_file_name:
      with open(args.dest_file_name, "w") as dest_file:
        dest_file.write(hl7v2_messages)
        return "The message was written to {} successfully.".format(
            args.dest_file_name)

    return JSON(hl7v2_messages)

  @magic_arguments.magic_arguments()
  @magic_arguments.argument(
      "--bucket_name",
      type=str,
      help="""The name of the GCS bucket to load data from.""",
      required=True)
  @magic_arguments.argument(
      "--source_blob_name",
      type=str,
      help="""The name of the blob to load.""",
      required=True)
  @magic_arguments.argument(
      "--dest_file_name",
      type=str,
      required=False,
      help="""
      The destination file path to store the loaded data.
      If not provided, the result will be directly returned to the IPython kernel.
      """,
      default="")
  @line_magic("load_hl7v2_gcs")
  def load_hl7v2_gcs(self, line):
    """Load and return parsed HL7v2 massage from the blob in a GCS bucket specified."""
    args = magic_arguments.parse_argstring(self.load_hl7v2_gcs, line)
    dest_file_name = args.dest_file_name

    storage_client = storage.Client()

    bucket = storage_client.bucket(args.bucket_name)
    if not bucket.exists():
      raise ValueError(
          "The bucket does not exist. Please check the provided bucket name.")
    blob = bucket.get_blob(args.source_blob_name)
    if not blob:
      raise ValueError(
          "The blob does not exist. Please check the provided blob name.")
    content = blob.download_as_string()

    if blob and blob.content_encoding:
      content = content.decode(blob.content_encoding)

    # check if the returned content is a json
    try:
      try:
        result = json.loads(content)
      except TypeError:
        result = json.loads(content.decode("UTF-8"))
    except json.JSONDecodeError:
      print(
          "The loaded content is not a valid JSON. Please check the source bucket and blob."
      )
      raise
    if dest_file_name:
      with open(dest_file_name, "w") as dest:
        dest.write(content)
        return "The message was written to {} successfully.".format(
            dest_file_name)

    return JSON(result)


def _get_message_from_hl7v2_store(api_version, project, region, dataset,
                                  data_store, filter_str):
  """Returns an authorized API client by discovering the Healthcare API and creating a service object using the service account credentials in the GOOGLE_APPLICATION_CREDENTIALS environment variable."""
  # TODO(): add paging support for HL7v2 messages.
  service_name = "healthcare"

  client = discovery.build(service_name, api_version)
  hl7v2_messages_parent = "projects/{}/locations/{}/datasets/{}".format(
      project, region, dataset)
  hl7v2_message_path = "{}/hl7V2Stores/{}".format(hl7v2_messages_parent,
                                                  data_store)
  if filter_str:
    filter_str = filter_str.strip("'")
  return (
      client.projects().locations().datasets().hl7V2Stores().messages().list(
          parent=hl7v2_message_path, view="FULL",
          filter=filter_str).execute().get("hl7V2Messages", []))


def _get_or_create_session(stub, shell):
  """Retrieves or creates the incremental transform session.

  Args:
    stub: gRPC client stub library.
    shell: an instance of the iPython shell that invoked the magic command.

  Returns:
    IncrementalSessionResponse or a grpc.RcpError
  """
  session_id = shell.history_manager.session_number
  req = wstlservice_pb2.CreateIncrementalSessionRequest(
      session_id=str(session_id))
  try:
    resp = stub.GetOrCreateIncrementalSession(req)
  except grpc.RpcError as rpc_error:
    return None, rpc_error
  else:
    return resp, None


def _get_incremental_transform(stub, shell, session_id, wstl_args, cell):
  """Invokes, throughs a gRPC request, an incremental Whistle transform.

  Args:
    stub: gRPC client stub library.
    shell: an instance of the iPython shell that invoked the magic command.
    session_id: the incremental transformation session id.
    wstl_args: the arguments to the wstl magic command.
    cell: the contents of the cell, containing whistle.

  Returns:
    TransformResponse or a grpc.RpcError
  """
  req = wstlservice_pb2.IncrementalTransformRequest()
  req.session_id = session_id
  req.wstl = cell
  if wstl_args.library_config:
    library_configs = _location.parse_location(
        shell,
        wstl_args.library_config,
        file_ext=_constants.WSTL_FILE_EXT,
        load_contents=False)
    if library_configs:
      req.library_config.extend(library_configs)

  if wstl_args.code_config:
    code_configs = _location.parse_location(
        shell,
        wstl_args.code_config,
        file_ext=_constants.JSON_FILE_EXT,
        load_contents=False)
    if code_configs:
      req.code_config.extend(code_configs)

  if wstl_args.unit_config:
    unit_config = _location.parse_location(
        shell,
        wstl_args.unit_config,
        file_ext=_constants.TEXTPROTO_FILE_EXT,
        load_contents=False)
    if unit_config:
      req.unit_config = unit_config[0]

  if wstl_args.input:
    inputs = _location.parse_location(
        shell,
        wstl_args.input,
        file_ext=_constants.JSON_FILE_EXT,
        load_contents=True)
    if inputs:
      req.input.extend(inputs)
    else:
      return None, "no inputs matching arguement {}".format(wstl_args.input)

  try:
    resp = stub.GetIncrementalTransform(req, timeout=_GRPC_TIMEOUT)
  except grpc.RpcError as rpc_error:
    return None, rpc_error
  else:
    return resp, None


def _convert_message_to_json(transform_record):
  """Converts the output or error of TransformedRecords proto to JSON.

  Args:
    transform_record: a TransformRecords to convert to JSON.

  Returns:
    The JSON representation of the output or error field.
  """
  if transform_record.HasField("output"):
    return json.loads(transform_record.output)
  elif transform_record.HasField("error"):
    return json_format.MessageToDict(transform_record.error)
  else:
    return json_format.MessageToDict(transform_record)


def _response_to_json(response):
  """Converts each element within a TransformResponse result into JSON.

  Args:
    response: the TransformResponse from a GetIncrementalTransform request.

  Returns:
    One or more TransformedRecords contained in the TransformResponse as JSON.
  """
  if len(response.results) > 1:
    return [_convert_message_to_json(result) for result in response.results]
  else:
    return _convert_message_to_json(response.results[0])


def _get_validation(stub, shell, version, input_arg):
  """Validates the input JSON resource(s) against the FHIR version.

  Args:
    stub: gRPC client stub library.
    shell: an instance of the iPython shell that invoked the magic command.
    version: the FHIR version to be used for validation.
    input_arg: the FHIR resource to be validated against the specified version.

  Returns:
    The ValidationResponse containing the validation results of the resources
    or an error.

  """
  req = wstlservice_pb2.ValidationRequest(
      input=_location.parse_location(shell, input_arg))
  if version.lower() == "r4":
    req.fhir_version = wstlservice_pb2.ValidationRequest.FhirVersion.R4
  elif version.lower() == "stu3":
    req.fhir_version = wstlservice_pb2.ValidationRequest.FhirVersion.STU3
  else:
    return None, ValueError("""FHIR version {} is incorrect or not supported,
{} are supported versions""".format(
    version, wstlservice_pb2.ValidationRequest.FhirVersion.keys()))
  try:
    resp = stub.FhirValidate(req)
  except grpc.RpcError as rpc_error:
    return None, rpc_error
  return resp, None
