def train()

in training/image_classification.py [0:0]


def train(params, mask):
    # Create logger and print params
    logger = create_logger(params)
    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    if params.is_slurm_job:
        init_signal_handler()

    trainloader, n_data = get_dataset(params=params, is_train=True, mask=mask)
    validloader, _ = get_dataset(params=params, is_train=False)

    model = build_model(params)
    model.cuda()

    if params.multi_gpu:
        if params.private:
            raise NotImplementedError('Distributed training not implemented with privacy')
        else:
            print('Using multi gpu')
            model = nn.parallel.DistributedDataParallel(model, device_ids=[params.local_rank], output_device=params.local_rank, broadcast_buffers=True)

    trainer = Trainer(model, params, n_data=n_data)
    trainer.reload_checkpoint()

    evaluator = Evaluator(model, params)

    # evaluation
    # if params.eval_only:
    #     scores = evaluator.run_all_evals(trainer, evals=['classif'], data_loader=validloader)

    #     for k, v in scores.items():
    #         logger.info('%s -> %.6f' % (k, v))
    #     logger.info("__log__:%s" % json.dumps(scores))
    #     exit()


    # training
    for epoch in range(trainer.epoch, params.epochs):

        # update epoch / sampler / learning rate
        trainer.epoch = epoch
        logger.info("============ Starting epoch %i ... ============" % trainer.epoch)

        # train
        for (idx, images, targets) in trainloader:
            trainer.classif_step(idx, images, targets)
            trainer.end_step()

        logger.info("============ End of epoch %i ============" % trainer.epoch)

        # evaluate classification accuracy
        scores = evaluator.run_all_evals(evals=['classif'], data_loader=validloader)
        for name, val in trainer.get_scores().items():
            scores[name] = val

        # print / JSON log
        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.end_epoch(scores)

    return model