in mobile_cv/arch/fbnet_v2/fbnet_fpn.py [0:0]
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
"""Forward pass is generated programatically to be independent of the
number of spatial resolutions. The only requirement is that there are
5 stages per spatial resolution
Since blocks depend on >=1 other blocks, we have to make sure of the
ordering of block calculations.
"""
# store results
data = [
[None] * self.num_stages_per_resolution for _ in range(self.num_resolutions)
]
for j in range(self.num_resolutions):
for i in range(self.num_stages_per_resolution):
if i == 0 or i == 1:
data[j][i] = self.stages[j][i](x[j * 2 + i])
elif i == 2:
a, b, c = data[j][0], data[j][1], data[j - 1][-1]
inputs = [t for t in [a, b, c] if t is not None]
combined_result = self.stage_combiners[j](inputs)
data[j][i] = self.stages[j][i](combined_result)
elif i == 3 or i == 4:
if self.stages[j][i] is not None:
data[j][i] = self.stages[j][i](data[j][2])
# if data[j][i] is not None:
# print("res", j, "stage", i, "shape", data[j][i].shape)
output = [data[j][-2] for j in range(self.num_resolutions)]
if self.combiner_path == "low_res":
output = output[::-1]
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `List[None]`.
return output