in pytorchvideo/models/audio_visual_slowfast.py [0:0]
def create_module(self, fusion_dim_in: int, stage_idx: int) -> nn.Module:
"""
Creates the module for the given stage
Args:
fusion_dim_in (int): input stage dimension
stage_idx (int): which stage this is
"""
if stage_idx > self.max_stage_idx:
return nn.Identity()
conv_stride = (
self.conv_stride[stage_idx]
if isinstance(self.conv_stride[0], Tuple)
else self.conv_stride
)
conv_stride_a = (
self.conv_stride_a[stage_idx]
if isinstance(self.conv_stride_a[0], Tuple)
else self.conv_stride_a
)
conv_dim_in = fusion_dim_in // self.slowfast_channel_reduction_ratio
conv_dim_in_a = fusion_dim_in // self.slowfast_audio_reduction_ratio
fastslow_module = []
fastslow_module.append(
nn.Conv3d(
conv_dim_in,
int(conv_dim_in * self.conv_fusion_channel_ratio),
kernel_size=self.conv_kernel_size,
stride=conv_stride,
padding=[k_size // 2 for k_size in self.conv_kernel_size],
bias=False,
)
)
if self.norm is not None:
fastslow_module.append(
self.norm(
num_features=conv_dim_in * self.conv_fusion_channel_ratio,
eps=self.norm_eps,
momentum=self.norm_momentum,
)
)
if self.activation is not None:
fastslow_module.append(self.activation())
if isinstance(self.conv_fusion_channel_interm_dim, int):
afs_fusion_interm_dim = self.conv_fusion_channel_interm_dim
else:
afs_fusion_interm_dim = int(
conv_dim_in_a * self.conv_fusion_channel_interm_dim
)
block_audio_to_fastslow = []
cur_dim_in = conv_dim_in_a
for idx in range(self.conv_num_a):
if idx == self.conv_num_a - 1:
cur_stride = conv_stride_a
cur_dim_out = int(
conv_dim_in * self.conv_fusion_channel_ratio + fusion_dim_in
)
else:
cur_stride = (1, 1, 1)
cur_dim_out = afs_fusion_interm_dim
block_audio_to_fastslow.append(
nn.Conv3d(
cur_dim_in,
cur_dim_out,
kernel_size=self.conv_kernel_size_a,
stride=cur_stride,
padding=[k_size // 2 for k_size in self.conv_kernel_size_a],
bias=False,
)
)
if self.norm is not None:
block_audio_to_fastslow.append(
self.norm(
num_features=cur_dim_out,
eps=self.norm_eps,
momentum=self.norm_momentum,
)
)
if self.activation is not None:
block_audio_to_fastslow.append(self.activation())
cur_dim_in = cur_dim_out
return FuseAudioToFastSlow(
block_fast_to_slow=nn.Sequential(*fastslow_module),
block_audio_to_fastslow=nn.Sequential(*block_audio_to_fastslow),
)