in main.py [0:0]
def update_args(args):
    if args.head_dim == 0:
        assert args.hid_sz % args.nheads == 0
        args.head_dim = args.hid_sz // args.nheads
    args.update_freq = args.split_batch
    if args.split_batch > 1:
        assert args.batch_sz % args.split_batch == 0
        assert args.test_batch_sz % args.split_batch == 0
        args.batch_sz = args.batch_sz // args.split_batch
        args.test_batch_sz = args.test_batch_sz // args.split_batch
        args.nbatches *= args.split_batch
        args.lr_warmup *= args.split_batch
    if args.plot and args.plot_name == "":
        args.plot_name = time.strftime("%Y%m%d_%H%M%S")
    if args.test_batch_sz == 0:
        args.test_batch_sz = args.batch_sz