def __init__()

in fairseq/models/wav2vec/wav2vec.py [0:0]


    def __init__(self, cfg: Wav2VecConfig):
        super().__init__()

        self.prediction_steps = cfg.prediction_steps
        offset = cfg.offset

        if cfg.activation == "relu":
            activation = nn.ReLU()
        elif cfg.activation == "gelu":
            activation = nn.GELU()
        else:
            raise Exception("unknown activation " + cfg.activation)

        feature_enc_layers = eval(cfg.conv_feature_layers)
        self.feature_extractor = ConvFeatureExtractionModel(
            conv_layers=feature_enc_layers,
            dropout=0.0,
            log_compression=cfg.log_compression,
            skip_connections=cfg.skip_connections_feat,
            residual_scale=cfg.residual_scale,
            non_affine_group_norm=cfg.non_affine_group_norm,
            activation=activation,
        )
        embed = feature_enc_layers[-1][0]

        self.vector_quantizer = None
        if cfg.vq_type == "gumbel":
            self.vector_quantizer = GumbelVectorQuantizer(
                dim=embed,
                num_vars=cfg.vq_vars,
                temp=cfg.vq_temp,
                groups=cfg.vq_groups,
                combine_groups=cfg.combine_groups,
                vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
                time_first=False,
                activation=activation,
                weight_proj_depth=cfg.vq_depth,
                weight_proj_factor=2,
            )
        elif cfg.vq_type == "kmeans":
            self.vector_quantizer = KmeansVectorQuantizer(
                dim=embed,
                num_vars=cfg.vq_vars,
                groups=cfg.vq_groups,
                combine_groups=cfg.combine_groups,
                vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
                time_first=False,
                gamma=cfg.vq_gamma,
            )
        else:
            assert (
                cfg.vq_type == "none" or cfg.vq_type is None
            ), "Unknown quantizer type"

        if cfg.offset == "auto":
            jin = 0
            rin = 0
            for _, k, stride in feature_enc_layers:
                if rin == 0:
                    rin = k
                rin = rin + (k - 1) * jin
                if jin == 0:
                    jin = stride
                else:
                    jin *= stride
            offset = math.ceil(rin / jin)

        offset = int(offset)

        def make_aggregator():
            if cfg.aggregator == "cnn":
                agg_layers = eval(cfg.conv_aggregator_layers)
                agg_dim = agg_layers[-1][0]
                feature_aggregator = ConvAggegator(
                    conv_layers=agg_layers,
                    embed=embed,
                    dropout=cfg.dropout,
                    skip_connections=cfg.skip_connections_agg,
                    residual_scale=cfg.residual_scale,
                    non_affine_group_norm=cfg.non_affine_group_norm,
                    conv_bias=not cfg.no_conv_bias,
                    zero_pad=cfg.agg_zero_pad,
                    activation=activation,
                )
            elif cfg.aggregator == "gru":
                agg_dim = cfg.gru_dim
                feature_aggregator = nn.Sequential(
                    TransposeLast(),
                    nn.GRU(
                        input_size=embed,
                        hidden_size=agg_dim,
                        num_layers=1,
                        dropout=cfg.dropout,
                    ),
                    TransposeLast(deconstruct_idx=0),
                )
            else:
                raise Exception("unknown aggregator type " + cfg.aggregator)

            return feature_aggregator, agg_dim

        self.feature_aggregator, agg_dim = make_aggregator()

        self.wav2vec_predictions = Wav2VecPredictionsModel(
            in_dim=agg_dim,
            out_dim=embed,
            prediction_steps=cfg.prediction_steps,
            n_negatives=cfg.num_negatives,
            cross_sample_negatives=cfg.cross_sample_negatives,
            sample_distance=cfg.sample_distance,
            dropout=cfg.dropout,
            offset=offset,
            balanced_classes=cfg.balanced_classes,
            infonce=cfg.infonce,
        )

        self.dropout_feats = nn.Dropout(p=cfg.dropout_features)
        self.dropout_agg = nn.Dropout(p=cfg.dropout_agg)

        if cfg.project_features == "none":
            self.project_features = None
        elif cfg.project_features == "same":
            self.project_features = self.feature_aggregator
        elif cfg.project_features == "new":
            self.project_features, _ = make_aggregator()