in src/setfit/span/modeling.py [0:0]
def predict(self, inputs: Union[str, List[str], Dataset]) -> Union[List[Dict[str, Any]], Dataset]:
"""Predicts aspects & their polarities of the given inputs.
Example::
>>> from setfit import AbsaModel
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> model.predict("The food and wine are just exquisite.")
[{'span': 'food', 'polarity': 'positive'}, {'span': 'wine', 'polarity': 'positive'}]
>>> from setfit import AbsaModel
>>> from datasets import load_dataset
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> dataset = load_dataset("tomaarsen/setfit-absa-semeval-restaurants", split="train")
>>> model.predict(dataset)
Dataset({
features: ['text', 'span', 'label', 'ordinal', 'pred_polarity'],
num_rows: 3693
})
Args:
inputs (Union[str, List[str], Dataset]): Either a sentence, a list of sentences,
or a dataset with columns `text` and `span` and optionally `ordinal`. This dataset
contains gold aspects, and we only predict the polarities for them.
Returns:
Union[List[Dict[str, Any]], Dataset]: Either a list of dictionaries with keys `span`
and `polarity` if the input was a sentence or a list of sentences, or a dataset with
columns `text`, `span`, `ordinal`, and `pred_polarity`.
"""
if isinstance(inputs, Dataset):
return self.predict_dataset(inputs)
is_str = isinstance(inputs, str)
inputs_list = [inputs] if is_str else inputs
docs, aspects_list = self.aspect_extractor(inputs_list)
if sum(aspects_list, []) == []:
return aspects_list
aspects_list = self.aspect_model(docs, aspects_list)
if sum(aspects_list, []) == []:
return aspects_list
polarity_list = self.polarity_model(docs, aspects_list)
outputs = []
for docs, aspects, polarities in zip(docs, aspects_list, polarity_list):
outputs.append(
[
{"span": docs[aspect_slice].text, "polarity": polarity}
for aspect_slice, polarity in zip(aspects, polarities)
]
)
return outputs if not is_str else outputs[0]