in pytext/models/tri_tower_classification_model.py [0:0]
def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
labels = tensorizers["labels"].vocab
if not labels:
raise ValueError("Labels were not created, see preceding errors")
if config.use_shared_embedding:
token_embedding = torch.nn.Embedding(
config.vocab_size, config.hidden_dim, padding_idx=config.padding_idx
)
else:
token_embedding = None
right_vocab = tensorizers["right_tokens"].vocab
right_encoder = create_module(
config.right_encoder,
token_embedding=token_embedding,
padding_idx=right_vocab.get_pad_index(),
vocab_size=len(right_vocab),
)
middle_vocab = tensorizers["middle_tokens"].vocab
middle_encoder = create_module(
config.middle_encoder,
token_embedding=token_embedding,
padding_idx=middle_vocab.get_pad_index(),
vocab_size=len(middle_vocab),
)
left_vocab = tensorizers["left_tokens"].vocab
left_encoder = create_module(
config.left_encoder,
token_embedding=token_embedding,
padding_idx=left_vocab.get_pad_index(),
vocab_size=len(left_vocab),
)
right_dense_dim = tensorizers["right_dense"].dim
middle_dense_dim = tensorizers["middle_dense"].dim
left_dense_dim = tensorizers["left_dense"].dim
decoder = None
if isinstance(config.decoder, MLPDecoderTriTower.Config):
decoder = create_module(
config.decoder,
right_dim=right_encoder.representation_dim + right_dense_dim,
middle_dim=middle_encoder.representation_dim + middle_dense_dim,
left_dim=left_encoder.representation_dim + left_dense_dim,
to_dim=len(labels),
)
elif isinstance(config.decoder, MLPDecoderNTower.Config):
decoder = create_module(
config.decoder,
tower_dims=[
right_encoder.representation_dim + right_dense_dim,
middle_encoder.representation_dim + middle_dense_dim,
left_encoder.representation_dim + left_dense_dim,
],
to_dim=len(labels),
)
label_weights = (
get_label_weights(labels.idx, config.output_layer.label_weights)
if config.output_layer.label_weights
else None
)
loss = create_loss(config.output_layer.loss, weight=label_weights)
if isinstance(loss, BinaryCrossEntropyLoss):
output_layer_cls = BinaryClassificationOutputLayer
elif isinstance(loss, MultiLabelSoftMarginLoss):
output_layer_cls = MultiLabelOutputLayer
else:
output_layer_cls = MulticlassOutputLayer
output_layer = output_layer_cls(list(labels), loss)
return cls(
right_encoder,
middle_encoder,
left_encoder,
decoder,
output_layer,
config.use_shared_embedding,
config.vocab_size,
config.hidden_dim,
config.padding_idx,
)