in src/infer_location.py [0:0]
def find_location(self, sequence, verbose=False):
inputs = self.tokenizer(sequence,
return_tensors="np", # ONNX requires inputs in NumPy format
padding="max_length", # Pad to max length
truncation=True, # Truncate if the text is too long
max_length=64)
input_feed = {
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64),
}
# Run inference with the ONNX model
outputs = self.ort_session.run(None, input_feed)
logits = outputs[0] # Assuming the model output is logits
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
predicted_ids = np.argmax(logits, axis=-1)
predicted_probs = np.max(probabilities, axis=-1)
# Define the threshold for NER probability
threshold = 0.5
label_map = self.label_map
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Initialize lists to hold detected entities
city_entities = []
state_entities = []
city_state_entities = []
for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]):
if prob > threshold:
if token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
# Handle the case of continuation tokens (like "##" in subwords)
if token.startswith("##") and city_entities:
city_entities[-1] += token[2:] # Remove "##" and append to the last token
else:
city_entities.append(token)
elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
if token.startswith("##") and state_entities:
state_entities[-1] += token[2:]
else:
state_entities.append(token)
elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
if token.startswith("##") and city_state_entities:
city_state_entities[-1] += token[2:]
else:
city_state_entities.append(token)
# Combine city_state entities and split into city and state if necessary
if city_state_entities:
city_state_str = " ".join(city_state_entities)
city_state_split = city_state_str.split(",") # Split on comma to separate city and state
city_res = city_state_split[0].strip() if city_state_split[0] else None
state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None
else:
# If no city_state entities, use detected city and state entities separately
city_res = " ".join(city_entities).strip() if city_entities else None
state_res = " ".join(state_entities).strip() if state_entities else None
# Return the detected city and state as separate components
return {
'city': city_res,
'state': state_res
}