in timm/models/ghostnet.py [0:0]
def switch_to_deploy(self):
if self.infer_mode:
return
primary_kernel, primary_bias = self._get_kernel_bias_primary()
self.primary_conv = nn.Conv2d(
in_channels=self.primary_rpr_conv[0].conv.in_channels,
out_channels=self.primary_rpr_conv[0].conv.out_channels,
kernel_size=self.primary_rpr_conv[0].conv.kernel_size,
stride=self.primary_rpr_conv[0].conv.stride,
padding=self.primary_rpr_conv[0].conv.padding,
dilation=self.primary_rpr_conv[0].conv.dilation,
groups=self.primary_rpr_conv[0].conv.groups,
bias=True
)
self.primary_conv.weight.data = primary_kernel
self.primary_conv.bias.data = primary_bias
self.primary_conv = nn.Sequential(
self.primary_conv,
self.primary_activation if self.primary_activation is not None else nn.Sequential()
)
cheap_kernel, cheap_bias = self._get_kernel_bias_cheap()
self.cheap_operation = nn.Conv2d(
in_channels=self.cheap_rpr_conv[0].conv.in_channels,
out_channels=self.cheap_rpr_conv[0].conv.out_channels,
kernel_size=self.cheap_rpr_conv[0].conv.kernel_size,
stride=self.cheap_rpr_conv[0].conv.stride,
padding=self.cheap_rpr_conv[0].conv.padding,
dilation=self.cheap_rpr_conv[0].conv.dilation,
groups=self.cheap_rpr_conv[0].conv.groups,
bias=True
)
self.cheap_operation.weight.data = cheap_kernel
self.cheap_operation.bias.data = cheap_bias
self.cheap_operation = nn.Sequential(
self.cheap_operation,
self.cheap_activation if self.cheap_activation is not None else nn.Sequential()
)
# Delete un-used branches
for para in self.parameters():
para.detach_()
if hasattr(self, 'primary_rpr_conv'):
self.__delattr__('primary_rpr_conv')
if hasattr(self, 'primary_rpr_scale'):
self.__delattr__('primary_rpr_scale')
if hasattr(self, 'primary_rpr_skip'):
self.__delattr__('primary_rpr_skip')
if hasattr(self, 'cheap_rpr_conv'):
self.__delattr__('cheap_rpr_conv')
if hasattr(self, 'cheap_rpr_scale'):
self.__delattr__('cheap_rpr_scale')
if hasattr(self, 'cheap_rpr_skip'):
self.__delattr__('cheap_rpr_skip')
self.infer_mode = True