in main.py [0:0]
def launch(env_params,
model_params,
adapt_io_params,
adapt_span_params,
pers_mem_params,
optim_params,
data_params,
trainer_params):
# ENVIRONMENT (device, distributed, etc.)
set_up_env(env_params)
device = env_params['device']
distributed = env_params['distributed']
if distributed == False or env_params['rank'] == 0:
print('model_params:\t', model_params)
print('optim_params:\t', optim_params)
print('data_params:\t', data_params)
print('trainer_params:\t', trainer_params)
print('adapt_io_params:\t', adapt_io_params)
print('adapt_span_params:\t', adapt_span_params)
print('pers_mem_params:\t', pers_mem_params)
# DATA
train_data, val_data, test_data = get_train_val_test_data(
data_params=data_params,
env_params=env_params,
batch_size=trainer_params['batch_size'],
device=device,
sort_dict=adapt_io_params['adapt_io_enabled'])
# MODEL
model = TransformerSeq(
vocab_size=data_params['vocab_size'], **model_params,
adapt_io_params=adapt_io_params,
adapt_span_params=adapt_span_params,
pers_mem_params=pers_mem_params)
if distributed:
local_rank = env_params['local_rank']
model = model.to(device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], output_device=local_rank)
else:
model = torch.nn.DataParallel(model)
model = model.to(device)
# OPTIMIZER AND SCHEDULER
optimizer, scheduler = get_optimizer_and_scheduler(
model=model, optim_params=optim_params)
# create logger
logger = Logger(data_params['data_unit'])
# resume training from last checkpoint if exists
iter_init = load_checkpoint(
trainer_params['checkpoint_path'], model, optimizer, scheduler,
logger, distributed)
if trainer_params['full_eval_mode']:
# evaluate the model on test data
with torch.no_grad():
loss_val = full_eval(model, optimizer, scheduler, val_data,
model_params['block_size'],
model_params['hidden_size'])
loss_test = full_eval(model, optimizer, scheduler, test_data,
model_params['block_size'],
model_params['hidden_size'])
if distributed:
# collect results into rank0
stats = torch.tensor(
[loss_val, loss_test]).to(device)
torch.distributed.reduce(stats, 0)
if env_params['rank'] == 0:
loss_val = stats[0] / env_params['world_size']
loss_test = stats[1] / env_params['world_size']
else:
return
if data_params['data_unit'] == 'bpc':
print('val: {:.3f}bpc'.format(loss_val / math.log(2)))
print('test: {:.3f}bpc'.format(loss_test / math.log(2)))
else:
print('val: {:.2f}ppl'.format(math.exp(loss_val)))
print('test: {:.2f}ppl'.format(math.exp(loss_test)))
return
# position of current batch
data_pos = [0] * 2
# initialize caches for train and valid
hid_cache = [[
torch.zeros(
train_data.size(0),
layer.attn.attn.get_cache_size(),
model_params['hidden_size']).to(device)
for layer in model.module.layers] for _ in range(2)]
nb_batches_per_iter = trainer_params['nb_batches_per_iter']
for iter_no in range(iter_init, trainer_params['nb_iter']):
t_sta = time.time()
loss_train, data_pos[0], hid_cache[0] = train_iteration(
model, optimizer, scheduler, train_data, nb_batches_per_iter,
model_params['block_size'], False, data_pos[0], hid_cache[0],
trainer_params['batch_split'])
elapsed = 1000 * (time.time() - t_sta) / nb_batches_per_iter
with torch.no_grad():
loss_val, data_pos[1], hid_cache[1] = train_iteration(
model, optimizer, scheduler, val_data, nb_batches_per_iter,
model_params['block_size'], True, data_pos[1], hid_cache[1],
trainer_params['batch_split'])
if distributed:
# collect results into rank0
stats = torch.tensor(
[loss_train, loss_val]).to(device)
torch.distributed.reduce(stats, 0)
if env_params['rank'] == 0:
loss_train = stats[0] / env_params['world_size']
loss_val = stats[1] / env_params['world_size']
else:
continue
logger.log_iter(iter_no, nb_batches_per_iter, loss_train,
loss_val, elapsed, model)
save_checkpoint(trainer_params['checkpoint_path'],
iter_no, model, optimizer, scheduler, logger)