in src/model.py [0:0]
def preprocess_md_util(md_embs, embedding_projection, md_dims, md_group_sizes):
''' Static helper to be used by base lstm model '''
processed_md = {}
for idx, (md_transform, md) in enumerate(md_embs.items()):
md_dim = md_dims[idx]
md_group_size = md_group_sizes[idx]
# Notice this breaks if md_dim == md_group_size
if md.shape[-1] != md_dim*md_group_size:
md = embedding_projection(md)
processed_md[md_transform] = md
return processed_md