in models/resnet_mlp.py [0:0]
def _forward_impl(self, x, layer, tsn_mode=False):
if tsn_mode:
batch_size, num_frames, C, H, W = x.shape
x = x.reshape(batch_size * num_frames, C, H, W)
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if layer == 5:
if tsn_mode:
# segmental consensus
x = x.reshape(batch_size, num_frames, -1)
x = torch.mean(x, dim=1)
return x
x = self.avgpool(x)
x = torch.flatten(x, 1)
if layer == 6:
if tsn_mode:
# segmental consensus
x = x.reshape(batch_size, num_frames, -1)
x = torch.mean(x, dim=1)
return x
if tsn_mode:
# segmental consensus
num_frames = x.shape[0] // batch_size
x = x.reshape(batch_size, num_frames, -1)
# tsn
x_tsn = torch.mean(x, dim=1)
x_tsn = nn.functional.normalize(self.fc_tsn(x_tsn), p=2, dim=1)
# o3n
if self.order_out:
x_o3n = nn.functional.normalize(self.fc_order(x), p=2, dim=1)
# concat
x_o3n = x_o3n.reshape(batch_size, -1)
out = [x_tsn, x_o3n]
else:
out = x_tsn
else:
out = [nn.functional.normalize(self.fc_inter(x), p=2, dim=1)]
if self.intra_out:
out.append(nn.functional.normalize(self.fc_intra(x), p=2, dim=1))
return out