in tools/taobao_prepare.py [0:0]
def gen_dataset(user_df, item_df, item_cnt, feature_size, dataset_pkl):
train_sample_list = []
test_sample_list = []
# get each user's last touch point time
print(len(user_df))
user_last_touch_time = []
for uid, hist in user_df:
user_last_touch_time.append(hist['time'].tolist()[-1])
print("get user last touch time completed")
user_last_touch_time_sorted = sorted(user_last_touch_time)
split_time = user_last_touch_time_sorted[int(len(user_last_touch_time_sorted) * 0.7)]
cnt = 0
for uid, hist in user_df:
cnt += 1
print(cnt)
item_hist = hist['iid'].tolist()
cate_hist = hist['cid'].tolist()
btag_hist = hist['btag'].tolist()
target_item_time = hist['time'].tolist()[-1]
target_item = item_hist[-1]
target_item_cate = cate_hist[-1]
target_item_btag = feature_size
label = 1
test = (target_item_time > split_time)
# neg sampling
neg = random.randint(0, 1)
if neg == 1:
label = 0
while target_item == item_hist[-1]:
target_item = random.randint(0, item_cnt - 1)
target_item_cate = item_df.get_group(target_item)['cid'].tolist()[0]
target_item_btag = feature_size
# the item history part of the sample
item_part = []
for i in range(len(item_hist) - 1):
item_part.append([uid, item_hist[i], cate_hist[i], btag_hist[i]])
item_part.append([uid, target_item, target_item_cate, target_item_btag])
# item_part_len = min(len(item_part), MAX_LEN_ITEM)
# choose the item side information: which user has clicked the target item
# padding history with 0
if len(item_part) <= MAX_LEN_ITEM:
item_part_pad = [[0] * 4] * (MAX_LEN_ITEM - len(item_part)) + item_part
else:
item_part_pad = item_part[len(item_part) - MAX_LEN_ITEM:len(item_part)]
# gen sample
# sample = (label, item_part_pad, item_part_len, user_part_pad, user_part_len)
if test:
# test_set.append(sample)
cat_list = []
item_list = []
# btag_list = []
for i in range(len(item_part_pad)):
item_list.append(item_part_pad[i][1])
cat_list.append(item_part_pad[i][2])
# cat_list.append(item_part_pad[i][0])
test_sample_list.append(str(uid) + "\t" + str(target_item) + "\t" + str(target_item_cate) + "\t" + str(label) + "\t" + ",".join(map(str, item_list)) + "\t" +",".join(map(str, cat_list))+"\n")
else:
cat_list = []
item_list = []
# btag_list = []
for i in range(len(item_part_pad)):
item_list.append(item_part_pad[i][1])
cat_list.append(item_part_pad[i][2])
train_sample_list.append(str(uid) + "\t" + str(target_item) + "\t" + str(target_item_cate) + "\t" + str(label) + "\t" + ",".join(map(str, item_list)) + "\t" +",".join(map(str, cat_list))+"\n")
train_sample_length_quant = len(train_sample_list)/256*256
test_sample_length_quant = len(test_sample_list)/256*256
print("length",len(train_sample_list))
train_sample_list = train_sample_list[:train_sample_length_quant]
test_sample_list = test_sample_list[:test_sample_length_quant]
random.shuffle(train_sample_list)
print("length",len(train_sample_list))
return train_sample_list, test_sample_list