in svg/dx.py [0:0]
def __init__(self,
env_name,
obs_dim, action_dim, action_range,
horizon, device,
detach_xt,
clip_grad_norm,
xu_enc_hidden_dim, xu_enc_hidden_depth,
x_dec_hidden_dim, x_dec_hidden_depth,
rec_type, rec_latent_dim, rec_num_layers,
lr):
super().__init__()
self.env_name = env_name
self.obs_dim = obs_dim
self.action_dim = action_dim
self.horizon = horizon
self.device = device
self.detach_xt = detach_xt
self.clip_grad_norm = clip_grad_norm
# Manually freeze the goal locations
if env_name == 'gym_petsReacher':
self.freeze_dims = torch.LongTensor([7,8,9])
elif env_name == 'gym_petsPusher':
self.freeze_dims = torch.LongTensor([20,21,22])
else:
self.freeze_dims = None
self.rec_type = rec_type
self.rec_num_layers = rec_num_layers
self.rec_latent_dim = rec_latent_dim
self.xu_enc = utils.mlp(
obs_dim+action_dim, xu_enc_hidden_dim, rec_latent_dim, xu_enc_hidden_depth)
self.x_dec = utils.mlp(
rec_latent_dim, x_dec_hidden_dim, obs_dim, x_dec_hidden_depth)
self.apply(utils.weight_init) # Don't apply this to the recurrent unit.
mods = [self.xu_enc, self.x_dec]
if rec_num_layers > 0:
if rec_type == 'LSTM':
self.rec = nn.LSTM(
rec_latent_dim, rec_latent_dim, num_layers=rec_num_layers)
elif rec_type == 'GRU':
self.rec = nn.GRU(
rec_latent_dim, rec_latent_dim, num_layers=rec_num_layers)
else:
assert False
mods.append(self.rec)
params = utils.get_params(mods)
self.opt = torch.optim.Adam(params, lr=lr)