in scripts/benchmark_pipelines.py [0:0]
def get_transformers_pipeline(args: Namespace):
if "dtype" in args:
assert args.dtype in {"float16", "bfloat16", "float32"}
return raw_pipeline(
model=args.model,
torch_dtype=args.dtype,
model_kwargs={
"device_map": "balanced",
"max_memory": {0: "20GiB", "cpu": "64GiB"},
},
)