modules/SwissArmyTransformer/sat/training/deepspeed_training.py (534 lines of code) (raw):

# coding=utf-8 # Rewrite by Ming Ding, Tsinghua University # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math import os import random from collections import defaultdict from contextlib import ExitStack from datetime import datetime import deepspeed import numpy as np import torch import torch.distributed as dist import wandb from concurrent_log_handler import ConcurrentRotatingFileHandler from sat import mpu from sat.data_utils import make_loaders from sat.helpers import logger, print_all, print_rank0 from sat.model.base_model import get_model from sat.transformer_defaults import NO_WD_MODULES from .learning_rates import AnnealingLR from .model_io import load_checkpoint, save_checkpoint from .utils import Timers, get_sample_writer, init_wandb_writer, print_args, report_memory try: import wandb except ImportError: print("wandb not installed.") def training_main(args, model_cls, forward_step_function, create_dataset_function, handle_metrics_function=None, init_function=None, collate_fn=None, forward_step_eval=None): """Main training program.""" hooks = { 'forward_step': forward_step_function, 'init_function': init_function, 'create_dataset_function': create_dataset_function, 'handle_metrics': handle_metrics_function, 'forward_step_eval': forward_step_eval or forward_step_function } timers = Timers() # Timer. # Experiment Name if args.load and args.mode == 'pretrain': # continue training args.experiment_name = os.path.basename(os.path.normpath(args.load)) else: args.experiment_name = args.experiment_name + '-' +datetime.now().strftime("%m-%d-%H-%M") # Pytorch distributed. must before seed. ALREADY MOVED TO arguments.py! # if isinstance(model_cls, type): # initialize_distributed(args) # set_random_seed(args.seed) # Random seeds for reproducability. # Data stuff. train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'], collate_fn=collate_fn) if args.epochs: args.train_iters = len(train_data) if args.eval_interval is None: args.eval_interval = len(train_data)//args.epochs if args.save_interval is None: args.save_interval = len(train_data)//args.epochs # Build model if isinstance(model_cls, type): model = get_model(args, model_cls) else: model = model_cls # for given model, make sure all the params are in the correct device, or the sync param will raise error correct_device = torch.device(args.device) for param in model.parameters(): if param.device != correct_device: param.data = param.data.to(correct_device) # register buffer for name, buffer in model.named_buffers(): if buffer.device != correct_device: buffer.data = buffer.data.to(correct_device) # Config model IO if args.load is not None: args.iteration = load_checkpoint(model, args) # if we don't load optim_states, filelock is no more needed. # with FileLock("/root/checkpoint_lock", timeout=-1): # args.iteration = load_checkpoint(model, optimizer, args) else: args.iteration = 0 if args.save: args.save = os.path.join(args.save, args.experiment_name) os.makedirs(args.save, exist_ok=True) fh = ConcurrentRotatingFileHandler(os.path.join(args.save,'logfile.log')) fh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s')) logger.addHandler(fh) torch.distributed.barrier() # init hook before building deepspeed model and optimizer if hooks['init_function'] is not None: hooks['init_function'](args, model) # training iteration = 0 if args.train_iters > 0: # Optimization related things model, optimizer = setup_model_untrainable_params_and_optimizer(args, model) # initialize lr scheduler lr_scheduler = get_learning_rate_scheduler(optimizer, args.iteration, args) assert isinstance(lr_scheduler, AnnealingLR), \ 'must be sat AnnealingLR, or the lr in param_groups will be wrong.' summary_writer = None if torch.distributed.get_rank() == 0: if args.mode == 'pretrain': print_rank0('Pretraining or Continuing training the Model...') elif args.mode == 'finetune': print_rank0('Finetuning Model...') print_args(args) summary_writer = get_sample_writer(base=args.summary_dir, name=args.experiment_name, iteration=args.iteration) if args.wandb: init_wandb_writer(args) # Resume data loader if necessary. if args.resume_dataloader: if not args.iterable_dataset: if train_data is not None: train_data.batch_sampler.start_iter = args.iteration % len(train_data) if val_data is not None: start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval val_data.batch_sampler.start_iter = start_iter_val % len(val_data) else: print_rank0('Warning: we cannot resume iterable dataloader. skipping...') if args.do_train: with ExitStack() as stack: def save_on_exit(args_, model_, optimizer_, lr_scheduler_): save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_) # re-sync random seed, or tensor parallel might be broken (dropout, droppath) # TODO add rng states for data parallel and wrap drops in main path. random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) # --------- iteration, skipped = train(model, optimizer, lr_scheduler, train_data, val_data, timers, args, summary_writer=summary_writer, hooks=hooks ) # final save if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, lr_scheduler, args) # final testing if args.do_test and test_data is not None: prefix = 'the end of training for test data' test_loss = evaluate_and_print_results(prefix, iter(test_data), model, len(test_data) if args.strict_eval else args.eval_iters, args, timers, True, split='test', hooks=hooks) return model def setup_model_untrainable_params_and_optimizer(args, model, config_params=None): """Setup model and optimizer.""" if hasattr(model, 'disable_untrainable_params'): model.disable_untrainable_params() # mark trainable params param_groups = get_optimizer_param_groups(model) # sync initialized parameters # zero3 don't need to sync from sat.helpers import check_if_zero3 if not check_if_zero3(args): print_rank0('Syncing initialized parameters...') for param_group in param_groups: for param in param_group['params']: if not param.model_parallel: # We already keep the same random seed for different ranks # However, it is not reliable. Non-model-parallel parameters could be different when initialization. dist.broadcast( param.data, src=0, # group is default group ) else: dist.broadcast( param.data, src=mpu.get_model_parallel_rank(), # 0 -- mp_size-1 group=mpu.get_data_parallel_group() # 1, mp_size + 1, ... ) print_rank0('Finished syncing initialized parameters.') if args.train_data is not None: if args.deepspeed: from packaging import version print_rank0("DeepSpeed is enabled.", level='DEBUG') # checking optimizer optimizer_name = args.deepspeed_config.get('optimizer',{}).get('type', '') if optimizer_name.startswith('sat.'): from functools import partial from importlib import import_module # split and import optimizer_callable = getattr(import_module(optimizer_name.rsplit('.', maxsplit=1)[0]), optimizer_name.split('.')[-1]) optimizer_callable = partial(optimizer_callable, **args.deepspeed_config.get('optimizer', {}).get('params', {})) print_rank0(f'Using optimizer {optimizer_name} from sat.') del args.deepspeed_config['optimizer'] else: optimizer_callable = None model, optimizer, _, _ = deepspeed.initialize( model=model, model_parameters=param_groups, optimizer=optimizer_callable, args=args, mpu=mpu, dist_init_required=False, config_params=args.deepspeed_config if version.parse(deepspeed.version) < version.parse("0.9.0") else None ) else: raise ValueError('Currently, we only support training with deepspeed.') else: optimizer = None return model, optimizer def add_param_by_lr(dic, p, no_weight_decay=False): if not hasattr(p, 'lr_scale'): dic[None]['params'].append(p) else: if p.lr_scale not in dic: dic[p.lr_scale] = {'params': [], 'lr': p.lr_scale} if not no_weight_decay else {'params': [], 'weight_decay': 0.0, 'lr': p.lr_scale} dic[p.lr_scale]['params'].append(p) def get_params_for_weight_decay_optimization(module): weight_decay_params = {None: {'params': [], 'lr': 1.}} no_weight_decay_params = {None: {'params': [], 'weight_decay': 0.0, 'lr': 1.}} print_rank0(f"{NO_WD_MODULES} is set to no_weight_decay") for module_ in module.modules(): if isinstance(module_, tuple(NO_WD_MODULES)): for p in module_._parameters.values(): if p is not None and p.requires_grad: add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True) else: for n, p in module_._parameters.items(): if p is not None and n != 'bias' and p.requires_grad: flag = True if hasattr(p, 'no_weight_decay') and p.no_weight_decay else False if flag: print_rank0(f"{n} is set to no_weight_decay") add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True) else: add_param_by_lr(weight_decay_params, p, no_weight_decay=False) for n, p in module_._parameters.items(): if p is not None and n == 'bias' and p.requires_grad: add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True) ret = [] for v in weight_decay_params.values(): if len(v['params']) != 0: ret.append(v) for v in no_weight_decay_params.values(): if len(v['params']) != 0: ret.append(v) return ret def get_optimizer_param_groups(model): # Build parameter groups (weight decay and non-decay). if hasattr(model, 'module'): model = model.module param_groups = get_params_for_weight_decay_optimization(model) # Add model parallel attribute if it is not set. for param_group in param_groups: for param in param_group['params']: if not hasattr(param, 'model_parallel') and not hasattr(param, 'tensor_model_parallel'): param.model_parallel = False param.tensor_model_parallel = False else: assert hasattr(param, 'model_parallel') and hasattr(param, 'tensor_model_parallel'), "model_parallel and tensor_model_parallel should both be set or unset!" return param_groups def get_learning_rate_scheduler(optimizer, iteration, args, auto_warmup_steps=100, auto_warmup_rate=0.05): """Build the learning rate scheduler.""" # Add linear learning rate scheduler. if args.lr_decay_iters is not None: num_iters = args.lr_decay_iters else: num_iters = args.train_iters num_iters = max(1, num_iters) init_step = max(iteration - auto_warmup_steps, 0) if args.mode == 'pretrain' and iteration == 0: auto_warmup_steps = 0 # If init_step <= current_steps <= init_step + auto_warmup_steps, # lr = auto_warmup_rate * args.lr. # This overrides other rules. warmup_iter = args.warmup * num_iters lr_scheduler = AnnealingLR(optimizer, start_lr=args.lr, warmup_iter=warmup_iter, num_iters=num_iters, decay_style=args.lr_decay_style, last_iter=init_step, decay_ratio=args.lr_decay_ratio, auto_warmup_steps=auto_warmup_steps, auto_warmup_rate=auto_warmup_rate ) return lr_scheduler def train(model, optimizer, lr_scheduler, train_data, val_data, timers, args, summary_writer=None, hooks={}): """Train the model.""" if train_data is not None: train_data_iterator = iter(train_data) else: train_data_iterator = None if val_data is not None: val_data_iterator = iter(val_data) else: val_data_iterator = None # Turn on training mode which enables dropout. model.train() # Tracking loss. total_lm_loss = 0.0 total_metrics = defaultdict(float) total_metrics_cnt = defaultdict(int) # Iterations. skipped_iters = 0 timers('interval time').start() report_memory_flag = True while args.iteration < args.train_iters: if args.profiling != -1 and args.iteration == args.profiling: torch.cuda.cudart().cudaProfilerStart() if args.profiling != -1 and args.iteration >= args.profiling: torch.cuda.nvtx.range_push("iteration{}".format(args.iteration)) lm_loss, skipped_iter, metrics = train_step(train_data_iterator, model, optimizer, lr_scheduler, args, timers, hooks=hooks) skipped_iters += skipped_iter if args.profiling != -1 and args.iteration >= args.profiling: torch.cuda.nvtx.range_pop() args.iteration += 1 # Update losses. total_lm_loss += lm_loss.data.detach().float() for name in metrics: if not 'eval' in name: assert len(metrics[name].shape)==0, 'metrics without eval must be scalar' value = metrics[name].data.detach().float().item() if value > -99: total_metrics[name] += value total_metrics_cnt[name] += 1 # Logging. if args.iteration % args.log_interval == 0: learning_rate = optimizer.param_groups[0]['lr'] avg_lm_loss = total_lm_loss.item() / args.log_interval # average img & txt loss avg_metrics = {} for key in total_metrics: avg_metrics[key] = total_metrics[key] / total_metrics_cnt[key] # args.log_interval elapsed_time = timers('interval time').elapsed() report_iteration_metrics(summary_writer, optimizer, learning_rate, avg_lm_loss, elapsed_time * 1000.0 / args.log_interval, args.iteration, args.train_iters, args, avg_metrics) total_lm_loss = 0.0 total_metrics = defaultdict(float) total_metrics_cnt = defaultdict(int) if report_memory_flag: report_memory('after {} iterations'.format(args.iteration)) report_memory_flag = False timers.log(['forward', 'backward', 'allreduce', 'optimizer', 'batch generator', 'data loader'], normalizer=args.log_interval) # Checkpointing if args.save and args.save_interval and args.iteration % args.save_interval == 0: save_checkpoint(args.iteration, model, optimizer, lr_scheduler, args) # Evaluation if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid: if args.strict_eval: val_data_iterator = iter(val_data) eval_iters = len(val_data) else: eval_iters = args.eval_iters prefix = 'iteration {}'.format(args.iteration) evaluate_and_print_results( prefix, val_data_iterator, model, eval_iters, args, timers, False, step=args.iteration, split='val', summary_writer=summary_writer, hooks=hooks) if args.exit_interval and args.iteration % args.exit_interval == 0: torch.distributed.barrier() time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') rank = torch.distributed.get_rank() print_all('rank: {} | time: {} | exiting the program at iteration {}'. format(rank, time_str, args.iteration), flush=True) exit() if args.profiling != -1: torch.cuda.cudart().cudaProfilerStop() return args.iteration, skipped_iters def train_step(data_iterator, model, optimizer, lr_scheduler, args, timers, hooks=None, single_step=False, **kwargs): """Single training step.""" if hooks is None: hooks = {} lm_loss_total, metrics_total, count, metrics_count = 0.0, {}, 0, {} forward_step = hooks['forward_step'] while True: profiling_flag = (args.profiling != -1 and args.iteration >= args.profiling) # Forward model for one step. if profiling_flag: torch.cuda.nvtx.range_push("forward") timers('forward').start() forward_ret = forward_step(data_iterator, model, args, timers, **kwargs) if isinstance(forward_ret, tuple): lm_loss, metrics = forward_ret else: lm_loss, metrics = forward_ret, {} timers('forward').stop() if profiling_flag: torch.cuda.nvtx.range_pop() # Check nan or inf in forward, preventing it from interfering loss scaler, # and all reduce metrics by the way if profiling_flag: torch.cuda.nvtx.range_push("loss_and_metrics") lm_loss_reduced = lm_loss.detach().clone() torch.distributed.all_reduce(lm_loss_reduced.data) lm_loss_reduced.data = lm_loss_reduced.data / args.world_size loss_checker = lm_loss_reduced for name in metrics: if not 'eval' in name: metrics[name] = metrics[name].detach().clone() if metrics[name].data.item() == -100: cnt = torch.zeros(1, dtype=torch.int64, device=metrics[name].data.device) metrics[name].data = torch.tensor(0., device=metrics[name].data.device) else: cnt = torch.ones(1, dtype=torch.int64, device=metrics[name].data.device) torch.distributed.all_reduce(metrics[name].data) torch.distributed.all_reduce(cnt) if cnt.item() == 0: metrics[name].data = torch.tensor(-100, device=metrics[name].data.device) else: metrics[name].data /= cnt.cpu().item() # args.world_size loss_checker = loss_checker + metrics[name] if loss_checker.isnan().any() or loss_checker.isinf().any(): print_all('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!') return lm_loss.detach(), 1, metrics # Accumulate the statistics lm_loss_total += lm_loss_reduced for name in metrics: if name not in metrics_total: metrics_total[name] = torch.tensor(0.0, device=metrics[name].data.device) if name not in metrics_count: metrics_count[name] = 0 if metrics[name].data.item() != -100: metrics_total[name] += metrics[name] metrics_count[name] += 1 count += 1 if profiling_flag: torch.cuda.nvtx.range_pop() if profiling_flag: torch.cuda.nvtx.range_push("backward") # Calculate gradients, reduce across processes, and clip. timers('backward').start() backward_step(optimizer, model, lm_loss, args, timers) timers('backward').stop() if profiling_flag: torch.cuda.nvtx.range_pop() # Update parameters. skipped_iter, complete = 0, False if profiling_flag: torch.cuda.nvtx.range_push("optimizer") timers('optimizer').start() if args.deepspeed: if model.is_gradient_accumulation_boundary(): model.step() complete = True if not (args.fp16 and optimizer.overflow): lr_scheduler.step() else: skipped_iter = 1 else: model.step() else: raise ValueError('Currently, we only support training with deepspeed.') timers('optimizer').stop() if profiling_flag: torch.cuda.nvtx.range_pop() if complete or single_step: break lm_loss_total /= count metrics_total = {key: torch.tensor(-100, device=metrics_total[key].data.device) if metrics_count[key] == 0 else value / metrics_count[key] for key, value in metrics_total.items()} return lm_loss_total, skipped_iter, metrics_total def backward_step(optimizer, model, loss, args, timers): """Backward step.""" # Backward pass. if args.deepspeed: model.backward(loss) else: raise ValueError('Currently, we only support training with deepspeed.') if args.deepspeed: # DeepSpeed backward propagation already addressed all reduce communication. # Reset the timer to avoid breaking timer logs below. timers('allreduce').reset() return def evaluate(data_iterator, model, eval_iters, args, timers, split, verbose=False, has_last=True, hooks={}): """Evaluation.""" forward_step = hooks['forward_step_eval'] # Turn on evaluation mode which disables dropout. model.eval() rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) total_lm_loss, metrics_total = 0, {} if split=='val': last_shape = args.val_last_shape drop_number = args.val_drop_number else: assert split=='test' last_shape = args.test_last_shape drop_number = args.test_drop_number is_scalar = {} with torch.no_grad(): iteration = 0 while iteration < eval_iters: iteration += 1 if verbose and iteration % args.log_interval == 0: print_rank0('Evaluating iter {}/{}'.format(iteration, eval_iters)) # Forward evaluation. # try: lm_loss, metrics = forward_step(data_iterator, model, args, timers) '''when contiguous memory optimizations are enabled, the buffers allocated by the optimizations are deallocated during backward pass in the absence of backward pass the buffers should be reset after each forward pass''' if args.deepspeed and args.deepspeed_activation_checkpointing: deepspeed.checkpointing.reset() total_lm_loss += lm_loss.data.detach().float().item() is_last = True if iteration == eval_iters and args.strict_eval and len(last_shape)>0 else False for name in metrics: if name not in metrics_total: metrics_total[name] = [] is_scalar[name] = True if len(metrics[name].shape)==0 else False shape = list(metrics[name].shape) if not is_scalar[name] and is_last and metrics[name].shape[0] != last_shape[0]: # pad tensor's first dim to args.batch_size metrics[name] = torch.concat([metrics[name], torch.zeros([last_shape[0]-metrics[name].shape[0]] + shape[1:], dtype=metrics[name].dtype, device=metrics[name].device)]) if rank==0: metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)] else: # metrics_gathered = None metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)] # torch.distributed.gather(metrics[name], metrics_gathered, 0) torch.distributed.all_gather(metrics_gathered, metrics[name]) if rank==0: gathered_len = len(metrics_gathered) if not is_last else len(metrics_gathered) - drop_number * args.model_parallel_size for i in range(gathered_len): if is_scalar[name] or not is_last: metrics_total[name].append(metrics_gathered[i].data.cpu()) else: metrics_total[name].append(metrics_gathered[i][:last_shape[i]].data.cpu()) # Move model back to the train mode. model.train() total_lm_loss /= eval_iters # metrics_avg = {key: value / eval_iters for key, value in metrics_total.items()} if rank==0: for name in metrics_total: if is_scalar[name]: metrics_total[name] = torch.stack(metrics_total[name], dim=0) else: metrics_total[name] = torch.concat(metrics_total[name], dim=0) if hooks['handle_metrics'] is not None: metrics = hooks['handle_metrics'](metrics_total) else: for name in metrics_total: assert is_scalar[name], 'you must return scalar metrics or implement handle_metrics hooks' metrics = {key: sum(value.split(1,0))/len(value) for key, value in metrics_total.items()} else: metrics = None return total_lm_loss, metrics def evaluate_and_print_results(prefix, data_iterator, model, eval_iters, args, timers, has_last, split, verbose=False, step=None, summary_writer=None, hooks={}): """Helper function to evaluate and dump results on screen.""" lm_loss, metrics = evaluate(data_iterator, model, eval_iters, args, timers, split, verbose, has_last, hooks=hooks) lm_ppl = math.exp(min(20, lm_loss)) if torch.distributed.get_rank(group=mpu.get_data_parallel_group())==0: report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step, args, metrics) return lm_loss def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, step, total_step, args, avg_metrics): log_string = ' iteration {:8d}/{:8d} |'.format(step, total_step) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time) log_string += ' learning rate {:.3E} |'.format(lr) log_string += ' total loss {:.6E} |'.format(loss) for key in avg_metrics: log_string += ' {} {:.6E} |'.format(key, avg_metrics[key]) if args.fp16: log_string += ' loss scale {:.1f} |'.format( optimizer.cur_scale if args.deepspeed else optimizer.loss_scale) log_string += 'speed {:.2f} samples/(min*GPU)'.format( (args.gradient_accumulation_steps * args.batch_size / args.model_parallel_size / (elapsed_time / 60000.0))) print_rank0(log_string) if summary_writer is not None: summary_writer.add_scalar(f'Train/lr', lr, step) summary_writer.add_scalar(f'Train/train_loss', loss, step) summary_writer.add_scalar(f'Train/elapsed_time', elapsed_time, step) for key in avg_metrics: summary_writer.add_scalar('Train/'+key, avg_metrics[key], step) if args.wandb and torch.distributed.get_rank() == 0: log_dict = { "Train/lr": lr, "Train/train_loss": loss, "Train/elapsed_time": elapsed_time } for key in avg_metrics: log_dict["Train/" + key] = avg_metrics[key] wandb.log(log_dict, step=step, commit=True) def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step, args, avg_metrics): string = ' validation loss at {} | '.format(prefix) string += 'loss: {:.6E} | '.format(loss) string += 'PPL: {:.6E}'.format(ppl) for key in avg_metrics: string += ' {} {:.6E} |'.format(key, avg_metrics[key].item()) length = len(string) + 1 print_rank0('-' * 100) print_rank0('-' * length) print_rank0(string) print_rank0('-' * length) if summary_writer is not None: summary_writer.add_scalar(f'Train/valid_ppl', ppl, step) summary_writer.add_scalar(f'Train/valid_loss', loss, step) for key in avg_metrics: summary_writer.add_scalar('Train/valid_'+key, avg_metrics[key], step) if args.wandb and torch.distributed.get_rank() == 0: log_dict = { "Train/valid_ppl": ppl, "Train/valid_loss": loss, } for key in avg_metrics: log_dict["Train/valid_" + key] = avg_metrics[key] wandb.log(log_dict, step=step, commit=True)