in fairdiplomacy/models/diplomacy_model/train_sl.py [0:0]
def main_subproc(rank, world_size, args, train_set, val_set, extra_val_datasets):
has_gpu = torch.cuda.is_available()
if has_gpu:
# distributed training setup
mp_setup(rank, world_size)
atexit.register(mp_cleanup)
torch.cuda.set_device(rank)
else:
assert rank == 0 and world_size == 1
metric_logger = Logger(is_master=rank == 0)
global_step = 0
log_scalars = lambda **scalars: metric_logger.log_metrics(
scalars, step=global_step, sanitize=True
)
# load checkpoint if specified
if args.checkpoint and os.path.isfile(args.checkpoint):
logger.info("Loading checkpoint at {}".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint, map_location="cuda:{}".format(rank))
else:
checkpoint = None
logger.info("Init model...")
net = new_model(args)
# send model to GPU
if has_gpu:
logger.debug("net.cuda({})".format(rank))
net.cuda(rank)
logger.debug("net {} DistributedDataParallel".format(rank))
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank])
logger.debug("net {} DistributedDataParallel done".format(rank))
# load from checkpoint if specified
if checkpoint:
logger.debug("net.load_state_dict")
net.load_state_dict(checkpoint["model"], strict=True)
# create optimizer, from checkpoint if specified
policy_loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
value_loss_fn = torch.nn.MSELoss(reduction="none")
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=args.lr_decay)
if checkpoint:
optim.load_state_dict(checkpoint["optim"])
# load best losses to not immediately overwrite best checkpoints
best_loss = checkpoint.get("best_loss") if checkpoint else None
best_p_loss = checkpoint.get("best_p_loss") if checkpoint else None
best_v_loss = checkpoint.get("best_v_loss") if checkpoint else None
if has_gpu:
train_set_sampler = DistributedSampler(train_set)
else:
train_set_sampler = RandomSampler(train_set)
for epoch in range(checkpoint["epoch"] + 1 if checkpoint else 0, args.num_epochs):
if has_gpu:
train_set_sampler.set_epoch(epoch)
batches = torch.tensor(list(iter(train_set_sampler)), dtype=torch.long).split(
args.batch_size
)
for batch_i, batch_idxs in enumerate(batches):
batch = train_set[batch_idxs]
logger.debug(f"Zero grad {batch_i} ...")
# check batch is not empty
if (batch["y_actions"] == EOS_IDX).all():
logger.warning("Skipping empty epoch {} batch {}".format(epoch, batch_i))
continue
# learn
logger.debug("Starting epoch {} batch {}".format(epoch, batch_i))
optim.zero_grad()
policy_losses, value_losses, _, _ = process_batch(
net,
batch,
policy_loss_fn,
value_loss_fn,
p_teacher_force=args.teacher_force,
shuffle_locs=args.shuffle_locs,
)
# backward
p_loss = torch.mean(policy_losses)
v_loss = torch.mean(value_losses)
loss = (1 - args.value_loss_weight) * p_loss + args.value_loss_weight * v_loss
loss.backward()
# clip gradients, step
value_decoder_grad_norm = torch.nn.utils.clip_grad_norm_(
getattr(net, "module", net).value_decoder.parameters(),
args.value_decoder_clip_grad_norm,
)
grad_norm = torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip_grad_norm)
optim.step()
# log diagnostics
if rank == 0 and batch_i % 10 == 0:
scalars = dict(
epoch=epoch,
batch=batch_i,
loss=loss,
lr=optim.state_dict()["param_groups"][0]["lr"],
grad_norm=grad_norm,
value_decoder_grad_norm=value_decoder_grad_norm,
p_loss=p_loss,
v_loss=v_loss,
)
log_scalars(**scalars)
logger.info(
"epoch {} batch {} / {}, ".format(epoch, batch_i, len(batches))
+ " ".join(f"{k}= {v}" for k, v in scalars.items())
)
global_step += 1
if args.epoch_max_batches and batch_i + 1 >= args.epoch_max_batches:
logging.info("Exiting early due to epoch_max_batches")
break
# calculate validation loss/accuracy
if not args.skip_validation and rank == 0:
logger.info("Calculating val loss...")
(
valid_loss,
valid_p_loss,
valid_v_loss,
valid_p_accuracy,
valid_v_accuracy,
split_pcts,
) = validate(
net,
val_set,
policy_loss_fn,
value_loss_fn,
args.batch_size,
value_loss_weight=args.value_loss_weight,
)
scalars = dict(
epoch=epoch,
valid_loss=valid_loss,
valid_p_loss=valid_p_loss,
valid_v_loss=valid_v_loss,
valid_p_accuracy=valid_p_accuracy,
valid_v_accuracy=valid_v_accuracy,
)
for name, extra_val_set in extra_val_datasets.items():
(
scalars[f"valid_{name}/loss"],
scalars[f"valid_{name}/p_loss"],
scalars[f"valid_{name}/v_loss"],
scalars[f"valid_{name}/p_accuracy"],
scalars[f"valid_{name}/v_accuracy"],
_,
) = validate(
net,
extra_val_set,
policy_loss_fn,
value_loss_fn,
args.batch_size,
value_loss_weight=args.value_loss_weight,
)
log_scalars(**scalars)
logger.info("Validation " + " ".join([f"{k}= {v}" for k, v in scalars.items()]))
for k, v in sorted(split_pcts.items()):
logger.info(f"val split epoch= {epoch} batch= {batch_i}: {k} = {v}")
# save model
if args.checkpoint and rank == 0:
obj = {
"model": net.state_dict(),
"optim": optim.state_dict(),
"epoch": epoch,
"batch_i": batch_i,
"valid_p_accuracy": valid_p_accuracy,
"args": args,
"best_loss": best_loss,
"best_p_loss": best_p_loss,
"best_v_loss": best_v_loss,
}
logger.info("Saving checkpoint to {}".format(args.checkpoint))
torch.save(obj, args.checkpoint)
if epoch % 10 == 0:
torch.save(obj, args.checkpoint + ".epoch_" + str(epoch))
if best_loss is None or valid_loss < best_loss:
best_loss = valid_loss
torch.save(obj, args.checkpoint + ".best")
if best_p_loss is None or valid_p_loss < best_p_loss:
best_p_loss = valid_p_loss
torch.save(obj, args.checkpoint + ".bestp")
if best_v_loss is None or valid_v_loss < best_v_loss:
best_v_loss = valid_v_loss
torch.save(obj, args.checkpoint + ".bestv")
lr_scheduler.step()