def fit()

in train/model.py [0:0]


    def fit(self, train_iter, optimizer, lr_scheduler,
            eval_iter=None,
            metrics=metric.Accuracy(topk=1),
            epoch_start=0,
            epoch_end=10000,
            precise_bn=False,
            precise_bn_steps=500,
            epoch_div_factor=(torch.distributed.get_world_size() if torch.distributed._initialized else 1),
            **kwargs):

        """
        checking
        """
        if kwargs:
            logging.warning("Unknown kwargs: {}".format(kwargs))

        assert torch.cuda.is_available(), "only support GPU version"

        """
        setup iterator
        """
        precise_bn_steps = 0 if not precise_bn else precise_bn_steps
        epoch_freeze_step = int(round(0.2*precise_bn_steps))
        epoch_train_steps = int(train_iter.batch_sampler.__len__() / epoch_div_factor)
        epoch_eval_steps = int(eval_iter.batch_sampler.__len__() / epoch_div_factor)
        if (train_iter.batch_sampler.__len__() - epoch_train_steps) > precise_bn_steps:
            # train iter is sufficient
            epoch_term_steps = epoch_train_steps + precise_bn_steps
        else:
            epoch_term_steps = epoch_train_steps
            epoch_train_steps = epoch_train_steps - precise_bn_steps
            assert epoch_train_steps > 0, "train_steps < precise_bn_steps ({} v.s. {})".format(epoch_train_steps, precise_bn_steps)
            logging.warning(">> using the last {} iter for computing the precise bathnorm.")

        """
        start the main loop
        """
        for i_epoch in range(epoch_start, epoch_end):

            self.callback_kwargs['epoch'] = i_epoch
            epoch_start_time = time.time()

            ###########
            # 1] TRAINING
            ###########
            metrics.reset()
            self.net.train()
            sum_batch_inst = 0
            sum_batch_elapse = 0.
            sum_update_elapse = 0
            batch_start_time = time.time()
            logging.info("Start epoch {:d}, iter stride {:d}, train steps {:d}, eval steps: {:d}".format( \
                        i_epoch, epoch_div_factor, epoch_train_steps, epoch_eval_steps))

            for i_batch, (data, target) in enumerate(train_iter):

                if i_batch >= epoch_term_steps:
                    break

                if precise_bn and i_batch == epoch_train_steps:
                    logging.info("Compute precise batchnorm: {} to {}.".format(epoch_train_steps, epoch_term_steps))
                    # TODO: better way to rsync running_mean / runing_var
                    self.save_checkpoint(epoch=i_epoch+1, optimizer_state=optimizer.state_dict())
                    while not os.path.exists(self.get_checkpoint_path(epoch=i_epoch+1)):
                        time.sleep(1)
                    time.sleep(5)
                    self.load_checkpoint(epoch=i_epoch+1)
                    metrics.reset()

                self.callback_kwargs['batch'] = i_batch
                update_start_time = time.time()

                # [forward] making next step
                outputs, losses = self.forward(data, target)

                # [backward]
                if i_batch < epoch_train_steps:
                    optimizer.zero_grad()
                    for loss in losses: loss.backward()
                    self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.update())
                    optimizer.step()
                elif i_batch < (epoch_term_steps - epoch_freeze_step):
                    # for precise bn (stage 1)
                    optimizer.zero_grad()
                    for loss in losses: loss.backward()
                    self.adjust_learning_rate(optimizer=optimizer, lr=lr_scheduler.get_lr())
                    optimizer.step(visiable=["precise.bn"])
                else:
                    # for precise bn (stage 2)
                    # update running mean/var (done forward pass)
                    pass

                self.callback_kwargs['lr'] = lr_scheduler.get_lr()

                # [evaluation] update train metric
                metrics.update([output.data.cpu() for output in outputs],
                               target.cpu(),
                               [loss.data.cpu() for loss in losses])

                # timing each batch
                sum_batch_elapse += time.time() - batch_start_time
                sum_update_elapse += time.time() - update_start_time
                sum_batch_inst += 1

                if (i_batch % self.step_callback_freq) == 0:
                    name_value_prefix = 'tr-' if i_batch < epoch_train_steps else 'bn-'
                    self.callback_kwargs['namevals'] = metrics.get_name_value(prefix=name_value_prefix)
                    metrics.reset()
                    # speed monitor
                    self.callback_kwargs['batch_elapse'] = sum_batch_elapse / sum_batch_inst
                    self.callback_kwargs['update_elapse'] = sum_update_elapse / sum_batch_inst
                    sum_update_elapse = 0
                    sum_batch_elapse = 0
                    sum_batch_inst = 0
                    # callbacks
                    self.step_end_callback()

                # save checkpoint in case of unexpected interrupt
                if (i_batch % 500) == 0 and i_batch < epoch_train_steps:
                    self.callback_kwargs['epoch_elapse'] = time.time() - epoch_start_time
                    self.callback_kwargs['optimizer_dict'] = optimizer.state_dict()
                    self.epoch_end_callback()

                # end of current train iter
                batch_start_time = time.time()


            ###########
            # 2] END OF EPOCH
            ###########
            self.callback_kwargs['epoch_elapse'] = time.time() - epoch_start_time
            self.callback_kwargs['optimizer_dict'] = optimizer.state_dict()
            self.epoch_end_callback()


            ###########
            # 3] Evaluation
            ###########
            if (eval_iter is not None) \
                and ((i_epoch+1) % max(1, int(self.save_checkpoint_freq/2))) == 0:
                logging.info("Start evaluating epoch {:d}:".format(i_epoch))

                metrics.reset()
                self.net.eval()
                sum_batch_elapse = 0.
                sum_batch_inst = 0
                sum_forward_elapse = 0.
                with torch.no_grad():
                    # if True:
                    batch_start_time = time.time()
                    for i_batch, (data, target) in enumerate(eval_iter):

                        forward_start_time = time.time()

                        outputs, losses = self.forward(data, target)

                        metrics.update([output.data.cpu() for output in outputs],
                                        target.cpu(),
                                       [loss.data.cpu() for loss in losses])

                        sum_forward_elapse += time.time() - forward_start_time
                        sum_batch_elapse += time.time() - batch_start_time
                        batch_start_time = time.time()
                        sum_batch_inst += 1

                        if i_batch >= epoch_eval_steps:
                            break

                    # evaluation callbacks
                    self.callback_kwargs['batch'] = sum_batch_inst
                    self.callback_kwargs['batch_elapse'] = sum_batch_elapse / sum_batch_inst
                    self.callback_kwargs['update_elapse'] = sum_forward_elapse / sum_batch_inst
                    self.callback_kwargs['namevals'] = metrics.get_name_value(prefix='ts-')
                    self.step_end_callback()

        logging.info("Optimization done!")