notebooks/packed_bert/utils/packing/dataset_templates.py (73 lines of code) (raw):

# Copyright (c) 2023 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torch.utils.data import Dataset class PackedClassificationDataset(Dataset): def __init__(self, input_ids, attention_mask, token_type_ids, position_ids, labels=None, example_ids=None): self.input_ids = input_ids self.attention_mask = attention_mask self.token_type_ids = token_type_ids self.position_ids = position_ids self.labels = labels self.example_ids = example_ids def __len__(self): return len(self.input_ids) def __getitem__(self, index): input_ids = self.input_ids[index] attention_masks = self.attention_mask[index] token_type_ids = self.token_type_ids[index] position_ids = self.position_ids[index] labels = self.labels[index] if self.labels is not None else None example_ids = self.example_ids[index] if self.example_ids is not None else None sample = { "input_ids": input_ids, "attention_mask": attention_masks, "token_type_ids": token_type_ids, "position_ids": position_ids, } if self.labels is not None: sample["labels"] = labels if self.example_ids is not None: sample["example_ids"] = example_ids return sample class PackedQuestionAnsweringDataset(Dataset): def __init__( self, input_ids, attention_mask, token_type_ids, position_ids, start_positions, end_positions, offset_mapping, example_ids, ): self.input_ids = input_ids self.attention_mask = attention_mask self.token_type_ids = token_type_ids self.position_ids = position_ids self.start_positions = start_positions self.end_positions = end_positions self.offset_mapping = offset_mapping self.example_ids = example_ids def __len__(self): return len(self.input_ids) def __getitem__(self, index): input_ids = self.input_ids[index] attention_masks = self.attention_mask[index] token_type_ids = self.token_type_ids[index] position_ids = self.position_ids[index] start_positions = self.start_positions[index] if self.start_positions is not None else None end_positions = self.end_positions[index] if self.end_positions is not None else None offset_mapping = self.offset_mapping[index] if self.offset_mapping is not None else None example_ids = self.example_ids[index] if self.example_ids is not None else None sample = { "input_ids": input_ids, "attention_mask": attention_masks, "token_type_ids": token_type_ids, "position_ids": position_ids, } if self.start_positions is not None and self.end_positions is not None: sample["start_positions"] = start_positions sample["end_positions"] = end_positions if self.offset_mapping is not None and self.example_ids is not None: sample["offset_mapping"] = offset_mapping sample["example_ids"] = example_ids return sample