def preprocess_md_util()

in src/model.py [0:0]


    def preprocess_md_util(md_embs, embedding_projection, md_dims, md_group_sizes):
        ''' Static helper to be used by base lstm model '''
        processed_md = {}
        for idx, (md_transform, md) in enumerate(md_embs.items()):
            md_dim = md_dims[idx]
            md_group_size = md_group_sizes[idx]
            # Notice this breaks if md_dim == md_group_size
            if md.shape[-1] != md_dim*md_group_size:
                md = embedding_projection(md)

            processed_md[md_transform] = md
        return processed_md