# coding=utf-8
# Rewrite by Ming Ding, Tsinghua University
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
import random
from collections import defaultdict
from contextlib import ExitStack
from datetime import datetime

import deepspeed
import numpy as np
import torch
import torch.distributed as dist
import wandb
from concurrent_log_handler import ConcurrentRotatingFileHandler

from sat import mpu
from sat.data_utils import make_loaders
from sat.helpers import logger, print_all, print_rank0
from sat.model.base_model import get_model
from sat.transformer_defaults import NO_WD_MODULES

from .learning_rates import AnnealingLR
from .model_io import load_checkpoint, save_checkpoint
from .utils import Timers, get_sample_writer, init_wandb_writer, print_args, report_memory


try:
    import wandb
except ImportError:
    print("wandb not installed.")

def training_main(args, model_cls, forward_step_function, create_dataset_function, handle_metrics_function=None, init_function=None, collate_fn=None, forward_step_eval=None):
    """Main training program."""
    hooks = {
        'forward_step': forward_step_function,
        'init_function': init_function,
        'create_dataset_function': create_dataset_function,
        'handle_metrics': handle_metrics_function,
        'forward_step_eval': forward_step_eval or forward_step_function
    }

    timers = Timers()  # Timer.

    # Experiment Name
    if args.load and args.mode == 'pretrain':  # continue training
        args.experiment_name = os.path.basename(os.path.normpath(args.load))
    else:
        args.experiment_name = args.experiment_name + '-' +datetime.now().strftime("%m-%d-%H-%M")

    # Pytorch distributed. must before seed. ALREADY MOVED TO arguments.py!
    # if isinstance(model_cls, type):
    #     initialize_distributed(args)
    #     set_random_seed(args.seed)  # Random seeds for reproducability.

    # Data stuff.
    train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'], collate_fn=collate_fn)
    if args.epochs:
        args.train_iters = len(train_data)
        if args.eval_interval is None:
            args.eval_interval = len(train_data)//args.epochs
        if args.save_interval is None:
            args.save_interval = len(train_data)//args.epochs

    # Build model
    if isinstance(model_cls, type):
        model = get_model(args, model_cls)
    else:
        model = model_cls
        # for given model, make sure all the params are in the correct device, or the sync param will raise error
        correct_device = torch.device(args.device)
        for param in model.parameters():
            if param.device != correct_device:
                param.data = param.data.to(correct_device)
        # register buffer
        for name, buffer in model.named_buffers():
            if buffer.device != correct_device:
                buffer.data = buffer.data.to(correct_device)

    # Config model IO
    if args.load is not None:
        args.iteration = load_checkpoint(model, args)
        # if we don't load optim_states, filelock is no more needed.
        # with FileLock("/root/checkpoint_lock", timeout=-1):
        #     args.iteration = load_checkpoint(model, optimizer, args)
    else:
        args.iteration = 0
    if args.save:
        args.save = os.path.join(args.save, args.experiment_name)
        os.makedirs(args.save, exist_ok=True)
    fh = ConcurrentRotatingFileHandler(os.path.join(args.save,'logfile.log'))
    fh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s'))
    logger.addHandler(fh)
    torch.distributed.barrier()

    # init hook before building deepspeed model and optimizer
    if hooks['init_function'] is not None:
        hooks['init_function'](args, model)

    # training 
    iteration = 0
    if args.train_iters > 0:
        # Optimization related things
        model, optimizer = setup_model_untrainable_params_and_optimizer(args, model)

        # initialize lr scheduler
        lr_scheduler = get_learning_rate_scheduler(optimizer, args.iteration, args)
        assert isinstance(lr_scheduler, AnnealingLR), \
            'must be sat AnnealingLR, or the lr in param_groups will be wrong.'

        summary_writer = None
        if torch.distributed.get_rank() == 0:
            if args.mode == 'pretrain':
                print_rank0('Pretraining or Continuing training the Model...')
            elif args.mode == 'finetune':
                print_rank0('Finetuning Model...')
            print_args(args)
            summary_writer = get_sample_writer(base=args.summary_dir, name=args.experiment_name, iteration=args.iteration)
            if args.wandb:
                init_wandb_writer(args)
        # Resume data loader if necessary.
        if args.resume_dataloader:
            if not args.iterable_dataset:
                if train_data is not None:
                    train_data.batch_sampler.start_iter = args.iteration % len(train_data)
                if val_data is not None:
                    start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
                    val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
            else:
                print_rank0('Warning: we cannot resume iterable dataloader. skipping...')

        if args.do_train:
            with ExitStack() as stack:
                def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
                    save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)

                # re-sync random seed, or tensor parallel might be broken (dropout, droppath)
                # TODO add rng states for data parallel and wrap drops in main path.
                random.seed(args.seed)
                np.random.seed(args.seed)
                torch.manual_seed(args.seed)
                torch.cuda.manual_seed(args.seed)
                # ---------

                iteration, skipped = train(model, optimizer,
                    lr_scheduler,
                    train_data,
                    val_data,
                    timers, args, summary_writer=summary_writer,
                    hooks=hooks
                    )

    # final save
    if args.save and iteration != 0:
        save_checkpoint(iteration, model, optimizer, lr_scheduler, args)

    # final testing
    if args.do_test and test_data is not None:
        prefix = 'the end of training for test data'
        test_loss = evaluate_and_print_results(prefix, iter(test_data),
            model, len(test_data) if args.strict_eval else args.eval_iters, args, timers, True, split='test', hooks=hooks)

    return model


