def train()

in vissl/trainer/trainer_main.py [0:0]


    def train(self):
        """
        The train workflow. We get the training loop to use (vissl default is
        standard_train_step) but the user can create their own training loop
        and specify the name TRAINER.TRAIN_STEP_NAME

        The training happens:
        1. Execute any hooks at the start of training (mostly resets the variable like
           iteration num phase_num etc)
        2. For each epoch (train or test), run the hooks at the start of an epoch. Mostly
           involves setting things like timer, setting dataloader epoch etc
        3. Execute the training loop (1 training iteration) involving forward, loss, backward,
           optimizer update, metrics collection etc.
        4. At the end of epoch, sync meters and execute hooks at the end of phase. Involves
           things like checkpointing model, logging timers, logging to tensorboard etc
        """
        train_step_fn = get_train_step(self.cfg["TRAINER"]["TRAIN_STEP_NAME"])
        self.task.prepare(pin_memory=self.cfg.DATA.PIN_MEMORY)
        self.task.init_distributed_data_parallel_model()

        # Find what phase, train_phase_idx, local_iteration_num we are starting from.
        # Recover it from the checkpoint (if available)
        task, phase_idx, iteration_num = self._init_training_state(self.cfg, self.task)

        # Good to go, (re) start training
        task.run_hooks(SSLClassyHookFunctions.on_start.name)

        if is_primary():
            logging.info("Model is:\n {}".format(task.model))
            logging.info("Loss is: {}".format(task.loss))
        logging.info("Starting training....")

        while phase_idx + 1 < len(task.phases):
            self._advance_phase(task)  # advances task.phase_idx
            phase_idx += 1
            iteration_num += 1
            task.local_iteration_num = iteration_num  # iteration_num=0 at this step
            task.run_hooks(SSLClassyHookFunctions.on_phase_start.name)
            while True:
                try:
                    if self.cfg.MODEL.CUDA_CACHE.CLEAR_CUDA_CACHE and (
                        iteration_num % self.cfg.MODEL.CUDA_CACHE.CLEAR_FREQ == 0
                    ):
                        logging.info(
                            f"Emptying CUDA cache at step count: {iteration_num}"
                        )
                        torch.cuda.empty_cache()
                        logging.info("CUDA cache cleared")
                    task = train_step_fn(task)
                    iteration_num += 1
                    task.local_iteration_num = iteration_num
                    # Book-keeping: update the training iteration number (only updated
                    # if it's a training phase).
                    task.iteration += 1 if task.train else 0
                    # Book-keeping. Track how many forward passes have been done.
                    # aka how many batches have been seen by the trainer irrespective of
                    # the train or test phase.
                    task.batches += 1
                    # update the batch time aka the training time for the current iteration.
                    task.batch_time.append(time.time() - task.start_time)
                    task.start_time = time.time()
                    task.run_hooks(SSLClassyHookFunctions.on_step.name)
                except StopIteration:
                    break
                except Exception as e:
                    task.run_hooks(SSLClassyHookFunctions.on_exception.name)
                    raise e
            for meter in task.meters:
                meter.sync_state()
            logging.info("Meters synced")
            barrier()
            task.run_hooks(SSLClassyHookFunctions.on_phase_end.name)

        task.run_hooks(SSLClassyHookFunctions.on_end.name)
        if hasattr(task, "data_iterator"):
            del task.data_iterator
            gc.collect()
        if hasattr(task, "dataloaders"):
            del task.dataloaders
            gc.collect()