def combine()

in point_e/diffusion/sampler.py [0:0]


    def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler":
        assert all(x.device == samplers[0].device for x in samplers[1:])
        assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:])
        assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:])
        return cls(
            device=samplers[0].device,
            models=[x for y in samplers for x in y.models],
            diffusions=[x for y in samplers for x in y.diffusions],
            num_points=[x for y in samplers for x in y.num_points],
            aux_channels=samplers[0].aux_channels,
            model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter],
            guidance_scale=[x for y in samplers for x in y.guidance_scale],
            clip_denoised=samplers[0].clip_denoised,
            use_karras=[x for y in samplers for x in y.use_karras],
            karras_steps=[x for y in samplers for x in y.karras_steps],
            sigma_min=[x for y in samplers for x in y.sigma_min],
            sigma_max=[x for y in samplers for x in y.sigma_max],
            s_churn=[x for y in samplers for x in y.s_churn],
        )