in sat/arguments.py [0:0]
def get_args(args_list=None, parser=None):
"""Parse all the args."""
if parser is None:
parser = argparse.ArgumentParser(description="sat")
else:
assert isinstance(parser, argparse.ArgumentParser)
parser = add_model_config_args(parser)
parser = add_sampling_config_args(parser)
parser = add_training_args(parser)
parser = add_training_extra_config_args(parser)
parser = add_evaluation_args(parser)
parser = add_data_args(parser)
import deepspeed
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(args_list)
args = process_config_to_args(args)
if not args.train_data:
print_rank0("No training data specified", level="WARNING")
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv("RANK", "0"))
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
if args.local_rank is None:
args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
if args.device == -1:
if torch.cuda.device_count() == 0:
args.device = "cpu"
elif args.local_rank is not None:
args.device = args.local_rank
else:
args.device = args.rank % torch.cuda.device_count()
if args.local_rank != args.device and args.mode != "inference":
raise ValueError(
"LOCAL_RANK (default 0) and args.device inconsistent. "
"This can only happens in inference mode. "
"Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
)
if args.rank == 0:
print_rank0("using world size: {}".format(args.world_size))
if args.train_data_weights is not None:
assert len(args.train_data_weights) == len(args.train_data)
if args.mode != "inference": # training with deepspeed
args.deepspeed = True
if args.deepspeed_config is None: # not specified
deepspeed_config_path = os.path.join(
os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
)
with open(deepspeed_config_path) as file:
args.deepspeed_config = json.load(file)
override_deepspeed_config = True
else:
override_deepspeed_config = False
assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
if args.zero_stage > 0 and not args.fp16 and not args.bf16:
print_rank0("Automatically set fp16=True to use ZeRO.")
args.fp16 = True
args.bf16 = False
if args.deepspeed:
if args.checkpoint_activations:
args.deepspeed_activation_checkpointing = True
else:
args.deepspeed_activation_checkpointing = False
if args.deepspeed_config is not None:
deepspeed_config = args.deepspeed_config
if override_deepspeed_config: # not specify deepspeed_config, use args
if args.fp16:
deepspeed_config["fp16"]["enabled"] = True
elif args.bf16:
deepspeed_config["bf16"]["enabled"] = True
deepspeed_config["fp16"]["enabled"] = False
else:
deepspeed_config["fp16"]["enabled"] = False
deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
optimizer_params_config = deepspeed_config["optimizer"]["params"]
optimizer_params_config["lr"] = args.lr
optimizer_params_config["weight_decay"] = args.weight_decay
else: # override args with values in deepspeed_config
if args.rank == 0:
print_rank0("Will override arguments with manually specified deepspeed_config!")
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
args.fp16 = True
else:
args.fp16 = False
if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
args.bf16 = True
else:
args.bf16 = False
if "train_micro_batch_size_per_gpu" in deepspeed_config:
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
if "gradient_accumulation_steps" in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
else:
args.gradient_accumulation_steps = None
if "optimizer" in deepspeed_config:
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
args.lr = optimizer_params_config.get("lr", args.lr)
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
args.deepspeed_config = deepspeed_config
# initialize distributed and random seed because it always seems to be necessary.
initialize_distributed(args)
if args.mode != "inference":
args.seed = args.seed + mpu.get_data_parallel_rank()
set_random_seed(args.seed)
print_rank0(f"args:\n{pformat(vars(args), indent=2, sort_dicts=True)}")
return args