def _train()

in rlkit/core/batch_rl_algorithm.py [0:0]


    def _train(self):
        if self.min_num_steps_before_training > 0 and not self.batch_rl:
            init_expl_paths = self.expl_data_collector.collect_new_paths(
                self.max_path_length,
                self.min_num_steps_before_training,
                discard_incomplete_paths=False,
            )
            self.replay_buffer.add_paths(init_expl_paths)
            self.expl_data_collector.end_epoch(-1)

        for epoch in gt.timed_for(
                range(self._start_epoch, self.num_epochs),
                save_itrs=True,
        ):
            if self.q_learning_alg:
                policy_fn = self.policy_fn
                if self.trainer.discrete:
                    policy_fn = self.policy_fn_discrete
                self.eval_data_collector.collect_new_paths(
                    policy_fn,
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True
                )
            else:
                self.eval_data_collector.collect_new_paths(
                    self.max_path_length,
                    self.num_eval_steps_per_epoch,
                    discard_incomplete_paths=True,
                )
            gt.stamp('evaluation sampling')

            for _ in range(self.num_train_loops_per_epoch):
                if not self.batch_rl:
                    # Sample new paths only if not doing batch rl
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        self.max_path_length,
                        self.num_expl_steps_per_train_loop,
                        discard_incomplete_paths=False,
                    )
                    gt.stamp('exploration sampling', unique=False)

                    self.replay_buffer.add_paths(new_expl_paths)
                    gt.stamp('data storing', unique=False)
                elif self.eval_both:
                    # Now evaluate the policy here:
                    policy_fn = self.policy_fn
                    if self.trainer.discrete:
                        policy_fn = self.policy_fn_discrete
                    new_expl_paths = self.expl_data_collector.collect_new_paths(
                        policy_fn,
                        self.max_path_length,
                        self.num_eval_steps_per_epoch,
                        discard_incomplete_paths=True,
                    )

                    gt.stamp('policy fn evaluation')

                self.training_mode(True)
                for _ in range(self.num_trains_per_train_loop):
                    train_data = self.replay_buffer.random_batch(
                        self.batch_size)
                    self.trainer.train(train_data)
                gt.stamp('training', unique=False)
                self.training_mode(False)

            has_nan = self._end_epoch(epoch)
            if has_nan:
                break