in easy_rec/python/core/sampler.py [0:0]
def build(data_config):
if not data_config.HasField('sampler'):
return None
sampler_type = data_config.WhichOneof('sampler')
print('sampler_type = %s' % sampler_type)
sampler_config = getattr(data_config, sampler_type)
if ds_util.is_on_ds():
gl.set_field_delimiter(sampler_config.field_delimiter)
if sampler_type == 'negative_sampler':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
input_path = process_multi_file_input_path(sampler_config.input_path)
return NegativeSampler.instance(
data_path=input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
batch_size=data_config.batch_size,
attr_delimiter=sampler_config.attr_delimiter,
num_eval_sample=sampler_config.num_eval_sample)
elif sampler_type == 'negative_sampler_in_memory':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
input_path = process_multi_file_input_path(sampler_config.input_path)
return NegativeSamplerInMemory.instance(
data_path=input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
batch_size=data_config.batch_size,
attr_delimiter=sampler_config.attr_delimiter,
num_eval_sample=sampler_config.num_eval_sample)
elif sampler_type == 'negative_sampler_v2':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
user_input_path = process_multi_file_input_path(
sampler_config.user_input_path)
item_input_path = process_multi_file_input_path(
sampler_config.item_input_path)
pos_edge_input_path = process_multi_file_input_path(
sampler_config.pos_edge_input_path)
return NegativeSamplerV2.instance(
user_data_path=user_input_path,
item_data_path=item_input_path,
edge_data_path=pos_edge_input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
batch_size=data_config.batch_size,
attr_delimiter=sampler_config.attr_delimiter,
num_eval_sample=sampler_config.num_eval_sample)
elif sampler_type == 'hard_negative_sampler':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
user_input_path = process_multi_file_input_path(
sampler_config.user_input_path)
item_input_path = process_multi_file_input_path(
sampler_config.item_input_path)
hard_neg_edge_input_path = process_multi_file_input_path(
sampler_config.hard_neg_edge_input_path)
return HardNegativeSampler.instance(
user_data_path=user_input_path,
item_data_path=item_input_path,
hard_neg_edge_data_path=hard_neg_edge_input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
num_hard_sample=sampler_config.num_hard_sample,
batch_size=data_config.batch_size,
attr_delimiter=sampler_config.attr_delimiter,
num_eval_sample=sampler_config.num_eval_sample)
elif sampler_type == 'hard_negative_sampler_v2':
input_fields = {f.input_name: f for f in data_config.input_fields}
attr_fields = [input_fields[name] for name in sampler_config.attr_fields]
user_input_path = process_multi_file_input_path(
sampler_config.user_input_path)
item_input_path = process_multi_file_input_path(
sampler_config.item_input_path)
pos_edge_input_path = process_multi_file_input_path(
sampler_config.pos_edge_input_path)
hard_neg_edge_input_path = process_multi_file_input_path(
sampler_config.hard_neg_edge_input_path)
return HardNegativeSamplerV2.instance(
user_data_path=user_input_path,
item_data_path=item_input_path,
edge_data_path=pos_edge_input_path,
hard_neg_edge_data_path=hard_neg_edge_input_path,
fields=attr_fields,
num_sample=sampler_config.num_sample,
num_hard_sample=sampler_config.num_hard_sample,
batch_size=data_config.batch_size,
attr_delimiter=sampler_config.attr_delimiter,
num_eval_sample=sampler_config.num_eval_sample)
else:
raise ValueError('Unknown sampler %s' % sampler_type)