in agents/nets.py [0:0]
def __init__(self,
action_size,
action_layers=1,
action_hidden_size=256,
fusion_place='last',
embedding_dim=None,
repr_merging_method=None,
n_regressor_outputs=None,
embeddor_type=None,
backbone='resnet18'):
super().__init__()
if backbone == "resnet18":
net = torchvision.models.resnet18(pretrained=False)
self._resnet_embedding_dim = 512
n_channels = (64, 64, 128, 256)
elif backbone == 'resnet50':
net = torchvision.models.resnet50(pretrained=False)
self._resnet_embedding_dim = 2048
n_channels = (64, 256, 512, 1024)
else:
raise ValueError("Invalid backbone: " + backbone)
conv1 = nn.Conv2d(phyre.NUM_COLORS,
64,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.register_buffer('embed_weights', torch.eye(phyre.NUM_COLORS))
self.stem = nn.Sequential(conv1, net.bn1, net.relu, net.maxpool)
self.stages = nn.ModuleList(
[net.layer1, net.layer2, net.layer3, net.layer4])
def build_film(output_size):
return FilmActionNetwork(action_size,
output_size,
hidden_size=action_hidden_size,
num_layers=action_layers)
assert fusion_place in ('first', 'last', 'all', 'none', 'last_single')
self.last_network = None
if fusion_place == 'all':
self.action_networks = nn.ModuleList(
[build_film(size) for size in n_channels])
elif fusion_place == 'last':
# Save module as attribute.
self._action_network = build_film(n_channels[3])
self.action_networks = nn.ModuleList(
[None, None, None, self._action_network])
elif fusion_place == 'first':
# Save module as attribute.
self._action_network = build_film(n_channels[0])
self.action_networks = nn.ModuleList(
[self._action_network, None, None, None])
elif fusion_place == 'last_single':
# Save module as attribute.
self.last_network = build_film(self._resnet_embedding_dim)
self.action_networks = [None, None, None, None]
elif fusion_place == 'none':
self.action_networks = [None, None, None, None]
else:
raise Exception('Unknown fusion place: %s' % fusion_place)
self.reason = nn.Linear(self._resnet_embedding_dim, 1)
if embeddor_type == "linear":
self.embeddor = nn.Linear(self._resnet_embedding_dim, embedding_dim)
elif embeddor_type == "mlp":
self.embeddor = nn.Sequential(
nn.Linear(self._resnet_embedding_dim, embedding_dim), nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim))
elif embeddor_type == "mlp_2hidden":
self.embeddor = nn.Sequential(
nn.Linear(self._resnet_embedding_dim, embedding_dim), nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim), nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim))
elif embeddor_type == "none":
self.embeddor = lambda x: x # an indentity function
embedding_dim = self._resnet_embedding_dim
else:
if embeddor_type is not None:
raise ValueError(f"unknown embeddor type {embeddor_type}")
self.embeddor = None
self.repr_merging_method = repr_merging_method
self.n_regressor_outputs = n_regressor_outputs
if repr_merging_method == "mul":
self.distance_regressor = nn.Linear(embedding_dim,
n_regressor_outputs)
elif repr_merging_method == "concat":
self.distance_regressor = nn.Linear(embedding_dim * 2,
n_regressor_outputs)
elif repr_merging_method == "outer":
self.distance_regressor = nn.Linear(embedding_dim**2,
n_regressor_outputs)
else:
self.distance_regressor = None