in main.py [0:0]
def update_args(args):
if args.head_dim == 0:
assert args.hid_sz % args.nheads == 0
args.head_dim = args.hid_sz // args.nheads
args.update_freq = args.split_batch
if args.split_batch > 1:
assert args.batch_sz % args.split_batch == 0
assert args.test_batch_sz % args.split_batch == 0
args.batch_sz = args.batch_sz // args.split_batch
args.test_batch_sz = args.test_batch_sz // args.split_batch
args.nbatches *= args.split_batch
args.lr_warmup *= args.split_batch
if args.plot and args.plot_name == "":
args.plot_name = time.strftime("%Y%m%d_%H%M%S")
if args.test_batch_sz == 0:
args.test_batch_sz = args.batch_sz