in tensorflow_datasets/text/glue.py [0:0]
def _generate_examples(self, data_file, split, mrpc_files=None):
if self.builder_config.name == "mrpc":
# We have to prepare the MRPC dataset from the original sources ourselves.
examples = self._generate_example_mrpc_files(
mrpc_files=mrpc_files, split=split)
for example in examples:
yield example["idx"], example
else:
process_label = self.builder_config.process_label
label_classes = self.builder_config.label_classes
# The train and dev files for CoLA are the only tsv files without a
# header.
is_cola_non_test = self.builder_config.name == "cola" and split != "test"
with tf.io.gfile.GFile(data_file) as f:
reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
if is_cola_non_test:
reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
for n, row in enumerate(reader):
if is_cola_non_test:
row = {
"sentence": row[3],
"is_acceptable": row[1],
}
example = {
feat: row[col]
for feat, col in six.iteritems(self.builder_config.text_features)
}
example["idx"] = n
if self.builder_config.label_column in row:
label = row[self.builder_config.label_column]
# For some tasks, the label is represented as 0 and 1 in the tsv
# files and needs to be cast to integer to work with the feature.
if label_classes and label not in label_classes:
label = int(label) if label else None
example["label"] = process_label(label)
else:
example["label"] = process_label(-1)
# Filter out corrupted rows.
for value in six.itervalues(example):
if value is None:
break
else:
yield example["idx"], example