in sat/sgm/modules/autoencoding/magvit2_pytorch.py [0:0]
def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True):
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
# whether to pad video or not
if video_contains_first_frame:
video_len = video.shape[2]
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
# conditioning, if needed
assert (not self.has_cond) or exists(
cond
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
if exists(cond):
assert cond.shape == (video.shape[0], self.dim_cond)
cond = self.encoder_cond_in(cond)
cond_kwargs = dict(cond=cond)
# initial conv
# taking into account whether to encode first frame separately
if encode_first_frame_separately:
pad, first_frame, video = unpack(video, video_packed_shape, "b c * h w")
first_frame = self.conv_in_first_frame(first_frame)
video = self.conv_in(video)
if encode_first_frame_separately:
video, _ = pack([first_frame, video], "b c * h w")
video = pad_at_dim(video, (self.time_padding, 0), dim=2)
# encoder layers
for fn, has_cond in zip(self.encoder_layers, self.has_cond_across_layers):
layer_kwargs = dict()
if has_cond:
layer_kwargs = cond_kwargs
video = fn(video, **layer_kwargs)
maybe_quantize = identity if not quantize else self.quantizers
return maybe_quantize(video)