in src/pixparse/app/train.py [0:0]
def main():
args = parser.parse_args()
train_cfg: TrainCfg = args.train
data_cfg: DataCfg = args.data
device_env = DeviceEnv()
task, task_cfg = TaskFactory.create_task(task_name=train_cfg.task_name, task_args=args.task, device_env=device_env, monitor=None)
random_seed(train_cfg.seed, rank=device_env.global_rank)
_logger.info(f"Device env is {device_env}")
# get the name of the experiments
if train_cfg.experiment is None:
model_name_safe = clean_name(task_cfg.model_name)
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
if device_env.world_size > 1:
# sync date_str from master to all ranks
date_str = device_env.broadcast_object(date_str)
experiment = '-'.join([
date_str,
f"task_{train_cfg.task_name}",
f"model_{model_name_safe}",
f"lr_{'{:.1e}'.format(task_cfg.opt.learning_rate)}",
f"b_{data_cfg.train.batch_size}",
#TODO make completion of exp name derived from essential hparams
])
train_cfg = replace(train_cfg, experiment=experiment)
resume_latest = False # train_cfg.resume == 'latest'
experiment_path = os.path.join(train_cfg.output_dir, train_cfg.experiment)
log_path = None
if device_env.is_primary():
os.makedirs(experiment_path, exist_ok=True)
log_path = os.path.join(experiment_path, train_cfg.log_filename)
if os.path.exists(log_path) and not resume_latest:
_logger.error(
"Error. Experiment already exists. Use --experiment {} to specify a new experiment."
)
return -1
# Setup text logger
setup_logging(log_path)
task.monitor = Monitor(
train_cfg.experiment,
output_dir=experiment_path,
wandb=train_cfg.wandb,
wandb_project=train_cfg.wandb_project,
tensorboard=train_cfg.tensorboard,
output_enabled=device_env.is_primary(),
)
# ----- Model resuming from checkpoint -----
# FIXME make optional for resume.
# Task needs to have
# -- an attribute OrderedDict state_dict
# -- an attribute bool resume
if train_cfg.resume:
checkpoint_path = train_cfg.checkpoint_path
train_cfg = replace(train_cfg, checkpoint_path=checkpoint_path)
# FIXME check if path is local or s3?
if train_cfg.s3_bucket != "":
_logger.info("s3 bucket specified. Loading checkpoint from s3.")
checkpoint = load_checkpoint_from_s3(
train_cfg.s3_bucket, train_cfg.checkpoint_path
)
else:
assert os.path.isfile(
checkpoint_path
), f"Cannot find checkpoint {checkpoint_path}: File not found"
checkpoint = torch.load(train_cfg.checkpoint_path)
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
else:
state_dict = checkpoint["model"]
task.state_dict = state_dict
task.resume = True
# ------------------------------------------
output_checkpoint_dir = train_cfg.output_checkpoint_dir or os.path.join(experiment_path, 'checkpoints')
os.makedirs(output_checkpoint_dir, exist_ok=True)
train_cfg = replace(train_cfg, output_checkpoint_dir=output_checkpoint_dir)
if device_env.is_primary():
_logger.info(task_cfg)
_logger.info(train_cfg)
loaders = {}
assert (data_cfg.train is not None) or (data_cfg.eval is not None), f"Neither data_cfg.train nor data_cfg.eval are set."
if data_cfg.train is not None:
loaders['train'] = create_loader(
data_cfg.train,
is_train=True,
collate_fn=task.collate_fn,
image_preprocess=task.image_preprocess_train,
anno_preprocess=task.anno_preprocess_train,
image_fmt=task_cfg.model.image_encoder.image_fmt,
world_size=device_env.world_size,
global_rank=device_env.global_rank,
create_decoder_pipe=create_doc_anno_pipe, # TODO abstract away type of decoder needed
)
task.train_setup(
num_batches_per_interval=loaders['train'].num_batches,
)
if device_env.is_primary():
_logger.info(task)
train(
train_cfg,
task,
loaders,
)