# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""IO utils.

isort:skip_file
"""
import logging
from future import standard_library
standard_library.install_aliases()

import os
import traceback
from subprocess import getstatusoutput

import six
import tensorflow as tf
from six.moves import http_client
from six.moves import urllib
import json
if six.PY2:
  from urllib import quote
else:
  from urllib.parse import quote

if tf.__version__ >= '2.0':
  tf = tf.compat.v1

EASY_REC_RES_DIR = 'easy_rec_user_resources'
HTTP_MAX_NUM_RETRY = 5
HTTP_MAX_TIMEOUT = 600


def http_read(url, timeout=HTTP_MAX_TIMEOUT, max_retry=HTTP_MAX_NUM_RETRY):
  """Read data from url with maximum retry.

  Args:
    url: http url to be read
    timeout: specifies a timeout in seconds for blocking operations.
    max_retry: http max retry times.
  """
  num_read_try = 0
  data = None
  while num_read_try < max_retry:
    try:
      if six.PY2:
        url = url.encode('utf-8')
      url = quote(url, safe='%/:?=&')
      data = urllib.request.urlopen(url, timeout=timeout).read()
      break
    except http_client.IncompleteRead:
      tf.logging.warning('incomplete read exception, will retry: %s' % url)
      num_read_try += 1
    except Exception:
      tf.logging.error(traceback.format_exc())
      break

  if data is None:
    tf.logging.error('http read %s failed' % url)

  return data


def download(oss_or_url, dst_dir=''):
  """Download file.

  Args:
    oss_or_url: http or oss path
    dst_dir: destination directory
  Return:
    dst_file: local path for the downloaded file
  """
  _, basename = os.path.split(oss_or_url)
  if oss_or_url[:3] == 'oss':
    with tf.gfile.GFile(oss_or_url, 'rb') as infile:
      file_content = infile.read()
  elif oss_or_url[:4] == 'http':
    try:
      response = urllib.request.urlopen(oss_or_url, timeout=HTTP_MAX_TIMEOUT)
      file_content = response.read()
    except Exception as e:
      raise RuntimeError('Download %s failed: %s\n %s' %
                         (oss_or_url, str(e), traceback.format_exc()))
  else:
    tf.logging.warning('skip downloading %s, seems to be a local file' %
                       oss_or_url)
    return oss_or_url

  if dst_dir != '' and not os.path.exists(dst_dir):
    os.makedirs(dst_dir)
  dst_file = os.path.join(dst_dir, basename)
  with tf.gfile.GFile(dst_file, 'wb') as ofile:
    ofile.write(file_content)

  return dst_file


def create_module_dir(dst_dir):
  if not os.path.exists(dst_dir):
    os.makedirs(dst_dir)
    with open(os.path.join(dst_dir, '__init__.py'), 'w') as ofile:
      ofile.write('\n')


def download_resource(resource_path, dst_dir=EASY_REC_RES_DIR):
  """Download user resource.

  Args:
    resource_path: http or oss path
    dst_dir: destination directory
  """
  create_module_dir(dst_dir)
  _, basename = os.path.split(resource_path)
  if not basename.endswith('.py'):
    raise ValueError('resource %s should be python file' % resource_path)

  target = download(resource_path, dst_dir)

  return target


def download_and_uncompress_resource(resource_path, dst_dir=EASY_REC_RES_DIR):
  """Download user resource and uncompress it if necessary.

  Args:
    resource_path: http or oss path
    dst_dir: download destination directory
  """
  create_module_dir(dst_dir)

  _, basename = os.path.split(resource_path)
  if not basename.endswith('.tar.gz') and not basename.endswith('.zip') and \
     not basename.endswith('.py'):
    raise ValueError('resource %s should be tar.gz or zip or py' %
                     resource_path)

  download(resource_path, dst_dir)

  stat = 0
  if basename.endswith('tar.gz'):
    stat, output = getstatusoutput('cd %s && tar -zxf %s' % (dst_dir, basename))
  elif basename.endswith('zip'):
    stat, output = getstatusoutput('cd %s && unzip %s' % (dst_dir, basename))

  if stat != 0:
    raise ValueError('uncompress resoruce %s failed: %s' % resource_path,
                     output)

  return dst_dir


def oss_has_t_mode(target_file):
  """Test if current enviroment support t-mode written to oss."""
  if 'PAI' not in tf.__version__:
    return False
  # test if running on cluster
  test_file = target_file + '.tmp'
  try:
    with tf.gfile.GFile(test_file, 't') as ofile:
      ofile.write('a')
      pass
    tf.gfile.Remove(test_file)
    return True
  except:  # noqa: E722
    return False


def fix_oss_dir(path):
  """Make sure that oss dir endswith /."""
  if path.startswith('oss://') and not path.endswith('/'):
    return path + '/'
  return path


def save_data_to_json_path(json_path, data):
  with tf.gfile.GFile(json_path, 'w') as fout:
    fout.write(json.dumps(data))
  assert tf.gfile.Exists(json_path), 'in_save_data_to_json_path, save_failed'


def read_data_from_json_path(json_path):
  if json_path and tf.gfile.Exists(json_path):
    with tf.gfile.GFile(json_path, 'r') as fin:
      data = json.loads(fin.read())
    return data
  else:
    logging.info('json_path not exists, return None')
    return None


def convert_tf_flags_to_argparse(flags):
  """Convert tf.app.flags.FLAGS to argparse.ArgumentParser.

  Args:
      flags: tf.app.flags.FLAGS
  Returns:
      argparse.ArgumentParser: configurate ArgumentParser object
  """
  import argparse
  import ast
  parser = argparse.ArgumentParser()

  args = {}
  for flag in flags._flags().values():
    flag_name = flag.name
    if flag_name in args:
      args[flag_name][0] = True
      continue
    default = flag.value
    flag_type = type(default)
    help_str = flag.help or ''
    args[flag_name] = [
        False, flag_type, default, help_str,
        flag.choices if hasattr(flag, 'choices') else None
    ]

  def str2bool(v):
    if isinstance(v, bool):
      return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
      return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
      return False
    else:
      raise argparse.ArgumentTypeError('Boolean value expected.')

  for flag_name, (multi, flag_type, default, help_str, choices) in args.items():
    if flag_type == bool:
      parser.add_argument(
          '--' + flag_name,
          type=str2bool,
          nargs='?',
          const=True,
          default=False,
          help=help_str)
    elif flag_type == str:
      if choices:
        parser.add_argument(
            '--' + flag_name,
            type=str,
            choices=choices,
            default=default,
            help=help_str)
      elif multi:
        parser.add_argument(
            '--' + flag_name,
            type=str,
            action='append',
            default=default,
            help=help_str)
      else:
        parser.add_argument(
            '--' + flag_name, type=str, default=default, help=help_str)
    elif flag_type in (list, dict):
      parser.add_argument(
          '--' + flag_name,
          type=lambda s: ast.literal_eval(s),
          default=default,
          help=help_str)
    elif flag_type in (int, float):
      parser.add_argument(
          '--' + flag_name, type=flag_type, default=default, help=help_str)
    else:
      parser.add_argument(
          '--' + flag_name, type=str, default=default, help=help_str)
  return parser


def filter_unknown_args(flags, args):
  """Filter unknown args."""
  known_args = [args[0]]
  parser = convert_tf_flags_to_argparse(flags)
  args, unknown = parser.parse_known_args(args)
  if len(unknown) > 1:
    logging.info('undefined arguments: %s', ', '.join(unknown[1:]))
  for key, value in vars(args).items():
    if value is None:
      continue
    if type(value) in (list, dict) and not value:
      continue
    known_args.append('--' + key + '=' + str(value))
  logging.info('defined arguments: %s', ', '.join(known_args[1:]))
  return known_args
