in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]
def training_main(args, model_cls, forward_step_function, create_dataset_function, handle_metrics_function=None, init_function=None, collate_fn=None, forward_step_eval=None):
"""Main training program."""
hooks = {
'forward_step': forward_step_function,
'init_function': init_function,
'create_dataset_function': create_dataset_function,
'handle_metrics': handle_metrics_function,
'forward_step_eval': forward_step_eval or forward_step_function
}
timers = Timers() # Timer.
# Experiment Name
if args.load and args.mode == 'pretrain': # continue training
args.experiment_name = os.path.basename(os.path.normpath(args.load))
else:
args.experiment_name = args.experiment_name + '-' +datetime.now().strftime("%m-%d-%H-%M")
# Pytorch distributed. must before seed. ALREADY MOVED TO arguments.py!
# if isinstance(model_cls, type):
# initialize_distributed(args)
# set_random_seed(args.seed) # Random seeds for reproducability.
# Data stuff.
train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'], collate_fn=collate_fn)
if args.epochs:
args.train_iters = len(train_data)
if args.eval_interval is None:
args.eval_interval = len(train_data)//args.epochs
if args.save_interval is None:
args.save_interval = len(train_data)//args.epochs
# Build model
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
# for given model, make sure all the params are in the correct device, or the sync param will raise error
correct_device = torch.device(args.device)
for param in model.parameters():
if param.device != correct_device:
param.data = param.data.to(correct_device)
# register buffer
for name, buffer in model.named_buffers():
if buffer.device != correct_device:
buffer.data = buffer.data.to(correct_device)
# Config model IO
if args.load is not None:
args.iteration = load_checkpoint(model, args)
# if we don't load optim_states, filelock is no more needed.
# with FileLock("/root/checkpoint_lock", timeout=-1):
# args.iteration = load_checkpoint(model, optimizer, args)
else:
args.iteration = 0
if args.save:
args.save = os.path.join(args.save, args.experiment_name)
os.makedirs(args.save, exist_ok=True)
fh = ConcurrentRotatingFileHandler(os.path.join(args.save,'logfile.log'))
fh.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s'))
logger.addHandler(fh)
torch.distributed.barrier()
# init hook before building deepspeed model and optimizer
if hooks['init_function'] is not None:
hooks['init_function'](args, model)
# training
iteration = 0
if args.train_iters > 0:
# Optimization related things
model, optimizer = setup_model_untrainable_params_and_optimizer(args, model)
# initialize lr scheduler
lr_scheduler = get_learning_rate_scheduler(optimizer, args.iteration, args)
assert isinstance(lr_scheduler, AnnealingLR), \
'must be sat AnnealingLR, or the lr in param_groups will be wrong.'
summary_writer = None
if torch.distributed.get_rank() == 0:
if args.mode == 'pretrain':
print_rank0('Pretraining or Continuing training the Model...')
elif args.mode == 'finetune':
print_rank0('Finetuning Model...')
print_args(args)
summary_writer = get_sample_writer(base=args.summary_dir, name=args.experiment_name, iteration=args.iteration)
if args.wandb:
init_wandb_writer(args)
# Resume data loader if necessary.
if args.resume_dataloader:
if not args.iterable_dataset:
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % len(train_data)
if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
else:
print_rank0('Warning: we cannot resume iterable dataloader. skipping...')
if args.do_train:
with ExitStack() as stack:
def save_on_exit(args_, model_, optimizer_, lr_scheduler_):
save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)
# re-sync random seed, or tensor parallel might be broken (dropout, droppath)
# TODO add rng states for data parallel and wrap drops in main path.
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# ---------
iteration, skipped = train(model, optimizer,
lr_scheduler,
train_data,
val_data,
timers, args, summary_writer=summary_writer,
hooks=hooks
)
# final save
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# final testing
if args.do_test and test_data is not None:
prefix = 'the end of training for test data'
test_loss = evaluate_and_print_results(prefix, iter(test_data),
model, len(test_data) if args.strict_eval else args.eval_iters, args, timers, True, split='test', hooks=hooks)
return model