sat/sgm/models/autoencoder.py [501:544]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class VideoAutoencodingEngine(AutoencodingEngine):
    def __init__(
        self,
        ckpt_path: Union[None, str] = None,
        ignore_keys: Union[Tuple, list] = (),
        image_video_weights=[1, 1],
        only_train_decoder=False,
        context_parallel_size=0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.context_parallel_size = context_parallel_size
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
        return self.log_images(batch, additional_log_kwargs, **kwargs)

    def get_input(self, batch: dict) -> torch.Tensor:
        if self.context_parallel_size > 0:
            if not is_context_parallel_initialized():
                initialize_context_parallel(self.context_parallel_size)

            batch = batch[self.input_key]

            global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
            torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())

            batch = _conv_split(batch, dim=2, kernel_size=1)
            return batch

        return batch[self.input_key]

    def apply_ckpt(self, ckpt: Union[None, str, dict]):
        if ckpt is None:
            return
        self.init_from_ckpt(ckpt)

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sat/vae_modules/autoencoder.py [526:569]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class VideoAutoencodingEngine(AutoencodingEngine):
    def __init__(
        self,
        ckpt_path: Union[None, str] = None,
        ignore_keys: Union[Tuple, list] = (),
        image_video_weights=[1, 1],
        only_train_decoder=False,
        context_parallel_size=0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.context_parallel_size = context_parallel_size
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
        return self.log_images(batch, additional_log_kwargs, **kwargs)

    def get_input(self, batch: dict) -> torch.Tensor:
        if self.context_parallel_size > 0:
            if not is_context_parallel_initialized():
                initialize_context_parallel(self.context_parallel_size)

            batch = batch[self.input_key]

            global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
            torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())

            batch = _conv_split(batch, dim=2, kernel_size=1)
            return batch

        return batch[self.input_key]

    def apply_ckpt(self, ckpt: Union[None, str, dict]):
        if ckpt is None:
            return
        self.init_from_ckpt(ckpt)

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



