def __init__()

in minihack/agent/polybeast/models/base.py [0:0]


    def __init__(self, observation_shape, num_actions, flags, device):
        super(BaseNet, self).__init__()

        self.flags = flags

        self.observation_shape = observation_shape

        self.H = observation_shape[0]
        self.W = observation_shape[1]

        self.num_actions = num_actions
        self.use_lstm = flags.use_lstm

        self.k_dim = flags.embedding_dim
        self.h_dim = flags.hidden_dim

        self.crop_model = flags.crop_model
        self.crop_dim = flags.crop_dim

        self.num_features = NUM_FEATURES

        self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim, device)

        self.glyph_type = flags.glyph_type
        self.glyph_embedding = GlyphEmbedding(
            flags.glyph_type,
            flags.embedding_dim,
            device,
            flags.use_index_select,
        )

        K = flags.embedding_dim  # number of input filters
        F = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 16  # number of intermediate filters
        self.Y = 8  # number of output filters
        L = flags.layers  # number of convnet layers

        in_channels = [K] + [M] * (L - 1)
        out_channels = [M] * (L - 1) + [self.Y]

        def interleave(xs, ys):
            return [val for pair in zip(xs, ys) for val in pair]

        conv_extract = [
            nn.Conv2d(
                in_channels=in_channels[i],
                out_channels=out_channels[i],
                kernel_size=(F, F),
                stride=S,
                padding=P,
            )
            for i in range(L)
        ]

        self.extract_representation = nn.Sequential(
            *interleave(conv_extract, [nn.ELU()] * len(conv_extract))
        )

        if self.crop_model == "transformer":
            self.extract_crop_representation = TransformerEncoder(
                K,
                N=L,
                heads=8,
                height=self.crop_dim,
                width=self.crop_dim,
                device=device,
            )
        elif self.crop_model == "cnn":
            conv_extract_crop = [
                nn.Conv2d(
                    in_channels=in_channels[i],
                    out_channels=out_channels[i],
                    kernel_size=(F, F),
                    stride=S,
                    padding=P,
                )
                for i in range(L)
            ]

            self.extract_crop_representation = nn.Sequential(
                *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract))
            )

        # MESSAGING MODEL
        if "msg" not in flags:
            self.msg_model = "none"
        else:
            self.msg_model = flags.msg.model
            self.msg_hdim = flags.msg.hidden_dim
            self.msg_edim = flags.msg.embedding_dim
        if self.msg_model in ("gru", "lstm", "lt_cnn"):
            # character-based embeddings
            self.char_lt = nn.Embedding(
                NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR
            )
        else:
            # forward will set up one-hot inputs for the cnn, no lt needed
            pass

        if self.msg_model.endswith("cnn"):
            # from Zhang et al, 2016
            # Character-level Convolutional Networks for Text Classification
            # https://arxiv.org/abs/1509.01626
            if self.msg_model == "cnn":
                # inputs will be one-hot vectors, as done in paper
                self.conv1 = nn.Conv1d(NUM_CHARS, self.msg_hdim, kernel_size=7)
            elif self.msg_model == "lt_cnn":
                # replace one-hot inputs with learned embeddings
                self.conv1 = nn.Conv1d(
                    self.msg_edim, self.msg_hdim, kernel_size=7
                )
            else:
                raise NotImplementedError("msg.model == %s", flags.msg.model)

            # remaining convolutions, relus, pools, and a small FC network
            self.conv2_6_fc = nn.Sequential(
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=3, stride=3),
                # conv2
                nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=7),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=3, stride=3),
                # conv3
                nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
                nn.ReLU(),
                # conv4
                nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
                nn.ReLU(),
                # conv5
                nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
                nn.ReLU(),
                # conv6
                nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=3, stride=3),
                # fc receives -- [ B x h_dim x 5 ]
                Flatten(),
                nn.Linear(5 * self.msg_hdim, 2 * self.msg_hdim),
                nn.ReLU(),
                nn.Linear(2 * self.msg_hdim, self.msg_hdim),
            )  # final output -- [ B x h_dim x 5 ]
        elif self.msg_model in ("gru", "lstm"):

            def rnn(flag):
                return nn.LSTM if flag == "lstm" else nn.GRU

            self.char_rnn = rnn(self.msg_model)(
                self.msg_edim,
                self.msg_hdim // 2,
                batch_first=True,
                bidirectional=True,
            )
        elif self.msg_model != "none":
            raise NotImplementedError("msg.model == %s", flags.msg.model)

        self.embed_features = nn.Sequential(
            nn.Linear(self.num_features, self.k_dim),
            nn.ReLU(),
            nn.Linear(self.k_dim, self.k_dim),
            nn.ReLU(),
        )

        self.equalize_input_dim = flags.equalize_input_dim
        if not self.equalize_input_dim:
            # just added up the output dimensions of the input featurizers
            # feature / status dim
            out_dim = self.k_dim
            # CNN over full glyph map
            out_dim += self.H * self.W * self.Y
            if self.crop_model == "transformer":
                out_dim += self.crop_dim ** 2 * K
            elif self.crop_model == "cnn":
                out_dim += self.crop_dim ** 2 * self.Y
            # messaging model
            if self.msg_model != "none":
                out_dim += self.msg_hdim
        else:
            # otherwise, project them all to h_dim
            NUM_INPUTS = 4 if self.msg_model != "none" else 3
            project_hdim = flags.equalize_factor * self.h_dim
            out_dim = project_hdim * NUM_INPUTS

            # set up linear layers for projections
            self.project_feature_dim = nn.Linear(self.k_dim, project_hdim)
            self.project_glyph_dim = nn.Linear(
                self.H * self.W * self.Y, project_hdim
            )
            c__2 = self.crop_dim ** 2
            if self.crop_model == "transformer":
                self.project_crop_dim = nn.Linear(c__2 * K, project_hdim)
            elif self.crop_model == "cnn":
                self.project_crop_dim = nn.Linear(c__2 * self.Y, project_hdim)
            if self.msg_model != "none":
                self.project_msg_dim = nn.Linear(self.msg_hdim, project_hdim)

        self.fc = nn.Sequential(
            nn.Linear(out_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
        )

        if self.use_lstm:
            self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1)

        self.policy = nn.Linear(self.h_dim, self.num_actions)
        self.baseline = nn.Linear(self.h_dim, 1)