in main.py [0:0]
def main(args):
args = copy.deepcopy(args)
use_cuda = not args.no_cuda and torch.cuda.is_available()
update_args(args)
distributed.init(args)
args.device = torch.device("cuda" if use_cuda else "cpu")
logger = Logger(args)
logger.print(f"PyTorch version: {torch.__version__}")
logger.print(f"PyTorch CUDA version: {torch.version.cuda}")
logger.print(str(args))
# load data
train_data, val_data, test_data, corpus = data.get_data(args, logger, args.data_eos)
if len(args.data_omit_labels) > 0:
args.data_omit_label_idx = [
corpus.dictionary.word2idx[w] for w in args.data_omit_labels
]
else:
args.data_omit_label_idx = None
# create a model
if args.feedback:
model = feedback.FeedbackTransformer(args)
elif args.expire_span:
model = expire_span.ExpireSpan(args)
elif args.compress:
model = compressive.CompressiveTransformer(args)
else:
model = transformer_seq.TransformerSeq(args)
model.to(args.device)
# count params
nparameters = 0
params = []
for param in model.parameters():
if param.requires_grad:
nparameters += param.numel()
params.append(param)
logger.print("nparameters={:.2f}M".format(nparameters / 1e6))
# OPTIM param
if args.optim == "sgd":
optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum)
elif args.optim == "adam":
optimizer = optim.Adam(params, lr=args.lr)
if args.lr_decay:
# will do warm-up manually later
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, args.nepochs * args.nbatches
)
elif args.lr_warmup > 0:
scheduler = optim.lr_scheduler.LambdaLR(
optimizer, lambda ep: min(1, ep / args.lr_warmup)
)
else:
scheduler = None
model = distributed.wrap_model(args, model)
ep_init = checkpoint.load(args, model, optimizer, logger, scheduler)
# pos: data samling 0=sequential, -1=random
pos = [0 for _ in range(3)]
if isinstance(train_data, tuple):
pos[0] = random.randrange(train_data[0].size(1) - args.mem_sz)
else:
pos[0] = random.randrange(train_data.size(1) - args.mem_sz)
hid_cache = [
model.module.init_hid_cache(args.batch_sz),
model.module.init_hid_cache(args.test_batch_sz),
model.module.init_hid_cache(args.test_batch_sz),
]
if args.full_test:
# perform evaluation only
with torch.no_grad():
stat_val, pos[1], hid_cache[1] = train(
args,
model,
optimizer,
scheduler,
val_data,
test_only=True,
train_pos=pos[1],
h_cache=hid_cache[1],
corpus=corpus,
)
stat_test, pos[2], hid_cache[2] = train(
args,
model,
optimizer,
scheduler,
test_data,
test_only=True,
train_pos=pos[2],
h_cache=hid_cache[2],
corpus=corpus,
)
gpu_mem = torch.cuda.max_memory_allocated() / 1024 ** 3
stat_test, stat_val, gpu_mem = distributed.collect_stat(
args, stat_test, stat_val, gpu_mem
)
if args.data_type == "char":
if "err" in stat_val:
logger.print("val err: {:.3f}%".format(stat_val["err"] * 100))
logger.print("test err: {:.3f}%".format(stat_test["err"] * 100))
else:
logger.print(
"val: {:.3f}bpc".format(stat_val["loss"] / math.log(2))
)
logger.print(
"test: {:.3f}bpc".format(stat_test["loss"] / math.log(2))
)
else:
logger.print("val: {:.3f}ppl".format(math.exp(stat_val["loss"])))
logger.print("test: {:.3f}ppl".format(math.exp(stat_test["loss"])))
logger.print(f"gpu_mem: {gpu_mem:.1f}gb")
return
for ep in range(ep_init, args.nepochs):
t_sta = time.time()
args.ep = ep
stat_train, pos[0], hid_cache[0] = train(
args,
model,
optimizer,
scheduler,
train_data,
train_pos=pos[0],
h_cache=hid_cache[0],
corpus=corpus,
)
elapsed = 1000 * (time.time() - t_sta) / args.nbatches
with torch.no_grad():
if args.full_valid:
stat_val, _, _ = train(
args,
model,
optimizer,
scheduler,
val_data,
test_only=True,
train_pos=pos[1],
h_cache=hid_cache[1],
corpus=corpus,
)
else:
stat_val, pos[1], hid_cache[1] = train(
args,
model,
optimizer,
scheduler,
val_data,
test_only=True,
train_pos=pos[1],
h_cache=hid_cache[1],
corpus=corpus,
)
gpu_mem = torch.cuda.max_memory_allocated() / 1024 ** 3
torch.cuda.reset_max_memory_allocated()
stat_train, stat_val, gpu_mem = distributed.collect_stat(
args, stat_train, stat_val, gpu_mem
)
if args.rank == 0:
# only the master process will do logging, plotting and checkpoint
if args.lr_decay:
logger.log("compute/lr", optimizer.param_groups[0]["lr"])
if args.adapt_span:
adaptive_span.log(args, model, logger, stat_train)
if args.expire_span:
expire_span.log(args, model, logger, stat_train)
if args.feedback:
feedback.log(args, model, logger, stat_train)
logger.step(args, stat_train, stat_val, elapsed, gpu_mem)
checkpoint.save(args, model, optimizer, logger, scheduler)