#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Pipeline manipulation utilities useful for many runners.

For internal use only; no backwards-compatibility guarantees.
"""

# pytype: skip-fileimport collections

import collections
import copy

from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.transforms import environments
from apache_beam.typehints import typehints


def group_by_key_input_visitor(deterministic_key_coders=True):
  # Importing here to avoid a circular dependency
  # pylint: disable=wrong-import-order, wrong-import-position
  from apache_beam.pipeline import PipelineVisitor
  from apache_beam.transforms.core import GroupByKey

  class GroupByKeyInputVisitor(PipelineVisitor):
    """A visitor that replaces `Any` element type for input `PCollection` of
    a `GroupByKey` with a `KV` type.

    TODO(BEAM-115): Once Python SDK is compatible with the new Runner API,
    we could directly replace the coder instead of mutating the element type.
    """
    def __init__(self, deterministic_key_coders=True):
      self.deterministic_key_coders = deterministic_key_coders

    def enter_composite_transform(self, transform_node):
      self.visit_transform(transform_node)

    def visit_transform(self, transform_node):
      if isinstance(transform_node.transform, GroupByKey):
        pcoll = transform_node.inputs[0]
        pcoll.element_type = typehints.coerce_to_kv_type(
            pcoll.element_type, transform_node.full_label)
        pcoll.requires_deterministic_key_coder = (
            self.deterministic_key_coders and transform_node.full_label)
        key_type, value_type = pcoll.element_type.tuple_types
        if transform_node.outputs:
          key = next(iter(transform_node.outputs.keys()))
          transform_node.outputs[key].element_type = typehints.KV[
              key_type, typehints.Iterable[value_type]]
          transform_node.outputs[key].requires_deterministic_key_coder = (
              self.deterministic_key_coders and transform_node.full_label)

  return GroupByKeyInputVisitor(deterministic_key_coders)


def validate_pipeline_graph(pipeline_proto):
  """Ensures this is a correctly constructed Beam pipeline.
  """
  def get_coder(pcoll_id):
    return pipeline_proto.components.coders[
        pipeline_proto.components.pcollections[pcoll_id].coder_id]

  def validate_transform(transform_id):
    transform_proto = pipeline_proto.components.transforms[transform_id]

    # Currently the only validation we perform is that GBK operations have
    # their coders set properly.
    if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
      if len(transform_proto.inputs) != 1:
        raise ValueError("Unexpected number of inputs: %s" % transform_proto)
      if len(transform_proto.outputs) != 1:
        raise ValueError("Unexpected number of outputs: %s" % transform_proto)
      input_coder = get_coder(next(iter(transform_proto.inputs.values())))
      output_coder = get_coder(next(iter(transform_proto.outputs.values())))
      if input_coder.spec.urn != common_urns.coders.KV.urn:
        raise ValueError(
            "Bad coder for input of %s: %s" % (transform_id, input_coder))
      if output_coder.spec.urn != common_urns.coders.KV.urn:
        raise ValueError(
            "Bad coder for output of %s: %s" % (transform_id, output_coder))
      output_values_coder = pipeline_proto.components.coders[
          output_coder.component_coder_ids[1]]
      if (input_coder.component_coder_ids[0] !=
          output_coder.component_coder_ids[0] or
          output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or
          output_values_coder.component_coder_ids[0] !=
          input_coder.component_coder_ids[1]):
        raise ValueError(
            "Incompatible input coder %s and output coder %s for transform %s" %
            (transform_id, input_coder, output_coder))
    elif transform_proto.spec.urn == common_urns.primitives.ASSIGN_WINDOWS.urn:
      if not transform_proto.inputs:
        raise ValueError("Missing input for transform: %s" % transform_proto)
    elif transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
      if not transform_proto.inputs:
        raise ValueError("Missing input for transform: %s" % transform_proto)

    for t in transform_proto.subtransforms:
      validate_transform(t)

  for t in pipeline_proto.root_transform_ids:
    validate_transform(t)


def _dep_key(dep):
  if dep.type_urn == common_urns.artifact_types.FILE.urn:
    payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
        dep.type_payload)
    if payload.sha256:
      type_info = 'sha256', payload.sha256
    else:
      type_info = 'path', payload.path
  elif dep.type_urn == common_urns.artifact_types.URL.urn:
    payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString(
        dep.type_payload)
    if payload.sha256:
      type_info = 'sha256', payload.sha256
    else:
      type_info = 'url', payload.url
  else:
    type_info = dep.type_urn, dep.type_payload
  return type_info, dep.role_urn, dep.role_payload


def _expanded_dep_keys(dep):
  if (dep.type_urn == common_urns.artifact_types.FILE.urn and
      dep.role_urn == common_urns.artifact_roles.STAGING_TO.urn):
    payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
        dep.type_payload)
    role = beam_runner_api_pb2.ArtifactStagingToRolePayload.FromString(
        dep.role_payload)
    if role.staged_name == 'submission_environment_dependencies.txt':
      return
    elif role.staged_name == 'requirements.txt':
      with open(payload.path) as fin:
        for line in fin:
          yield 'requirements.txt', line.strip()
      return

  yield _dep_key(dep)


def _base_env_key(env, include_deps=True):
  return (
      env.urn,
      env.payload,
      tuple(sorted(env.capabilities)),
      tuple(sorted(env.resource_hints.items())),
      tuple(sorted(_dep_key(dep)
                   for dep in env.dependencies)) if include_deps else None)


def _env_key(env):
  return tuple(
      sorted(
          _base_env_key(e)
          for e in environments.expand_anyof_environments(env)))


def merge_common_environments(pipeline_proto, inplace=False):
  canonical_environments = collections.defaultdict(list)
  for env_id, env in pipeline_proto.components.environments.items():
    canonical_environments[_env_key(env)].append(env_id)

  if len(canonical_environments) == len(pipeline_proto.components.environments):
    # All environments are already sufficiently distinct.
    return pipeline_proto

  environment_remappings = {
      e: es[0]
      for es in canonical_environments.values() for e in es
  }

  return update_environments(pipeline_proto, environment_remappings, inplace)


def merge_superset_dep_environments(pipeline_proto):
  """Merges all environemnts A and B where A and B are equivalent except that
  A has a superset of the dependencies of B.
  """
  docker_envs = {}
  for env_id, env in pipeline_proto.components.environments.items():
    docker_env = environments.resolve_anyof_environment(
        env, common_urns.environments.DOCKER.urn)
    if docker_env.urn == common_urns.environments.DOCKER.urn:
      docker_envs[env_id] = docker_env

  has_base_and_dep = collections.defaultdict(set)
  env_scores = {
      env_id: (len(env.dependencies), env_id)
      for (env_id, env) in docker_envs.items()
  }

  for env_id, env in docker_envs.items():
    base_key = _base_env_key(env, include_deps=False)
    has_base_and_dep[base_key, None].add(env_id)
    for dep in env.dependencies:
      for dep_key in _expanded_dep_keys(dep):
        has_base_and_dep[base_key, dep_key].add(env_id)

  environment_remappings = {}
  for env_id, env in docker_envs.items():
    base_key = _base_env_key(env, include_deps=False)
    # This is the set of all environments that have at least all of env's deps.
    candidates = set.intersection(
        has_base_and_dep[base_key, None],
        *[
            has_base_and_dep[base_key, dep_key] for dep in env.dependencies
            for dep_key in _expanded_dep_keys(dep)
        ])
    # Choose the maximal one.
    best = max(candidates, key=env_scores.get)
    if best != env_id:
      environment_remappings[env_id] = best

  return update_environments(pipeline_proto, environment_remappings)


def update_environments(pipeline_proto, environment_remappings, inplace=False):
  if not environment_remappings:
    return pipeline_proto

  if not inplace:
    pipeline_proto = copy.copy(pipeline_proto)

  for t in pipeline_proto.components.transforms.values():
    if t.environment_id not in pipeline_proto.components.environments:
      # TODO(https://github.com/apache/beam/issues/30876): Remove this
      #  workaround.
      continue
    if t.environment_id and t.environment_id in environment_remappings:
      t.environment_id = environment_remappings[t.environment_id]
  for w in pipeline_proto.components.windowing_strategies.values():
    if w.environment_id not in pipeline_proto.components.environments:
      # TODO(https://github.com/apache/beam/issues/30876): Remove this
      #  workaround.
      continue
    if w.environment_id and w.environment_id in environment_remappings:
      w.environment_id = environment_remappings[w.environment_id]
  for e in set(environment_remappings.keys()) - set(
      environment_remappings.values()):
    del pipeline_proto.components.environments[e]
  return pipeline_proto