def setup_model_untrainable_params_and_optimizer(args, model, config_params=None):
    """Setup model and optimizer."""

    if hasattr(model, 'disable_untrainable_params'):
        model.disable_untrainable_params() # mark trainable params

    param_groups = get_optimizer_param_groups(model)

    # sync initialized parameters
    # zero3 don't need to sync
    from sat.helpers import check_if_zero3
    if not check_if_zero3(args):
        print_rank0('Syncing initialized parameters...')
        for param_group in param_groups:
            for param in param_group['params']:
                if not param.model_parallel:
                    # We already keep the same random seed for different ranks
                    # However, it is not reliable. Non-model-parallel parameters could be different when initialization.
                    dist.broadcast(
                        param.data,
                        src=0, # group is default group
                    )
                else:
                    dist.broadcast(
                        param.data,
                        src=mpu.get_model_parallel_rank(), # 0 -- mp_size-1
                        group=mpu.get_data_parallel_group() # 1, mp_size + 1, ...
                    )
        print_rank0('Finished syncing initialized parameters.')

    if args.train_data is not None:
        if args.deepspeed:
            from packaging import version
            print_rank0("DeepSpeed is enabled.", level='DEBUG')
            # checking optimizer
            optimizer_name = args.deepspeed_config.get('optimizer',{}).get('type', '')
            if optimizer_name.startswith('sat.'):
                from functools import partial
                from importlib import import_module
                # split and import 
                optimizer_callable = getattr(import_module(optimizer_name.rsplit('.', maxsplit=1)[0]), optimizer_name.split('.')[-1])
                optimizer_callable = partial(optimizer_callable, **args.deepspeed_config.get('optimizer', {}).get('params', {}))
                print_rank0(f'Using optimizer {optimizer_name} from sat.')
                del args.deepspeed_config['optimizer']
            else:
                optimizer_callable = None

            model, optimizer, _, _ = deepspeed.initialize(
                model=model,
                model_parameters=param_groups,
                optimizer=optimizer_callable,
                args=args,
                mpu=mpu,
                dist_init_required=False,
                config_params=args.deepspeed_config
                    if version.parse(deepspeed.version) < version.parse("0.9.0")
                    else None
            )
        else:
            raise ValueError('Currently, we only support training with deepspeed.')
    else:
        optimizer = None

    return model, optimizer


def add_param_by_lr(dic, p, no_weight_decay=False):
    if not hasattr(p, 'lr_scale'):
        dic[None]['params'].append(p)
    else:
        if p.lr_scale not in dic:
            dic[p.lr_scale] = {'params': [], 'lr': p.lr_scale} if not no_weight_decay else {'params': [], 'weight_decay': 0.0, 'lr': p.lr_scale}
        dic[p.lr_scale]['params'].append(p)

