in src/model.py [0:0]
def __init__(self, metadata_constructor_params, dimension_params):
super().__init__()
self.md_projection_dim = metadata_constructor_params["md_projection_dim"]
self.md_dims = metadata_constructor_params["md_dims"]
self.md_group_sizes = metadata_constructor_params["md_group_sizes"]
self.context_dim = dimension_params["context_dim"]
# when using radians
self.attention_mechanism = metadata_constructor_params["attention_mechanism"]
if self.attention_mechanism:
self.query_type = metadata_constructor_params["query_type"]
assert(self.attention_mechanism in ATTENTION_MAP),\
f"Invalid attention type: {self.attention_mechanism}"
assert(self.query_type in ("word", "hidden")),\
f"Invalid query type: {self.query_type}"
query_dim = self.get_query_dim(dimension_params)
self.use_null_token = metadata_constructor_params["use_null_token"]
self.attention_modules = []
for md_dim, md_group_size in zip(self.md_dims, self.md_group_sizes):
attention_module = ATTENTION_MAP[self.attention_mechanism](md_dim,
query_dim,
md_group_size,
self.use_null_token).to(device)
self.attention_modules.append(attention_module)
# After attention module, the resulting metadata embeddings are projected
# to size md_projection_dim
self.projection_layers = []
for md_dim in self.md_dims:
projection = nn.Linear(md_dim, self.md_projection_dim).to(device)
self.projection_layers.append(projection)
# The resulting metadata embeddings can now be combined via another
# attention mechanism (specified by "hierarchical_attention" bool parameter),
# or via a simpler concatenation of the metadata together
self.use_hierarchical_attention = metadata_constructor_params["hierarchical_attention"]
if self.use_hierarchical_attention:
num_attention_groups = len(self.md_dims)
query_dim = self.get_query_dim(dimension_params)
# NOTE: we use the same query embedding in the attention module
# as in the previous attention modules
self.hierarchical_attention_module = ATTENTION_MAP[self.attention_mechanism](self.md_projection_dim,
query_dim,
num_attention_groups,
self.use_null_token).to(device)
context_projection_input_dim = self.md_projection_dim
else:
# If metadata is not combined hierarchically, all the metadata is
# instead concated together. Computing the resulting size of the
# concatenated embedding
context_projection_input_dim = 0
for md_group_size in self.md_group_sizes:
if self.attention_mechanism:
context_projection_input_dim += self.md_projection_dim
else:
context_projection_input_dim += self.md_projection_dim * md_group_size
# Finally the metadata embedding (either concatenated or combined via attetntion)
# are projected to size of context_dim
self.context_projection = nn.Linear(context_projection_input_dim,
self.context_dim).to(device)
self.context_normalization = nn.LayerNorm(self.context_dim)