in train.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--lr_mp', type=float, help='Learning rate for the mapping network')
parser.add_argument('--lr_backbones', type=float, help='Learning rate for the backbones')
parser.add_argument('--vlm_checkpoint_path', type=str, help='Path to the VLM checkpoint for loading or saving')
parser.add_argument('--compile', type=bool, help='Use torch.compile to optimize the model')
parser.add_argument('--log_wandb', type=bool, help='Log to wandb')
parser.add_argument('--resume_from_vlm_checkpoint', type=bool, default=False, help='Resume training from VLM checkpoint specified by vlm_checkpoint_path (or default if not provided)')
parser.add_argument('--no_log_wandb', action='store_true', help='Do not log to wandb')
args = parser.parse_args()
vlm_cfg = config.VLMConfig()
train_cfg = config.TrainConfig()
if args.lr_mp is not None:
train_cfg.lr_mp = args.lr_mp
if args.lr_backbones is not None:
train_cfg.lr_backbones = args.lr_backbones
if args.vlm_checkpoint_path is not None:
vlm_cfg.vlm_checkpoint_path = args.vlm_checkpoint_path
if args.compile is not None:
train_cfg.compile = args.compile
if args.no_log_wandb is True:
train_cfg.log_wandb = False
if args.resume_from_vlm_checkpoint and args.vlm_checkpoint_path is not None:
train_cfg.resume_from_vlm_checkpoint = True
# When resuming a full VLM, we don't need to load individual backbone weights from original sources
vlm_cfg.vlm_load_backbone_weights = False
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
init_dist()
if is_master():
print("--- VLM Config ---")
print(vlm_cfg)
print("--- Train Config ---")
print(train_cfg)
train(train_cfg, vlm_cfg)
if is_dist():
destroy_dist()