data/envs/metaworld/train.py (65 lines of code) (raw):

import argparse import sys from typing import Dict, Optional import gymnasium as gym import metaworld # noqa: F401 from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args from sample_factory.envs.env_utils import register_env from sample_factory.train import run_rl def make_custom_env( full_env_name: str, cfg: Optional[Dict] = None, env_config: Optional[Dict] = None, render_mode: Optional[str] = None, ) -> gym.Env: return gym.make(full_env_name, render_mode=render_mode) def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.set_defaults( batched_sampling=False, device="cpu", num_workers=8, num_envs_per_worker=8, worker_num_splits=2, train_for_env_steps=10_000_000, encoder_mlp_layers=[64, 64], env_frameskip=1, nonlinearity="tanh", batch_size=1024, kl_loss_coeff=0.1, use_rnn=False, adaptive_stddev=False, policy_initialization="torch_default", restart_behavior="restart", reward_scale=0.1, rollout=64, max_grad_norm=3.5, num_epochs=2, num_batches_per_epoch=4, ppo_clip_ratio=0.2, value_loss_coeff=1.3, exploration_loss_coeff=0.0, learning_rate=0.00295, lr_schedule="linear_decay", shuffle_minibatches=False, gamma=0.99, gae_lambda=0.95, with_vtrace=False, recurrence=1, normalize_input=True, normalize_returns=True, value_bootstrap=True, experiment_summaries_interval=3, save_every_sec=15, serial_mode=False, async_rl=False, ) return parser def main() -> int: parser, _ = parse_sf_args(argv=None, evaluation=False) parser = override_defaults(parser) cfg = parse_full_cfg(parser) register_env(cfg.env, make_custom_env) status = run_rl(cfg) return status if __name__ == "__main__": sys.exit(main())