easy_rec/python/input/load_parquet.py (278 lines of code) (raw):
import logging
import multiprocessing
import queue
import numpy as np
import pandas as pd
def start_data_proc(task_index,
task_num,
num_proc,
file_que,
data_que,
proc_start_que,
proc_stop_que,
batch_size,
label_fields,
sparse_fea_names,
dense_fea_names,
dense_fea_cfgs,
reserve_fields,
drop_remainder,
need_pack=True):
mp_ctxt = multiprocessing.get_context('spawn')
proc_arr = []
for proc_id in range(num_proc):
proc = mp_ctxt.Process(
target=load_data_proc,
args=(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
batch_size, label_fields, sparse_fea_names, dense_fea_names,
dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
task_num, need_pack),
name='task_%d_data_proc_%d' % (task_index, proc_id))
proc.daemon = True
proc.start()
proc_arr.append(proc)
return proc_arr
def _should_stop(proc_stop_que):
try:
proc_stop_que.get(block=False)
logging.info('data_proc stop signal received')
proc_stop_que.close()
return True
except queue.Empty:
return False
except ValueError:
return True
except AssertionError:
return True
def _add_to_que(data_dict, data_que, proc_stop_que):
while True:
try:
data_que.put(data_dict, timeout=5)
return True
except queue.Full:
logging.warning('data_que is full')
if _should_stop(proc_stop_que):
return False
except ValueError:
logging.warning('data_que is closed')
return False
except AssertionError:
logging.warning('data_que is closed')
return False
def _get_one_file(file_que, proc_stop_que):
while True:
try:
input_file = file_que.get(timeout=1)
return input_file
except queue.Empty:
pass
return None
def _pack_sparse_feas(data_dict, sparse_fea_names):
fea_val_arr = []
fea_len_arr = []
for fea_name in sparse_fea_names:
fea_len_arr.append(data_dict[fea_name][0])
fea_val_arr.append(data_dict[fea_name][1])
del data_dict[fea_name]
fea_lens = np.concatenate(fea_len_arr, axis=0)
fea_vals = np.concatenate(fea_val_arr, axis=0)
data_dict['sparse_fea'] = (fea_lens, fea_vals)
def _pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
fea_val_arr = []
for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
fea_val_arr.append(data_dict[fea_name].reshape([-1, fea_cfg.raw_input_dim]))
del data_dict[fea_name]
fea_vals = np.concatenate(fea_val_arr, axis=1)
data_dict['dense_fea'] = fea_vals
def _reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs):
for fea_name, fea_cfg in zip(dense_fea_names, dense_fea_cfgs):
data_dict[fea_name] = data_dict[fea_name].reshape(
[-1, fea_cfg.raw_input_dim])
def _load_dense(input_data, field_names, sid, eid, dense_dict):
for k in field_names:
if isinstance(input_data[k][0], np.ndarray):
np_dtype = type(input_data[k][sid][0])
dense_dict[k] = np.array([x[0] for x in input_data[k][sid:eid]],
dtype=np_dtype)
else:
dense_dict[k] = input_data[k][sid:eid].to_numpy()
def _load_and_pad_dense(input_data, field_names, sid, dense_dict,
part_dense_dict, part_dense_dict_n, batch_size):
for k in field_names:
if isinstance(input_data[k][0], np.ndarray):
np_dtype = type(input_data[k][sid][0])
tmp_lbls = np.array([x[0] for x in input_data[k][sid:]], dtype=np_dtype)
else:
tmp_lbls = input_data[k][sid:].to_numpy()
if part_dense_dict is not None and k in part_dense_dict:
tmp_lbls = np.concatenate([part_dense_dict[k], tmp_lbls], axis=0)
if len(tmp_lbls) > batch_size:
dense_dict[k] = tmp_lbls[:batch_size]
part_dense_dict_n[k] = tmp_lbls[batch_size:]
elif len(tmp_lbls) == batch_size:
dense_dict[k] = tmp_lbls
else:
part_dense_dict_n[k] = tmp_lbls
else:
part_dense_dict_n[k] = tmp_lbls
def load_data_proc(proc_id, file_que, data_que, proc_start_que, proc_stop_que,
batch_size, label_fields, sparse_fea_names, dense_fea_names,
dense_fea_cfgs, reserve_fields, drop_remainder, task_index,
task_num, need_pack):
logging.info('data proc %d start, proc_start_que=%s' %
(proc_id, proc_start_que.qsize()))
proc_start_que.get()
effective_fields = sparse_fea_names + dense_fea_names
all_fields = effective_fields
if label_fields is not None:
all_fields = all_fields + label_fields
if reserve_fields is not None:
for tmp in reserve_fields:
if tmp not in all_fields:
all_fields.append(tmp)
logging.info('data proc %d start, file_que.qsize=%d' %
(proc_id, file_que.qsize()))
num_files = 0
part_data_dict = {}
is_good = True
total_batch_cnt = 0
total_sample_cnt = 0
while is_good:
if _should_stop(proc_stop_que):
is_good = False
break
input_file = _get_one_file(file_que, proc_stop_que)
if input_file is None:
break
num_files += 1
input_data = pd.read_parquet(input_file, columns=all_fields)
data_len = len(input_data[all_fields[0]])
total_sample_cnt += data_len
batch_num = int(data_len / batch_size)
res_num = data_len % batch_size
sid = 0
for batch_id in range(batch_num):
eid = sid + batch_size
data_dict = {}
if label_fields is not None and len(label_fields) > 0:
_load_dense(input_data, label_fields, sid, eid, data_dict)
if reserve_fields is not None and len(reserve_fields) > 0:
data_dict['reserve'] = {}
_load_dense(input_data, reserve_fields, sid, eid, data_dict['reserve'])
if len(sparse_fea_names) > 0:
for k in sparse_fea_names:
val = input_data[k][sid:eid]
if isinstance(input_data[k][sid], np.ndarray):
all_lens = np.array([len(x) for x in val], dtype=np.int32)
all_vals = np.concatenate(val.to_numpy())
else:
all_lens = np.ones([len(val)], dtype=np.int32)
all_vals = val.to_numpy()
assert np.sum(all_lens) == len(
all_vals), 'len(all_vals)=%d np.sum(all_lens)=%d' % (
len(all_vals), np.sum(all_lens))
data_dict[k] = (all_lens, all_vals)
if len(dense_fea_names) > 0:
_load_dense(input_data, dense_fea_names, sid, eid, data_dict)
if need_pack:
if len(sparse_fea_names) > 0:
_pack_sparse_feas(data_dict, sparse_fea_names)
if len(dense_fea_names) > 0:
_pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
else:
if len(dense_fea_names) > 0:
_reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
# logging.info('task_index=%d sid=%d eid=%d total_len=%d' % (task_index, sid, eid,
# len(data_dict['sparse_fea'][1])))
if not _add_to_que(data_dict, data_que, proc_stop_que):
logging.info('add to que failed')
is_good = False
break
total_batch_cnt += 1
sid += batch_size
if res_num > 0 and is_good:
data_dict = {}
part_data_dict_n = {}
if label_fields is not None and len(label_fields) > 0:
_load_and_pad_dense(input_data, label_fields, sid, data_dict,
part_data_dict, part_data_dict_n, batch_size)
if reserve_fields is not None and len(reserve_fields) > 0:
data_dict['reserve'] = {}
part_data_dict_n['reserve'] = {}
_load_and_pad_dense(input_data, label_fields, sid, data_dict['reserve'],
part_data_dict['reserve'],
part_data_dict_n['reserve'], batch_size)
if len(dense_fea_names) > 0:
_load_and_pad_dense(input_data, dense_fea_names, sid, data_dict,
part_data_dict, part_data_dict_n, batch_size)
if len(sparse_fea_names) > 0:
for k in sparse_fea_names:
val = input_data[k][sid:]
if isinstance(input_data[k][sid], np.ndarray):
all_lens = np.array([len(x) for x in val], dtype=np.int32)
all_vals = np.concatenate(val.to_numpy())
else:
all_lens = np.ones([len(val)], dtype=np.int32)
all_vals = val.to_numpy()
if part_data_dict is not None and k in part_data_dict:
tmp_lens = np.concatenate([part_data_dict[k][0], all_lens], axis=0)
tmp_vals = np.concatenate([part_data_dict[k][1], all_vals], axis=0)
if len(tmp_lens) > batch_size:
tmp_res_lens = tmp_lens[batch_size:]
tmp_lens = tmp_lens[:batch_size]
tmp_num_elems = np.sum(tmp_lens)
tmp_res_vals = tmp_vals[tmp_num_elems:]
tmp_vals = tmp_vals[:tmp_num_elems]
part_data_dict_n[k] = (tmp_res_lens, tmp_res_vals)
data_dict[k] = (tmp_lens, tmp_vals)
elif len(tmp_lens) == batch_size:
data_dict[k] = (tmp_lens, tmp_vals)
else:
part_data_dict_n[k] = (tmp_lens, tmp_vals)
else:
part_data_dict_n[k] = (all_lens, all_vals)
if effective_fields[0] in data_dict:
if need_pack:
if len(sparse_fea_names) > 0:
_pack_sparse_feas(data_dict, sparse_fea_names)
if len(dense_fea_names) > 0:
_pack_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
else:
if len(dense_fea_names) > 0:
_reshape_dense_feas(data_dict, dense_fea_names, dense_fea_cfgs)
if not _add_to_que(data_dict, data_que, proc_stop_que):
logging.info('add to que failed')
is_good = False
break
total_batch_cnt += 1
part_data_dict = part_data_dict_n
if len(part_data_dict) > 0 and is_good:
batch_len = len(part_data_dict[effective_fields[0]][0])
if not drop_remainder:
if need_pack:
if len(sparse_fea_names) > 0:
_pack_sparse_feas(part_data_dict, sparse_fea_names)
if len(dense_fea_names) > 0:
_pack_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
else:
if len(dense_fea_names) > 0:
_reshape_dense_feas(part_data_dict, dense_fea_names, dense_fea_cfgs)
logging.info('remainder batch: %s sample_num=%d' %
(','.join(part_data_dict.keys()), batch_len))
_add_to_que(part_data_dict, data_que, proc_stop_que)
total_batch_cnt += 1
else:
logging.warning('drop remain %d samples as drop_remainder is set' %
batch_len)
if is_good:
is_good = _add_to_que(None, data_que, proc_stop_que)
logging.info(
'data_proc_id[%d]: is_good = %s, total_batch_cnt=%d, total_sample_cnt=%d'
% (proc_id, is_good, total_batch_cnt, total_sample_cnt))
data_que.close(wait_send_finish=is_good)
while not is_good:
try:
if file_que.get(timeout=1) is None:
break
except queue.Empty:
pass
file_que.close()
logging.info('data proc %d done, file_num=%d' % (proc_id, num_files))