in src/sagemaker_sklearn_extension/contrib/taei/nn_utils.py [0:0]
def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim=None):
""" This is an embedding module for an entire set of features
Parameters
----------
input_dim : int
Number of features coming as input (number of columns)
cat_dims : list of int
Number of modalities for each categorial features
If the list is empty, no embeddings will be done
cat_idxs : list of int
Positional index for each categorical features in inputs
cat_emb_dim : int or list of int
Embedding dimension for each categorical features
If int, the same embdeding dimension will be used for all categorical features
"""
super(EmbeddingGenerator, self).__init__()
if cat_dims == [] or cat_idxs == []:
self.skip_embedding = True
self.post_embed_dim = input_dim
return
if cat_emb_dim is None:
# use heuristic
cat_emb_dim = [min(600, round(1.6 * n_cats ** 0.56)) for n_cats in cat_dims]
# heuristic
self.skip_embedding = False
if isinstance(cat_emb_dim, int):
self.cat_emb_dims = [cat_emb_dim] * len(cat_idxs)
else:
self.cat_emb_dims = cat_emb_dim
# check that all embeddings are provided
if len(self.cat_emb_dims) != len(cat_dims):
msg = """ cat_emb_dim and cat_dims must be lists of same length, got {len(self.cat_emb_dims)}
and {len(cat_dims)}"""
raise ValueError(msg)
self.post_embed_dim = int(input_dim + np.sum(self.cat_emb_dims) - len(self.cat_emb_dims))
self.embeddings = torch.nn.ModuleList()
# Sort dims by cat_idx
sorted_idxs = np.argsort(cat_idxs)
cat_dims = [cat_dims[i] for i in sorted_idxs]
self.cat_emb_dims = [self.cat_emb_dims[i] for i in sorted_idxs]
for cat_dim, emb_dim in zip(cat_dims, self.cat_emb_dims):
self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim))
# record continuous indices
self.continuous_idx = torch.ones(input_dim, dtype=torch.bool)
self.continuous_idx[cat_idxs] = 0