def update_args()

in main.py [0:0]


def update_args(args):
    if args.head_dim == 0:
        assert args.hid_sz % args.nheads == 0
        args.head_dim = args.hid_sz // args.nheads

    args.update_freq = args.split_batch
    if args.split_batch > 1:
        assert args.batch_sz % args.split_batch == 0
        assert args.test_batch_sz % args.split_batch == 0
        args.batch_sz = args.batch_sz // args.split_batch
        args.test_batch_sz = args.test_batch_sz // args.split_batch
        args.nbatches *= args.split_batch
        args.lr_warmup *= args.split_batch

    if args.plot and args.plot_name == "":
        args.plot_name = time.strftime("%Y%m%d_%H%M%S")

    if args.test_batch_sz == 0:
        args.test_batch_sz = args.batch_sz