def add_checkpointing_args()

in chatlearn/tools/megatron_to_hf.py [0:0]


def add_checkpointing_args(parser):
    parser.add_argument(
        "--load_path",
        type=str,
        required=True,
        help="Path to the checkpoint to convert.",
    )
    parser.add_argument(
        "--save_path",
        type=str,
        required=True,
        help="Path to the converted checkpoint.",
    )
    parser.add_argument(
        "--vocab_dir",
        type=str,
        help="Vocab dir.",
    )
    parser.add_argument(
        "--model_type",
        type=str,
        choices=['llama'],
        default="llama",
        help="model type.",
    )
    parser.add_argument("--print-checkpoint-structure", action="store_true")
    parser.add_argument(
        "--target_params_dtype",
        type=str,
        default="fp32",
        help=(
            "The dtype of the converted checkpoint. "
            "Only used when converting a Transformers checkpoint to a Megatron checkpoint."
        ),
    )
    parser.add_argument(
        "--max_shard_size",
        type=str,
        default="10GB",
        help=(
            "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size "
            "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). "
            "Only used when converting a Megatron checkpoint to a Transformers checkpoint."
        ),
    )
    parser.add_argument(
        "--megatron_path",
        type=str,
        default=None,
        help=(
            "Path to Megatron-LM"
        ),
    )
    return parser