torchbenchmark/models/soft_actor_critic/nets.py (236 lines of code) (raw):
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import distributions as pyd
from torch import nn
from . import utils
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
# delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
assert m.weight.size(2) == m.weight.size(3)
m.weight.data.fill_(0.0)
m.bias.data.fill_(0.0)
mid = m.weight.size(2) // 2
gain = nn.init.calculate_gain("relu")
nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
class BigPixelEncoder(nn.Module):
def __init__(self, obs_shape, out_dim=50):
super().__init__()
channels = obs_shape[0]
self.conv1 = nn.Conv2d(channels, 32, kernel_size=3, stride=2)
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
output_height, output_width = utils.compute_conv_output(
obs_shape[1:], kernel_size=(3, 3), stride=(2, 2)
)
for _ in range(3):
output_height, output_width = utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)
self.fc = nn.Linear(output_height * output_width * 32, out_dim)
self.ln = nn.LayerNorm(out_dim)
self.apply(weight_init)
def forward(self, obs):
obs /= 255.0
x = F.relu(self.conv1(obs))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = x.view(x.size(0), -1)
x = self.fc(x)
x = self.ln(x)
state = torch.tanh(x)
return state
class SmallPixelEncoder(nn.Module):
def __init__(self, obs_shape, out_dim=50):
super().__init__()
channels = obs_shape[0]
self.conv1 = nn.Conv2d(channels, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
output_height, output_width = utils.compute_conv_output(
obs_shape[1:], kernel_size=(8, 8), stride=(4, 4)
)
output_height, output_width = utils.compute_conv_output(
(output_height, output_width), kernel_size=(4, 4), stride=(2, 2)
)
output_height, output_width = utils.compute_conv_output(
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
)
self.fc = nn.Linear(output_height * output_width * 64, out_dim)
self.apply(weight_init)
def forward(self, obs):
obs /= 255.0
x = F.relu(self.conv1(obs))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1)
state = self.fc(x)
return state
class StochasticActor(nn.Module):
def __init__(
self,
state_space_size,
act_space_size,
log_std_low=-10,
log_std_high=2,
hidden_size=1024,
dist_impl="pyd",
):
super().__init__()
assert dist_impl in ["pyd", "beta"]
self.fc1 = nn.Linear(state_space_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 2 * act_space_size)
self.log_std_low = log_std_low
self.log_std_high = log_std_high
self.apply(weight_init)
self.dist_impl = dist_impl
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
out = self.fc3(x)
mu, log_std = out.chunk(2, dim=1)
if self.dist_impl == "pyd":
log_std = torch.tanh(log_std)
log_std = self.log_std_low + 0.5 * (
self.log_std_high - self.log_std_low
) * (log_std + 1)
std = log_std.exp()
dist = SquashedNormal(mu, std)
elif self.dist_impl == "beta":
out = 1.0 + F.softplus(out)
alpha, beta = out.chunk(2, dim=1)
dist = BetaDist(alpha, beta)
return dist
class BigCritic(nn.Module):
def __init__(self, state_space_size, act_space_size, hidden_size=1024):
super().__init__()
self.fc1 = nn.Linear(state_space_size + act_space_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)
self.apply(weight_init)
def forward(self, state, action):
x = F.relu(self.fc1(torch.cat((state, action), dim=1)))
x = F.relu(self.fc2(x))
out = self.fc3(x)
return out
class BaselineActor(nn.Module):
def __init__(self, state_size, action_size, hidden_size=400):
super().__init__()
self.fc1 = nn.Linear(state_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, action_size)
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
act = torch.tanh(self.out(x))
return act
class BaselineCritic(nn.Module):
def __init__(self, state_size, action_size):
super().__init__()
self.fc1 = nn.Linear(state_size + action_size, 400)
self.fc2 = nn.Linear(400, 300)
self.out = nn.Linear(300, 1)
def forward(self, state, action):
x = torch.cat((state, action), dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
val = self.out(x)
return val
class BetaDist(pyd.transformed_distribution.TransformedDistribution):
class _BetaDistTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)
def __eq__(self, other):
return isinstance(other, _BetaDistTransform)
def _inverse(self, y):
return (y.clamp(-0.99, 0.99) + 1.0) / 2.0
def _call(self, x):
return (2.0 * x) - 1.0
def log_abs_det_jacobian(self, x, y):
# return log det jacobian |dy/dx| given input and output
return torch.Tensor([math.log(2.0)]).to(x.device)
def __init__(self, alpha, beta):
self.base_dist = pyd.beta.Beta(alpha, beta)
transforms = [self._BetaDistTransform()]
super().__init__(self.base_dist, transforms)
@property
def mean(self):
mu = self.base_dist.mean
for tr in self.transforms:
mu = tr(mu)
return mu
"""
Credit for actor distribution code: https://github.com/denisyarats/pytorch_sac/blob/master/agent/actor.py
"""
class TanhTransform(pyd.transforms.Transform):
domain = pyd.constraints.real
codomain = pyd.constraints.interval(-1.0, 1.0)
bijective = True
sign = +1
def __init__(self, cache_size=1):
super().__init__(cache_size=cache_size)
@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())
def __eq__(self, other):
return isinstance(other, TanhTransform)
def _call(self, x):
return x.tanh()
def _inverse(self, y):
return self.atanh(y.clamp(-0.99, 0.99))
def log_abs_det_jacobian(self, x, y):
return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
def __init__(self, loc, scale):
self.loc = loc
self.scale = scale
self.base_dist = pyd.Normal(loc, scale)
transforms = [TanhTransform()]
super().__init__(self.base_dist, transforms)
@property
def mean(self):
mu = self.loc
for tr in self.transforms:
mu = tr(mu)
return mu
class GracBaselineActor(nn.Module):
def __init__(self, obs_size, action_size):
super().__init__()
self.fc1 = nn.Linear(obs_size, 400)
self.fc2 = nn.Linear(400, 300)
self.fc_mean = nn.Linear(300, action_size)
self.fc_std = nn.Linear(300, action_size)
def forward(self, state, stochastic=False):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
mean = torch.tanh(self.fc_mean(x))
std = F.softplus(self.fc_std(x)) + 1e-3
dist = pyd.Normal(mean, std)
return dist
class BaselineDiscreteActor(nn.Module):
def __init__(self, obs_shape, action_size, hidden_size=300):
super().__init__()
self.fc1 = nn.Linear(obs_shape, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.act_p = nn.Linear(hidden_size, action_size)
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
act_p = F.softmax(self.act_p(x), dim=1)
dist = pyd.categorical.Categorical(act_p)
return dist
class BaselineDiscreteCritic(nn.Module):
def __init__(self, obs_shape, action_shape, hidden_size=300):
super().__init__()
self.fc1 = nn.Linear(obs_shape, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.out = nn.Linear(hidden_size, action_shape)
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
vals = self.out(x)
return vals