pipeline/common/marian.py (30 lines of code) (raw):
"""
Common utilities related to working with Marian.
"""
from pathlib import Path
from typing import Union
import yaml
def get_combined_config(config_path: Path, extra_marian_args: list[str]) -> dict[str, any]:
"""
Frequently we combine a Marian yml config with extra marian args when running
training. To get the final value, add both here.
"""
return {
**yaml.safe_load(config_path.open()),
**marian_args_to_dict(extra_marian_args),
}
def marian_args_to_dict(extra_marian_args: list[str]) -> dict[str, Union[str, bool, list[str]]]:
"""
Converts marian args, to the dict format. This will combine a decoder.yml
and extra marian args.
e.g. `--precision float16` becomes {"precision": "float16"}
"""
decoder_config = {}
if extra_marian_args and extra_marian_args[0] == "--":
extra_marian_args = extra_marian_args[1:]
previous_key = None
for arg in extra_marian_args:
if arg.startswith("--"):
previous_key = arg[2:]
decoder_config[previous_key] = True
continue
if not previous_key:
raise Exception(
f"Expected to have a previous key when converting marian args to a dict: {extra_marian_args}"
)
prev_value = decoder_config.get(previous_key)
if prev_value is True:
decoder_config[previous_key] = arg
elif isinstance(prev_value, list):
prev_value.append(arg)
else:
decoder_config[previous_key] = [prev_value, arg]
return decoder_config