sat/arguments.py (251 lines of code) (raw):

import argparse import json import os import warnings from datetime import timedelta from pprint import pformat import omegaconf import torch import torch.distributed from omegaconf import OmegaConf from sat import mpu from sat.arguments import add_data_args, add_evaluation_args, add_training_args, set_random_seed from sat.helpers import print_rank0 os.environ["NCCL_BLOCKING_WAIT"] = "0" # os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" os.environ["NCCL_IB_GID_INDEX"] = "3" def add_model_config_args(parser): """Model arguments""" group = parser.add_argument_group("model", "model configuration") group.add_argument("--base", type=str, nargs="*", help="config for input and saving") group.add_argument( "--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert." ) group.add_argument("--force-pretrain", action="store_true") group.add_argument("--device", type=int, default=-1) group.add_argument("--debug", action="store_true") group.add_argument("--log-image", type=bool, default=True) return parser def add_sampling_config_args(parser): """Sampling configurations""" group = parser.add_argument_group("sampling", "Sampling Configurations") group.add_argument("--output-dir", type=str, default="samples") group.add_argument("--input-dir", type=str, default=None) group.add_argument("--input-type", type=str, default="cli") group.add_argument("--input-file", type=str, default="input.txt") group.add_argument("--final-size", type=int, default=2048) group.add_argument("--sdedit", action="store_true") group.add_argument("--grid-num-rows", type=int, default=1) group.add_argument("--force-inference", action="store_true") group.add_argument("--lcm_steps", type=int, default=None) group.add_argument("--sampling-num-frames", type=int, default=32) group.add_argument("--sampling-fps", type=int, default=8) group.add_argument("--only-save-latents", type=bool, default=False) group.add_argument("--only-log-video-latents", type=bool, default=False) group.add_argument("--latent-channels", type=int, default=32) group.add_argument("--image2video", action="store_true") group.add_argument("--no_flow_injection", type=bool, default=False) group.add_argument("--flow_path", type=str, default=None) group.add_argument("--flow_from_prompt", type=str, default="cli") group.add_argument( "--point_path", nargs="+", type=str, help="List of point paths, eg. --point_path path/traj1.txt path/traj2.txt", default=None, ) group.add_argument("--vis_traj_features", type=bool, default=False) group.add_argument("--img_dir", type=str, default="cli") group.add_argument("--num_samples_per_prompt", type=int, default=1) return parser def add_training_extra_config_args(parser): """Extra training arguments.""" group = parser.add_argument_group("train_extra", "Extra training configurations") group.add_argument("--sample_flow", type=bool, default=True) group.add_argument("--do_fuse_object_features", type=bool, default=False) group.add_argument("--use_raft", type=bool, default=True) return parser def get_args(args_list=None, parser=None): """Parse all the args.""" if parser is None: parser = argparse.ArgumentParser(description="sat") else: assert isinstance(parser, argparse.ArgumentParser) parser = add_model_config_args(parser) parser = add_sampling_config_args(parser) parser = add_training_args(parser) parser = add_training_extra_config_args(parser) parser = add_evaluation_args(parser) parser = add_data_args(parser) import deepspeed parser = deepspeed.add_config_arguments(parser) args = parser.parse_args(args_list) args = process_config_to_args(args) if not args.train_data: print_rank0("No training data specified", level="WARNING") assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set." if args.train_iters is None and args.epochs is None: args.train_iters = 10000 # default 10k iters print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING") args.cuda = torch.cuda.is_available() args.rank = int(os.getenv("RANK", "0")) args.world_size = int(os.getenv("WORLD_SIZE", "1")) if args.local_rank is None: args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun if args.device == -1: if torch.cuda.device_count() == 0: args.device = "cpu" elif args.local_rank is not None: args.device = args.local_rank else: args.device = args.rank % torch.cuda.device_count() if args.local_rank != args.device and args.mode != "inference": raise ValueError( "LOCAL_RANK (default 0) and args.device inconsistent. " "This can only happens in inference mode. " "Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. " ) if args.rank == 0: print_rank0("using world size: {}".format(args.world_size)) if args.train_data_weights is not None: assert len(args.train_data_weights) == len(args.train_data) if args.mode != "inference": # training with deepspeed args.deepspeed = True if args.deepspeed_config is None: # not specified deepspeed_config_path = os.path.join( os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json" ) with open(deepspeed_config_path) as file: args.deepspeed_config = json.load(file) override_deepspeed_config = True else: override_deepspeed_config = False assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16." if args.zero_stage > 0 and not args.fp16 and not args.bf16: print_rank0("Automatically set fp16=True to use ZeRO.") args.fp16 = True args.bf16 = False if args.deepspeed: if args.checkpoint_activations: args.deepspeed_activation_checkpointing = True else: args.deepspeed_activation_checkpointing = False if args.deepspeed_config is not None: deepspeed_config = args.deepspeed_config if override_deepspeed_config: # not specify deepspeed_config, use args if args.fp16: deepspeed_config["fp16"]["enabled"] = True elif args.bf16: deepspeed_config["bf16"]["enabled"] = True deepspeed_config["fp16"]["enabled"] = False else: deepspeed_config["fp16"]["enabled"] = False deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps optimizer_params_config = deepspeed_config["optimizer"]["params"] optimizer_params_config["lr"] = args.lr optimizer_params_config["weight_decay"] = args.weight_decay else: # override args with values in deepspeed_config if args.rank == 0: print_rank0("Will override arguments with manually specified deepspeed_config!") if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: args.fp16 = True else: args.fp16 = False if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]: args.bf16 = True else: args.bf16 = False if "train_micro_batch_size_per_gpu" in deepspeed_config: args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] if "gradient_accumulation_steps" in deepspeed_config: args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] else: args.gradient_accumulation_steps = None if "optimizer" in deepspeed_config: optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) args.lr = optimizer_params_config.get("lr", args.lr) args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) args.deepspeed_config = deepspeed_config # initialize distributed and random seed because it always seems to be necessary. initialize_distributed(args) if args.mode != "inference": args.seed = args.seed + mpu.get_data_parallel_rank() set_random_seed(args.seed) print_rank0(f"args:\n{pformat(vars(args), indent=2, sort_dicts=True)}") return args def initialize_distributed(args): """Initialize torch.distributed.""" if torch.distributed.is_initialized(): if mpu.model_parallel_is_initialized(): if args.model_parallel_size != mpu.get_model_parallel_world_size(): raise ValueError( "model_parallel_size is inconsistent with prior configuration." "We currently do not support changing model_parallel_size." ) return False else: if args.model_parallel_size > 1: warnings.warn( "model_parallel_size > 1 but torch.distributed is not initialized via SAT." "Please carefully make sure the correctness on your own." ) mpu.initialize_model_parallel(args.model_parallel_size) return True # the automatic assignment of devices has been moved to arguments.py if args.device == "cpu": pass else: torch.cuda.set_device(args.device) # Call the init process init_method = "tcp://" args.master_ip = os.getenv("MASTER_ADDR", "localhost") if args.world_size == 1: from sat.helpers import get_free_port default_master_port = str(get_free_port()) else: default_master_port = "6000" args.master_port = os.getenv("MASTER_PORT", default_master_port) init_method += args.master_ip + ":" + args.master_port torch.distributed.init_process_group( backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method, timeout=timedelta(seconds=7200000), ) # Set the model-parallel / data-parallel communicators. mpu.initialize_model_parallel(args.model_parallel_size) # Set vae context parallel group equal to model parallel group from sgm.util import initialize_context_parallel, set_context_parallel_group if args.model_parallel_size <= 2: set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group()) else: initialize_context_parallel(2) # mpu.initialize_model_parallel(1) # Optional DeepSpeed Activation Checkpointing Features if args.deepspeed: import deepspeed deepspeed.init_distributed( dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method ) # # It seems that it has no negative influence to configure it even without using checkpointing. # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) else: # in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout. try: import deepspeed from deepspeed.runtime.activation_checkpointing.checkpointing import ( _CUDA_RNG_STATE_TRACKER, _MODEL_PARALLEL_RNG_TRACKER_NAME, ) _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1 except Exception as e: from sat.helpers import print_rank0 print_rank0(str(e), level="DEBUG") return True def process_config_to_args(args): """Fetch args from only --base""" configs = [OmegaConf.load(cfg) for cfg in args.base] config = OmegaConf.merge(*configs) args_config = config.pop("args", OmegaConf.create()) for key in args_config: if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig): arg = OmegaConf.to_object(args_config[key]) else: arg = args_config[key] if hasattr(args, key): setattr(args, key, arg) if "model" in config: model_config = config.pop("model", OmegaConf.create()) args.model_config = model_config if "deepspeed" in config: deepspeed_config = config.pop("deepspeed", OmegaConf.create()) args.deepspeed_config = OmegaConf.to_object(deepspeed_config) if "data" in config: data_config = config.pop("data", OmegaConf.create()) args.data_config = data_config return args