in train.py [0:0]
def make_agent(obs_shape, action_shape, args, device):
if args.agent == 'baseline':
agent = BaselineAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
encoder_stride=args.encoder_stride,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_weight_lambda=args.decoder_weight_lambda,
transition_model_type=args.transition_model_type,
num_layers=args.num_layers,
num_filters=args.num_filters
)
elif args.agent == 'bisim':
agent = BisimAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
encoder_stride=args.encoder_stride,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_weight_lambda=args.decoder_weight_lambda,
transition_model_type=args.transition_model_type,
num_layers=args.num_layers,
num_filters=args.num_filters,
bisim_coef=args.bisim_coef
)
elif args.agent == 'deepmdp':
agent = DeepMDPAgent(
obs_shape=obs_shape,
action_shape=action_shape,
device=device,
hidden_dim=args.hidden_dim,
discount=args.discount,
init_temperature=args.init_temperature,
alpha_lr=args.alpha_lr,
alpha_beta=args.alpha_beta,
actor_lr=args.actor_lr,
actor_beta=args.actor_beta,
actor_log_std_min=args.actor_log_std_min,
actor_log_std_max=args.actor_log_std_max,
actor_update_freq=args.actor_update_freq,
encoder_stride=args.encoder_stride,
critic_lr=args.critic_lr,
critic_beta=args.critic_beta,
critic_tau=args.critic_tau,
critic_target_update_freq=args.critic_target_update_freq,
encoder_type=args.encoder_type,
encoder_feature_dim=args.encoder_feature_dim,
encoder_lr=args.encoder_lr,
encoder_tau=args.encoder_tau,
decoder_type=args.decoder_type,
decoder_lr=args.decoder_lr,
decoder_update_freq=args.decoder_update_freq,
decoder_weight_lambda=args.decoder_weight_lambda,
transition_model_type=args.transition_model_type,
num_layers=args.num_layers,
num_filters=args.num_filters
)
if args.load_encoder:
model_dict = agent.actor.encoder.state_dict()
encoder_dict = torch.load(args.load_encoder)
encoder_dict = {k[8:]: v for k, v in encoder_dict.items() if 'encoder.' in k} # hack to remove encoder. string
agent.actor.encoder.load_state_dict(encoder_dict)
agent.critic.encoder.load_state_dict(encoder_dict)
return agent