def get_params_for_weight_decay_optimization(module):
    weight_decay_params = {None: {'params': [], 'lr': 1.}}
    no_weight_decay_params = {None: {'params': [], 'weight_decay': 0.0, 'lr': 1.}}
    print_rank0(f"{NO_WD_MODULES} is set to no_weight_decay")
    for module_ in module.modules():
        if isinstance(module_, tuple(NO_WD_MODULES)):
            for p in module_._parameters.values():
                if p is not None and p.requires_grad:
                    add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
        else:
            for n, p in module_._parameters.items():
                if p is not None and n != 'bias' and p.requires_grad:
                    flag = True if hasattr(p, 'no_weight_decay') and p.no_weight_decay else False
                    if flag:
                        print_rank0(f"{n} is set to no_weight_decay")
                        add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
                    else:
                        add_param_by_lr(weight_decay_params, p, no_weight_decay=False)
            for n, p in module_._parameters.items():
                if p is not None and n == 'bias' and p.requires_grad:
                    add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
    ret = []
    for v in weight_decay_params.values():
        if len(v['params']) != 0:
            ret.append(v)
    for v in no_weight_decay_params.values():
        if len(v['params']) != 0:
            ret.append(v)
    return ret


def get_optimizer_param_groups(model):
    # Build parameter groups (weight decay and non-decay).
    if hasattr(model, 'module'):
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)
    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel') and not hasattr(param, 'tensor_model_parallel'):
                param.model_parallel = False
                param.tensor_model_parallel = False
            else:
                assert hasattr(param, 'model_parallel') and hasattr(param, 'tensor_model_parallel'), "model_parallel and tensor_model_parallel should both be set or unset!"
    return param_groups


def get_learning_rate_scheduler(optimizer, iteration, args,
                                auto_warmup_steps=100, auto_warmup_rate=0.05):
    """Build the learning rate scheduler."""

    # Add linear learning rate scheduler.
    if args.lr_decay_iters is not None:
        num_iters = args.lr_decay_iters
    else:
        num_iters = args.train_iters
    num_iters = max(1, num_iters)
    init_step = max(iteration - auto_warmup_steps, 0)
    if args.mode == 'pretrain' and iteration == 0:
        auto_warmup_steps = 0
    # If init_step <= current_steps <= init_step + auto_warmup_steps,
    # lr = auto_warmup_rate * args.lr.
    # This overrides other rules.
    warmup_iter = args.warmup * num_iters
    lr_scheduler = AnnealingLR(optimizer,
                               start_lr=args.lr,
                               warmup_iter=warmup_iter,
                               num_iters=num_iters,
                               decay_style=args.lr_decay_style,
                               last_iter=init_step,
                               decay_ratio=args.lr_decay_ratio,
                               auto_warmup_steps=auto_warmup_steps,
                               auto_warmup_rate=auto_warmup_rate
                               )

    return lr_scheduler


