models/spconv.py (124 lines of code) (raw):

# Copyright (c) Alibaba, Inc. and its affiliates. import sys from pathlib import Path sys.path.append(Path(__file__).parent.parent.absolute().__str__()) # to run '$ python *.py' files in subdirectories import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from models.common import YOLOv6Head, Conv from utils.torch_utils import time_synchronized class SparseTensor(object): def __init__(self, x, indices) -> None: self.features = x self.indices = indices class SPConv2D1x1(nn.Module): # Sparse 2d-convolution layer 1x1 def __init__(self, m: nn.Conv2d): super(SPConv2D1x1, self).__init__() assert m.kernel_size == (1, 1) self.weight = m.weight.flatten(1) self.bias = m.bias def forward(self, x, indices=None, to_dense=False): if indices is None: return F.linear(x, self.weight, self.bias) else: z = x.permute(0, 2, 3, 1).contiguous() # channel_last bi, yi, xi = indices.T y = F.linear(z[bi, yi, xi], self.weight, self.bias) if to_dense: # assert z.shape[-1] == y.shape[-1] z[bi, yi, xi] = y y = z.permute(0, 3, 1, 2).contiguous() # channel_first return y class SPConv2Dkxk(nn.Module): # Sparse 2d-convolution layer kxk def __init__(self, m: nn.Conv2d): super(SPConv2Dkxk, self).__init__() self.weight = m.weight self.weight_flatten = m.weight.flatten(1) self.bias = m.bias self.kernel_size = m.kernel_size self.padding = m.padding self.stride = m.stride assert all([k == p * 2 + 1 for (k, p) in zip(m.kernel_size, m.padding)]), \ f'Unsupported kernel size {m.kernel_size} and padding {m.padding}' assert m.stride == (1, 1), f'Unsupported stride {m.stride}' def forward(self, x, indices=None, to_dense=False): if indices is None: assert self.kernel_size == (1, 1) and self.padding == (0, 0) return F.linear(x, self.weight_flatten, self.bias) bs, c1, ny, nx = x.shape unfold = F.unfold(x, self.kernel_size, padding=self.padding, stride=self.stride) unfold = unfold.transpose(1, 2).view(bs, ny, nx, c1*np.prod(self.kernel_size)) bi, yi, xi = indices.T z = F.linear(unfold[bi, yi, xi], self.weight_flatten, self.bias) if to_dense: x[bi, yi, xi] = z z = x return z class SPConv(nn.Module): # Sparse convolution def __init__(self, conv: Conv): super(SPConv, self).__init__() self.conv = SPConv2Dkxk(conv.conv) if hasattr(conv, 'bn'): raise NotImplementedError self.bn = make_spbn(conv.bn) self.act = conv.act def forward(self, x, indices=None, to_dense=False): x = self.conv(x, indices, to_dense) if hasattr(self, 'bn'): x = self.bn(x) x = self.act(x) return x def forward_dense(self, x): m = self.conv stem = F.conv2d(x, m.weight, m.bias, m.stride, m.padding) return self.act(stem) class SPYOLOv6Head(nn.Module): def __init__(self, dense_head: YOLOv6Head): super(SPYOLOv6Head, self).__init__() self.na = dense_head.na self.nc = dense_head.nc self.sp_stem = SPConv(dense_head.stem) self.sp_cls_conv = SPConv(dense_head.cls_conv) self.sp_reg_conv = SPConv(dense_head.reg_conv) self.sp_cls_pred = SPConv2Dkxk(dense_head.cls_pred) self.sp_reg_pred = SPConv2Dkxk(dense_head.reg_pred) self.sp_obj_pred = SPConv2Dkxk(dense_head.obj_pred) def forward(self, x: torch.Tensor, indices: torch.Tensor): assert not self.training # stem = self.sp_stem(x, indices, to_dense=True) stem = self.sp_stem.forward_dense(x) cls_feat = self.sp_cls_conv(stem, indices) reg_feat = self.sp_reg_conv(stem, indices) cls = self.sp_cls_pred(cls_feat).view(-1, self.na, self.nc) reg = self.sp_reg_pred(reg_feat).view(-1, self.na, 4) obj = self.sp_obj_pred(reg_feat).view(-1, self.na, 1) y_sp = torch.cat((reg, obj, cls), -1).view(-1, self.na*(4+1+self.nc)) return SparseTensor(y_sp, indices) class SPYOLOv5Head(nn.Module): def __init__(self, dense_head: nn.Conv2d): super(SPYOLOv5Head, self).__init__() self.sp_head = SPConv2D1x1(dense_head) def forward(self, x: torch.Tensor, indices: torch.Tensor): assert not self.training y_sp = self.sp_head(x, indices) return SparseTensor(y_sp, indices) if __name__ == '__main__': bs, c1, h, w = 2, 3, 6, 5 c2 = 4 k, s, p = 3, 1, 1 with torch.no_grad(): inp = torch.rand(bs, c1, h, w).cuda() conv = torch.nn.Conv2d(c1, c2, k, s, p, bias=True).cuda() # dummy unfold = F.unfold(inp*0.1, k, padding=p, stride=s) unfold = unfold.transpose(1, 2).view(bs, h, w, c1*k*k) y = unfold @ conv.weight.flatten(1).T + conv.bias t0 = time_synchronized() unfold = F.unfold(inp, k, padding=p, stride=s) unfold = unfold.transpose(1, 2).view(bs, h, w, c1*k*k) y = unfold @ conv.weight.flatten(1).T + conv.bias t1 = time_synchronized() y = y.permute(0, 3, 1, 2).contiguous() # dummy z = F.conv2d(inp*0.1, conv.weight, conv.bias, stride=s, padding=p) t2 = time_synchronized() z = F.conv2d(inp, conv.weight, conv.bias, stride=s, padding=p) t3 = time_synchronized() print(y.shape, z.shape) print(f'Error: {(y - z).abs().sum().item()}. Cost: {(t1 - t0)*1000:.1f}ms v.s. {(t3 - t2)*1000:.1f}ms')