def main()

in train-distillation.py [0:0]


def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment / load data
    logger = initialize_exp(params)

    # Seed
    torch.manual_seed(params.seed)
    torch.cuda.manual_seed_all(params.seed)

    # initialize SLURM signal handler for time limit / pre-emption
    if params.is_slurm_job:
        init_signal_handler()

    # data loaders / samplers
    populate_dataset(params)
    train_data_loader, train_sampler, _ = get_data_loader(
        img_size=params.img_size,
        crop_size=params.crop_size,
        shuffle=True,
        batch_size=params.batch_size,
        num_classes=params.num_classes,
        nb_workers=params.nb_workers,
        distributed_sampler=params.multi_gpu,
        dataset=params.dataset,
        data_path=params.train_path,
        transform=params.train_transform,
        split='valid' if params.debug_train else 'train',
        seed=params.seed
    )

    valid_data_loader, _, _ = get_data_loader(
        img_size=params.img_size,
        crop_size=params.crop_size,
        shuffle=False,
        batch_size=params.batch_size,
        num_classes=params.num_classes,
        nb_workers=params.nb_workers,
        distributed_sampler=False,
        dataset=params.dataset,
        transform='center',
        split='valid',
        seed=params.seed
    )

    # build model / cuda
    logger.info("Building %s model ..." % params.architecture)
    model = build_model(params)
    model.cuda()

    teacher_model = build_model(params)
    teacher_model.cuda()
    ckpt = torch.load(params.teacher_ckpt)
    state_dict = {k.replace("module.", ""): v for k, v in ckpt['model'].items()}
    missing_keys, unexcepted_keys = teacher_model.load_state_dict(state_dict, strict=False)
    print("Missing keys: ", missing_keys)
    print("Unexcepted keys: ", unexcepted_keys)

    # distributed  # TODO: check this https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main.py#L142
    if params.multi_gpu:
        logger.info("Using nn.parallel.DistributedDataParallel ...")
        model = nn.parallel.DistributedDataParallel(model, device_ids=[params.local_rank], output_device=params.local_rank, broadcast_buffers=True)

    # build trainer / reload potential checkpoints / build evaluator
    trainer = Trainer(model=model, params=params, teacher_model=teacher_model)
    trainer.reload_checkpoint()
    evaluator = Evaluator(trainer, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer, evals=['classif', 'recognition'], data_loader=valid_data_loader)

        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # training
    for epoch in range(trainer.epoch, params.epochs):

        # update epoch / sampler / learning rate
        trainer.epoch = epoch
        logger.info("============ Starting epoch %i ... ============" % trainer.epoch)
        if params.multi_gpu:
            train_sampler.set_epoch(epoch)

        # update learning rate
        trainer.update_learning_rate()

        # train
        for i, (images, targets) in enumerate(train_data_loader):
            trainer.distillation_step(images, None)
            trainer.iter()

        logger.info("============ End of epoch %i ============" % trainer.epoch)

        # evaluate classification accuracy
        scores = evaluator.run_all_evals(trainer, evals=['classif'], data_loader=valid_data_loader)

        for name, val in trainer.get_scores().items():
            scores[name] = val

        # print / JSON log
        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)