in src/feature_extractor.py [0:0]
def get_embeddings(self, texts):
"""
Extract and pool embeddings for a batch of texts.
Args:
texts (list): List of input texts.
Returns:
np.ndarray: Pooled and normalized embeddings of shape [batchSize, embedDim].
"""
last_hidden_states, attention_masks = [], []
max_seq_length = 0
# Extract features for each text
for text in texts:
hidden_state, attention_mask = self.extract_features(text)
last_hidden_states.append(hidden_state)
attention_masks.append(attention_mask)
max_seq_length = max(max_seq_length, hidden_state.shape[0]) # Update max sequence length
# Pad last_hidden_states and attention_masks to max_seq_length
padded_last_hidden_states = []
padded_attention_masks = []
for hidden_state, attention_mask in zip(last_hidden_states, attention_masks):
seq_length, embed_dim = hidden_state.shape
# Pad hidden_state
padded_hidden_state = np.zeros((max_seq_length, embed_dim))
padded_hidden_state[:seq_length, :] = hidden_state # Copy original values
padded_last_hidden_states.append(padded_hidden_state)
# Pad attention_mask
padded_attention_mask = np.zeros((max_seq_length,))
padded_attention_mask[:seq_length] = attention_mask # Copy original values
padded_attention_masks.append(padded_attention_mask)
# Stack to create batch tensors
last_hidden_state = np.stack(padded_last_hidden_states) # [batchSize, maxSeqLength, embedDim]
attention_mask = np.stack(padded_attention_masks) # [batchSize, maxSeqLength]
# Perform mean pooling
return self.mean_pooling(last_hidden_state, attention_mask)