def _train()

in MTRF/algorithms/softlearning/algorithms/phased_sac.py [0:0]


    def _train(self):
        """Return a generator that performs RL training.

        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy (`Policy`): Policy used for training
            initial_exploration_policy ('Policy'): Policy used for exploration
                If None, then all exploration is done using policy
            pool (`PoolBase`): Sample pool to add samples to
        """
        import gtimer as gt
        from itertools import count
        training_environment = self._training_environment
        evaluation_environment = self._evaluation_environment
        training_metrics = [0 for _ in range(self._num_goals)]

        if not self._training_started:
            self._init_training()

            for i in range(self._num_goals):
                self._initial_exploration_hook(
                    training_environment, self._initial_exploration_policy, i)

        self._initialize_samplers()
        self._sample_count = 0

        gt.reset_root()
        gt.rename_root('RLAlgorithm')
        gt.set_def_unique(False)

        print("starting_training")
        self._training_before_hook()
        import time

        for self._epoch in gt.timed_for(range(self._epoch, self._n_epochs)):
            self._epoch_before_hook()
            gt.stamp('epoch_before_hook')
            start_samples = sum([self._samplers[i]._total_samples for i in range(self._num_goals)])
            sample_times = []
            for i in count():
                samples_now = sum([self._samplers[i]._total_samples for i in range(self._num_goals)])
                self._timestep = samples_now - start_samples

                # Stopping condition
                if samples_now >= start_samples + self._epoch_length and self.ready_to_train:
                    break

                t0 = time.time()
                self._timestep_before_hook()
                gt.stamp('timestep_before_hook')

                self._do_sampling(timestep=self._total_timestep)
                gt.stamp('sample')
                sample_times.append(time.time() - t0)
                t0 = time.time()
                if self.ready_to_train:
                    self._do_training_repeats(timestep=self._total_timestep)
                gt.stamp('train')
                # print("Train time: ", time.time() - t0)

                self._timestep_after_hook()
                gt.stamp('timestep_after_hook')

            print("Average Sample Time: ", np.mean(np.array(sample_times)))
            print("Step count", self._sample_count)

            training_paths_per_policy = self._training_paths()
            # self.sampler.get_last_n_paths(
            #     math.ceil(self._epoch_length / self.sampler._max_path_length))
            gt.stamp('training_paths')
            evaluation_paths_per_policy = self._evaluation_paths()
            if self._eval_n_episodes < 1:
                # If we don't choose to do evaluations, set the eval paths as
                # the fake ones generated on the first iteration.
                # NOTE: If you do this, however, all logged evaluation metrics are garbage.
                evaluation_paths_per_policy = self.fake_eval_paths
            gt.stamp('evaluation_paths')

            # Overwrite with fake trajectories
            empty_policies = []
            for tpn, tp in enumerate(training_paths_per_policy):
                if len(tp) == 0:
                    empty_policies.append(tpn)
                    training_paths_per_policy[tpn] = evaluation_paths_per_policy[tpn]

            training_metrics_per_policy = self._evaluate_rollouts(
                training_paths_per_policy, training_environment)
            gt.stamp('training_metrics')

            if evaluation_paths_per_policy:
                evaluation_metrics_per_policy = self._evaluate_rollouts(
                    evaluation_paths_per_policy, evaluation_environment)
                gt.stamp('evaluation_metrics')
            else:
                evaluation_metrics_per_policy = [{} for _ in range(self._num_goals)]

            self._epoch_after_hook(training_paths_per_policy)
            gt.stamp('epoch_after_hook')

            t0 = time.time()

            sampler_diagnostics_per_policy = [
                self._samplers[i].get_diagnostics() for i in range(self._num_goals)]

            diagnostics = self.get_diagnostics(
                iteration=self._total_timestep,
                batches=self._evaluation_batches(),
                training_paths_per_policy=training_paths_per_policy,
                evaluation_paths_per_policy=evaluation_paths_per_policy)

            time_diagnostics = gt.get_times().stamps.itrs

            print("Basic diagnostics: ", time.time() - t0)
            print("Sample count: ", self._sample_count)

            diagnostics.update(OrderedDict((
                *(
                    (f'times/{key}', time_diagnostics[key][-1])
                    for key in sorted(time_diagnostics.keys())
                ),
                ('epoch', self._epoch),
                ('timestep', self._timestep),
                ('timesteps_total', self._total_timestep),
                ('train-steps', self._num_train_steps),
            )))

            diagnostics.update({
                f"phase_{i}/episode_count_this_epoch": self._num_paths_per_phase[i]
                for i in range(self._num_goals)
            })
            print("Other basic diagnostics: ", time.time() - t0)
            for i, (evaluation_metrics, training_metrics, sampler_diagnostics) in (
                enumerate(zip(evaluation_metrics_per_policy,
                              training_metrics_per_policy,
                              sampler_diagnostics_per_policy))):
                if i not in empty_policies:
                    if self._eval_n_episodes >= 1:
                        # Only log evaluation metrics if they are actually meaningful
                        diagnostics.update({
                            f'evaluation_{i}/{key}': evaluation_metrics[key]
                            for key in sorted(evaluation_metrics.keys())
                        })
                    diagnostics.update(OrderedDict((
                        *(
                            (f'training_{i}/{key}', training_metrics[key])
                            for key in sorted(training_metrics.keys())
                        ),
                        *(
                            (f'sampler_{i}/{key}', sampler_diagnostics[key])
                            for key in sorted(sampler_diagnostics.keys())
                        ),
                    )))
                else:
                    if self._eval_n_episodes >= 1:
                        diagnostics.update({
                            f'evaluation_{i}/{key}': evaluation_metrics[key]
                            for key in sorted(evaluation_metrics.keys())
                        })
                    diagnostics.update(OrderedDict((
                        *(
                            (f'training_{i}/{key}', -1000)
                            for key in sorted(training_metrics.keys())
                        ),
                        *(
                            (f'sampler_{i}/{key}', -1000)
                            for key in sorted(sampler_diagnostics.keys())
                        ),
                    )))


            # if self._eval_render_kwargs and hasattr(
            #         evaluation_environment, 'render_rollouts'):
            #     # TODO(hartikainen): Make this consistent such that there's no
            #     # need for the hasattr check.
            #     training_environment.render_rollouts(evaluation_paths)

            yield diagnostics
            print("Diagnostic time: ",  time.time() - t0)

        for i in range(self._num_goals):
            self._samplers[i].terminate()

        self._training_after_hook()

        del evaluation_paths_per_policy

        yield {'done': True, **diagnostics}