in sockeye/train.py [0:0]
def fixed_param_names_from_stragegy(config: model.ModelConfig,
params: C.ParameterDict,
strategy: str) -> List[str]:
"""
Generate a fixed parameter list given a list of all parameter names and
a strategy.
"""
# Number of encoder/decoder layers in model.
num_encoder_layers = config.config_encoder.num_layers
num_decoder_layers = config.config_decoder.num_layers
def is_fixed(name: str) -> bool:
if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER:
# Any decoder layer.
return not name.startswith(C.DECODER_PREFIX)
if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_OUTER_LAYERS:
# First and last encoder and decoder layers.
first_encoder_prefix = f'{C.ENCODER_PREFIX}.layers.{0}'
last_encoder_prefix = f'{C.ENCODER_PREFIX}.layers.{num_encoder_layers - 1}'
first_decoder_prefix = f'{C.DECODER_PREFIX}.layers.{0}'
last_decoder_prefix = f'{C.DECODER_PREFIX}.layers.{num_decoder_layers - 1}'
return not (name.startswith(first_encoder_prefix) or
name.startswith(last_encoder_prefix) or
name.startswith(first_decoder_prefix) or
name.startswith(last_decoder_prefix))
if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_EMBEDDINGS:
# Any type of learned embedding.
return not (name.startswith(C.SOURCE_EMBEDDING_PREFIX) or name.startswith(C.TARGET_EMBEDDING_PREFIX))
if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_OUTPUT_PROJ:
# Target output projection.
return not name.startswith(C.DEFAULT_OUTPUT_LAYER_PREFIX)
if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_FEED_FORWARD:
return not (name.endswith("ff.ff1.bias") or name.endswith("ff.ff1.weight") or
name.endswith("ff.ff2.bias") or name.endswith("ff.ff2.weight"))
if strategy == C.FIXED_PARAM_STRATEGY_ENCODER_AND_SOURCE_EMBEDDINGS:
return name.startswith(C.ENCODER_PREFIX) or name.startswith(C.SOURCE_EMBEDDING_PREFIX)
if strategy == C.FIXED_PARAM_STRATEGY_ENCODER_HALF_AND_SOURCE_EMBEDDINGS:
if name.startswith(C.ENCODER_PREFIX):
for i in range(num_encoder_layers // 2):
if name.startswith(f"{C.ENCODER_PREFIX}.layers.{i}"):
return True
return name.startswith(C.SOURCE_EMBEDDING_PREFIX)
raise ValueError("Unknown fixed parameter strategy: %s" % strategy)
return [name for name in params if is_fixed(name)]