in model/train.py [0:0]
def train():
global train_step
global best_val_nll
log_train_loss = torch.tensor(0.0).float().to(device)
log_grad_norm = torch.tensor(0.0).float().to(device)
log_token_num = torch.tensor(0).to(device)
# Discriminator related
log_gen_train_loss = torch.tensor(0.0).float().to(device) # Log discriminator loss
log_gen_num = torch.tensor(0.0).float().to(device)
log_dis_train_loss = torch.tensor(0.0).float().to(device)
log_dis_num = torch.tensor(0.0).float().to(device)
dis_iterations = 0 # Num dis iters
best_gen_val_loss = np.inf
if cfg.DISCRIMINATOR.type != "Null" and cfg.DISCRIMINATOR.type != "":
dis_iterator = dis_iter()
log_start_time = time.time() # coding: utf-8
mems = [None for _ in range(cfg.TRAIN.batch_chunk)]
assert batch_size % cfg.TRAIN.batch_chunk == 0
train_real_iter = train_iter()
for batch, (data, target, reset_mems, batch_token_num, status_vec) in enumerate(
train_real_iter
):
beta = get_fixed_temperature(
cfg.DISCRIMINATOR.beta_max,
train_step,
cfg.TRAIN.max_step,
cfg.DISCRIMINATOR.adapt,
)
model.module.temperature = 1.0 / beta
model.zero_grad()
# Batch chunking
data_chunks = torch.chunk(data, cfg.TRAIN.batch_chunk, 1)
target_chunks = torch.chunk(target, cfg.TRAIN.batch_chunk, 1)
reset_mems_chunks = torch.chunk(reset_mems, cfg.TRAIN.batch_chunk, 0)
if status_vec is not None:
status_vec_chunks = torch.chunk(status_vec, cfg.TRAIN.batch_chunk, 1)
for i in range(cfg.TRAIN.batch_chunk):
data = data_chunks[i].contiguous()
target = target_chunks[i].contiguous()
reset_mems = reset_mems_chunks[i].contiguous()
if status_vec is not None:
status_vec = status_vec_chunks[i].contiguous()
# reset_mems = None
ret = model(data, target, reset_mems, "mle", mems[i], status_vec=status_vec)
loss, mems[i] = ret["mle"], ret["mems"]
loss = loss[target != dataset.vocab.pad_id]
loss = loss.float().mean() / cfg.TRAIN.batch_chunk
log_train_loss += (
loss.item()
* (target != dataset.vocab.pad_id).sum()
* cfg.TRAIN.batch_chunk
)
if cfg.TRAIN.use_mle:
if args.fp16:
with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
log_token_num += int(batch_token_num)
if cfg.TRAIN.use_mle:
if args.fp16:
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), cfg.TRAIN.clip
)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.module.generator.parameters(), cfg.TRAIN.clip
)
# a = [torch.norm(w.grad) for w in model.module.generator.parameters()]
log_grad_norm += grad_norm
optimizer.step()
optimizer.zero_grad()
# Train discriminator
if train_step > cfg.DISCRIMINATOR.start_iter and (
train_step % cfg.DISCRIMINATOR.dis_loss_freq == 0
):
# TODO: dis training messes up memory structure maintained during batch loading
# (we need another dataloader foor real data)
if not (cfg.DISCRIMINATOR.freeze_discriminator):
for dis_iterations in range(cfg.DISCRIMINATOR.dis_steps):
try:
dis_data, _ = next(dis_iterator)
except StopIteration:
dis_iterator = dis_iter()
# Batch chunking for generator and discriminator
dis_data_chunks = torch.chunk(
dis_data, cfg.DISCRIMINATOR.batch_chunk, 1
)
if cfg.DISCRIMINATOR.type == "bert":
for idx, p in enumerate(
model.module.discriminator.parameters()
):
if idx in model.module.discriminator.unfreeze_idx:
p.requires_grad = True
else:
for p in model.module.discriminator.parameters():
p.requires_grad = True
for i in range(cfg.DISCRIMINATOR.batch_chunk):
dis_data = dis_data_chunks[i].contiguous()
# Share the same mems with mle iter
ret = model(dis_data, None, None, "dis_loss")
dis_loss = ret["dis_loss"]
log_dis_train_loss += dis_loss.float().item()
dis_loss = (
dis_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
)
if (
cfg.DISCRIMINATOR.type == "bert"
and "gp" in cfg.DISCRIMINATOR.BERT.loss_type
):
gp_loss = ret["gp_loss"]
gp_loss = (
gp_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
)
elif (
cfg.DISCRIMINATOR.type == "cnn"
and "gp" in cfg.DISCRIMINATOR.CNN.loss_type
):
gp_loss = ret["gp_loss"]
gp_loss = (
gp_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
)
log_dis_num += 1
if args.fp16:
with amp.scale_loss(
dis_loss, dis_optimizer, loss_id=0
) as scaled_dis_loss:
scaled_dis_loss.backward()
else:
if not cfg.DISCRIMINATOR.backprop_outside:
dis_loss.backward()
if (
cfg.DISCRIMINATOR.type == "bert"
and "gp" in cfg.DISCRIMINATOR.BERT.loss_type
):
gp_loss.backward()
elif (
cfg.DISCRIMINATOR.type == "cnn"
and "gp" in cfg.DISCRIMINATOR.CNN.loss_type
):
gp_loss.backward()
# TODO: investigate training tricks for dis different clip?
if args.fp16:
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(dis_optimizer), cfg.TRAIN.clip
)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.module.discriminator.parameters(), cfg.TRAIN.clip,
)
dis_optimizer.step()
dis_optimizer.zero_grad()
for p in model.module.discriminator.parameters():
p.requires_grad = False
if train_step > cfg.DISCRIMINATOR.start_iter and (
train_step % cfg.DISCRIMINATOR.gen_loss_freq == 0
):
# Train generator
# Make dis parameters non trainable
try:
dis_data, _ = next(dis_iterator)
except StopIteration:
dis_iterator = dis_iter()
# Batch chunking for generator and discriminator
dis_data_chunks = torch.chunk(dis_data, cfg.DISCRIMINATOR.batch_chunk, 1)
for i in range(cfg.DISCRIMINATOR.batch_chunk):
dis_data = dis_data_chunks[i].contiguous()
update_D0 = False
if train_step % cfg.PPO.dis_D_update_D0_freq == 0:
update_D0 = True
if 'ppo' in cfg.DISCRIMINATOR.BERT.loss_type or 'ppo' in cfg.DISCRIMINATOR.CNN.loss_type:
for p in model.module.dis_D.parameters():
p.requires_grad = True
#Use same real batch and generate new fake batch
# Always backprop outside
ret = model(dis_data, None, None, "classifier_loss")
torch.nn.utils.clip_grad_norm_(model.module.dis_D.parameters(), cfg.TRAIN.clip)
dis_D_optimizer.step()
dis_D_optimizer.zero_grad()
for p in model.module.dis_D.parameters():
p.requires_grad = False
ret = model(dis_data, None, None, "gen_loss", update_D0=update_D0)
gen_loss = ret["gen_loss"]
log_gen_train_loss += gen_loss.float().item()
gen_loss = gen_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
log_gen_num += 1
# if args.fp16:
# with amp.scale_loss(gen_loss, optimizer, loss_id=2) as scaled_gen_loss:
# scaled_gen_loss.backward(retain_graph=True)
# else:
# gen_loss.backward(retain_graph=True)
if args.fp16:
with amp.scale_loss(gen_loss, optimizer, loss_id=1) as scaled_loss:
scaled_loss.backward()
else:
# a = [torch.norm(w.grad) for w in model.module.generator.parameters()]
if not cfg.DISCRIMINATOR.backprop_outside:
gen_loss.backward()
# b = [torch.norm(w.grad) for w in model.module.generator.parameters()]
# c = [(j-i) for i,j in zip(a,b)]
# d = ([i/j for i,j in zip(a,c)])
# d = sum(d)/len(d)
pass
if args.fp16:
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), cfg.TRAIN.clip
)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
model.module.generator.parameters(), cfg.TRAIN.clip
)
gen_optimizer.step()
gen_optimizer.zero_grad()
# step-wise learning rate annealing
train_step += 1
if cfg.TRAIN.scheduler in ["cosine", "constant", "dev_perf"]:
# linear warmup stage
if train_step < cfg.TRAIN.warmup_step:
curr_lr = cfg.TRAIN.lr * train_step / cfg.TRAIN.warmup_step
optimizer.param_groups[0]["lr"] = curr_lr
else:
if cfg.TRAIN.scheduler == "cosine":
scheduler.step()
elif cfg.TRAIN.scheduler == "inv_sqrt":
scheduler.step()
if cfg.DISCRIMINATOR.type != "Null" and cfg.DISCRIMINATOR.type != "":
if cfg.DISCRIMINATOR.gen_scheduler in ["cosine", "constant", "dev_perf"]:
# linear warmup stage
if train_step < cfg.DISCRIMINATOR.gen_warmup_step:
curr_gen_lr = (
cfg.DISCRIMINATOR.gen_lr * train_step / cfg.TRAIN.warmup_step
)
gen_optimizer.param_groups[0]["lr"] = curr_gen_lr
else:
if cfg.DISCRIMINATOR.gen_scheduler == "cosine":
gen_scheduler.step()
elif cfg.DISCRIMINATOR.gen_scheduler == "inv_sqrt":
gen_scheduler.step()
if cfg.DISCRIMINATOR.dis_scheduler in ["cosine", "constant", "dev_perf"]:
# linear warmup stage
if train_step < cfg.DISCRIMINATOR.dis_warmup_step:
curr_dis_lr = (
cfg.DISCRIMINATOR.dis_lr * train_step / cfg.TRAIN.warmup_step
)
dis_optimizer.param_groups[0]["lr"] = curr_dis_lr
else:
if cfg.DISCRIMINATOR.dis_scheduler == "cosine":
dis_scheduler.step()
elif cfg.DISCRIMINATOR.dis_scheduler == "inv_sqrt":
dis_scheduler.step()
if train_step % cfg.TRAIN.log_interval == 0:
torch.distributed.all_reduce(log_train_loss)
torch.distributed.all_reduce(log_grad_norm)
torch.distributed.all_reduce(log_token_num)
torch.distributed.all_reduce(log_gen_train_loss)
torch.distributed.all_reduce(log_gen_num)
log_train_loss /= log_token_num
log_grad_norm /= cfg.TRAIN.log_interval * num_gpus
log_gen_train_loss = (
log_gen_train_loss / log_gen_num
if log_gen_num != 0
else torch.tensor(0.0).float().to(device)
)
log_dis_train_loss = (
log_dis_train_loss / log_dis_num
if log_dis_num != 0
else torch.tensor(0.0).float().to(device)
)
if args.local_rank == 0:
elapsed = time.time() - log_start_time
logging.info(
"Train Step {}/{}, lr={:f}, tokens/s={:.1f},"
" nll={:.4f}, ppl={:.2f}, grad norm={}, gen_loss={:5.4f}, dis_loss={:5.4f}".format(
train_step,
cfg.TRAIN.max_step,
optimizer.param_groups[0]["lr"],
log_token_num.item() / elapsed,
log_train_loss.item(),
math.exp(log_train_loss.item()),
log_grad_norm.item(),
log_gen_train_loss.item(),
log_dis_train_loss.item(),
)
)
log_train_loss[()] = 0
log_grad_norm[()] = 0
log_token_num[()] = 0
log_gen_train_loss[()] = 0
log_gen_num[()] = 0
log_dis_train_loss[()] = 0
log_dis_num[()] = 0
log_start_time = time.time()
if train_step % cfg.TRAIN.eval_interval == 0:
eval_start_time = time.time()
val_token_num, val_total_nll, val_metrics = evaluate(
eval_iter=val_iter, dis_val_iter=None, mode="eval"
)
val_token_num_pt = torch.tensor(val_token_num).to(device)
val_total_nll_pt = torch.tensor(val_total_nll / 10000.0).to(device)
torch.distributed.all_reduce(val_token_num_pt)
torch.distributed.all_reduce(val_total_nll_pt)
val_token_num = val_token_num_pt.item()
val_total_nll = val_total_nll_pt.item()
val_nll = val_total_nll / (val_token_num / 10000.0)
if args.local_rank == 0:
logging.info(
"Eval step {}, time={}s, val nll={}, val ppl={}, #evaluated tokens={}, bleu={}, self_bleu={"
"}, class_acc={}".format(
train_step,
time.time() - eval_start_time,
val_nll,
math.exp(val_nll),
val_token_num,
val_metrics[0],
val_metrics[1],
val_metrics[2],
)
)
# Save the model if the validation loss is the best we've seen so far.
# Always save after eval if save_all is true and not debug
if not args.debug and args.save_all:
name = f"checkpoint_{train_step}.pt"
save_checkpoint(
args,
model,
optimizer,
dis_optimizer,
gen_optimizer,
dataset.vocab,
train_step,
val_nll,
scheduler,
dis_scheduler,
gen_scheduler,
name,
)
# Save last checkpoint if not debug and not save_all
if not args.debug and not args.save_all:
name = "checkpoint_last.pt"
save_checkpoint(
args,
model,
optimizer,
dis_optimizer,
gen_optimizer,
dataset.vocab,
train_step,
val_nll,
scheduler,
dis_scheduler,
gen_scheduler,
name,
)
if not best_val_nll or val_nll < best_val_nll:
best_val_nll = val_nll
if not args.debug:
name = "checkpoint_best.pt"
save_checkpoint(
args,
model,
optimizer,
dis_optimizer,
gen_optimizer,
dataset.vocab,
train_step,
best_val_nll,
scheduler,
dis_scheduler,
gen_scheduler,
name,
)
test_start_time = time.time()
def calculate_test_nll_during_training(test_iter):
# Run on test data.
# test_token_num, test_total_nll, test_gen_loss, test_gen_num = evaluate(
# eval_iter=test_iter, dis_val_iter=dis_test_iter
# )
test_token_num, test_total_nll, test_metrics = evaluate(
eval_iter=test_iter, dis_val_iter=None, mode="test"
)
test_token_num_pt = torch.tensor(test_token_num).to(device)
test_total_nll_pt = torch.tensor(test_total_nll / 10000.0).to(
device
)
# test_gen_loss_pt = torch.tensor(test_gen_loss).to(device)
# test_gen_num_pt = torch.tensor(test_gen_num).to(device)
torch.distributed.all_reduce(test_token_num_pt)
torch.distributed.all_reduce(test_total_nll_pt)
# torch.distributed.all_reduce(test_gen_loss_pt)
# torch.distributed.all_reduce(test_gen_num_pt)
test_token_num = test_token_num_pt.item()
test_nll = test_total_nll_pt.item() / (test_token_num / 10000.0)
# test_gen_loss = test_gen_loss_pt.item()
# test_gen_num = test_gen_num_pt.item()
# test_gen_loss = (
# test_gen_loss / test_gen_num
# if test_gen_num != 0
# else torch.tensor(0.0).float().to(device)
# )
return test_token_num, test_nll, test_metrics
(
test_token_num,
test_nll,
test_metrics,
) = calculate_test_nll_during_training(test_iter)
if args.local_rank == 0:
logging.info(
"Test step {}, time={}s, test nll={}, test ppl={}, #evaluated tokens={}"
" test_bleu={}".format(
train_step,
time.time() - test_start_time,
test_nll,
math.exp(test_nll),
test_token_num,
test_metrics[0],
)
)
# dev-performance based learning rate annealing
if cfg.TRAIN.scheduler == "dev_perf":
scheduler.step(val_nll)
if train_step == cfg.TRAIN.max_step:
logging.info("-" * 100)
logging.info("End of training")
break