in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]
def evaluate(data_iterator, model, eval_iters, args, timers, split, verbose=False, has_last=True, hooks={}):
"""Evaluation."""
forward_step = hooks['forward_step_eval']
# Turn on evaluation mode which disables dropout.
model.eval()
rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
total_lm_loss, metrics_total = 0, {}
if split=='val':
last_shape = args.val_last_shape
drop_number = args.val_drop_number
else:
assert split=='test'
last_shape = args.test_last_shape
drop_number = args.test_drop_number
is_scalar = {}
with torch.no_grad():
iteration = 0
while iteration < eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank0('Evaluating iter {}/{}'.format(iteration, eval_iters))
# Forward evaluation.
# try:
lm_loss, metrics = forward_step(data_iterator, model, args, timers)
'''when contiguous memory optimizations are enabled, the buffers
allocated by the optimizations are deallocated during backward pass
in the absence of backward pass the buffers should be reset after each
forward pass'''
if args.deepspeed and args.deepspeed_activation_checkpointing:
deepspeed.checkpointing.reset()
total_lm_loss += lm_loss.data.detach().float().item()
is_last = True if iteration == eval_iters and args.strict_eval and len(last_shape)>0 else False
for name in metrics:
if name not in metrics_total:
metrics_total[name] = []
is_scalar[name] = True if len(metrics[name].shape)==0 else False
shape = list(metrics[name].shape)
if not is_scalar[name] and is_last and metrics[name].shape[0] != last_shape[0]:
# pad tensor's first dim to args.batch_size
metrics[name] = torch.concat([metrics[name], torch.zeros([last_shape[0]-metrics[name].shape[0]] + shape[1:], dtype=metrics[name].dtype, device=metrics[name].device)])
if rank==0:
metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)]
else:
# metrics_gathered = None
metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)]
# torch.distributed.gather(metrics[name], metrics_gathered, 0)
torch.distributed.all_gather(metrics_gathered, metrics[name])
if rank==0:
gathered_len = len(metrics_gathered) if not is_last else len(metrics_gathered) - drop_number * args.model_parallel_size
for i in range(gathered_len):
if is_scalar[name] or not is_last:
metrics_total[name].append(metrics_gathered[i].data.cpu())
else:
metrics_total[name].append(metrics_gathered[i][:last_shape[i]].data.cpu())
# Move model back to the train mode.
model.train()
total_lm_loss /= eval_iters
# metrics_avg = {key: value / eval_iters for key, value in metrics_total.items()}
if rank==0:
for name in metrics_total:
if is_scalar[name]:
metrics_total[name] = torch.stack(metrics_total[name], dim=0)
else:
metrics_total[name] = torch.concat(metrics_total[name], dim=0)
if hooks['handle_metrics'] is not None:
metrics = hooks['handle_metrics'](metrics_total)
else:
for name in metrics_total:
assert is_scalar[name], 'you must return scalar metrics or implement handle_metrics hooks'
metrics = {key: sum(value.split(1,0))/len(value) for key, value in metrics_total.items()}
else:
metrics = None
return total_lm_loss, metrics