in pytorchvideo/models/resnet.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Explicitly forward every layer.
# Branch2a, for example Tx1x1, BN, ReLU.
if self.conv_a is not None:
x = self.conv_a(x)
if self.norm_a is not None:
x = self.norm_a(x)
if self.act_a is not None:
x = self.act_a(x)
# Branch2b, for example 1xHxW, BN, ReLU.
output = []
for ind in range(len(self.conv_b)):
x_ = self.conv_b[ind](x)
if self.norm_b[ind] is not None:
x_ = self.norm_b[ind](x_)
if self.act_b[ind] is not None:
x_ = self.act_b[ind](x_)
output.append(x_)
if self.reduce_method == "sum":
x = torch.stack(output, dim=0).sum(dim=0, keepdim=False)
elif self.reduce_method == "cat":
x = torch.cat(output, dim=1)
# Branch2c, for example 1x1x1, BN.
x = self.conv_c(x)
if self.norm_c is not None:
x = self.norm_c(x)
return x