def get_transformers_pipeline()

in scripts/benchmark_pipelines.py [0:0]


def get_transformers_pipeline(args: Namespace):
    if "dtype" in args:
        assert args.dtype in {"float16", "bfloat16", "float32"}

    return raw_pipeline(
        model=args.model,
        torch_dtype=args.dtype,
        model_kwargs={
            "device_map": "balanced",
            "max_memory": {0: "20GiB", "cpu": "64GiB"},
        },
    )