in minihack/agent/polybeast/models/intrinsic.py [0:0]
def __init__(self, observation_shape, num_actions, flags, device):
super(RNDNet, self).__init__(
observation_shape, num_actions, flags, device
)
if self.equalize_input_dim:
raise NotImplementedError(
"rnd model does not support equalize_input_dim"
)
Y = 8 # number of output filters
# IMPLEMENTED HERE: RND net using the default feature extractor
self.rndtgt_embed = GlyphEmbedding(
flags.glyph_type,
flags.embedding_dim,
device,
flags.use_index_select,
).requires_grad_(False)
self.rndprd_embed = GlyphEmbedding(
flags.glyph_type,
flags.embedding_dim,
device,
flags.use_index_select,
)
if self.intrinsic_input not in ("crop_only", "glyph_only", "full"):
raise NotImplementedError(
"RND input type %s" % self.intrinsic_input
)
rnd_out_dim = 0
if self.intrinsic_input in ("crop_only", "full"):
self.rndtgt_extract_crop_representation = copy.deepcopy(
self.extract_crop_representation
).requires_grad_(False)
self.rndprd_extract_crop_representation = copy.deepcopy(
self.extract_crop_representation
)
rnd_out_dim += self.crop_dim ** 2 * Y # crop dim
if self.intrinsic_input in ("full", "glyph_only"):
self.rndtgt_extract_representation = copy.deepcopy(
self.extract_representation
).requires_grad_(False)
self.rndprd_extract_representation = copy.deepcopy(
self.extract_representation
)
rnd_out_dim += self.H * self.W * Y # glyph dim
if self.intrinsic_input == "full":
self.rndtgt_embed_features = nn.Sequential(
nn.Linear(self.num_features, self.k_dim),
nn.ELU(),
nn.Linear(self.k_dim, self.k_dim),
nn.ELU(),
).requires_grad_(False)
self.rndprd_embed_features = nn.Sequential(
nn.Linear(self.num_features, self.k_dim),
nn.ELU(),
nn.Linear(self.k_dim, self.k_dim),
nn.ELU(),
)
rnd_out_dim += self.k_dim # feature dim
if self.intrinsic_input == "full" and self.msg_model != "none":
# we only implement the lt_cnn msg model for RND for simplicity & speed
if self.msg_model != "lt_cnn":
logging.warning(
"msg.model set to %s, but RND overriding to lt_cnn for its input--"
"so the policy and RND are using different models for the messages"
% self.msg_model
)
self.rndtgt_char_lt = nn.Embedding(
NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR
).requires_grad_(False)
self.rndprd_char_lt = nn.Embedding(
NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR
)
# similar to Zhang et al, 2016
# Character-level Convolutional Networks for Text Classification
# https://arxiv.org/abs/1509.01626
# replace one-hot inputs with learned embeddings
self.rndtgt_conv1 = nn.Conv1d(
self.msg_edim, self.msg_hdim, kernel_size=7
).requires_grad_(False)
self.rndprd_conv1 = nn.Conv1d(
self.msg_edim, self.msg_hdim, kernel_size=7
)
# remaining convolutions, relus, pools, and a small FC network
self.rndtgt_conv2_6_fc = copy.deepcopy(
self.conv2_6_fc
).requires_grad_(False)
self.rndprd_conv2_6_fc = copy.deepcopy(self.conv2_6_fc)
rnd_out_dim += self.msg_hdim
self.rndtgt_fc = (
nn.Sequential( # matching RND paper making this smaller
nn.Linear(rnd_out_dim, self.h_dim)
).requires_grad_(False)
)
self.rndprd_fc = (
nn.Sequential( # matching RND paper making this bigger
nn.Linear(rnd_out_dim, self.h_dim),
nn.ELU(),
nn.Linear(self.h_dim, self.h_dim),
nn.ELU(),
nn.Linear(self.h_dim, self.h_dim),
)
)
modules_to_init = [
self.rndtgt_embed,
self.rndprd_embed,
self.rndtgt_fc,
self.rndprd_fc,
]
SQRT_2 = math.sqrt(2)
def init(p):
if isinstance(p, nn.Conv2d) or isinstance(p, nn.Linear):
# init method used in paper
nn.init.orthogonal_(p.weight, SQRT_2)
p.bias.data.zero_()
if isinstance(p, nn.Embedding):
nn.init.orthogonal_(p.weight, SQRT_2)
# manually init all to orthogonal dist
if self.intrinsic_input in ("full", "crop_only"):
modules_to_init.append(self.rndtgt_extract_crop_representation)
modules_to_init.append(self.rndprd_extract_crop_representation)
if self.intrinsic_input in ("full", "glyph_only"):
modules_to_init.append(self.rndtgt_extract_representation)
modules_to_init.append(self.rndprd_extract_representation)
if self.intrinsic_input == "full":
modules_to_init.append(self.rndtgt_embed_features)
modules_to_init.append(self.rndprd_embed_features)
if self.msg_model != "none":
modules_to_init.append(self.rndtgt_conv2_6_fc)
modules_to_init.append(self.rndprd_conv2_6_fc)
for m in modules_to_init:
for p in m.modules():
init(p)