in contrib/action_recognition/i3d/train.py [0:0]
def run(*options, cfg=None):
"""Run training and validation of model
Notes:
Options can be passed in via the options argument and loaded from the cfg file
Options loaded from default.py will be overridden by options loaded from cfg file
Options passed in through options argument will override option loaded from cfg file
Args:
*options (str,int ,optional): Options used to overide what is loaded from the config.
To see what options are available consult default.py
cfg (str, optional): Location of config file to load. Defaults to None.
"""
update_config(config, options=options, config_file=cfg)
print("Training ", config.TRAIN.MODALITY, " model.")
print("Batch size:", config.TRAIN.BATCH_SIZE, " Gradient accumulation steps:", config.TRAIN.GRAD_ACCUM_STEPS)
torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK
torch.manual_seed(config.SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config.SEED)
np.random.seed(seed=config.SEED)
# Log to tensorboard
writer = SummaryWriter(log_dir=config.LOG_DIR)
# Setup dataloaders
train_loader = torch.utils.data.DataLoader(
I3DDataSet(
data_root=config.DATASET.DIR,
split=config.DATASET.SPLIT,
sample_frames=config.TRAIN.SAMPLE_FRAMES,
modality=config.TRAIN.MODALITY,
transform=torchvision.transforms.Compose([
GroupScale(config.TRAIN.RESIZE_MIN),
GroupRandomCrop(config.TRAIN.INPUT_SIZE),
GroupRandomHorizontalFlip(),
GroupNormalize(modality=config.TRAIN.MODALITY),
Stack(),
])
),
batch_size=config.TRAIN.BATCH_SIZE,
shuffle=True,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY
)
val_loader = torch.utils.data.DataLoader(
I3DDataSet(
data_root=config.DATASET.DIR,
split=config.DATASET.SPLIT,
modality=config.TRAIN.MODALITY,
train_mode=False,
transform=torchvision.transforms.Compose([
GroupScale(config.TRAIN.RESIZE_MIN),
GroupCenterCrop(config.TRAIN.INPUT_SIZE),
GroupNormalize(modality=config.TRAIN.MODALITY),
Stack(),
]),
),
batch_size=config.TEST.BATCH_SIZE,
shuffle=False,
num_workers=config.WORKERS,
pin_memory=config.PIN_MEMORY
)
# Setup model
if config.TRAIN.MODALITY == "RGB":
channels = 3
checkpoint = config.MODEL.PRETRAINED_RGB
elif config.TRAIN.MODALITY == "flow":
channels = 2
checkpoint = config.MODEL.PRETRAINED_FLOW
else:
raise ValueError("Modality must be RGB or flow")
i3d_model = InceptionI3d(400, in_channels=channels)
i3d_model.load_state_dict(torch.load(checkpoint))
# Replace final FC layer to match dataset
i3d_model.replace_logits(config.DATASET.NUM_CLASSES)
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(
i3d_model.parameters(),
lr=0.1,
momentum=0.9,
weight_decay=0.0000001
)
i3d_model = i3d_model.cuda()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=0.1,
patience=2,
verbose=True,
threshold=1e-4,
min_lr=1e-4
)
# Data-parallel
devices_lst = list(range(torch.cuda.device_count()))
print("Devices {}".format(devices_lst))
if len(devices_lst) > 1:
i3d_model = torch.nn.DataParallel(i3d_model)
if not os.path.exists(config.MODEL.CHECKPOINT_DIR):
os.makedirs(config.MODEL.CHECKPOINT_DIR)
for epoch in range(config.TRAIN.MAX_EPOCHS):
train(train_loader,
i3d_model,
criterion,
optimizer,
epoch,
writer
)
if (epoch + 1) % config.TEST.EVAL_FREQ == 0 or epoch == config.TRAIN.MAX_EPOCHS - 1:
val_loss = validate(val_loader, i3d_model, criterion, epoch, writer)
scheduler.step(val_loss)
torch.save(
i3d_model.module.state_dict(),
config.MODEL.CHECKPOINT_DIR+'/'+config.MODEL.NAME+'_split'+str(config.DATASET.SPLIT)+'_epoch'+str(epoch).zfill(3)+'.pt'
)
writer.close()