in dataflux_pytorch/benchmark/checkpointing/multinode/train.py [0:0]
def get_strategy(args, project):
strategy = None
policy = {
torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer
}
if args.strategy == DF_FSDP_STRATEGY:
print("Using DatafluxFSDPStrategy")
strategy = DatafluxFSDPStrategy(
project_name=project,
storage_client=None,
state_dict_type="sharded",
use_orig_params=False,
auto_wrap_policy=policy,
)
elif args.strategy == FSSPEC_FSDP_STRATEGY:
print("Using FSSpecFSDPStrategy")
strategy = FSSpecFSDPStrategy(
state_dict_type="sharded",
use_orig_params=False,
auto_wrap_policy=policy,
)
elif args.strategy == FSDP_STRATEGY and args.load_only:
print("Using CustomFSDPStrategy.")
strategy = LoadFromBootDiskFSDP(
project_name=project,
state_dict_type="sharded",
use_orig_params=False,
auto_wrap_policy=policy,
)
elif (args.strategy == FSDP_STRATEGY
and args.save_only) or args.distributed_filesystem:
print("Using FSDPStrategy.")
strategy = FSDPStrategy(
state_dict_type="sharded",
use_orig_params=False,
auto_wrap_policy=policy,
)
else:
raise ValueError("Invalid strategy.")
return strategy