def transmute_Conv3d3x3x3DwBnAct()

in pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py [0:0]


def transmute_Conv3d3x3x3DwBnAct(input_module: nn.Module):
    """
    Given an input_module, transmutes it into a equivalent Conv3d3x3x3DwBnAct. Returns
    None if no equivalent Conv3d3x3x3DwBnAct is found, else returns an instance of
    equivalent Conv3d3x3x3DwBnAct.
    Args:
        input_module (nn.Module): input module to find an equivalent Conv3d3x3x3DwBnAct
    """
    if not isinstance(input_module, nn.Conv3d):
        return None
    if (
        input_module.kernel_size == (3, 3, 3)
        and input_module.in_channels == input_module.out_channels
        and input_module.groups == input_module.out_channels
        and input_module.stride[0] == 1
        and input_module.stride[1] == input_module.stride[2]
        and input_module.padding == (1, 1, 1)
        and input_module.padding_mode == "zeros"
        and input_module.dilation == (1, 1, 1)
    ):
        spatial_stride = input_module.stride[1]
        module = Conv3d3x3x3DwBnAct(
            in_channels=input_module.in_channels,
            spatial_stride=spatial_stride,
            bias=False if input_module.bias is None else True,
            activation="identity",
            use_bn=False,
        )
        module.kernel.conv.load_state_dict(input_module.state_dict())
        return module
    else:
        return None