in sagemaker/25_pytorch_fsdp_model_parallelism/scripts/run_clm.py [0:0]
def safe_save_model_for_hf_trainer(trainer: Trainer, tokenizer: AutoTokenizer, output_dir: str):
"""Helper method to save model for HF Trainer."""
# see: https://github.com/tatsu-lab/stanford_alpaca/issues/65
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
FullStateDictConfig,
StateDictType,
)
model = trainer.model
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state_dict = model.state_dict()
if trainer.args.should_save:
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
tokenizer.save_pretrained(output_dir)