def train(model, optimizer, lr_scheduler,
        train_data, val_data, timers, args,
        summary_writer=None, hooks={}):
    """Train the model."""
    if train_data is not None:
        train_data_iterator = iter(train_data)
    else:
        train_data_iterator = None
    if val_data is not None:
        val_data_iterator = iter(val_data)
    else:
        val_data_iterator = None

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0
    total_metrics = defaultdict(float)
    total_metrics_cnt = defaultdict(int)

    # Iterations.
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    while args.iteration < args.train_iters:
        if args.profiling != -1 and args.iteration == args.profiling:
            torch.cuda.cudart().cudaProfilerStart()

        if args.profiling != -1 and args.iteration >= args.profiling:
            torch.cuda.nvtx.range_push("iteration{}".format(args.iteration))
        lm_loss, skipped_iter, metrics = train_step(train_data_iterator,
                                                    model,
                                                    optimizer,
                                                    lr_scheduler,
                                                    args, timers, hooks=hooks)
        skipped_iters += skipped_iter
        if args.profiling != -1 and args.iteration >= args.profiling:
            torch.cuda.nvtx.range_pop()
        args.iteration += 1
        # Update losses.
        total_lm_loss += lm_loss.data.detach().float()
        for name in metrics:
            if not 'eval' in name:
                assert len(metrics[name].shape)==0, 'metrics without eval must be scalar'
                value = metrics[name].data.detach().float().item()
                if value > -99:
                    total_metrics[name] += value
                    total_metrics_cnt[name] += 1

        # Logging.
        if args.iteration % args.log_interval == 0:
            learning_rate = optimizer.param_groups[0]['lr']
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            # average img & txt loss
            avg_metrics = {}
            for key in total_metrics:
                avg_metrics[key] = total_metrics[key] / total_metrics_cnt[key] # args.log_interval

            elapsed_time = timers('interval time').elapsed()
            report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss,
                                     elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args,
                                     avg_metrics)
            total_lm_loss = 0.0
            total_metrics = defaultdict(float)
            total_metrics_cnt = defaultdict(int)
            if report_memory_flag:
                report_memory('after {} iterations'.format(args.iteration))
                report_memory_flag = False

            timers.log(['forward', 'backward', 'allreduce', 'optimizer',
                        'batch generator', 'data loader'],
                       normalizer=args.log_interval)
        # Checkpointing
        if args.save and args.save_interval and args.iteration % args.save_interval == 0:
            save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args)

        # Evaluation
        if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
            if args.strict_eval:
                val_data_iterator = iter(val_data)
                eval_iters = len(val_data)
            else:
                eval_iters = args.eval_iters
            prefix = 'iteration {}'.format(args.iteration)
            evaluate_and_print_results(
                prefix, val_data_iterator, model, eval_iters, args, timers, False, step=args.iteration, split='val', summary_writer=summary_writer, hooks=hooks)

        if args.exit_interval and args.iteration % args.exit_interval == 0:
            torch.distributed.barrier()
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
            print_all('rank: {} | time: {} | exiting the program at iteration {}'.
                  format(rank, time_str, args.iteration), flush=True)
            exit()
    if args.profiling != -1:
        torch.cuda.cudart().cudaProfilerStop()

    return args.iteration, skipped_iters


