in fairseq/models/wav2vec/wav2vec.py [0:0]
def __init__(self, cfg: Wav2VecConfig):
super().__init__()
self.prediction_steps = cfg.prediction_steps
offset = cfg.offset
if cfg.activation == "relu":
activation = nn.ReLU()
elif cfg.activation == "gelu":
activation = nn.GELU()
else:
raise Exception("unknown activation " + cfg.activation)
feature_enc_layers = eval(cfg.conv_feature_layers)
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
log_compression=cfg.log_compression,
skip_connections=cfg.skip_connections_feat,
residual_scale=cfg.residual_scale,
non_affine_group_norm=cfg.non_affine_group_norm,
activation=activation,
)
embed = feature_enc_layers[-1][0]
self.vector_quantizer = None
if cfg.vq_type == "gumbel":
self.vector_quantizer = GumbelVectorQuantizer(
dim=embed,
num_vars=cfg.vq_vars,
temp=cfg.vq_temp,
groups=cfg.vq_groups,
combine_groups=cfg.combine_groups,
vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
time_first=False,
activation=activation,
weight_proj_depth=cfg.vq_depth,
weight_proj_factor=2,
)
elif cfg.vq_type == "kmeans":
self.vector_quantizer = KmeansVectorQuantizer(
dim=embed,
num_vars=cfg.vq_vars,
groups=cfg.vq_groups,
combine_groups=cfg.combine_groups,
vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed,
time_first=False,
gamma=cfg.vq_gamma,
)
else:
assert (
cfg.vq_type == "none" or cfg.vq_type is None
), "Unknown quantizer type"
if cfg.offset == "auto":
jin = 0
rin = 0
for _, k, stride in feature_enc_layers:
if rin == 0:
rin = k
rin = rin + (k - 1) * jin
if jin == 0:
jin = stride
else:
jin *= stride
offset = math.ceil(rin / jin)
offset = int(offset)
def make_aggregator():
if cfg.aggregator == "cnn":
agg_layers = eval(cfg.conv_aggregator_layers)
agg_dim = agg_layers[-1][0]
feature_aggregator = ConvAggegator(
conv_layers=agg_layers,
embed=embed,
dropout=cfg.dropout,
skip_connections=cfg.skip_connections_agg,
residual_scale=cfg.residual_scale,
non_affine_group_norm=cfg.non_affine_group_norm,
conv_bias=not cfg.no_conv_bias,
zero_pad=cfg.agg_zero_pad,
activation=activation,
)
elif cfg.aggregator == "gru":
agg_dim = cfg.gru_dim
feature_aggregator = nn.Sequential(
TransposeLast(),
nn.GRU(
input_size=embed,
hidden_size=agg_dim,
num_layers=1,
dropout=cfg.dropout,
),
TransposeLast(deconstruct_idx=0),
)
else:
raise Exception("unknown aggregator type " + cfg.aggregator)
return feature_aggregator, agg_dim
self.feature_aggregator, agg_dim = make_aggregator()
self.wav2vec_predictions = Wav2VecPredictionsModel(
in_dim=agg_dim,
out_dim=embed,
prediction_steps=cfg.prediction_steps,
n_negatives=cfg.num_negatives,
cross_sample_negatives=cfg.cross_sample_negatives,
sample_distance=cfg.sample_distance,
dropout=cfg.dropout,
offset=offset,
balanced_classes=cfg.balanced_classes,
infonce=cfg.infonce,
)
self.dropout_feats = nn.Dropout(p=cfg.dropout_features)
self.dropout_agg = nn.Dropout(p=cfg.dropout_agg)
if cfg.project_features == "none":
self.project_features = None
elif cfg.project_features == "same":
self.project_features = self.feature_aggregator
elif cfg.project_features == "new":
self.project_features, _ = make_aggregator()