in paq/generation/answer_extractor/span2D_model.py [0:0]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
# Linear mapping for start and end representation
self.start_outputs = nn.Linear(config.hidden_size, config.span_output_size)
self.end_outputs = nn.Linear(config.hidden_size, config.span_output_size)
prev_out_size = config.span_output_size * 2
# Add final MLP output layers to produce probabilities
self.output_mlp = None
mlp_sizes = getattr(config, "output_mlp_sizes", None)
if mlp_sizes and len(mlp_sizes) > 0:
self.output_mlp = ModuleList()
for cur_size in mlp_sizes:
self.output_mlp.append(nn.Linear(prev_out_size, cur_size))
self.output_mlp.append(nn.ReLU())
prev_out_size = cur_size
self.span_outputs = nn.Linear(prev_out_size, 1)
self.max_answer_length = getattr(config, "max_answer_length", 30)
self.init_weights()