def get_strategy()

in dataflux_pytorch/benchmark/checkpointing/multinode/train.py [0:0]


def get_strategy(args, project):
    strategy = None
    policy = {
        torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer
    }
    if args.strategy == DF_FSDP_STRATEGY:
        print("Using DatafluxFSDPStrategy")
        strategy = DatafluxFSDPStrategy(
            project_name=project,
            storage_client=None,
            state_dict_type="sharded",
            use_orig_params=False,
            auto_wrap_policy=policy,
        )
    elif args.strategy == FSSPEC_FSDP_STRATEGY:
        print("Using FSSpecFSDPStrategy")
        strategy = FSSpecFSDPStrategy(
            state_dict_type="sharded",
            use_orig_params=False,
            auto_wrap_policy=policy,
        )
    elif args.strategy == FSDP_STRATEGY and args.load_only:
        print("Using CustomFSDPStrategy.")
        strategy = LoadFromBootDiskFSDP(
            project_name=project,
            state_dict_type="sharded",
            use_orig_params=False,
            auto_wrap_policy=policy,
        )
    elif (args.strategy == FSDP_STRATEGY
          and args.save_only) or args.distributed_filesystem:
        print("Using FSDPStrategy.")
        strategy = FSDPStrategy(
            state_dict_type="sharded",
            use_orig_params=False,
            auto_wrap_policy=policy,
        )
    else:
        raise ValueError("Invalid strategy.")
    return strategy