def __init__()

in src/model.py [0:0]


    def __init__(self, metadata_constructor_params, dimension_params):

        super().__init__()

        self.md_projection_dim = metadata_constructor_params["md_projection_dim"]
        self.md_dims = metadata_constructor_params["md_dims"]
        self.md_group_sizes = metadata_constructor_params["md_group_sizes"]
        self.context_dim = dimension_params["context_dim"]

        # when using radians
        self.attention_mechanism = metadata_constructor_params["attention_mechanism"]

        if self.attention_mechanism:
            self.query_type = metadata_constructor_params["query_type"]

            assert(self.attention_mechanism in ATTENTION_MAP),\
                f"Invalid attention type: {self.attention_mechanism}"
            assert(self.query_type in ("word", "hidden")),\
                f"Invalid query type: {self.query_type}"

            query_dim = self.get_query_dim(dimension_params)

            self.use_null_token = metadata_constructor_params["use_null_token"]

            self.attention_modules = []
            for md_dim, md_group_size in zip(self.md_dims, self.md_group_sizes):
                attention_module = ATTENTION_MAP[self.attention_mechanism](md_dim,
                                                                           query_dim,
                                                                           md_group_size,
                                                                           self.use_null_token).to(device)
                self.attention_modules.append(attention_module)

        # After attention module, the resulting metadata embeddings are projected
        # to size md_projection_dim
        self.projection_layers = []
        for md_dim in self.md_dims:
            projection = nn.Linear(md_dim, self.md_projection_dim).to(device)
            self.projection_layers.append(projection)

        # The resulting metadata embeddings can now be combined via another
        # attention mechanism (specified by "hierarchical_attention" bool parameter),
        # or via a simpler concatenation of the metadata together
        self.use_hierarchical_attention = metadata_constructor_params["hierarchical_attention"]
        if self.use_hierarchical_attention:
            num_attention_groups = len(self.md_dims)
            query_dim = self.get_query_dim(dimension_params)
            # NOTE: we use the same query embedding in the attention module
            # as in the previous attention modules
            self.hierarchical_attention_module =  ATTENTION_MAP[self.attention_mechanism](self.md_projection_dim,
                                                                                          query_dim,
                                                                                          num_attention_groups,
                                                                                          self.use_null_token).to(device)
            context_projection_input_dim = self.md_projection_dim
        else:
            # If metadata is not combined hierarchically, all the metadata is
            # instead concated together. Computing the resulting size of the
            # concatenated embedding
            context_projection_input_dim = 0
            for md_group_size in self.md_group_sizes:
                if self.attention_mechanism:
                    context_projection_input_dim += self.md_projection_dim
                else:
                    context_projection_input_dim += self.md_projection_dim * md_group_size

        # Finally the metadata embedding (either concatenated or combined via attetntion)
        #  are projected to size of context_dim
        self.context_projection = nn.Linear(context_projection_input_dim,
                                            self.context_dim).to(device)
        self.context_normalization = nn.LayerNorm(self.context_dim)