in text/src/autogluon/text/text_prediction/mx/modules.py [0:0]
def __init__(self, text_backbone,
num_text_features,
num_categorical_features,
num_numerical_features,
numerical_input_units,
num_categories,
out_shape,
cfg=None,
get_embedding=False,
prefix=None,
params=None):
"""
Parameters
----------
text_backbone
Backbone network for handling the text data
num_text_features
Number of text features.
Each text feature will have (text_token_ids, valid_length)
num_categorical_features
Number of categorical features
num_numerical_features
Number of numerical features
numerical_input_units
The number of units for each numerical column
num_categories
The number of categories for each categorical column.
out_shape
Shape of the output
cfg
The configuration of the network
get_embedding
Whether to output the aggregated intermediate embedding from the network
prefix
params
"""
super().__init__(prefix=prefix, params=params)
self.cfg = cfg = MultiModalWithPretrainedTextNN.get_cfg().clone_merge(cfg)
assert self.cfg.text_net.pool_type == 'cls'
base_feature_units = self.cfg.base_feature_units
if not isinstance(out_shape, (list, tuple)):
out_shape = (out_shape,)
self.out_shape = out_shape
if base_feature_units == -1:
base_feature_units = text_backbone.units
self.get_embedding = get_embedding
self.num_text_features = num_text_features
self.num_categorical_features = num_categorical_features
self.num_numerical_features = num_numerical_features
if numerical_input_units is None:
numerical_input_units = []
elif not isinstance(numerical_input_units, (list, tuple)):
numerical_input_units = [numerical_input_units] * self.num_numerical_features
self.numerical_input_units = numerical_input_units
self.num_categories = num_categories
if self.num_categorical_features > 0:
assert len(self.num_categories) == self.num_categorical_features
weight_initializer = mx.init.create(*cfg.initializer.weight)
bias_initializer = mx.init.create(*cfg.initializer.bias)
self.agg_type = cfg.agg_net.agg_type
if self.agg_type == 'attention_token':
assert self.num_text_features == 1, \
'Only supports a single text input if use token-level attention'
with self.name_scope():
self.text_backbone = text_backbone
if base_feature_units != text_backbone.units:
self.text_proj = nn.HybridSequential()
for i in range(self.num_text_features):
with self.text_proj.name_scope():
self.text_proj.add(nn.Dense(in_units=text_backbone.units,
units=base_feature_units,
use_bias=False,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
flatten=False))
else:
self.text_proj = None
if self.num_categorical_features > 0:
self.categorical_networks = nn.HybridSequential()
for i in range(self.num_categorical_features):
with self.categorical_networks.name_scope():
self.categorical_networks.add(
CategoricalFeatureNet(num_class=self.num_categories[i],
out_units=base_feature_units,
cfg=cfg.categorical_net))
else:
self.categorical_networks = None
if self.cfg.aggregate_categorical and self.num_categorical_features > 1:
# Use another dense layer to aggregate the categorical features
self.categorical_agg = BasicMLP(
in_units=base_feature_units * self.num_categorical_features,
mid_units=cfg.categorical_agg.mid_units,
out_units=base_feature_units,
activation=cfg.categorical_agg.activation,
dropout=cfg.categorical_agg.dropout,
num_layers=cfg.categorical_agg.num_layers,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer
)
if self.cfg.categorical_agg.gated_activation:
self.categorical_agg_gate = BasicMLP(
in_units=base_feature_units * self.num_categorical_features,
mid_units=cfg.categorical_agg.mid_units,
out_units=base_feature_units,
activation=cfg.categorical_agg.activation,
dropout=cfg.categorical_agg.dropout,
num_layers=cfg.categorical_agg.num_layers,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer
)
else:
self.categorical_agg_gate = None
else:
self.categorical_agg = None
self.categorical_agg_gate = None
if self.num_numerical_features > 0:
self.numerical_networks = nn.HybridSequential()
for i in range(self.num_numerical_features):
with self.numerical_networks.name_scope():
self.numerical_networks.add(
NumericalFeatureNet(input_shape=self.numerical_input_units[i],
out_units=base_feature_units,
cfg=cfg.numerical_net))
else:
self.numerical_networks = None
self.agg_layer = FeatureAggregator(num_fields=self.num_fields,
out_shape=out_shape,
in_units=base_feature_units,
cfg=cfg.agg_net,
get_embedding=get_embedding)