in mobile_cv/arch/fbnet_v2/fbnet_fpn.py [0:0]
def _verify_init(self, input_channels: List[int], arch_def: Dict) -> None:
"""Check that arch_def is a valid FBNetFPN definition
Valid definition properties:
* input_channels specified
* stages specified
* num_resolutions = len(input_channels) // 2
* stage combiners = num_resolutions
* num_stages = 4 * num_resolutions + num_resolutions - 1
* inputs to combiners have specific number of channels
* add: channels should be the same
* mul: channels should be the same or 1
Note that this function sets the following parameters:
* input_channels, num_resolutions, num_stages_per_resolution (by
default = 4), blocks, blocks_out_dims, combiner_path
"""
# TODO: do we want to put "input_channels" into arch_def
assert isinstance(input_channels, list) and len(input_channels) > 0
self.input_channels = input_channels
self.num_resolutions = len(self.input_channels) // 2
# pyre-fixme[16]: `FBNetFPNBuilder` has no attribute
# `num_stages_per_resolution`.
self.num_stages_per_resolution = 5
assert "stages" in arch_def
# self.num_stages = len(arch_def["stages"])
self.num_stages = self.num_resolutions * self.num_stages_per_resolution - 1
assert (
# num_stages == num_resolutions * num_stages_per_resolution - 1
len(arch_def["stages"])
>= self.num_stages
), (
f"FBNet FPN requires 4 stages per spatial resolution and in "
f"total {self.num_resolutions - 1} stages for cross-resolution connections"
)
default_combiners = ["add"] * self.num_resolutions
stage_combiners = arch_def.get("stage_combiners", default_combiners)
assert isinstance(stage_combiners, list)
assert len(stage_combiners) == self.num_resolutions
# iterate over stages and check that the inputs to the combiners
# have the same number of channels or 1
self.blocks = fbnet_builder.unify_arch_def_blocks(arch_def["stages"])
self.blocks_out_dims = fbnet_builder.get_stages_dim_out(self.blocks)
# pyre-fixme[16]: `FBNetFPNBuilder` has no attribute
# `stage_combiner_num_inputs`.
self.stage_combiner_num_inputs = [0] * self.num_resolutions
for i in range(self.num_resolutions):
input_chdepths = []
inputA_stage = i * self.num_stages_per_resolution
if self.blocks[inputA_stage]["block_op"] != "noop":
input_chdepths.append(self.blocks_out_dims[inputA_stage])
inputB_stage = i * self.num_stages_per_resolution + 1
if self.blocks[inputB_stage]["block_op"] != "noop":
input_chdepths.append(self.blocks_out_dims[inputB_stage])
if i > 0:
inputC_stage = i * self.num_stages_per_resolution - 1
if self.blocks[inputC_stage]["block_op"] != "noop":
input_chdepths.append(self.blocks_out_dims[inputC_stage])
stage_combiner = stage_combiners[i]
if stage_combiner == "add":
assert (
len(set(input_chdepths)) == 1
), f"Trying to add features with different chls: {input_chdepths}"
self.stage_combiner_num_inputs[i] = len(input_chdepths)
else:
raise NotImplementedError
# pyre-fixme[16]: `FBNetFPNBuilder` has no attribute `output_channels`.
self.output_channels = []
for i in range(self.num_resolutions):
self.output_channels.append(self.blocks_out_dims[i * 5 + 3])
self.combiner_path = arch_def.get("combiner_path", "high_res")
# always order the output from high res to low res
if self.combiner_path == "low_res":
self.output_channels = self.output_channels[::-1]