def progress()

in torchrec/distributed/train_pipeline.py [0:0]


    def progress(self, dataloader_iter: Iterator[In]) -> Out:
        if not self._connected:
            self._connect(dataloader_iter)

        if self._model.training:
            with record_function("## zero_grad ##"):
                self._optimizer.zero_grad()

        with record_function("## copy_batch_to_gpu ##"):
            with torch.cuda.stream(self._memcpy_stream):
                batch_ip2 = next(dataloader_iter)
                self._batch_ip2 = batch_ip2 = _to_device(
                    batch_ip2, self._device, non_blocking=True
                )
        batch_i = cast(In, self._batch_i)
        batch_ip1 = cast(In, self._batch_ip1)

        with record_function("## wait_for_batch ##"):
            _wait_for_batch(batch_i, self._data_dist_stream)

        # Forward
        with record_function("## forward ##"):
            # if using multiple streams (ie. CUDA), create an event in default stream
            # before starting forward pass
            if self._data_dist_stream:
                event = torch.cuda.current_stream().record_event()
            (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))

        # Data Distribution
        with record_function("## sparse_data_dist ##"):
            with torch.cuda.stream(self._data_dist_stream):
                _wait_for_batch(batch_ip1, self._memcpy_stream)
                # Ensure event in default stream has been called before
                # starting data dist
                if self._data_dist_stream:
                    # pyre-ignore [61]: Local variable `event` is undefined, or not always defined
                    self._data_dist_stream.wait_event(event)
                _start_data_dist(self._pipelined_modules, batch_ip1, self._context)

        if self._model.training:
            # Backward
            with record_function("## backward ##"):
                torch.sum(losses, dim=0).backward()

            # Update
            with record_function("## optimizer ##"):
                # pyre-fixme[20]: Argument `closure` expected.
                self._optimizer.step()

        self._batch_i = batch_ip1
        self._batch_ip1 = batch_ip2

        return output