in train.py [0:0]
def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if args.device_modules:
for module in args.device_modules:
importlib.import_module(module)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
args.prefetcher = not args.no_prefetcher
args.grad_accum_steps = max(1, args.grad_accum_steps)
device = utils.init_distributed_device(args)
if args.distributed:
_logger.info(
'Training in distributed mode with multiple processes, 1 device per process.'
f'Process {args.rank}, total {args.world_size}, device {args.device}.')
else:
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)
if model_dtype == torch.float16:
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
use_amp = 'apex'
assert args.amp_dtype == 'float16'
else:
use_amp = 'native'
assert args.amp_dtype in ('float16', 'bfloat16')
if args.amp_dtype == 'bfloat16':
amp_dtype = torch.bfloat16
utils.random_seed(args.seed, args.rank)
if args.fuser:
utils.set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
factory_kwargs = {}
if args.pretrained_path:
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
factory_kwargs['pretrained_cfg_overlay'] = dict(
file=args.pretrained_path,
num_classes=-1, # force head adaptation
)
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
**factory_kwargs,
**args.model_kwargs,
)
if args.head_init_scale is not None:
with torch.no_grad():
model.get_classifier().weight.mul_(args.head_init_scale)
model.get_classifier().bias.mul_(args.head_init_scale)
if args.head_init_bias is not None:
nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
if utils.is_primary(args):
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
args.dist_bn = '' # disable dist_bn when sync BN active
assert not args.split_bn
if has_apex and use_amp == 'apex':
# Apex SyncBN used with Apex AMP
# WARNING this won't currently work with models using BatchNormAct2d
model = convert_syncbn_model(model)
else:
model = convert_sync_batchnorm(model)
if utils.is_primary(args):
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not args.torchcompile
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
if not args.lr:
global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale:
on = args.opt.lower()
args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
if args.lr_base_scale == 'sqrt':
batch_ratio = batch_ratio ** 0.5
args.lr = args.lr_base * batch_ratio
if utils.is_primary(args):
_logger.info(
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
**args.opt_kwargs,
)
if utils.is_primary(args):
defaults = copy.deepcopy(optimizer.defaults)
defaults['weight_decay'] = args.weight_decay # this isn't stored in optimizer.defaults
defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()])
logging.info(
f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}'
)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
assert device.type == 'cuda'
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
if device.type in ('cuda',) and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler(device=device.type)
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if utils.is_primary(args):
_logger.info(f'AMP not enabled. Training in {model_dtype or torch.float32}.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model,
args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=utils.is_primary(args),
)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = utils.ModelEmaV3(
model,
decay=args.model_ema_decay,
use_warmup=args.model_ema_warmup,
device='cpu' if args.model_ema_force_cpu else None,
)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
if args.torchcompile:
model_ema = torch.compile(model_ema, backend=args.torchcompile)
# setup distributed training
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
if utils.is_primary(args):
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if utils.is_primary(args):
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
if args.torchcompile:
# torch compile should be done after DDP
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
if args.input_img_mode is None:
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else:
input_img_mode = args.input_img_mode
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
split=args.train_split,
is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
seed=args.seed,
repeats=args.epoch_repeats,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.train_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
)
dataset_eval = None
if args.val_split:
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
)
# create data loaders w/ augmentation pipeline
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
# Check if we should use the NaFlex scheduled loader
common_loader_kwargs = dict(
mean=data_config['mean'],
std=data_config['std'],
pin_memory=args.pin_mem,
img_dtype=model_dtype or torch.float32,
device=device,
distributed=args.distributed,
use_prefetcher=args.prefetcher,
)
train_loader_kwargs = dict(
batch_size=args.batch_size,
is_training=True,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
train_crop_mode=args.train_crop_mode,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
color_jitter_prob=args.color_jitter_prob,
grayscale_prob=args.grayscale_prob,
gaussian_blur_prob=args.gaussian_blur_prob,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
num_workers=args.workers,
worker_seeding=args.worker_seeding,
)
mixup_fn = None
mixup_args = {}
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
naflex_mode = False
model_patch_size = None
if args.naflex_loader:
if utils.is_primary(args):
_logger.info('Using NaFlex loader')
assert num_aug_splits <= 1, 'Augmentation splits not supported in NaFlex mode'
naflex_mixup_fn = None
if mixup_active:
from timm.data import NaFlexMixup
mixup_args.pop('mode') # not supported
mixup_args.pop('cutmix_minmax') # not supported
naflex_mixup_fn = NaFlexMixup(**mixup_args)
# Extract model's patch size for NaFlex mode
if hasattr(model, 'embeds') and hasattr(model.embeds, 'patch_size'):
# NaFlexVit models have embeds.patch_size
model_patch_size = model.embeds.patch_size
else:
# Fallback to default
model_patch_size = (16, 16)
if utils.is_primary(args):
_logger.warning(f'Could not determine model patch size, using default: {model_patch_size}')
# Configure patch sizes for NaFlex loader
patch_loader_kwargs = {}
if args.naflex_patch_sizes:
# Variable patch size mode
patch_loader_kwargs['patch_size_choices'] = args.naflex_patch_sizes
if args.naflex_patch_size_probs:
if len(args.naflex_patch_size_probs) != len(args.naflex_patch_sizes):
parser.error('--naflex-patch-size-probs must have same length as --naflex-patch-sizes')
patch_loader_kwargs['patch_size_choice_probs'] = args.naflex_patch_size_probs
if utils.is_primary(args):
_logger.info(f'Using variable patch sizes: {args.naflex_patch_sizes}')
else:
# Single patch size mode - use model's patch size
patch_loader_kwargs['patch_size'] = model_patch_size
if utils.is_primary(args):
_logger.info(f'Using model patch size: {model_patch_size}')
naflex_mode = True
loader_train = create_naflex_loader(
dataset=dataset_train,
train_seq_lens=args.naflex_train_seq_lens,
mixup_fn=naflex_mixup_fn,
rank=args.rank,
world_size=args.world_size,
**patch_loader_kwargs,
**common_loader_kwargs,
**train_loader_kwargs,
)
else:
# setup mixup / cutmix
collate_fn = None
if mixup_active:
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# Use standard loader
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
collate_fn=collate_fn,
use_multi_epochs_loader=args.use_multi_epochs_loader,
**common_loader_kwargs,
**train_loader_kwargs,
)
loader_eval = None
if args.val_split:
assert dataset_eval is not None
eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min(2, args.workers)
eval_loader_kwargs = dict(
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
interpolation=data_config['interpolation'],
num_workers=eval_workers,
crop_pct=data_config['crop_pct'],
)
if args.naflex_loader:
# Use largest sequence length for validation
loader_eval = create_naflex_loader(
dataset=dataset_eval,
patch_size=model_patch_size, # Use model's native patch size (already determined above)
max_seq_len=args.naflex_max_seq_len,
**common_loader_kwargs,
**eval_loader_kwargs
)
else:
# Use standard loader
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
**common_loader_kwargs,
**eval_loader_kwargs,
)
# setup loss function
if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform which outputs sparse, soft targets
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
target_threshold=args.bce_target_thresh,
sum_classes=args.bce_sum,
pos_weight=args.bce_pos_weight,
)
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
smoothing=args.smoothing,
target_threshold=args.bce_target_thresh,
sum_classes=args.bce_sum,
pos_weight=args.bce_pos_weight,
)
else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric if loader_eval is not None else 'loss'
decreasing_metric = eval_metric == 'loss'
best_metric = None
best_epoch = None
saver = None
output_dir = None
if utils.is_primary(args):
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
saver = utils.CheckpointSaver(
model=model,
optimizer=optimizer,
args=args,
model_ema=model_ema,
amp_scaler=loss_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing_metric,
max_history=args.checkpoint_hist
)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if args.log_wandb:
if has_wandb:
assert not args.wandb_resume_id or args.resume
wandb.init(
project=args.wandb_project,
name=exp_name,
config=args,
tags=args.wandb_tags,
resume="must" if args.wandb_resume_id else None,
id=args.wandb_resume_id if args.wandb_resume_id else None,
)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
lr_scheduler, num_epochs = create_scheduler_v2(
optimizer,
**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
updates_per_epoch=updates_per_epoch,
)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
if args.sched_on_updates:
lr_scheduler.step_update(start_epoch * updates_per_epoch)
else:
lr_scheduler.step(start_epoch)
if utils.is_primary(args):
if args.warmup_prefix:
sched_explain = '(warmup_epochs + epochs + cooldown_epochs). Warmup added to total when warmup_prefix=True'
else:
sched_explain = '(epochs + cooldown_epochs). Warmup within epochs when warmup_prefix=False'
_logger.info(
f'Scheduled epochs: {num_epochs} {sched_explain}. '
f'LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
results = []
try:
for epoch in range(start_epoch, num_epochs):
if hasattr(dataset_train, 'set_epoch'):
dataset_train.set_epoch(epoch)
elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch(
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
args,
device=device,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_dtype=model_dtype,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
naflex_mode=naflex_mode,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if utils.is_primary(args):
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
if loader_eval is not None:
eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
model_dtype=model_dtype,
)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema,
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics
else:
eval_metrics = None
if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
utils.update_summary(
epoch,
train_metrics,
eval_metrics,
filename=os.path.join(output_dir, 'summary.csv'),
lr=sum(lrs) / len(lrs),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,
)
if eval_metrics is not None:
latest_metric = eval_metrics[eval_metric]
else:
latest_metric = train_metrics[eval_metric]
if saver is not None:
# save proper checkpoint with eval metric
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, latest_metric)
latest_results = {
'epoch': epoch,
'train': train_metrics,
}
if eval_metrics is not None:
latest_results['validation'] = eval_metrics
results.append(latest_results)
except KeyboardInterrupt:
pass
if best_metric is not None:
# log best metric as tracked by checkpoint saver
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
if utils.is_primary(args):
# for parsable results display, dump top-10 summaries to avoid excess console spam
display_results = sorted(
results,
key=lambda x: x.get('validation', x.get('train')).get(eval_metric, 0),
reverse=decreasing_metric,
)
print(f'--result\n{json.dumps(display_results[-10:], indent=4)}')