in tools/create_examples_from_transformers.py [0:0]
def main():
args = parse_args()
examples = args.examples
if examples == "all":
examples = SUPPORTED_EXAMPLES
download_examples_from_transformers(
examples, args.dest, predicate=keep_only_examples_with_trainer_and_requirements_predicate, version=args.version
)
for example_dir in args.dest.iterdir():
if example_dir.is_file():
continue
for file_path in example_dir.iterdir():
if file_path.name == "run_generation.py":
continue
if "run" in file_path.name and file_path.suffix == ".py":
if file_path.name == "run_qa.py":
trainer_file_path = file_path.parent / "trainer_qa.py"
elif file_path.name == "run_seq2seq_qa.py":
trainer_file_path = file_path.parent / "trainer_seq2seq_qa.py"
else:
trainer_file_path = file_path
hf_argument_file_path = file_path
training_argument_file_path = file_path
print(f"Processing {file_path}")
with open(trainer_file_path, "r") as fp:
file_content = fp.read()
trainer_cls, processed_content, import_end_index = remove_import(TRAINER_IMPORT_PATTERN, file_content)
code = generate_new_import_code(AWS_CODE[trainer_cls])
code = f"\n{code}\n"
processed_content = insert_code_at_position(code, processed_content, import_end_index)
with open(trainer_file_path, "w") as fp:
fp.write(processed_content)
with open(hf_argument_file_path, "r") as fp:
file_content = fp.read()
_, processed_content, import_end_index = remove_import(HF_ARGUMENT_PARSER_IMPORT_PATTERN, file_content)
code = generate_new_import_code(AWS_CODE["HfArgumentParser"])
code = f"\n{code}\n"
processed_content = insert_code_at_position(code, processed_content, import_end_index)
with open(hf_argument_file_path, "w") as fp:
fp.write(processed_content)
with open(training_argument_file_path, "r") as fp:
file_content = fp.read()
training_args_cls, processed_content, import_end_index = remove_import(
TRAINING_ARGUMENTS_IMPORT_PATTERN, file_content
)
code = generate_new_import_code(AWS_CODE[training_args_cls])
code = f"\n{code}\n"
processed_content = insert_code_at_position(code, processed_content, import_end_index)
with open(training_argument_file_path, "w") as fp:
fp.write(processed_content)
elif file_path.name == "requirements.txt":
with open(file_path, "r") as fp:
file_content = fp.read()
processed_content = re.sub(TORCH_REQUIREMENT_PATTERN, "", file_content)
if file_path.parent.name == "image-classification":
processed_content += "\nscikit-learn"
with open(file_path, "w") as fp:
fp.write(processed_content)
# Linting and styling.
subprocess.run(["ruff", f"{args.dest}", "--fix"])