# -*- encoding:utf-8 -*-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library of common learning rate schedules."""

import numpy as np
import tensorflow as tf

if tf.__version__ >= '2.0':
  tf = tf.compat.v1


def exponential_decay_with_burnin(global_step,
                                  learning_rate_base,
                                  learning_rate_decay_steps,
                                  learning_rate_decay_factor,
                                  burnin_learning_rate=0.0,
                                  burnin_steps=0,
                                  min_learning_rate=0.0,
                                  staircase=True):
  """Exponential decay schedule with burn-in period.

  In this schedule, learning rate is fixed at burnin_learning_rate
  for a fixed period, before transitioning to a regular exponential
  decay schedule.

  Args:
    global_step: int tensor representing global step.
    learning_rate_base: base learning rate.
    learning_rate_decay_steps: steps to take between decaying the learning rate.
      Note that this includes the number of burn-in steps.
    learning_rate_decay_factor: multiplicative factor by which to decay
      learning rate.
    burnin_learning_rate: initial learning rate during burn-in period.  If
      0.0 (which is the default), then the burn-in learning rate is simply
      set to learning_rate_base.
    burnin_steps: number of steps to use burnin learning rate.
    min_learning_rate: the minimum learning rate.
    staircase: whether use staircase decay.

  Returns:
    a (scalar) float tensor representing learning rate
  """
  if burnin_learning_rate == 0:
    burnin_rate = learning_rate_base
  else:
    slope = (learning_rate_base - burnin_learning_rate) / burnin_steps
    burnin_rate = slope * tf.cast(global_step,
                                  tf.float32) + burnin_learning_rate
  post_burnin_learning_rate = tf.train.exponential_decay(
      learning_rate_base,
      global_step - burnin_steps,
      learning_rate_decay_steps,
      learning_rate_decay_factor,
      staircase=staircase)
  return tf.maximum(
      tf.where(
          tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
          burnin_rate, post_burnin_learning_rate),
      min_learning_rate,
      name='learning_rate')


def cosine_decay_with_warmup(global_step,
                             learning_rate_base,
                             total_steps,
                             warmup_learning_rate=0.0,
                             warmup_steps=0,
                             hold_base_rate_steps=0):
  """Cosine decay schedule with warm up period.

  Cosine annealing learning rate as described in:
    Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
    ICLR 2017. https://arxiv.org/abs/1608.03983
  In this schedule, the learning rate grows linearly from warmup_learning_rate
  to learning_rate_base for warmup_steps, then transitions to a cosine decay
  schedule.

  Args:
    global_step: int64 (scalar) tensor representing global step.
    learning_rate_base: base learning rate.
    total_steps: total number of training steps.
    warmup_learning_rate: initial learning rate for warm up.
    warmup_steps: number of warmup steps.
    hold_base_rate_steps: Optional number of steps to hold base learning rate
      before decaying.

  Returns:
    a (scalar) float tensor representing learning rate.

  Raises:
    ValueError: if warmup_learning_rate is larger than learning_rate_base,
      or if warmup_steps is larger than total_steps.
  """
  if learning_rate_base < warmup_learning_rate:
    raise ValueError('learning_rate_base must be larger '
                     'or equal to warmup_learning_rate.')
  if total_steps < warmup_steps:
    raise ValueError('total_steps must be larger or equal to ' 'warmup_steps.')
  learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
      np.pi *
      (tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps) /
      float(total_steps - warmup_steps - hold_base_rate_steps)))
  if hold_base_rate_steps > 0:
    learning_rate = tf.where(global_step > warmup_steps + hold_base_rate_steps,
                             learning_rate, learning_rate_base)
  if warmup_steps > 0:
    slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
    warmup_rate = slope * tf.cast(global_step,
                                  tf.float32) + warmup_learning_rate
    learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
                             learning_rate)
  return tf.where(
      global_step > total_steps, 0.0, learning_rate, name='learning_rate')


def manual_stepping(global_step, boundaries, rates, warmup=False):
  """Manually stepped learning rate schedule.

  This function provides fine grained control over learning rates.  One must
  specify a sequence of learning rates as well as a set of integer steps
  at which the current learning rate must transition to the next.  For example,
  if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
  rate returned by this function is .1 for global_step=0,...,4, .01 for
  global_step=5...9, and .001 for global_step=10 and onward.

  Args:
    global_step: int64 (scalar) tensor representing global step.
    boundaries: a list of global steps at which to switch learning
      rates.  This list is assumed to consist of increasing positive integers.
    rates: a list of (float) learning rates corresponding to intervals between
      the boundaries.  The length of this list must be exactly
      len(boundaries) + 1.
    warmup: Whether to linearly interpolate learning rate for steps in
      [0, boundaries[0]].

  Returns:
    a (scalar) float tensor representing learning rate
  Raises:
    ValueError: if one of the following checks fails:
      1. boundaries is a strictly increasing list of positive integers
      2. len(rates) == len(boundaries) + 1
      3. boundaries[0] != 0
  """
  if any([b < 0 for b in boundaries]) or any(
      [not isinstance(b, int) for b in boundaries]):
    raise ValueError('boundaries must be a list of positive integers')
  if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
    raise ValueError('Entries in boundaries must be strictly increasing.')
  if any([not isinstance(r, float) for r in rates]):
    raise ValueError('Learning rates must be floats')
  if len(rates) != len(boundaries) + 1:
    raise ValueError('Number of provided learning rates must exceed '
                     'number of boundary points by exactly 1.')

  if boundaries and boundaries[0] == 0:
    raise ValueError('First step cannot be zero.')

  if warmup and boundaries:
    slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
    warmup_steps = list(range(boundaries[0]))
    warmup_rates = [rates[0] + slope * step for step in warmup_steps]
    boundaries = warmup_steps + boundaries
    rates = warmup_rates + rates[1:]
  else:
    boundaries = [0] + boundaries
  num_boundaries = len(boundaries)
  rate_index = tf.reduce_max(
      tf.where(
          tf.greater_equal(global_step, boundaries),
          list(range(num_boundaries)), [0] * num_boundaries))
  return tf.reduce_sum(
      rates * tf.one_hot(rate_index, depth=num_boundaries),
      name='learning_rate')


def transformer_policy(global_step,
                       learning_rate,
                       d_model,
                       warmup_steps,
                       step_scaling_rate=1.0,
                       max_lr=None,
                       coefficient=1.0,
                       dtype=tf.float32):
  """Transformer's learning rate schedule.

  Transformer's learning rate policy from
  https://arxiv.org/pdf/1706.03762.pdf
  with a hat (max_lr) (also called "noam" learning rate decay scheme).

  Args:
    global_step: global step TensorFlow tensor (ignored for this policy).
    learning_rate (float): initial learning rate to use.
    d_model (int): model dimensionality.
    warmup_steps (int): number of warm-up steps.
    step_scaling_rate (float): num step scale rate
    max_lr (float): maximal learning rate, i.e. hat.
    coefficient (float): optimizer adjustment.
        Recommended 0.002 if using "Adam" else 1.0.
    dtype: dtype for this policy.

  Returns:
    learning rate at step ``global_step``.
  """
  step_num = tf.cast(global_step, dtype=dtype)
  ws = tf.cast(warmup_steps, dtype=dtype)
  step_num *= step_scaling_rate
  ws *= step_scaling_rate

  decay = coefficient * d_model**-0.5 * tf.minimum((step_num + 1) * ws**-1.5,
                                                   (step_num + 1)**-0.5)

  new_lr = decay * learning_rate
  if max_lr is not None:
    return tf.minimum(max_lr, new_lr)
  return new_lr
