in drqv2.py [0:0]
def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
hidden_dim, critic_target_tau, num_expl_steps,
update_every_steps, stddev_schedule, stddev_clip, use_tb):
self.device = device
self.critic_target_tau = critic_target_tau
self.update_every_steps = update_every_steps
self.use_tb = use_tb
self.num_expl_steps = num_expl_steps
self.stddev_schedule = stddev_schedule
self.stddev_clip = stddev_clip
# models
self.encoder = Encoder(obs_shape).to(device)
self.actor = Actor(self.encoder.repr_dim, action_shape, feature_dim,
hidden_dim).to(device)
self.critic = Critic(self.encoder.repr_dim, action_shape, feature_dim,
hidden_dim).to(device)
self.critic_target = Critic(self.encoder.repr_dim, action_shape,
feature_dim, hidden_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
# optimizers
self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
# data augmentation
self.aug = RandomShiftsAug(pad=4)
self.train()
self.critic_target.train()