def __init__()

in src/sagemaker_huggingface_inference_toolkit/diffusers_utils.py [0:0]


    def __init__(self, model_dir: str, device: str = None):  # needs "cuda" for GPU
        self.pipeline = None
        dtype = torch.float32
        if device == "cuda":
            dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
            if torch.cuda.device_count() > 1:
                device_map = "balanced"
                self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)

        if not self.pipeline:
            self.pipeline = DiffusionPipeline.from_pretrained(model_dir, torch_dtype=dtype).to(device)