in optimum/exporters/neuron/__main__.py [0:0]
def main():
parser = ArgumentParser(f"Hugging Face Optimum {NEURON_COMPILER} exporter")
parse_args_neuron(parser)
# Retrieve CLI arguments
args = parser.parse_args()
task = infer_task(args.model) if args.task == "auto" else args.task
library_name = TasksManager.infer_library_from_model(args.model, cache_dir=args.cache_dir)
if library_name == "diffusers":
input_shapes = normalize_stable_diffusion_input_shapes(args)
submodels = {"unet": args.unet}
elif library_name == "sentence_transformers":
input_shapes = normalize_sentence_transformers_input_shapes(args)
submodels = None
else:
# New export mode using dedicated neuron model classes
kwargs = vars(args).copy()
if maybe_export_from_neuron_model_class(**kwargs):
return
# Fallback to legacy export
input_shapes = get_input_shapes(task, args)
submodels = None
disable_neuron_cache = args.disable_neuron_cache
compiler_kwargs = infer_compiler_kwargs(args)
optional_outputs = customize_optional_outputs(args)
optlevel = parse_optlevel(args)
lora_args = LoRAAdapterArguments(
model_ids=getattr(args, "lora_model_ids", None),
weight_names=getattr(args, "lora_weight_names", None),
adapter_names=getattr(args, "lora_adapter_names", None),
scales=getattr(args, "lora_scales", None),
)
ip_adapter_args = IPAdapterArguments(
model_id=getattr(args, "ip_adapter_id", None),
subfolder=getattr(args, "ip_adapter_subfolder", None),
weight_name=getattr(args, "ip_adapter_weight_name", None),
scale=getattr(args, "ip_adapter_scale", None),
)
main_export(
model_name_or_path=args.model,
output=args.output,
compiler_kwargs=compiler_kwargs,
torch_dtype=args.torch_dtype,
tensor_parallel_size=args.tensor_parallel_size,
task=task,
dynamic_batch_size=args.dynamic_batch_size,
atol=args.atol,
cache_dir=args.cache_dir,
disable_neuron_cache=disable_neuron_cache,
compiler_workdir=args.compiler_workdir,
inline_weights_to_neff=args.inline_weights_neff,
optlevel=optlevel,
trust_remote_code=args.trust_remote_code,
subfolder=args.subfolder,
do_validation=not args.disable_validation,
submodels=submodels,
library_name=library_name,
controlnet_ids=getattr(args, "controlnet_ids", None),
lora_args=lora_args,
ip_adapter_args=ip_adapter_args,
**optional_outputs,
**input_shapes,
)