def build()

in easy_rec/python/core/sampler.py [0:0]


def build(data_config):

  if not data_config.HasField('sampler'):
    return None
  sampler_type = data_config.WhichOneof('sampler')
  print('sampler_type = %s' % sampler_type)
  sampler_config = getattr(data_config, sampler_type)

  if ds_util.is_on_ds():
    gl.set_field_delimiter(sampler_config.field_delimiter)

  if sampler_type == 'negative_sampler':
    input_fields = {f.input_name: f for f in data_config.input_fields}
    attr_fields = [input_fields[name] for name in sampler_config.attr_fields]

    input_path = process_multi_file_input_path(sampler_config.input_path)
    return NegativeSampler.instance(
        data_path=input_path,
        fields=attr_fields,
        num_sample=sampler_config.num_sample,
        batch_size=data_config.batch_size,
        attr_delimiter=sampler_config.attr_delimiter,
        num_eval_sample=sampler_config.num_eval_sample)
  elif sampler_type == 'negative_sampler_in_memory':
    input_fields = {f.input_name: f for f in data_config.input_fields}
    attr_fields = [input_fields[name] for name in sampler_config.attr_fields]

    input_path = process_multi_file_input_path(sampler_config.input_path)
    return NegativeSamplerInMemory.instance(
        data_path=input_path,
        fields=attr_fields,
        num_sample=sampler_config.num_sample,
        batch_size=data_config.batch_size,
        attr_delimiter=sampler_config.attr_delimiter,
        num_eval_sample=sampler_config.num_eval_sample)
  elif sampler_type == 'negative_sampler_v2':
    input_fields = {f.input_name: f for f in data_config.input_fields}
    attr_fields = [input_fields[name] for name in sampler_config.attr_fields]

    user_input_path = process_multi_file_input_path(
        sampler_config.user_input_path)
    item_input_path = process_multi_file_input_path(
        sampler_config.item_input_path)
    pos_edge_input_path = process_multi_file_input_path(
        sampler_config.pos_edge_input_path)
    return NegativeSamplerV2.instance(
        user_data_path=user_input_path,
        item_data_path=item_input_path,
        edge_data_path=pos_edge_input_path,
        fields=attr_fields,
        num_sample=sampler_config.num_sample,
        batch_size=data_config.batch_size,
        attr_delimiter=sampler_config.attr_delimiter,
        num_eval_sample=sampler_config.num_eval_sample)
  elif sampler_type == 'hard_negative_sampler':
    input_fields = {f.input_name: f for f in data_config.input_fields}
    attr_fields = [input_fields[name] for name in sampler_config.attr_fields]

    user_input_path = process_multi_file_input_path(
        sampler_config.user_input_path)
    item_input_path = process_multi_file_input_path(
        sampler_config.item_input_path)
    hard_neg_edge_input_path = process_multi_file_input_path(
        sampler_config.hard_neg_edge_input_path)
    return HardNegativeSampler.instance(
        user_data_path=user_input_path,
        item_data_path=item_input_path,
        hard_neg_edge_data_path=hard_neg_edge_input_path,
        fields=attr_fields,
        num_sample=sampler_config.num_sample,
        num_hard_sample=sampler_config.num_hard_sample,
        batch_size=data_config.batch_size,
        attr_delimiter=sampler_config.attr_delimiter,
        num_eval_sample=sampler_config.num_eval_sample)
  elif sampler_type == 'hard_negative_sampler_v2':
    input_fields = {f.input_name: f for f in data_config.input_fields}
    attr_fields = [input_fields[name] for name in sampler_config.attr_fields]

    user_input_path = process_multi_file_input_path(
        sampler_config.user_input_path)
    item_input_path = process_multi_file_input_path(
        sampler_config.item_input_path)
    pos_edge_input_path = process_multi_file_input_path(
        sampler_config.pos_edge_input_path)
    hard_neg_edge_input_path = process_multi_file_input_path(
        sampler_config.hard_neg_edge_input_path)
    return HardNegativeSamplerV2.instance(
        user_data_path=user_input_path,
        item_data_path=item_input_path,
        edge_data_path=pos_edge_input_path,
        hard_neg_edge_data_path=hard_neg_edge_input_path,
        fields=attr_fields,
        num_sample=sampler_config.num_sample,
        num_hard_sample=sampler_config.num_hard_sample,
        batch_size=data_config.batch_size,
        attr_delimiter=sampler_config.attr_delimiter,
        num_eval_sample=sampler_config.num_eval_sample)
  else:
    raise ValueError('Unknown sampler %s' % sampler_type)