def train_step(data_iterator, model, optimizer, lr_scheduler,
               args, timers, hooks=None, single_step=False, **kwargs):
    """Single training step."""
    if hooks is None:
        hooks = {}
    lm_loss_total, metrics_total, count, metrics_count = 0.0, {}, 0, {}
    forward_step = hooks['forward_step']

    while True:
        profiling_flag = (args.profiling != -1 and args.iteration >= args.profiling)
        # Forward model for one step.
        if profiling_flag:
            torch.cuda.nvtx.range_push("forward")
        timers('forward').start()
        forward_ret = forward_step(data_iterator, model, args, timers, **kwargs)
        if isinstance(forward_ret, tuple):
            lm_loss, metrics = forward_ret
        else:
            lm_loss, metrics = forward_ret, {}
        timers('forward').stop()
        if profiling_flag:
            torch.cuda.nvtx.range_pop()

        # Check nan or inf in forward, preventing it from interfering loss scaler,
        # and all reduce metrics by the way
        if profiling_flag:
            torch.cuda.nvtx.range_push("loss_and_metrics")
        lm_loss_reduced = lm_loss.detach().clone()
        torch.distributed.all_reduce(lm_loss_reduced.data)
        lm_loss_reduced.data = lm_loss_reduced.data / args.world_size

        loss_checker = lm_loss_reduced
        for name in metrics:
            if not 'eval' in name:
                metrics[name] = metrics[name].detach().clone()
                if metrics[name].data.item() == -100:
                    cnt = torch.zeros(1, dtype=torch.int64, device=metrics[name].data.device)
                    metrics[name].data = torch.tensor(0., device=metrics[name].data.device)
                else:
                    cnt = torch.ones(1, dtype=torch.int64, device=metrics[name].data.device)
                torch.distributed.all_reduce(metrics[name].data)
                torch.distributed.all_reduce(cnt)
                if cnt.item() == 0:
                    metrics[name].data = torch.tensor(-100, device=metrics[name].data.device)
                else:
                    metrics[name].data /= cnt.cpu().item() # args.world_size
                loss_checker = loss_checker + metrics[name]
        if loss_checker.isnan().any() or loss_checker.isinf().any():
            print_all('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
            return lm_loss.detach(), 1, metrics

        # Accumulate the statistics
        lm_loss_total += lm_loss_reduced
        for name in metrics:
            if name not in metrics_total:
                metrics_total[name] = torch.tensor(0.0, device=metrics[name].data.device)
            if name not in metrics_count:
                metrics_count[name] = 0
            if metrics[name].data.item() != -100:
                metrics_total[name] += metrics[name]
                metrics_count[name] += 1
        count += 1
        if profiling_flag:
            torch.cuda.nvtx.range_pop()

        if profiling_flag:
            torch.cuda.nvtx.range_push("backward")
        # Calculate gradients, reduce across processes, and clip.
        timers('backward').start()
        backward_step(optimizer, model, lm_loss, args, timers)
        timers('backward').stop()
        if profiling_flag:
            torch.cuda.nvtx.range_pop()
        # Update parameters.
        skipped_iter, complete = 0, False
        if profiling_flag:
            torch.cuda.nvtx.range_push("optimizer")
        timers('optimizer').start()
        if args.deepspeed:
            if model.is_gradient_accumulation_boundary():
                model.step()
                complete = True
                if not (args.fp16 and optimizer.overflow):
                    lr_scheduler.step()
                else:
                    skipped_iter = 1
            else:
                model.step()
        else:
            raise ValueError('Currently, we only support training with deepspeed.')
        timers('optimizer').stop()
        if profiling_flag:
            torch.cuda.nvtx.range_pop()
        if complete or single_step:
            break
    lm_loss_total /= count
    metrics_total = {key: torch.tensor(-100, device=metrics_total[key].data.device) if metrics_count[key] == 0 else value / metrics_count[key] for key, value in metrics_total.items()}
    return lm_loss_total, skipped_iter, metrics_total


def backward_step(optimizer, model, loss, args, timers):
    """Backward step."""

    # Backward pass.
    if args.deepspeed:
        model.backward(loss)
    else:
        raise ValueError('Currently, we only support training with deepspeed.')

    if args.deepspeed:
        # DeepSpeed backward propagation already addressed all reduce communication.
        # Reset the timer to avoid breaking timer logs below.
        timers('allreduce').reset()

    return

def evaluate(data_iterator, model, eval_iters, args, timers, split, verbose=False, has_last=True, hooks={}):
    """Evaluation."""
    forward_step = hooks['forward_step_eval']
    # Turn on evaluation mode which disables dropout.
    model.eval()
    rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
    total_lm_loss, metrics_total = 0, {}
    if split=='val':
        last_shape = args.val_last_shape
        drop_number = args.val_drop_number
    else:
        assert split=='test'
        last_shape = args.test_last_shape
        drop_number = args.test_drop_number
    is_scalar = {}
    with torch.no_grad():
        iteration = 0
        while iteration < eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank0('Evaluating iter {}/{}'.format(iteration, eval_iters))
            # Forward evaluation.
            # try:
            lm_loss, metrics = forward_step(data_iterator, model, args, timers)
            '''when contiguous memory optimizations are enabled, the buffers
            allocated by the optimizations are deallocated during backward pass
            in the absence of backward pass the buffers should be reset after each
            forward pass'''
            if args.deepspeed and args.deepspeed_activation_checkpointing:
                deepspeed.checkpointing.reset()
            total_lm_loss += lm_loss.data.detach().float().item()
            is_last = True if iteration == eval_iters and args.strict_eval and len(last_shape)>0 else False
            for name in metrics:
                if name not in metrics_total:
                    metrics_total[name] = []
                is_scalar[name] = True if len(metrics[name].shape)==0 else False
                shape = list(metrics[name].shape)
                if not is_scalar[name] and is_last and metrics[name].shape[0] != last_shape[0]:
                    # pad tensor's first dim to args.batch_size
                    metrics[name] = torch.concat([metrics[name], torch.zeros([last_shape[0]-metrics[name].shape[0]] + shape[1:], dtype=metrics[name].dtype, device=metrics[name].device)])
                if rank==0:
                    metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)]
                else:
                    # metrics_gathered = None
                    metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)]
                # torch.distributed.gather(metrics[name], metrics_gathered, 0)
                torch.distributed.all_gather(metrics_gathered, metrics[name])

                if rank==0:
                    gathered_len = len(metrics_gathered) if not is_last else len(metrics_gathered) - drop_number * args.model_parallel_size
                    for i in range(gathered_len):
                        if is_scalar[name] or not is_last:
                            metrics_total[name].append(metrics_gathered[i].data.cpu())
                        else:
                            metrics_total[name].append(metrics_gathered[i][:last_shape[i]].data.cpu())
    # Move model back to the train mode.
    model.train()

    total_lm_loss /= eval_iters
    # metrics_avg = {key: value / eval_iters for key, value in metrics_total.items()}
    if rank==0:
        for name in metrics_total:
            if is_scalar[name]:
                metrics_total[name] = torch.stack(metrics_total[name], dim=0)
            else:
                metrics_total[name] = torch.concat(metrics_total[name], dim=0)
        if hooks['handle_metrics'] is not None:
            metrics = hooks['handle_metrics'](metrics_total)
        else:
            for name in metrics_total:
                assert is_scalar[name], 'you must return scalar metrics or implement handle_metrics hooks'
            metrics = {key: sum(value.split(1,0))/len(value) for key, value in metrics_total.items()}
    else:
        metrics = None
    return total_lm_loss, metrics

