in timm/models/_efficientnet_builder.py [0:0]
def _decode_block_str(block_str):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
skip = None
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
skip = False # force no skip connection
elif op == 'skip':
skip = True # force a skip connection
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = get_act_layer('relu')
elif v == 'r6':
value = get_act_layer('relu6')
elif v == 'hs':
value = get_act_layer('hard_swish')
elif v == 'sw':
value = get_act_layer('swish') # aka SiLU
elif v == 'mi':
value = get_act_layer('mish')
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_layer is None, the model default (passed to model init) will be used
act_layer = options['n'] if 'n' in options else None
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
block_args = dict(
block_type=block_type,
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
if block_type == 'ir':
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=start_kernel_size,
pw_kernel_size=end_kernel_size,
exp_ratio=float(options['e']),
se_ratio=float(options.get('se', 0.)),
noskip=skip is False,
s2d=int(options.get('d', 0)) > 0,
))
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=end_kernel_size,
se_ratio=float(options.get('se', 0.)),
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or skip is False,
s2d=int(options.get('d', 0)) > 0,
))
elif block_type == 'er':
block_args.update(dict(
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=end_kernel_size,
exp_ratio=float(options['e']),
force_in_chs=force_in_chs,
se_ratio=float(options.get('se', 0.)),
noskip=skip is False,
))
elif block_type == 'cn':
block_args.update(dict(
kernel_size=int(options['k']),
skip=skip is True,
))
elif block_type == 'uir':
# override exp / proj kernels for start/end in uir block
start_kernel_size = _parse_ksize(options['a']) if 'a' in options else 0
end_kernel_size = _parse_ksize(options['p']) if 'p' in options else 0
block_args.update(dict(
dw_kernel_size_start=start_kernel_size, # overload exp ks arg for dw start
dw_kernel_size_mid=_parse_ksize(options['k']),
dw_kernel_size_end=end_kernel_size, # overload pw ks arg for dw end
exp_ratio=float(options['e']),
se_ratio=float(options.get('se', 0.)),
noskip=skip is False,
))
elif block_type == 'mha':
kv_dim = int(options['d'])
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
num_heads=int(options['h']),
key_dim=kv_dim,
value_dim=kv_dim,
kv_stride=int(options.get('v', 1)),
noskip=skip is False,
))
elif block_type == 'mqa':
kv_dim = int(options['d'])
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
num_heads=int(options['h']),
key_dim=kv_dim,
value_dim=kv_dim,
kv_stride=int(options.get('v', 1)),
noskip=skip is False,
))
else:
assert False, 'Unknown block type (%s)' % block_type
if 'gs' in options:
block_args['group_size'] = int(options['gs'])
return block_args, num_repeat