in siammot/modelling/backbone/dla.py [0:0]
def forward(self, x, residual=None):
if residual is None:
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
spo = []
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
sp = spx[i] if i == 0 or self.is_first else sp + spx[i]
sp = conv(sp)
sp = bn(sp)
sp = self.relu(sp)
spo.append(sp)
if self.scale > 1 :
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
out = torch.cat(spo, 1)
out = self.conv3(out)
out = self.bn3(out)
out += residual
out = self.relu(out)
return out