def evaluate_and_print_results(prefix, data_iterator, model, eval_iters,
                            args, timers, has_last, split, verbose=False, step=None, summary_writer=None, hooks={}):
    """Helper function to evaluate and dump results on screen."""
    lm_loss, metrics = evaluate(data_iterator, model, eval_iters, args, timers, split, verbose, has_last, hooks=hooks)
    lm_ppl = math.exp(min(20, lm_loss))
    if torch.distributed.get_rank(group=mpu.get_data_parallel_group())==0:
        report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step, args, metrics)
    return lm_loss


def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, avg_metrics):
    log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step)
    log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time)
    log_string += ' learning rate {:.3E} |'.format(lr)
    log_string += ' total loss {:.6E} |'.format(loss)
    for key in avg_metrics:
        log_string += ' {} {:.6E} |'.format(key, avg_metrics[key])
    if args.fp16:
        log_string += ' loss scale {:.1f} |'.format(
            optimizer.cur_scale if args.deepspeed else optimizer.loss_scale)
    log_string += 'speed {:.2f} samples/(min*GPU)'.format(
        (args.gradient_accumulation_steps * args.batch_size / args.model_parallel_size / (elapsed_time / 60000.0)))
    print_rank0(log_string)
    if summary_writer is not None:
        summary_writer.add_scalar(f'Train/lr', lr, step)
        summary_writer.add_scalar(f'Train/train_loss', loss, step)
        summary_writer.add_scalar(f'Train/elapsed_time', elapsed_time, step)
        for key in avg_metrics:
            summary_writer.add_scalar('Train/'+key, avg_metrics[key], step)
    if args.wandb and torch.distributed.get_rank() == 0:
        log_dict = {
            "Train/lr": lr,
            "Train/train_loss": loss,
            "Train/elapsed_time": elapsed_time
            }
        for key in avg_metrics:
            log_dict["Train/" + key] = avg_metrics[key]
        wandb.log(log_dict, step=step, commit=True)


def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step, args, avg_metrics):
    string = ' validation loss at {} | '.format(prefix)
    string += 'loss: {:.6E} | '.format(loss)
    string += 'PPL: {:.6E}'.format(ppl)
    for key in avg_metrics:
        string += ' {} {:.6E} |'.format(key, avg_metrics[key].item())
    length = len(string) + 1
    print_rank0('-' * 100)
    print_rank0('-' * length)
    print_rank0(string)
    print_rank0('-' * length)
    if summary_writer is not None:
        summary_writer.add_scalar(f'Train/valid_ppl', ppl, step)
        summary_writer.add_scalar(f'Train/valid_loss', loss, step)
        for key in avg_metrics:
            summary_writer.add_scalar('Train/valid_'+key, avg_metrics[key], step)
    if args.wandb and torch.distributed.get_rank() == 0:
        log_dict = {
            "Train/valid_ppl": ppl,
            "Train/valid_loss": loss,
            }
        for key in avg_metrics:
            log_dict["Train/valid_" + key] = avg_metrics[key]
        wandb.log(log_dict, step=step, commit=True)
