in datasets.py [0:0]
def __init__(self, data_path, split, subsample_what=None, duplicates=None):
root = os.path.join(data_path, "multinli", "glue_data", "MNLI")
metadata = os.path.join(data_path, "metadata_multinli.csv")
self.features_array = []
for feature_file in [
"cached_train_bert-base-uncased_128_mnli",
"cached_dev_bert-base-uncased_128_mnli",
"cached_dev_bert-base-uncased_128_mnli-mm",
]:
features = torch.load(os.path.join(root, feature_file))
self.features_array += features
self.all_input_ids = torch.tensor(
[f.input_ids for f in self.features_array], dtype=torch.long
)
self.all_input_masks = torch.tensor(
[f.input_mask for f in self.features_array], dtype=torch.long
)
self.all_segment_ids = torch.tensor(
[f.segment_ids for f in self.features_array], dtype=torch.long
)
self.all_label_ids = torch.tensor(
[f.label_id for f in self.features_array], dtype=torch.long
)
self.x_array = torch.stack(
(self.all_input_ids, self.all_input_masks, self.all_segment_ids), dim=2
)
self.data_type = "text"
super().__init__(
split, "", metadata, self.transform, subsample_what, duplicates
)