def create_module()

in pytorchvideo/models/audio_visual_slowfast.py [0:0]


    def create_module(self, fusion_dim_in: int, stage_idx: int) -> nn.Module:
        """
        Creates the module for the given stage
        Args:
            fusion_dim_in (int): input stage dimension
            stage_idx (int): which stage this is
        """
        if stage_idx > self.max_stage_idx:
            return nn.Identity()

        conv_stride = (
            self.conv_stride[stage_idx]
            if isinstance(self.conv_stride[0], Tuple)
            else self.conv_stride
        )
        conv_stride_a = (
            self.conv_stride_a[stage_idx]
            if isinstance(self.conv_stride_a[0], Tuple)
            else self.conv_stride_a
        )

        conv_dim_in = fusion_dim_in // self.slowfast_channel_reduction_ratio
        conv_dim_in_a = fusion_dim_in // self.slowfast_audio_reduction_ratio
        fastslow_module = []
        fastslow_module.append(
            nn.Conv3d(
                conv_dim_in,
                int(conv_dim_in * self.conv_fusion_channel_ratio),
                kernel_size=self.conv_kernel_size,
                stride=conv_stride,
                padding=[k_size // 2 for k_size in self.conv_kernel_size],
                bias=False,
            )
        )
        if self.norm is not None:
            fastslow_module.append(
                self.norm(
                    num_features=conv_dim_in * self.conv_fusion_channel_ratio,
                    eps=self.norm_eps,
                    momentum=self.norm_momentum,
                )
            )
        if self.activation is not None:
            fastslow_module.append(self.activation())

        if isinstance(self.conv_fusion_channel_interm_dim, int):
            afs_fusion_interm_dim = self.conv_fusion_channel_interm_dim
        else:
            afs_fusion_interm_dim = int(
                conv_dim_in_a * self.conv_fusion_channel_interm_dim
            )

        block_audio_to_fastslow = []
        cur_dim_in = conv_dim_in_a
        for idx in range(self.conv_num_a):
            if idx == self.conv_num_a - 1:
                cur_stride = conv_stride_a
                cur_dim_out = int(
                    conv_dim_in * self.conv_fusion_channel_ratio + fusion_dim_in
                )
            else:
                cur_stride = (1, 1, 1)
                cur_dim_out = afs_fusion_interm_dim

            block_audio_to_fastslow.append(
                nn.Conv3d(
                    cur_dim_in,
                    cur_dim_out,
                    kernel_size=self.conv_kernel_size_a,
                    stride=cur_stride,
                    padding=[k_size // 2 for k_size in self.conv_kernel_size_a],
                    bias=False,
                )
            )
            if self.norm is not None:
                block_audio_to_fastslow.append(
                    self.norm(
                        num_features=cur_dim_out,
                        eps=self.norm_eps,
                        momentum=self.norm_momentum,
                    )
                )
            if self.activation is not None:
                block_audio_to_fastslow.append(self.activation())
            cur_dim_in = cur_dim_out

        return FuseAudioToFastSlow(
            block_fast_to_slow=nn.Sequential(*fastslow_module),
            block_audio_to_fastslow=nn.Sequential(*block_audio_to_fastslow),
        )