learning/dataset.py (151 lines of code) (raw):
import numpy as np
import subprocess
import h5py
import ciseau
from os.path import exists, splitext, join
from wikidata_linker_utils.wikidata_ids import load_wikidata_ids
def count_examples(lines, comment, ignore_value, column_indices):
example_length = 0
has_labels = False
found = 0
for line in lines:
if len(line) == 0 or (comment is not None and line.startswith(comment)):
if example_length > 0 and has_labels:
found += 1
example_length = 0
has_labels = False
else:
example_length += 1
if not has_labels:
cols = line.split("\t")
if len(cols) > 1:
if ignore_value is not None:
for col_index in column_indices:
if cols[col_index] != ignore_value:
has_labels = True
break
else:
has_labels = True
if example_length > 0 and has_labels:
found += 1
return found
def retokenize_example(x, y):
tokens = ciseau.tokenize(" ".join(w for w in x),
normalize_ascii=False)
out_y = []
regular_cursor = 0
tokens_length_total = 0
regular_length_total = len(x[regular_cursor]) + 1 if len(x) > 0 else 0
if regular_cursor + 1 == len(x):
regular_length_total -= 1
for i in range(len(tokens)):
tokens_length_total = tokens_length_total + len(tokens[i])
while regular_length_total < tokens_length_total:
regular_cursor += 1
regular_length_total = regular_length_total + len(x[regular_cursor]) + 1
if regular_cursor + 1 == len(x):
regular_length_total -= 1
out_y.append(y[regular_cursor])
assert(regular_cursor + 1 == len(x)), "error with %r" % (x,)
return ([tok.rstrip() for tok in tokens], out_y)
def convert_lines_to_examples(lines, comment, ignore_value,
column_indices, x_column, empty_column,
retokenize=False):
examples = []
x = []
y = []
for line in lines:
if len(line) == 0 or (comment is not None and line.startswith(comment)):
if len(x) > 0:
if not all(row == empty_column for row in y):
examples.append((x, y))
x = []
y = []
else:
cols = line.split("\t")
x.append(cols[x_column])
if len(cols) == 1:
y.append(empty_column)
else:
if ignore_value is not None:
y.append(
tuple(
cols[col_index] if col_index is not None and cols[col_index] != ignore_value else None
for col_index in column_indices
)
)
else:
y.append(
tuple(
cols[col_index] if col_index is not None else None
for col_index in column_indices
)
)
if len(x) > 0 and not all(row == empty_column for row in y):
examples.append((x, y))
if retokenize:
examples = [retokenize_example(x, y) for x, y in examples]
return examples
def load_tsv(path, x_column, y_columns, objective_names, comment, ignore_value,
retokenize):
""""
Deprecated method for loading a tsv file as a training/test set for a model.
Arguments:
----------
path: str, location of tsv file
x_column: int
y_columns: list<dict>, objectives in this file along with their column.
(e.g. `y_columns=[{"objective": "POS", "column": 2}, ...])`)
objective_names: name of all desired columns
comment: line beginning indicating it's okay to skip
ignore_value: label value that should be treated as missing
retokenize: run tokenizer again.
Returns
-------
list<tuple> : examples loaded into memory
Note: can use a lot of memory since entire file is loaded.
"""
objective2column = {col['objective']: col['column'] for col in y_columns}
column_indices = [objective2column.get(name, None) for name in objective_names]
empty_column = tuple(None for _ in objective_names)
if all(col_index is None for col_index in column_indices):
return []
with open(path, "rt") as fin:
lines = fin.read().splitlines()
return convert_lines_to_examples(lines,
ignore_value=ignore_value,
empty_column=empty_column,
x_column=x_column,
column_indices=column_indices,
comment=comment,
retokenize=retokenize)
class RandomizableDataset(object):
def set_rng(self, rng):
self.rng = rng
def set_randomize(self, randomize):
self.randomize = randomize
def set_ignore_y(self, ignore):
self.ignore_y = ignore
class TSVDataset(RandomizableDataset):
_fhandle = None
_fhandle_position = 0
_examples = None
_example_indices = None
_example_index = 0
_eof = False
ignore_y = False
def __init__(self, path, x_column, y_columns, objective_names, comment, ignore_value,
retokenize=False, chunksize=50000000, randomize=False, rng=None):
""""
Arguments:
----------
path: str, location of tsv file
x_column: int
y_columns: list<dict>, objectives in this file along with their column.
(e.g. `y_columns=[{"objective": "POS", "column": 2}, ...])`)
objective_names: name of all desired columns
comment: line beginning indicating it's okay to skip
ignore_value: label value that should be treated as missing
chunksize: how many bytes to read from the file at a time.
rng: numpy RandomState
retokenize: run tokenizer on x again.
"""
self.path = path
self.randomize = randomize
self.x_column = x_column
self.y_columns = y_columns
self.objective_names = objective_names
self.comment = comment
self.ignore_value = ignore_value
self.retokenize = retokenize
self.chunksize = chunksize
if rng is None:
rng = np.random.RandomState(0)
self.rng = rng
# column picking setup:
objective2column = {col['objective']: col['column'] for col in y_columns}
self.column_indices = [objective2column.get(name, None) for name in objective_names]
self.empty_column = tuple(None for _ in objective_names)
if all(col_index is None for col_index in self.column_indices):
self.length = 0
else:
self._compute_length()
def _signature(self):
try:
file_sha1sum = subprocess.check_output(
["sha1sum", self.path], universal_newlines=True
).split(" ")[0]
except FileNotFoundError:
file_sha1sum = subprocess.check_output(
["shasum", self.path], universal_newlines=True
).split(" ")[0]
sorted_cols = list(
map(
str,
sorted(
[col for col in self.column_indices if col is not None]
)
)
)
return "-".join([file_sha1sum] + sorted_cols)
def _compute_length(self):
length_file = (
splitext(self.path)[0] +
"-length-" +
self._signature() + ".txt"
)
if exists(length_file):
with open(length_file, "rt") as fin:
total = int(fin.read())
else:
total = 0
while True:
total += self._count_examples()
if self._eof:
break
with open(length_file, "wt") as fout:
fout.write(str(total) + "\n")
self.length = total
def __len__(self):
return self.length
def close(self):
if self._fhandle is not None:
self._fhandle.close()
self._fhandle = None
self._fhandle_position = 0
self._eof = False
self._examples = None
self._example_indices = None
def __del__(self):
self.close()
def _read_file_until_newline(self):
if self._fhandle is None:
self._fhandle = open(self.path, "rb")
if self._eof:
self._fhandle_position = 0
self._fhandle.seek(0)
self._eof = False
read_chunk = None
while True:
new_read_chunk = self._fhandle.read(self.chunksize)
if read_chunk is None:
read_chunk = new_read_chunk
else:
read_chunk += new_read_chunk
if len(new_read_chunk) < self.chunksize:
del new_read_chunk
self._fhandle_position += len(read_chunk)
self._eof = True
break
else:
del new_read_chunk
newline_pos = read_chunk.rfind(b"\n\n")
if newline_pos != -1:
# move to last line end position (so that we don't get
# half an example.)
self._fhandle.seek(self._fhandle_position + newline_pos + 2)
self._fhandle_position += newline_pos + 2
read_chunk = read_chunk[:newline_pos]
break
return read_chunk
def _count_examples(self):
read_chunk = self._read_file_until_newline()
return count_examples(
read_chunk.decode("utf-8").splitlines(),
ignore_value=self.ignore_value,
column_indices=self.column_indices,
comment=self.comment
)
def _load_examples(self):
read_chunk = self._read_file_until_newline()
if self._examples is not None:
del self._examples
self._examples = convert_lines_to_examples(
read_chunk.decode("utf-8").splitlines(),
ignore_value=self.ignore_value,
empty_column=self.empty_column,
x_column=self.x_column,
column_indices=self.column_indices,
comment=self.comment,
retokenize=self.retokenize
)
self._example_indices = np.arange(len(self._examples))
if self.randomize:
# access loaded data randomly:
self.rng.shuffle(self._example_indices)
self._example_index = 0
def __getitem__(self, index):
"""Retrieve the next example (index is ignored)"""
if index >= self.length:
raise StopIteration()
if self._example_indices is None or self._example_index == len(self._example_indices):
self._load_examples()
while len(self._examples) == 0:
self._load_examples()
if len(self._examples) > 0:
break
if self._eof:
raise StopIteration()
ex = self._examples[self._example_indices[self._example_index]]
self._example_index += 1
return ex
def set_randomize(self, randomize):
if randomize != self.randomize:
self.randomize = randomize
def close(self):
if self._fhandle is not None:
self._fhandle.close()
self._fhandle = None
class OracleClassification(object):
def __init__(self, classes, classification, path):
self.classes = classes
self.classification = classification
self.path = path
self.contains_other = self.classes[-1] == "other"
def classify(self, index):
return self.classification[index]
def load_oracle_classification(path):
with open(join(path, "classes.txt"), "rt", encoding="UTF-8") as fin:
classes = fin.read().splitlines()
classification = np.load(join(path, "classification.npy"))
return OracleClassification(classes, classification, path)
class ClassificationHandler(object):
def __init__(self, wikidata_path, classification_path):
self.classification_path = classification_path
_, self.name2index = load_wikidata_ids(wikidata_path, verbose=False)
self.classifiers = {}
def get_classifier(self, name):
if name not in self.classifiers:
self.classifiers[name] = load_oracle_classification(
join(self.classification_path, name)
)
return self.classifiers[name]
class H5Dataset(RandomizableDataset):
handle_open = False
ignore_y = False
_max_generated_example = 0
_min_generated_example = 0
def __init__(self, path, x_column, y_columns, objective_names,
classifications, ignore_value, randomize=False, rng=None):
self.x_column = str(x_column)
self.y_columns = y_columns
self.ignore_value = ignore_value
self.objective_names = objective_names
self.randomize = randomize
if rng is None:
rng = np.random.RandomState(0)
self.rng = rng
self._classifications = classifications
self.handle = h5py.File(path, "r")
self.path = path
self.handle_open = True
self.length = len(self.handle[self.x_column])
self.chunksize = self.handle[self.x_column].chunks[0]
self._example_indices = None
objective2column = {
col['objective']: (
str(col['column']),
self._classifications.get_classifier(col['classification'])
) for col in y_columns
}
if self.ignore_value is not None:
for _, classifier in objective2column.values():
if self.ignore_value in classifier.classes:
classifier.classes[classifier.classes.index(self.ignore_value)] = None
self.column2col_indices = {}
for col_idx, name in enumerate(self.objective_names):
if name not in objective2column:
continue
column, classifier = objective2column[name]
if column not in self.column2col_indices:
self.column2col_indices[column] = [(classifier, col_idx)]
else:
self.column2col_indices[column].append((classifier, col_idx))
def close(self):
if self.handle_open:
self.handle.close()
self.handle_open = False
def __del__(self):
self.close()
def __len__(self):
return self.length
def _build_examples(self, index):
x = [x_chunk.split("\n") for x_chunk in self.handle[self.x_column][index:index + self.chunksize]]
y = [[[None for k in range(len(self.objective_names))] for j in range(len(x[i]))] for i in range(len(x))]
if not self.ignore_y:
for handle_column, col_content in self.column2col_indices.items():
col_ids = [[self._classifications.name2index[name] if name != "" else None
for name in y_chunk.split("\n")]
for y_chunk in self.handle[handle_column][index:index + self.chunksize]]
for i in range(len(col_ids)):
for j, idx in enumerate(col_ids[i]):
if idx is not None:
for classifier, k in col_content:
y[i][j][k] = classifier.classify(idx)
return x, y
def set_randomize(self, randomize):
if self.randomize != randomize:
self.randomize = randomize
if self._max_generated_example != self._min_generated_example:
self.xorder = np.arange(self._min_generated_example, self._max_generated_example)
self.rng.shuffle(self.xorder)
def __getitem__(self, index):
if index >= len(self):
raise StopIteration()
if self.randomize:
if self._example_indices is None or index == 0:
self._example_indices = np.arange(0, len(self), self.chunksize)
self.rng.shuffle(self._example_indices)
# transformed index:
index = (self._example_indices[index // self.chunksize] + (index % self.chunksize)) % len(self)
if index < self._min_generated_example or index >= self._max_generated_example:
self.x, self.y = self._build_examples(index)
# store bounds of generated data:
self._min_generated_example = index
self._max_generated_example = index + len(self.x)
if self.randomize:
self.xorder = np.arange(self._min_generated_example, self._max_generated_example)
self.rng.shuffle(self.xorder)
if self.randomize:
index = self.xorder[index - self._min_generated_example]
return self.x[index - self._min_generated_example], self.y[index - self._min_generated_example]
class CombinedDataset(object):
_which_dataset = None
_dataset_counters = None
def set_rng(self, rng):
self.rng = rng
for dataset in self.datasets:
dataset.rng = rng
def set_randomize(self, randomize):
self.randomize = randomize
for dataset in self.datasets:
dataset.set_randomize(randomize)
def set_ignore_y(self, ignore):
for dataset in self.datasets:
dataset.set_ignore_y(ignore)
def close(self):
for dataset in self.datasets:
dataset.close()
def _build_which_dataset(self):
self._which_dataset = np.empty(self.length, dtype=np.int16)
self._dataset_counters = np.zeros(len(self.datasets), dtype=np.int64)
offset = 0
for index, dataset in enumerate(self.datasets):
# ensure each dataset is seen as much as its content
# says:
self._which_dataset[offset:offset + len(dataset)] = index
offset += len(dataset)
def __getitem__(self, index):
if index == 0:
if self.randomize:
# visit datasets in random orders:
self.rng.shuffle(self._which_dataset)
self._dataset_counters[:] = 0
which = self._which_dataset[index]
idx = self._dataset_counters[which]
self._dataset_counters[which] += 1
return self.datasets[which][idx]
def __init__(self, datasets, rng=None, randomize=False):
self.datasets = datasets
if rng is None:
rng = np.random.RandomState(0)
self.set_rng(rng)
self.set_randomize(randomize)
self.length = sum(len(dataset) for dataset in datasets)
self._build_which_dataset()
def __len__(self):
return self.length