in minihack/agent/polybeast/models/intrinsic.py [0:0]
def __init__(self, observation_shape, num_actions, flags, device):
super(RIDENet, self).__init__(
observation_shape, num_actions, flags, device
)
if flags.msg.model != "none":
raise NotImplementedError(
"model=%s + msg.model=%s" % (flags.model, flags.msg.model)
)
self.forward_dynamics_model = ForwardDynamicsNet(
num_actions,
flags.ride.hidden_dim,
flags.hidden_dim,
flags.hidden_dim,
)
self.inverse_dynamics_model = InverseDynamicsNet(
num_actions,
flags.ride.hidden_dim,
flags.hidden_dim,
flags.hidden_dim,
)
Y = 8 # number of output filters
# IMPLEMENTED HERE: RIDE net using the default feature extractor
self.ride_embed = GlyphEmbedding(
flags.glyph_type,
flags.embedding_dim,
device,
flags.use_index_select,
)
if self.intrinsic_input not in ("crop_only", "glyph_only", "full"):
raise NotImplementedError(
"RIDE input type %s" % self.intrinsic_input
)
ride_out_dim = 0
if self.intrinsic_input in ("crop_only", "full"):
self.ride_extract_crop_representation = copy.deepcopy(
self.extract_crop_representation
)
ride_out_dim += self.crop_dim ** 2 * Y # crop dim
if self.intrinsic_input in ("full", "glyph_only"):
self.ride_extract_representation = copy.deepcopy(
self.extract_representation
)
ride_out_dim += self.H * self.W * Y # glyph dim
if self.intrinsic_input == "full":
self.ride_embed_features = nn.Sequential(
nn.Linear(self.num_features, self.k_dim),
nn.ELU(),
nn.Linear(self.k_dim, self.k_dim),
nn.ELU(),
)
ride_out_dim += self.k_dim # feature dim
self.ride_fc = nn.Sequential(
nn.Linear(ride_out_dim, self.h_dim),
# nn.ELU(),
# nn.Linear(self.h_dim, self.h_dim),
# nn.ELU(),
# nn.Linear(self.h_dim, self.h_dim),
)
# reinitialize all deep-copied layers
modules_to_init = []
if self.intrinsic_input in ("full", "crop_only"):
modules_to_init.append(self.ride_extract_crop_representation)
if self.intrinsic_input in ("full", "glyph_only"):
modules_to_init.append(self.ride_extract_representation)
for m in modules_to_init:
for p in m.modules():
if isinstance(p, nn.Conv2d):
p.reset_parameters()