mico/dataloader/query_doc_pair.py (129 lines of code) (raw):
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import numpy as np
import logging
import os
import csv
import json
class QueryDocumentsPair:
"""An online dataloader which reads CSV with memory efficient caching.
Please use the `get_loaders` method to obtain the dataloader.
For advanced usage, you can obtain the train/val/test dataset by
`data.train_dataset`, `data.val_dataset`, `data.test_dataset` when the instance name is `data`.
"""
def __init__(self, train_folder_path=None, test_folder_path=None, is_csv_header=True, val_ratio=0.1, is_get_all_info=False):
"""We load the training and test datasets here and split the training data into non-overlapping training and validation.
Parameters
----------
train_folder_path : string
The path to the folder containing CSV files for training and validation.
test_folder_path : string
The path to the folder containing CSV files for testing.
is_csv_header : bool
When reading CSV files as input, set `True` if they have headers, so we can skip the first line.
val_ratio : float
How much of the training data will be in the validation dataset.
The rest will be put in training dataset.
is_get_all_info : bool
Get the id, click, and purchase for final evaluation.
"""
self._is_csv_header = is_csv_header
train_files = list(map(lambda x : train_folder_path + '/' + x, (filter(lambda x : x.endswith("csv"), sorted(os.listdir(train_folder_path))))))
train_dataset_list = []
val_dataset_list = []
for csv_file in train_files:
train_dataset = LazyTextDataset(csv_file, val_ratio=val_ratio, is_csv_header=self._is_csv_header, is_get_all_info=is_get_all_info)
train_dataset_list.append(train_dataset)
val_dataset = LazyTextDataset(csv_file, val_indices=train_dataset.val_indices, is_csv_header=self._is_csv_header, is_get_all_info=is_get_all_info)
val_dataset_list.append(val_dataset)
self.train_dataset = ConcatDataset(train_dataset_list)
logging.info('train_dataset sample size: %d' % self.train_dataset.__len__())
self.val_dataset = ConcatDataset(val_dataset_list)
logging.info('val_dataset sample size: %d' % self.val_dataset.__len__())
test_files = list(map(lambda x : test_folder_path + '/' + x, (filter(lambda x : x.endswith("csv"), sorted(os.listdir(test_folder_path))))))
test_dataset_list = list(map(lambda x : LazyTextDataset(x, is_csv_header=self._is_csv_header, is_get_all_info=is_get_all_info), test_files))
self.test_dataset = ConcatDataset(test_dataset_list)
logging.info('test_dataset sample size: %d' % self.test_dataset.__len__())
def get_loaders(self, batch_size, num_workers, is_shuffle_train=True, is_get_test=True, prefetch_factor=2, pin_memory=False):
"""Get train/val/test loaders for training and testing.
Note
----
Setting `pin_memory=True` or larger `prefetch_factor` may increase the speed a little,
but it costs much more memory.
Parameters
----------
batch_size : int
The batch_size is for each process on each GPU.
num_workers : int
Setting this larger than 1 means we have more threads for loading text data.
It may increase the speed a little but cost much more memory.
is_shuffle_train : bool
Whether we shuffle the training data.
is_get_test : bool
If you want to make sure the test set is not used, set it to be `False`.
Returns
-------
(train_loader, val_loader, test_loader) :
The three dataloaders are PyTorch Dataloader objects.
"""
train_loader = DataLoader(self.train_dataset, batch_size=batch_size,
num_workers=num_workers, pin_memory=pin_memory,
shuffle=is_shuffle_train,
prefetch_factor=prefetch_factor)
val_loader = DataLoader(self.val_dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=True, pin_memory=pin_memory,
prefetch_factor=prefetch_factor)
test_loader = DataLoader(self.test_dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=True, pin_memory=pin_memory,
prefetch_factor=prefetch_factor) if is_get_test else None
return train_loader, val_loader, test_loader
def generate_json_for_csv_line_offset(csv_filepath, save_per_line=1):
"""This is a function for caching the offset, so we can efficient find a random line in the CSV file.
We save the offset data using `json` format on disk and `np.array` format in memory.
The first number in the `json` file is the total line number of the CSV. The rest numbers are offsets.
On the disk, if a CSV file is named `example.csv`, the offset file will be saved in the same folder,
and named as `example.csv.offset_per_*.json` where `*` is the `save_per_line` parameter.
Parameters
----------
csv_filepath : string
The path to the CSV file which we are calculating the offset on.
save_per_line : int
If it is larger than 1, we only save the offset per `save_per_line` line.
This will increase the reading time a little, but save a lot of memory usage for the offset.
Returns
-------
offset_idx : int
This is the total line number of the CSV which is also the sample size of this file.
offset_data : numpy.array
This is the offset data.
"""
if not os.path.isfile(csv_filepath):
raise ValueError("CSV File %s does not exist" % csv_filepath)
offset_json_filepath = csv_filepath + '.offset_per' + str(save_per_line) + '.json'
if os.path.isfile(offset_json_filepath):
# Offset files already exists, just read it
with open(offset_json_filepath, 'r') as f:
offset_data = json.load(f)
return offset_data[0], np.array(offset_data[1:])
offset_loc = 0
offset_idx = 0
offset_data = [offset_loc]
with open(csv_filepath, 'rb') as csv_file:
for line in csv_file:
offset_idx += 1
offset_loc += len(line)
if offset_idx % save_per_line == 0:
offset_data.append(offset_loc)
offset_data = [offset_idx] + offset_data
with open(offset_json_filepath, 'w+') as f:
json.dump(offset_data, f, indent='\t')
return offset_idx, np.array(offset_data[1:])
class LazyTextDataset(Dataset):
"""An dataset which loads CSV with memory efficient caching.
Here we use the offsets for each line in the CSV files to speed up random fetching a line.
For splitting training and validation data, we use block-wise split which saves some memory.
"""
def __init__(self, filename, val_ratio=0, val_indices=None, is_csv_header=True, is_get_all_info=False):
"""We generate or load the offset data of the CSV file (if it exists).
If we are generating training or validation data, we use blocks to split them and create a mapping
for reading the raw CSV rows as training or validation samples.
Parameters
----------
filename : string
The path to the CSV file.
It should be in the format [query, ID, doc, click, purchase]
val_ratio : float
How much of the training data will be in the validation dataset.
The rest will be put in training dataset.
If this is set, we know that we are generating the training dataset instead of validation or testing.
val_indices : numpy.array
Which sample in the data is used as validation data.
If this is set, we know that we are generating the validation dataset instead of training or testing.
is_csv_header : bool
When reading CSV files as input, set `True` if they have headers, so we can skip the first line.
is_get_all_info : bool
Set this to be True if we are running evaluation of our model on test dataset to check its performance.
"""
self._filename = filename
self.is_get_all_info = is_get_all_info
self.offset_save_per_line = 10 # a trade-off between the loading speed and the memory usage
self._total_size, self.offset_data = generate_json_for_csv_line_offset(filename, save_per_line=self.offset_save_per_line)
self._is_csv_header = int(is_csv_header)
with open(self._filename, 'r') as csv_file:
line = csv_file.readline().strip()
expected_header = 'query, ID, doc, click, purchase'
if line == expected_header:
if not is_csv_header:
logging.info('The first line of the file "%s" \t is \t "%s".' % (filename, line))
logging.info('It seems like this CSV file has header. Please set --is_csv_header')
raise ValueError
else:
if is_csv_header:
logging.info('The first line of the file "%s" \t is \t "%s".' % (filename, line))
logging.info('It seems like this CSV file does not have header. Please do not set --is_csv_header')
raise ValueError
self._total_size -= self._is_csv_header # the CSV header is not a sample
self.csv_reader_setting = {'delimiter':",", 'quotechar':'"', 'doublequote':False, 'escapechar':'\\', 'skipinitialspace':True}
self.train_val_subblock_size = 10
if val_ratio != 0: # train dataset
val_size = int(self._total_size * val_ratio)
np.random.seed(0)
total_block_num = self._total_size // self.train_val_subblock_size
val_block_num = val_size // self.train_val_subblock_size
remaining_num = self._total_size % self.train_val_subblock_size
self.val_indices = np.random.choice(range(total_block_num), \
size=val_block_num, replace=False)
val_indices_set = set(self.val_indices)
self.idx_mapping = []
for index_on_train_file in range(total_block_num + int(remaining_num != 0)):
if index_on_train_file not in val_indices_set:
self.idx_mapping.append(index_on_train_file)
self.idx_mapping = np.array(self.idx_mapping)
self._total_size = (len(self.idx_mapping) - int(remaining_num != 0)) * self.train_val_subblock_size + remaining_num
self.is_test_data = False
elif val_indices is not None: # val dataset
self.idx_mapping = val_indices
self._total_size = len(self.idx_mapping) * self.train_val_subblock_size
self.is_test_data = False
else: # test dataset
self.is_test_data = True
def __getitem__(self, idx):
"""We fetch the sample (from train/val/test dataset) from the CSV file.
If it is training or validation dataset, we use block-wise mapping to find the line number.
Parameters
----------
idx : int
Which sample we are fetching.
Returns
-------
parsed_list : list
If is_get_all_info is True, return a list of string: [query, ID, document, click, purchase]
If is_get_all_info is False (by default), return a list of strings: [query, document]
"""
if not self.is_test_data:
block_idx = idx // self.train_val_subblock_size
inbloack_idx = idx % self.train_val_subblock_size
block_idx = self.idx_mapping[block_idx]
idx = block_idx * self.train_val_subblock_size + inbloack_idx
idx += self._is_csv_header # the CSV header is not a sample
offset_idx = idx // self.offset_save_per_line
offset = self.offset_data[offset_idx]
try:
with open(self._filename, 'r') as csv_file:
csv_file.seek(offset)
for _ in range(1 + idx % self.offset_save_per_line):
line = csv_file.readline()
line = line.replace('\0','')
csv_line = csv.reader([line], **self.csv_reader_setting)
parsed_list = next(csv_line) # [query, ID, document, click, purchase]
if self.is_get_all_info:
return parsed_list
else:
return [parsed_list[0], parsed_list[2]]
except: # This is for quick inspection about which part goes wrong.
error_message = "Something wrong when reading CSV samples.\n Details (filename, line index): {}, \t {}".format(self._filename, idx)
raise IOError(error_message)
def __len__(self):
"""The total number of samples in the dataset.
"""
return self._total_size