in timm/models/_efficientnet_builder.py [0:0]
def __call__(self, in_chs, model_block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
model_block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
_log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0
current_stride = 2
current_dilation = 1
stages = []
if model_block_args[0][0]['stride'] > 1:
# if the first block starts with a stride, we need to extract first level feat from stem
feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
self.features.append(feature_info)
# outer list of block_args defines the stacks
space2depth = 0
for stack_idx, stack_args in enumerate(model_block_args):
last_stack = stack_idx + 1 == len(model_block_args)
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
assert isinstance(stack_args, list)
blocks = []
# each stack (stage of blocks) contains a list of block arguments
for block_idx, block_args in enumerate(stack_args):
last_block = block_idx + 1 == len(stack_args)
_log_info_if(' Block: {}'.format(block_idx), self.verbose)
assert block_args['stride'] in (1, 2)
if block_idx >= 1: # only the first block in any stack can have a stride > 1
block_args['stride'] = 1
if not space2depth and block_args.pop('s2d', False):
assert block_args['stride'] == 1
space2depth = 1
if space2depth > 0:
# FIXME s2d is a WIP
if space2depth == 2 and block_args['stride'] == 2:
block_args['stride'] = 1
# to end s2d region, need to correct expansion and se ratio relative to input
block_args['exp_ratio'] /= 4
space2depth = 0
else:
block_args['s2d'] = space2depth
extract_features = False
if last_block:
next_stack_idx = stack_idx + 1
extract_features = next_stack_idx >= len(model_block_args) or \
model_block_args[next_stack_idx][0]['stride'] > 1
next_dilation = current_dilation
if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride:
next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1
_log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
self.output_stride), self.verbose)
else:
current_stride = next_output_stride
block_args['dilation'] = current_dilation
if next_dilation != current_dilation:
current_dilation = next_dilation
# create the block
block = self._make_block(block_args, total_block_idx, total_block_count)
blocks.append(block)
if space2depth == 1:
space2depth = 2
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_info = dict(
stage=stack_idx + 1,
reduction=current_stride,
**block.feature_info(self.feature_location),
)
leaf_name = feature_info.get('module', '')
if leaf_name:
feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
else:
assert last_block
feature_info['module'] = f'blocks.{stack_idx}'
self.features.append(feature_info)
total_block_idx += 1 # incr global block idx (across all stacks)
stages.append(nn.Sequential(*blocks))
return stages