in mico/dataloader/query_doc_pair.py [0:0]
def __init__(self, train_folder_path=None, test_folder_path=None, is_csv_header=True, val_ratio=0.1, is_get_all_info=False):
"""We load the training and test datasets here and split the training data into non-overlapping training and validation.
Parameters
----------
train_folder_path : string
The path to the folder containing CSV files for training and validation.
test_folder_path : string
The path to the folder containing CSV files for testing.
is_csv_header : bool
When reading CSV files as input, set `True` if they have headers, so we can skip the first line.
val_ratio : float
How much of the training data will be in the validation dataset.
The rest will be put in training dataset.
is_get_all_info : bool
Get the id, click, and purchase for final evaluation.
"""
self._is_csv_header = is_csv_header
train_files = list(map(lambda x : train_folder_path + '/' + x, (filter(lambda x : x.endswith("csv"), sorted(os.listdir(train_folder_path))))))
train_dataset_list = []
val_dataset_list = []
for csv_file in train_files:
train_dataset = LazyTextDataset(csv_file, val_ratio=val_ratio, is_csv_header=self._is_csv_header, is_get_all_info=is_get_all_info)
train_dataset_list.append(train_dataset)
val_dataset = LazyTextDataset(csv_file, val_indices=train_dataset.val_indices, is_csv_header=self._is_csv_header, is_get_all_info=is_get_all_info)
val_dataset_list.append(val_dataset)
self.train_dataset = ConcatDataset(train_dataset_list)
logging.info('train_dataset sample size: %d' % self.train_dataset.__len__())
self.val_dataset = ConcatDataset(val_dataset_list)
logging.info('val_dataset sample size: %d' % self.val_dataset.__len__())
test_files = list(map(lambda x : test_folder_path + '/' + x, (filter(lambda x : x.endswith("csv"), sorted(os.listdir(test_folder_path))))))
test_dataset_list = list(map(lambda x : LazyTextDataset(x, is_csv_header=self._is_csv_header, is_get_all_info=is_get_all_info), test_files))
self.test_dataset = ConcatDataset(test_dataset_list)
logging.info('test_dataset sample size: %d' % self.test_dataset.__len__())