easy_rec/python/input/criteo_binary_reader.py (208 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import concurrent
import concurrent.futures
import glob
import logging
import os
import queue
import time
import numpy as np
class BinaryDataset:
def __init__(
self,
label_bins,
dense_bins,
category_bins,
batch_size=1,
drop_last=False,
prefetch=1,
global_rank=0,
global_size=1,
):
total_sample_num = 0
self._sample_num_arr = []
for label_bin in label_bins:
sample_num = os.path.getsize(label_bin) // 4
total_sample_num += sample_num
self._sample_num_arr.append(sample_num)
logging.info('total number samples = %d' % total_sample_num)
self._total_sample_num = total_sample_num
self._batch_size = batch_size
self._compute_global_start_pos(total_sample_num, batch_size, global_rank,
global_size, drop_last)
self._label_file_arr = [None for _ in self._sample_num_arr]
self._dense_file_arr = [None for _ in self._sample_num_arr]
self._category_file_arr = [None for _ in self._sample_num_arr]
for tmp_file_id in range(self._start_file_id, self._end_file_id + 1):
self._label_file_arr[tmp_file_id] = os.open(label_bins[tmp_file_id],
os.O_RDONLY)
self._dense_file_arr[tmp_file_id] = os.open(dense_bins[tmp_file_id],
os.O_RDONLY)
self._category_file_arr[tmp_file_id] = os.open(category_bins[tmp_file_id],
os.O_RDONLY)
self._prefetch = min(prefetch, self._num_entries)
self._prefetch_queue = queue.Queue()
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self._prefetch)
self._os_close_func = os.close
def _compute_global_start_pos(self, total_sample_num, batch_size, global_rank,
global_size, drop_last):
# ensure all workers have the same number of samples
avg_sample_num = (total_sample_num // global_size)
res_num = (total_sample_num % global_size)
self._num_samples = avg_sample_num
if res_num > 0:
self._num_samples += 1
if global_rank < res_num:
global_start_pos = (avg_sample_num + 1) * global_rank
else:
global_start_pos = avg_sample_num * global_rank + res_num - 1
else:
global_start_pos = avg_sample_num * global_rank
# global_end_pos = global_start_pos + self._num_samples
self._num_entries = self._num_samples // batch_size
self._last_batch_size = batch_size
if not drop_last and (self._num_samples % batch_size != 0):
self._num_entries += 1
self._last_batch_size = self._num_samples % batch_size
logging.info('num_batches = %d num_samples = %d' %
(self._num_entries, self._num_samples))
start_file_id = 0
curr_pos = 0
while curr_pos + self._sample_num_arr[start_file_id] <= global_start_pos:
start_file_id += 1
curr_pos += self._sample_num_arr[start_file_id]
self._start_file_id = start_file_id
self._start_file_pos = global_start_pos - curr_pos
logging.info('start_file_id = %d start_file_pos = %d' %
(start_file_id, self._start_file_pos))
# find the start of each batch
self._start_pos_arr = np.zeros([self._num_entries, 2], dtype=np.uint32)
batch_id = 0
tmp_start_pos = self._start_file_pos
while batch_id < self._num_entries:
self._start_pos_arr[batch_id] = (start_file_id, tmp_start_pos)
batch_id += 1
# the last batch
if batch_id == self._num_entries:
tmp_start_pos += self._last_batch_size
while start_file_id < len(
self._sample_num_arr
) and tmp_start_pos > self._sample_num_arr[start_file_id]:
tmp_start_pos -= self._sample_num_arr[start_file_id]
start_file_id += 1
else:
tmp_start_pos += batch_size
while start_file_id < len(
self._sample_num_arr
) and tmp_start_pos >= self._sample_num_arr[start_file_id]:
tmp_start_pos -= self._sample_num_arr[start_file_id]
start_file_id += 1
self._end_file_id = start_file_id
self._end_file_pos = tmp_start_pos
logging.info('end_file_id = %d end_file_pos = %d' %
(self._end_file_id, self._end_file_pos))
def __del__(self):
for f in self._label_file_arr:
if f is not None:
self._os_close_func(f)
for f in self._dense_file_arr:
if f is not None:
self._os_close_func(f)
for f in self._category_file_arr:
if f is not None:
self._os_close_func(f)
def __len__(self):
return self._num_entries
def __getitem__(self, idx):
if idx >= self._num_entries:
raise IndexError()
if self._prefetch <= 1:
return self._get(idx)
if idx == 0:
for i in range(self._prefetch):
self._prefetch_queue.put(self._executor.submit(self._get, (i)))
if idx < (self._num_entries - self._prefetch):
self._prefetch_queue.put(
self._executor.submit(self._get, (idx + self._prefetch)))
return self._prefetch_queue.get().result()
def _get(self, idx):
curr_file_id = self._start_pos_arr[idx][0]
start_read_pos = self._start_pos_arr[idx][1]
end_read_pos = start_read_pos + self._batch_size
total_read_num = 0
label_read_arr = []
dense_read_arr = []
cate_read_arr = []
while total_read_num < self._batch_size and curr_file_id < len(
self._sample_num_arr):
tmp_read_num = min(end_read_pos,
self._sample_num_arr[curr_file_id]) - start_read_pos
label_raw_data = os.pread(self._label_file_arr[curr_file_id],
4 * tmp_read_num, start_read_pos * 4)
tmp_lbl_np = np.frombuffer(
label_raw_data, dtype=np.int32).reshape([tmp_read_num, 1])
label_read_arr.append(tmp_lbl_np)
dense_raw_data = os.pread(self._dense_file_arr[curr_file_id],
52 * tmp_read_num, start_read_pos * 52)
part_dense_np = np.frombuffer(
dense_raw_data, dtype=np.float32).reshape([tmp_read_num, 13])
# part_dense_np = np.log(part_dense_np + 3, dtype=np.float32)
dense_read_arr.append(part_dense_np)
category_raw_data = os.pread(self._category_file_arr[curr_file_id],
104 * tmp_read_num, start_read_pos * 104)
part_cate_np = np.frombuffer(
category_raw_data, dtype=np.uint32).reshape([tmp_read_num, 26])
cate_read_arr.append(part_cate_np)
curr_file_id += 1
start_read_pos = 0
total_read_num += tmp_read_num
if len(label_read_arr) == 1:
label = label_read_arr[0]
else:
label = np.concatenate(label_read_arr, axis=0)
if len(cate_read_arr) == 1:
category = cate_read_arr[0]
else:
category = np.concatenate(cate_read_arr, axis=0)
if len(dense_read_arr) == 1:
dense = dense_read_arr[0]
else:
dense = np.concatenate(dense_read_arr, axis=0)
return dense, category, label
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1024, help='batch_size')
parser.add_argument(
'--dataset_dir', type=str, default='./', help='dataset_dir')
parser.add_argument('--task_num', type=int, default=1, help='task number')
parser.add_argument('--task_index', type=int, default=0, help='task index')
parser.add_argument(
'--prefetch_size', type=int, default=10, help='prefetch size')
args = parser.parse_args()
batch_size = args.batch_size
dataset_dir = args.dataset_dir
logging.info('batch_size = %d' % batch_size)
logging.info('dataset_dir = %s' % dataset_dir)
label_files = glob.glob(os.path.join(dataset_dir, '*_label.bin'))
dense_files = glob.glob(os.path.join(dataset_dir, '*_dense.bin'))
category_files = glob.glob(os.path.join(dataset_dir, '*_category.bin'))
label_files.sort()
dense_files.sort()
category_files.sort()
test_dataset = BinaryDataset(
label_files,
dense_files,
category_files,
batch_size=batch_size,
drop_last=False,
prefetch=args.prefetch_size,
global_rank=args.task_index,
global_size=args.task_num,
)
for step, (dense, category, labels) in enumerate(test_dataset):
# if (step % 100 == 0):
# print(step, dense.shape, category.shape, labels.shape)
if step == 0:
logging.info('warmup over!')
start_time = time.time()
if step == 1000:
logging.info('1000 steps time = %.3f' % (time.time() - start_time))
logging.info('total_steps = %d total_time = %.3f' %
(step + 1, time.time() - start_time))
logging.info(
'final step[%d] dense_shape=%s category_shape=%s labels_shape=%s' %
(step, dense.shape, category.shape, labels.shape))