in torchbeast/polybeast_learner.py [0:0]
def __init__(self, num_actions, use_lstm=False):
super(Net, self).__init__()
self.num_actions = num_actions
self.use_lstm = use_lstm
self.feat_convs = []
self.resnet1 = []
self.resnet2 = []
self.convs = []
input_channels = 4
for num_ch in [16, 32, 32]:
feats_convs = []
feats_convs.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.feat_convs.append(nn.Sequential(*feats_convs))
input_channels = num_ch
for i in range(2):
resnet_block = []
resnet_block.append(nn.ReLU())
resnet_block.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
resnet_block.append(nn.ReLU())
resnet_block.append(
nn.Conv2d(
in_channels=input_channels,
out_channels=num_ch,
kernel_size=3,
stride=1,
padding=1,
)
)
if i == 0:
self.resnet1.append(nn.Sequential(*resnet_block))
else:
self.resnet2.append(nn.Sequential(*resnet_block))
self.feat_convs = nn.ModuleList(self.feat_convs)
self.resnet1 = nn.ModuleList(self.resnet1)
self.resnet2 = nn.ModuleList(self.resnet2)
self.fc = nn.Linear(3872, 256)
# FC output size + last reward.
core_output_size = self.fc.out_features + 1
if use_lstm:
self.core = nn.LSTM(core_output_size, 256, num_layers=1)
core_output_size = 256
self.policy = nn.Linear(core_output_size, self.num_actions)
self.baseline = nn.Linear(core_output_size, 1)