in src/datatuner/lm/data_loader.py [0:0]
def get_dataset_from_file(tokenizer, filename, task_config, max_data, max_block_size=None):
"""Read dataset from file"""
def tokenize(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
with open(filename, "r") as f:
data = json.load(f)
# get the max size supported by the tokenizer model
# {'gpt2': 1024, 'gpt2-medium': 1024, 'gpt2-large': 1024, 'distilgpt2': 1024}
max_tokenizer_size = min(tokenizer.max_model_input_sizes.values())
if max_block_size is not None:
max_tokenizer_size = min(max_block_size, max_tokenizer_size)
if max_data > 0:
data = data[:max_data]
ignored_sequences = 0
output_data = []
logger.info(f"initial data: {len(data)}")
text_fields = [x for x in task_config["data_shape"] if x["type"] == "text"]
len_special_fields = 0
for x in task_config["data_shape"]:
if x["type"] == "special":
len_special_fields += len(tokenizer.tokenize(x["id"]))
elif x["type"] == "special_id":
len_special_fields += len(x["id"])
failed_conversions = 0
for inst_i, inst in enumerate(tqdm(data)):
# check the inclusion criteria
if "include" in task_config:
include = True
for field, value in task_config["include"].items():
if field in inst and inst[field] != value:
include = False
break
if not include:
continue
item = {}
total_seq_len = 0
stop = False
for field in text_fields:
field_v = inst[field["id"]]
if "converter" in field:
try:
func = converters[field["converter"]]
except:
logger.error(f"Unable to get the converter {field['converter']}")
raise
field_v = func(field_v)
if field_v is None:
stop = True
break
item[field["id"]] = tokenize(field_v)
total_seq_len += len(item[field["id"]])
if stop:
failed_conversions += 1
continue
if "extra_fields" in task_config:
for field in task_config["extra_fields"]:
item[field] = inst[field]
# 1 is for eos token
if total_seq_len + len_special_fields + 1 > max_tokenizer_size:
for field in text_fields:
item[field["id"]] = item[field["id"]][: max_tokenizer_size - 100]
print(f"warning: this input is longer than the sequence length so we truncated: {inst_i}")
ignored_sequences += 1
# continue
output_data.append(item)
logger.info(
"%d / %d sequences ignored due to positional embedding restriction or max block size restriction"
% (ignored_sequences, len(data))
)
logger.info("%d / %d removed due to failed conversions" % (failed_conversions, len(data)))
logger.info(f"preprocessed data: {len(output_data)}")
return output_data