in sockeye/train.py [0:0]
def create_model_config(args: argparse.Namespace,
source_vocab_sizes: List[int],
target_vocab_sizes: List[int],
max_seq_len_source: int,
max_seq_len_target: int,
config_data: data_io.DataConfig) -> model.ModelConfig:
"""
Create a ModelConfig from the argument given in the command line.
:param args: Arguments as returned by argparse.
:param source_vocab_sizes: The size of the source vocabulary (and source factors).
:param target_vocab_sizes: The size of the target vocabulary (and target factors).
:param max_seq_len_source: Maximum source sequence length.
:param max_seq_len_target: Maximum target sequence length.
:param config_data: Data config.
:return: The model configuration.
"""
num_embed_source, num_embed_target = get_num_embed(args)
embed_dropout_source, embed_dropout_target = args.embed_dropout
source_vocab_size, *source_factor_vocab_sizes = source_vocab_sizes
target_vocab_size, *target_factor_vocab_sizes = target_vocab_sizes
config_encoder, encoder_num_hidden = create_encoder_config(args, max_seq_len_source, max_seq_len_target,
num_embed_source)
config_decoder = create_decoder_config(args, encoder_num_hidden, max_seq_len_source, max_seq_len_target,
num_embed_target)
source_factor_configs = None
if len(source_vocab_sizes) > 1:
source_factors_num_embed = args.source_factors_num_embed
if not source_factors_num_embed:
# This happens if the combination method is sum or average. We then
# set the dimension to num_embed_source for all factors
logger.info("Setting all source factor embedding sizes to `num_embed` ('%d')",
num_embed_source)
source_factors_num_embed = [num_embed_source] * len(source_factor_vocab_sizes)
else:
# Check each individual factor
for i, combine in enumerate(args.source_factors_combine):
if combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]:
logger.info("Setting embedding size of source factor %d to `num_embed` ('%d') for %s",
i + 1, num_embed_source,
"summing" if combine == C.FACTORS_COMBINE_SUM else "averaging")
source_factors_num_embed[i] = num_embed_source
source_factor_configs = [encoder.FactorConfig(size, dim, combine, share) \
for size, dim, combine, share in zip(source_factor_vocab_sizes,
source_factors_num_embed,
args.source_factors_combine,
args.source_factors_share_embedding)]
target_factor_configs = None
if len(target_vocab_sizes) > 1:
target_factors_num_embed = args.target_factors_num_embed
if not target_factors_num_embed:
# This happens if the combination method is sum or average. We then
# set the dimension to num_embed_target for all factors
logger.info("Setting all target factor embedding sizes to `num_embed` ('%d')",
num_embed_target)
target_factors_num_embed = [num_embed_target] * len(target_factor_vocab_sizes)
else:
# Check each individual factor
for i, combine in enumerate(args.target_factors_combine):
if combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]:
logger.info("Setting embedding size of target factor %d to `num_embed` ('%d') for %s",
i + 1, num_embed_target,
"summing" if combine == C.FACTORS_COMBINE_SUM else "averaging")
target_factors_num_embed[i] = num_embed_target
target_factor_configs = [encoder.FactorConfig(size, dim, combine, share) \
for size, dim, combine, share in zip(target_factor_vocab_sizes,
target_factors_num_embed,
args.target_factors_combine,
args.target_factors_share_embedding)]
allow_sparse_grad = args.update_interval == 1 # sparse embedding gradients do not work with grad_req='add'
config_embed_source = encoder.EmbeddingConfig(vocab_size=source_vocab_size,
num_embed=num_embed_source,
dropout=embed_dropout_source,
factor_configs=source_factor_configs,
allow_sparse_grad=allow_sparse_grad)
config_embed_target = encoder.EmbeddingConfig(vocab_size=target_vocab_size,
num_embed=num_embed_target,
dropout=embed_dropout_target,
factor_configs=target_factor_configs,
allow_sparse_grad=allow_sparse_grad)
config_length_task = None
if args.length_task is not None:
config_length_task = layers.LengthRatioConfig(num_layers=args.length_task_layers,
weight=args.length_task_weight)
model_config = model.ModelConfig(config_data=config_data,
vocab_source_size=source_vocab_size,
vocab_target_size=target_vocab_size,
config_embed_source=config_embed_source,
config_embed_target=config_embed_target,
config_encoder=config_encoder,
config_decoder=config_decoder,
config_length_task=config_length_task,
weight_tying_type=args.weight_tying_type,
lhuc=args.lhuc is not None,
dtype=args.dtype)
return model_config