def get_dataset_from_file()

in src/datatuner/lm/data_loader.py [0:0]


def get_dataset_from_file(tokenizer, filename, task_config, max_data, max_block_size=None):
    """Read dataset from file"""

    def tokenize(obj):
        if isinstance(obj, str):
            return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
        if isinstance(obj, dict):
            return dict((n, tokenize(o)) for n, o in obj.items())
        return list(tokenize(o) for o in obj)

    with open(filename, "r") as f:
        data = json.load(f)

    # get the max size supported by the tokenizer model
    # {'gpt2': 1024, 'gpt2-medium': 1024, 'gpt2-large': 1024, 'distilgpt2': 1024}
    max_tokenizer_size = min(tokenizer.max_model_input_sizes.values())
    if max_block_size is not None:
        max_tokenizer_size = min(max_block_size, max_tokenizer_size)

    if max_data > 0:
        data = data[:max_data]

    ignored_sequences = 0

    output_data = []
    logger.info(f"initial data: {len(data)}")

    text_fields = [x for x in task_config["data_shape"] if x["type"] == "text"]
    len_special_fields = 0
    for x in task_config["data_shape"]:
        if x["type"] == "special":
            len_special_fields += len(tokenizer.tokenize(x["id"]))
        elif x["type"] == "special_id":
            len_special_fields += len(x["id"])

    failed_conversions = 0
    for inst_i, inst in enumerate(tqdm(data)):

        # check the inclusion criteria
        if "include" in task_config:
            include = True
            for field, value in task_config["include"].items():
                if field in inst and inst[field] != value:
                    include = False
                    break
            if not include:
                continue

        item = {}

        total_seq_len = 0
        stop = False
        for field in text_fields:
            field_v = inst[field["id"]]

            if "converter" in field:
                try:
                    func = converters[field["converter"]]
                except:
                    logger.error(f"Unable to get the converter {field['converter']}")
                    raise

                field_v = func(field_v)
                if field_v is None:
                    stop = True
                    break

            item[field["id"]] = tokenize(field_v)

            total_seq_len += len(item[field["id"]])

        if stop:
            failed_conversions += 1
            continue

        if "extra_fields" in task_config:
            for field in task_config["extra_fields"]:
                item[field] = inst[field]

        # 1 is for eos token
        if total_seq_len + len_special_fields + 1 > max_tokenizer_size:
            for field in text_fields:
                item[field["id"]] = item[field["id"]][: max_tokenizer_size - 100]
            print(f"warning: this input is longer than the sequence length so we truncated: {inst_i}")
            ignored_sequences += 1
            # continue
        output_data.append(item)

    logger.info(
        "%d / %d sequences ignored due to positional embedding restriction or max block size restriction"
        % (ignored_sequences, len(data))
    )
    logger.info("%d / %d removed due to failed conversions" % (failed_conversions, len(data)))
    logger.info(f"preprocessed data: {len(output_data)}")
    return output_data