in notebooks/packed_bert/pipeline/packed_bert.py [0:0]
def predict(self, questions, contexts):
prep_st = time.time()
dataset = Dataset.from_dict(
{
"id": np.array([str(i) for i in range(len(questions))]).astype("<U32"),
"question": questions,
"context": contexts,
}
)
enc_data = preprocess_packed_qa(
dataset=dataset,
tokenizer=self.tokenizer,
question_key="question",
context_key="context",
answer_key="answer",
sequence_length=self.max_seq_length,
padding=False,
train=False,
)
packed_data_pre = PackedDatasetCreator(
tokenized_dataset=enc_data,
max_sequence_length=self.max_seq_length,
max_sequences_per_pack=self.max_seq_per_pack,
inference=True,
pad_to_global_batch_size=True,
global_batch_size=self.gbs,
problem_type=self.problem_type,
).create()
# Not the most efficient way...
packed_data = Dataset.from_list(packed_data_pre)
packed_data = packed_data.remove_columns(["offset_mapping", "example_ids"])
packed_data = PackedQuestionAnsweringDataset(
input_ids=packed_data["input_ids"],
attention_mask=packed_data["attention_mask"],
token_type_ids=packed_data["token_type_ids"],
position_ids=packed_data["position_ids"],
start_positions=None,
end_positions=None,
offset_mapping=None,
example_ids=None,
)
dataloader = prepare_inference_dataloader(
self.ipu_config, packed_data, self.micro_batch_size, self.dataloader_mode
)
outputs = []
prep_time = time.time() - prep_st
model_st = time.time()
for batch in iter(dataloader):
logits = self.poplar_executor(**batch)
outputs.append(torch.stack(logits))
model_en = time.time()
model_time = model_en - model_st
tput = len(questions) / (model_time)
post_st = time.time()
outputs = torch.cat(outputs, dim=1).numpy()
final_preds = postprocess_packed_qa_predictions(dataset, packed_data_pre, outputs)
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_preds.items()]
post_proc_time = time.time() - post_st
return {
"predictions": formatted_predictions,
"throughput": tput,
"inference_total_time": model_time,
"preprocessing_time": prep_time,
"postprocessing_time": post_proc_time,
}