def find_location()

in src/infer_location.py [0:0]


    def find_location(self, sequence, verbose=False):
        inputs = self.tokenizer(sequence,
                                return_tensors="np",  # ONNX requires inputs in NumPy format
                                padding="max_length",  # Pad to max length
                                truncation=True,       # Truncate if the text is too long
                                max_length=64)
        input_feed = {
            'input_ids': inputs['input_ids'].astype(np.int64),
            'attention_mask': inputs['attention_mask'].astype(np.int64),
        }

        # Run inference with the ONNX model
        outputs = self.ort_session.run(None, input_feed)
        logits = outputs[0]  # Assuming the model output is logits
        probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
        
        predicted_ids = np.argmax(logits, axis=-1)
        predicted_probs = np.max(probabilities, axis=-1)
        
        # Define the threshold for NER probability
        threshold = 0.5
        
        label_map = self.label_map
        
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

        # Initialize lists to hold detected entities
        city_entities = []
        state_entities = []
        city_state_entities = []
        
        for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]):
            if prob > threshold:
                if token in ["[CLS]", "[SEP]", "[PAD]"]:
                    continue
                if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
                    # Handle the case of continuation tokens (like "##" in subwords)
                    if token.startswith("##") and city_entities:
                        city_entities[-1] += token[2:]  # Remove "##" and append to the last token
                    else:
                        city_entities.append(token)
                elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
                    if token.startswith("##") and state_entities:
                        state_entities[-1] += token[2:]
                    else:
                        state_entities.append(token)
                elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
                    if token.startswith("##") and city_state_entities:
                        city_state_entities[-1] += token[2:]
                    else:
                        city_state_entities.append(token)

        # Combine city_state entities and split into city and state if necessary
        if city_state_entities:
            city_state_str = " ".join(city_state_entities)
            city_state_split = city_state_str.split(",")  # Split on comma to separate city and state
            city_res = city_state_split[0].strip() if city_state_split[0] else None
            state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None
        else:
            # If no city_state entities, use detected city and state entities separately
            city_res = " ".join(city_entities).strip() if city_entities else None
            state_res = " ".join(state_entities).strip() if state_entities else None

        # Return the detected city and state as separate components
        return {
            'city': city_res,
            'state': state_res
        }