in pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py [0:0]
def transmute_Conv3dTemporalKernel1BnAct(input_module: nn.Module):
"""
Given an input_module, transmutes it into a equivalent Conv3dTemporalKernel1BnAct.
Returns None if no equivalent Conv3dTemporalKernel1BnAct is found, else returns
an instance of equivalent Conv3dTemporalKernel1BnAct.
Args:
input_module (nn.Module): input module to find an equivalent Conv3dTemporalKernel1BnAct
"""
if not isinstance(input_module, nn.Conv3d):
return None
"""
If the input_module can be replaced by Conv3dPwBnAct, don't use
Conv3dTemporalKernel1BnAct.
"""
if (
input_module.kernel_size == (1, 1, 1)
and input_module.groups == 1
and input_module.stride == (1, 1, 1)
and input_module.padding == (0, 0, 0)
and input_module.dilation == (1, 1, 1)
):
return None
if (
input_module.kernel_size[0] == 1
and input_module.kernel_size[1] == input_module.kernel_size[2]
and input_module.stride[0] == 1
and input_module.stride[1] == input_module.stride[2]
and input_module.padding[0] == 0
and input_module.dilation[0] == 1
):
spatial_stride = input_module.stride[1]
spatial_kernel = input_module.kernel_size[1]
spatial_padding = input_module.padding[1]
spatial_dilation = input_module.dilation[1]
module = Conv3dTemporalKernel1BnAct(
in_channels=input_module.in_channels,
out_channels=input_module.out_channels,
bias=False if input_module.bias is None else True,
groups=input_module.groups,
spatial_kernel=spatial_kernel,
spatial_stride=spatial_stride,
spatial_padding=spatial_padding,
spatial_dilation=spatial_dilation,
activation="identity",
use_bn=False,
)
module.kernel.conv.load_state_dict(input_module.state_dict())
return module
else:
return None