def __init__()

in text/src/autogluon/text/text_prediction/mx/modules.py [0:0]


    def __init__(self, text_backbone,
                 num_text_features,
                 num_categorical_features,
                 num_numerical_features,
                 numerical_input_units,
                 num_categories,
                 out_shape,
                 cfg=None,
                 get_embedding=False,
                 prefix=None,
                 params=None):
        """

        Parameters
        ----------
        text_backbone
            Backbone network for handling the text data
        num_text_features
            Number of text features.
            Each text feature will have (text_token_ids, valid_length)
        num_categorical_features
            Number of categorical features
        num_numerical_features
            Number of numerical features
        numerical_input_units
            The number of units for each numerical column
        num_categories
            The number of categories for each categorical column.
        out_shape
            Shape of the output
        cfg
            The configuration of the network
        get_embedding
            Whether to output the aggregated intermediate embedding from the network
        prefix
        params
        """
        super().__init__(prefix=prefix, params=params)
        self.cfg = cfg = MultiModalWithPretrainedTextNN.get_cfg().clone_merge(cfg)
        assert self.cfg.text_net.pool_type == 'cls'
        base_feature_units = self.cfg.base_feature_units
        if not isinstance(out_shape, (list, tuple)):
            out_shape = (out_shape,)
        self.out_shape = out_shape
        if base_feature_units == -1:
            base_feature_units = text_backbone.units
        self.get_embedding = get_embedding
        self.num_text_features = num_text_features
        self.num_categorical_features = num_categorical_features
        self.num_numerical_features = num_numerical_features
        if numerical_input_units is None:
            numerical_input_units = []
        elif not isinstance(numerical_input_units, (list, tuple)):
            numerical_input_units = [numerical_input_units] * self.num_numerical_features
        self.numerical_input_units = numerical_input_units
        self.num_categories = num_categories
        if self.num_categorical_features > 0:
            assert len(self.num_categories) == self.num_categorical_features
        weight_initializer = mx.init.create(*cfg.initializer.weight)
        bias_initializer = mx.init.create(*cfg.initializer.bias)
        self.agg_type = cfg.agg_net.agg_type
        if self.agg_type == 'attention_token':
            assert self.num_text_features == 1, \
                'Only supports a single text input if use token-level attention'
        with self.name_scope():
            self.text_backbone = text_backbone
            if base_feature_units != text_backbone.units:
                self.text_proj = nn.HybridSequential()
                for i in range(self.num_text_features):
                    with self.text_proj.name_scope():
                        self.text_proj.add(nn.Dense(in_units=text_backbone.units,
                                                    units=base_feature_units,
                                                    use_bias=False,
                                                    weight_initializer=weight_initializer,
                                                    bias_initializer=bias_initializer,
                                                    flatten=False))
            else:
                self.text_proj = None
            if self.num_categorical_features > 0:
                self.categorical_networks = nn.HybridSequential()
                for i in range(self.num_categorical_features):
                    with self.categorical_networks.name_scope():
                        self.categorical_networks.add(
                            CategoricalFeatureNet(num_class=self.num_categories[i],
                                                  out_units=base_feature_units,
                                                  cfg=cfg.categorical_net))
            else:
                self.categorical_networks = None
            if self.cfg.aggregate_categorical and self.num_categorical_features > 1:
                # Use another dense layer to aggregate the categorical features
                self.categorical_agg = BasicMLP(
                    in_units=base_feature_units * self.num_categorical_features,
                    mid_units=cfg.categorical_agg.mid_units,
                    out_units=base_feature_units,
                    activation=cfg.categorical_agg.activation,
                    dropout=cfg.categorical_agg.dropout,
                    num_layers=cfg.categorical_agg.num_layers,
                    weight_initializer=weight_initializer,
                    bias_initializer=bias_initializer
                )
                if self.cfg.categorical_agg.gated_activation:
                    self.categorical_agg_gate = BasicMLP(
                        in_units=base_feature_units * self.num_categorical_features,
                        mid_units=cfg.categorical_agg.mid_units,
                        out_units=base_feature_units,
                        activation=cfg.categorical_agg.activation,
                        dropout=cfg.categorical_agg.dropout,
                        num_layers=cfg.categorical_agg.num_layers,
                        weight_initializer=weight_initializer,
                        bias_initializer=bias_initializer
                    )
                else:
                    self.categorical_agg_gate = None
            else:
                self.categorical_agg = None
                self.categorical_agg_gate = None

            if self.num_numerical_features > 0:
                self.numerical_networks = nn.HybridSequential()
                for i in range(self.num_numerical_features):
                    with self.numerical_networks.name_scope():
                        self.numerical_networks.add(
                            NumericalFeatureNet(input_shape=self.numerical_input_units[i],
                                                out_units=base_feature_units,
                                                cfg=cfg.numerical_net))
            else:
                self.numerical_networks = None
            self.agg_layer = FeatureAggregator(num_fields=self.num_fields,
                                               out_shape=out_shape,
                                               in_units=base_feature_units,
                                               cfg=cfg.agg_net,
                                               get_embedding=get_embedding)