def train_model()

in train_model.py [0:0]


def train_model(sym_net, model_prefix, dataset, input_conf,
                clip_length=8, train_frame_interval=2, val_frame_interval=2,
                resume_epoch=-1, batch_size=4, save_frequency=1,
                lr_base=0.01, lr_factor=0.1, lr_steps=[400000, 800000],
                end_epoch=1000, distributed=False, fine_tune=False,
                epoch_div_factor=4, precise_bn=False,
                **kwargs):

    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    # data iterator
    iter_seed = torch.initial_seed() \
                + (torch.distributed.get_rank() * 10 if distributed else 100) \
                + max(0, resume_epoch) * 100
    train_iter, eval_iter = iterator_factory.creat(name=dataset,
                                                   batch_size=batch_size,
                                                   clip_length=clip_length,
                                                   train_interval=train_frame_interval,
                                                   val_interval=val_frame_interval,
                                                   mean=input_conf['mean'],
                                                   std=input_conf['std'],
                                                   seed=iter_seed)
    # model (dynamic)
    net = model(net=sym_net,
                criterion=torch.nn.CrossEntropyLoss().cuda(),
                model_prefix=model_prefix,
                step_callback_freq=50,
                save_checkpoint_freq=save_frequency,
                opt_batch_size=batch_size, # optional
                single_checkpoint=precise_bn, # TODO: use shared filesystem to rsync running mean/var
                )
    # if True:
    #     for name, module in net.net.named_modules():
    #         if name.endswith("bn"): module.momentum = 0.005
    net.net.cuda()

    # config optimization, [[w/ wd], [w/o wd]]
    param_base_layers = [[[], []], [[], []]]
    param_new_layers = [[[], []], [[], []]]
    name_freeze_layers, name_base_layers = [], []
    for name, param in net.net.named_parameters():
        idx_wd = 0 if name.endswith('.bias') else 1
        idx_bn = 0 if name.endswith(('.bias', 'bn.weight')) else 1
        if fine_tune:
            if not name.startswith('classifier'):
                param_base_layers[idx_bn][idx_wd].append(param)
                name_base_layers.append(name)
            else:
                param_new_layers[idx_bn][idx_wd].append(param)
        else:
            if "conv_m2" in name:
                param_base_layers[idx_bn][idx_wd].append(param)
                name_base_layers.append(name)
            else:
                param_new_layers[idx_bn][idx_wd].append(param)


    if name_freeze_layers:
        out = "[\'" + '\', \''.join(name_freeze_layers) + "\']"
        logging.info("Optimizer:: >> freezing {} params: {}".format(len(name_freeze_layers),
                     out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))
    if name_base_layers:
        out = "[\'" + '\', \''.join(name_base_layers) + "\']"
        logging.info("Optimizer:: >> recuding the learning rate of {} params: {}".format(len(name_base_layers),
                     out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))

    if distributed:
        net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda()
    else:
        net.net = torch.nn.DataParallel(net.net).cuda()

    # optimizer = torch.optim.SGD(sym_net.parameters(),
    wd = 0.0001
    optimizer = custom_optim.SGD([{'params': param_base_layers[0][0], 'lr_mult': 0.5, 'weight_decay': 0.},
                                  {'params': param_base_layers[0][1], 'lr_mult': 0.5, 'weight_decay': wd},
                                  {'params': param_base_layers[1][0], 'lr_mult': 0.5, 'weight_decay': 0., 'name': 'precise.bn'},  # *.bias
                                  {'params': param_base_layers[1][1], 'lr_mult': 0.5, 'weight_decay': wd, 'name': 'precise.bn'},  # bn.weight
                                  {'params': param_new_layers[0][0],  'lr_mult': 1.0, 'weight_decay': 0.},
                                  {'params': param_new_layers[0][1],  'lr_mult': 1.0, 'weight_decay': wd}, 
                                  {'params': param_new_layers[1][0],  'lr_mult': 1.0, 'weight_decay': 0., 'name': 'precise.bn'},  # *.bias
                                  {'params': param_new_layers[1][1],  'lr_mult': 1.0, 'weight_decay': wd, 'name': 'precise.bn'}], # bn.weight
                                 lr=lr_base,
                                 momentum=0.9,
                                 nesterov=True)

    # resume: model and optimizer
    if resume_epoch < 0:
        epoch_start = 0
        step_counter = 0
    else:
        net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer)
        epoch_start = resume_epoch
        step_counter = epoch_start * int(train_iter.__len__()/epoch_div_factor)

    num_worker = torch.distributed.get_world_size() if torch.distributed._initialized else 1
    lr_scheduler = MultiFactorScheduler(base_lr=lr_base,
                                        steps=[int(x/(batch_size*num_worker)) for x in lr_steps],
                                        factor=lr_factor,
                                        step_counter=step_counter)

    metrics = metric.MetricList(metric.Loss(name="loss-ce"),
                                metric.Accuracy(name="top1", topk=1),
                                metric.Accuracy(name="top5", topk=5),)

    cudnn.benchmark = True
    # cudnn.fastest = False
    # cudnn.enabled = False

    net.fit(train_iter=train_iter,
            eval_iter=eval_iter,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            metrics=metrics,
            epoch_start=epoch_start,
            epoch_end=end_epoch,
            epoch_div_factor=epoch_div_factor,
            precise_bn=precise_bn,
            )