def __init__()

in minihack/agent/polybeast/models/intrinsic.py [0:0]


    def __init__(self, observation_shape, num_actions, flags, device):
        super(RNDNet, self).__init__(
            observation_shape, num_actions, flags, device
        )

        if self.equalize_input_dim:
            raise NotImplementedError(
                "rnd model does not support equalize_input_dim"
            )

        Y = 8  # number of output filters

        # IMPLEMENTED HERE: RND net using the default feature extractor
        self.rndtgt_embed = GlyphEmbedding(
            flags.glyph_type,
            flags.embedding_dim,
            device,
            flags.use_index_select,
        ).requires_grad_(False)
        self.rndprd_embed = GlyphEmbedding(
            flags.glyph_type,
            flags.embedding_dim,
            device,
            flags.use_index_select,
        )

        if self.intrinsic_input not in ("crop_only", "glyph_only", "full"):
            raise NotImplementedError(
                "RND input type %s" % self.intrinsic_input
            )

        rnd_out_dim = 0
        if self.intrinsic_input in ("crop_only", "full"):
            self.rndtgt_extract_crop_representation = copy.deepcopy(
                self.extract_crop_representation
            ).requires_grad_(False)
            self.rndprd_extract_crop_representation = copy.deepcopy(
                self.extract_crop_representation
            )

            rnd_out_dim += self.crop_dim ** 2 * Y  # crop dim

        if self.intrinsic_input in ("full", "glyph_only"):
            self.rndtgt_extract_representation = copy.deepcopy(
                self.extract_representation
            ).requires_grad_(False)
            self.rndprd_extract_representation = copy.deepcopy(
                self.extract_representation
            )
            rnd_out_dim += self.H * self.W * Y  # glyph dim

            if self.intrinsic_input == "full":
                self.rndtgt_embed_features = nn.Sequential(
                    nn.Linear(self.num_features, self.k_dim),
                    nn.ELU(),
                    nn.Linear(self.k_dim, self.k_dim),
                    nn.ELU(),
                ).requires_grad_(False)
                self.rndprd_embed_features = nn.Sequential(
                    nn.Linear(self.num_features, self.k_dim),
                    nn.ELU(),
                    nn.Linear(self.k_dim, self.k_dim),
                    nn.ELU(),
                )
                rnd_out_dim += self.k_dim  # feature dim

        if self.intrinsic_input == "full" and self.msg_model != "none":
            # we only implement the lt_cnn msg model for RND for simplicity & speed
            if self.msg_model != "lt_cnn":
                logging.warning(
                    "msg.model set to %s, but RND overriding to lt_cnn for its input--"
                    "so the policy and RND are using different models for the messages"
                    % self.msg_model
                )

            self.rndtgt_char_lt = nn.Embedding(
                NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR
            ).requires_grad_(False)
            self.rndprd_char_lt = nn.Embedding(
                NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR
            )

            # similar to Zhang et al, 2016
            # Character-level Convolutional Networks for Text Classification
            # https://arxiv.org/abs/1509.01626
            # replace one-hot inputs with learned embeddings
            self.rndtgt_conv1 = nn.Conv1d(
                self.msg_edim, self.msg_hdim, kernel_size=7
            ).requires_grad_(False)
            self.rndprd_conv1 = nn.Conv1d(
                self.msg_edim, self.msg_hdim, kernel_size=7
            )

            # remaining convolutions, relus, pools, and a small FC network
            self.rndtgt_conv2_6_fc = copy.deepcopy(
                self.conv2_6_fc
            ).requires_grad_(False)
            self.rndprd_conv2_6_fc = copy.deepcopy(self.conv2_6_fc)
            rnd_out_dim += self.msg_hdim

        self.rndtgt_fc = (
            nn.Sequential(  # matching RND paper making this smaller
                nn.Linear(rnd_out_dim, self.h_dim)
            ).requires_grad_(False)
        )
        self.rndprd_fc = (
            nn.Sequential(  # matching RND paper making this bigger
                nn.Linear(rnd_out_dim, self.h_dim),
                nn.ELU(),
                nn.Linear(self.h_dim, self.h_dim),
                nn.ELU(),
                nn.Linear(self.h_dim, self.h_dim),
            )
        )

        modules_to_init = [
            self.rndtgt_embed,
            self.rndprd_embed,
            self.rndtgt_fc,
            self.rndprd_fc,
        ]

        SQRT_2 = math.sqrt(2)

        def init(p):
            if isinstance(p, nn.Conv2d) or isinstance(p, nn.Linear):
                # init method used in paper
                nn.init.orthogonal_(p.weight, SQRT_2)
                p.bias.data.zero_()
            if isinstance(p, nn.Embedding):
                nn.init.orthogonal_(p.weight, SQRT_2)

        # manually init all to orthogonal dist

        if self.intrinsic_input in ("full", "crop_only"):
            modules_to_init.append(self.rndtgt_extract_crop_representation)
            modules_to_init.append(self.rndprd_extract_crop_representation)
        if self.intrinsic_input in ("full", "glyph_only"):
            modules_to_init.append(self.rndtgt_extract_representation)
            modules_to_init.append(self.rndprd_extract_representation)
        if self.intrinsic_input == "full":
            modules_to_init.append(self.rndtgt_embed_features)
            modules_to_init.append(self.rndprd_embed_features)
            if self.msg_model != "none":
                modules_to_init.append(self.rndtgt_conv2_6_fc)
                modules_to_init.append(self.rndprd_conv2_6_fc)

        for m in modules_to_init:
            for p in m.modules():
                init(p)