torchbenchmark/models/drq/drq.py (200 lines of code) (raw):

import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import copy import math from . import utils class Encoder(nn.Module): """Convolutional encoder for image-based observations.""" def __init__(self, obs_shape, feature_dim): super().__init__() assert len(obs_shape) == 3 self.num_layers = 4 self.num_filters = 32 self.output_dim = 35 self.output_logits = False self.feature_dim = feature_dim self.convs = nn.ModuleList([ nn.Conv2d(obs_shape[0], self.num_filters, 3, stride=2), nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1), nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1), nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1) ]) self.head = nn.Sequential( nn.Linear(self.num_filters * 35 * 35, self.feature_dim), nn.LayerNorm(self.feature_dim)) self.outputs = dict() def forward_conv(self, obs): obs = obs / 255. self.outputs['obs'] = obs conv = torch.relu(self.convs[0](obs)) self.outputs['conv1'] = conv for i in range(1, self.num_layers): conv = torch.relu(self.convs[i](conv)) self.outputs['conv%s' % (i + 1)] = conv h = conv.view(conv.size(0), -1) return h def forward(self, obs, detach=False): h = self.forward_conv(obs) if detach: h = h.detach() out = self.head(h) if not self.output_logits: out = torch.tanh(out) self.outputs['out'] = out return out def copy_conv_weights_from(self, source): """Tie convolutional layers""" for i in range(self.num_layers): utils.tie_weights(src=source.convs[i], trg=self.convs[i]) def log(self, logger, step): pass class Actor(nn.Module): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth, log_std_bounds): super().__init__() self.encoder = Encoder(*encoder_cfg) self.log_std_bounds = log_std_bounds self.trunk = utils.mlp(self.encoder.feature_dim, hidden_dim, 2 * action_shape[0], hidden_depth) self.outputs = dict() self.apply(utils.weight_init) def forward(self, obs, detach_encoder=False): obs = self.encoder(obs, detach=detach_encoder) mu, log_std = self.trunk(obs).chunk(2, dim=-1) # constrain log_std inside [log_std_min, log_std_max] log_std = torch.tanh(log_std) log_std_min, log_std_max = self.log_std_bounds log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) std = log_std.exp() self.outputs['mu'] = mu self.outputs['std'] = std dist = utils.SquashedNormal(mu, std) return dist def log(self, logger, step): pass class Critic(nn.Module): """Critic network, employes double Q-learning.""" def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth): super().__init__() self.encoder = Encoder(*encoder_cfg) self.Q1 = utils.mlp(self.encoder.feature_dim + action_shape[0], hidden_dim, 1, hidden_depth) self.Q2 = utils.mlp(self.encoder.feature_dim + action_shape[0], hidden_dim, 1, hidden_depth) self.outputs = dict() self.apply(utils.weight_init) def forward(self, obs, action, detach_encoder=False): assert obs.size(0) == action.size(0) obs = self.encoder(obs, detach=detach_encoder) obs_action = torch.cat([obs, action], dim=-1) q1 = self.Q1(obs_action) q2 = self.Q2(obs_action) self.outputs['q1'] = q1 self.outputs['q2'] = q2 return q1, q2 def log(self, logger, step): pass class DRQAgent(object): """Data regularized Q: actor-critic method for learning from pixels.""" def __init__(self, cfg, device, obs_shape, action_shape, action_range): self.action_range = action_range self.device = torch.device(device) self.discount = cfg.discount self.critic_tau = cfg.critic_tau self.actor_update_frequency = cfg.actor_update_frequency self.critic_target_update_frequency = cfg.critic_target_update_frequency self.batch_size = cfg.batch_size encoder_cfg = (obs_shape, cfg.feature_dim) self.actor = Actor(encoder_cfg=encoder_cfg, action_shape=action_shape, hidden_dim=cfg.hidden_dim, hidden_depth=cfg.hidden_depth, log_std_bounds=cfg.log_std_bounds).to(self.device) self.critic = Critic(encoder_cfg=encoder_cfg, action_shape=action_shape, hidden_dim=cfg.hidden_dim, hidden_depth=cfg.hidden_depth).to(self.device) self.critic_target = Critic(encoder_cfg=encoder_cfg, action_shape=action_shape, hidden_dim=cfg.hidden_dim, hidden_depth=cfg.hidden_depth).to(self.device) self.critic_target.load_state_dict(self.critic.state_dict()) # tie conv layers between actor and critic self.actor.encoder.copy_conv_weights_from(self.critic.encoder) self.log_alpha = torch.tensor(np.log(cfg.init_temperature)).to(device) self.log_alpha.requires_grad = True # set target entropy to -|A| self.target_entropy = -action_shape[0] # optimizers self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=cfg.lr) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=cfg.lr) self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=cfg.lr) self.train() self.critic_target.train() def train(self, training=True): self.training = training self.actor.train(training) self.critic.train(training) @property def alpha(self): return self.log_alpha.exp() def act(self, obs, sample=False): obs = torch.FloatTensor(obs).to(self.device) obs = obs.unsqueeze(0) dist = self.actor(obs) action = dist.sample() if sample else dist.mean action = action.clamp(*self.action_range) assert action.ndim == 2 and action.shape[0] == 1 return utils.to_np(action[0]) def update_critic(self, obs, obs_aug, action, reward, next_obs, next_obs_aug, not_done, logger, step): with torch.no_grad(): dist = self.actor(next_obs) next_action = dist.rsample() log_prob = dist.log_prob(next_action).sum(-1, keepdim=True) target_Q1, target_Q2 = self.critic_target(next_obs, next_action) target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_prob target_Q = reward + (not_done * self.discount * target_V) dist_aug = self.actor(next_obs_aug) next_action_aug = dist_aug.rsample() log_prob_aug = dist_aug.log_prob(next_action_aug).sum(-1, keepdim=True) target_Q1, target_Q2 = self.critic_target(next_obs_aug, next_action_aug) target_V = torch.min( target_Q1, target_Q2) - self.alpha.detach() * log_prob_aug target_Q_aug = reward + (not_done * self.discount * target_V) target_Q = (target_Q + target_Q_aug) / 2 # get current Q estimates current_Q1, current_Q2 = self.critic(obs, action) critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) Q1_aug, Q2_aug = self.critic(obs_aug, action) critic_loss += F.mse_loss(Q1_aug, target_Q) + F.mse_loss( Q2_aug, target_Q) # logger.log('train_critic/loss', critic_loss, step) # Optimize the critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() self.critic.log(logger, step) def update_actor_and_alpha(self, obs, logger, step): # detach conv filters, so we don't update them with the actor loss dist = self.actor(obs, detach_encoder=True) action = dist.rsample() log_prob = dist.log_prob(action).sum(-1, keepdim=True) # detach conv filters, so we don't update them with the actor loss actor_Q1, actor_Q2 = self.critic(obs, action, detach_encoder=True) actor_Q = torch.min(actor_Q1, actor_Q2) actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean() # optimize the actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.actor.log(logger, step) self.log_alpha_optimizer.zero_grad() alpha_loss = (self.alpha * (-log_prob - self.target_entropy).detach()).mean() alpha_loss.backward() self.log_alpha_optimizer.step() def update(self, replay_buffer, logger, step): obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample( self.batch_size) self.update_critic(obs, obs_aug, action, reward, next_obs, next_obs_aug, not_done, logger, step) if step % self.actor_update_frequency == 0: self.update_actor_and_alpha(obs, logger, step) if step % self.critic_target_update_frequency == 0: utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)