in c2/tools/train_net.py [0:0]
def Train(args):
if args.gpus is not None:
gpus = [int(x) for x in args.gpus.split(',')]
num_gpus = len(gpus)
else:
gpus = range(args.num_gpus)
num_gpus = args.num_gpus
log.info("Running on GPUs: {}".format(gpus))
# Modify to make it consistent with the distributed trainer
total_batch_size = args.batch_size * num_gpus
batch_per_device = args.batch_size
# Round down epoch size to closest multiple of batch size across machines
epoch_iters = int(args.epoch_size / total_batch_size)
args.epoch_size = epoch_iters * total_batch_size
log.info("Using epoch size: {}".format(args.epoch_size))
# Create CNNModeLhelper object
train_model = cnn.CNNModelHelper(
order="NCHW",
name='{}_train'.format(args.model_name),
use_cudnn=(True if args.use_cudnn == 1 else False),
cudnn_exhaustive_search=True,
ws_nbytes_limit=(args.cudnn_workspace_limit_mb * 1024 * 1024),
)
# Model building functions
def create_model_ops(model, loss_scale):
return model_builder.build_model(
model=model,
model_name=args.model_name,
model_depth=args.model_depth,
num_labels=args.num_labels,
batch_size=args.batch_size,
num_channels=args.num_channels,
crop_size=args.crop_size,
clip_length=(
args.clip_length_of if args.input_type
else args.clip_length_rgb
),
loss_scale=loss_scale,
pred_layer_name=args.pred_layer_name,
multi_label=args.multi_label,
channel_multiplier=args.channel_multiplier,
bottleneck_multiplier=args.bottleneck_multiplier,
use_dropout=args.use_dropout,
conv1_temporal_stride=args.conv1_temporal_stride,
conv1_temporal_kernel=args.conv1_temporal_kernel,
use_pool1=args.use_pool1,
audio_input_3d=args.audio_input_3d,
g_blend=args.g_blend,
audio_weight=args.audio_weight,
visual_weight=args.visual_weight,
av_weight=args.av_weight,
)
# SGD
def add_parameter_update_ops(model):
model.AddWeightDecay(args.weight_decay)
ITER = model.Iter("ITER")
stepsz = args.step_epoch * args.epoch_size / args.batch_size / num_gpus
LR = model.net.LearningRate(
[ITER],
"LR",
base_lr=args.base_learning_rate * num_gpus,
policy="step",
stepsize=int(stepsz),
gamma=args.gamma,
)
AddMomentumParameterUpdate(model, LR)
# Input. Note that the reader must be shared with all GPUS.
train_reader, train_examples = reader_utils.create_data_reader(
train_model,
name="train_reader",
input_data=args.train_data,
)
log.info("Training set has {} examples".format(train_examples))
def add_video_input(model):
model_helper.AddVideoInput(
model,
train_reader,
batch_size=batch_per_device,
length_rgb=args.clip_length_rgb,
clip_per_video=1,
random_mirror=True,
decode_type=0,
sampling_rate_rgb=args.sampling_rate_rgb,
scale_h=args.scale_h,
scale_w=args.scale_w,
crop_size=args.crop_size,
video_res_type=args.video_res_type,
short_edge=min(args.scale_h, args.scale_w),
num_decode_threads=args.num_decode_threads,
do_multi_label=args.multi_label,
num_of_class=args.num_labels,
random_crop=True,
input_type=args.input_type,
length_of=args.clip_length_of,
sampling_rate_of=args.sampling_rate_of,
frame_gap_of=args.frame_gap_of,
do_flow_aggregation=args.do_flow_aggregation,
flow_data_type=args.flow_data_type,
get_rgb=(args.input_type == 0 or args.input_type >= 3),
get_optical_flow=(args.input_type == 1 or args.input_type >= 4),
get_logmels=(args.input_type >= 2),
get_video_id=args.get_video_id,
jitter_scales=[int(n) for n in args.jitter_scales.split(',')],
use_local_file=args.use_local_file,
)
# Create parallelized model
data_parallel_model.Parallelize_GPU(
train_model,
input_builder_fun=add_video_input,
forward_pass_builder_fun=create_model_ops,
param_update_builder_fun=add_parameter_update_ops,
devices=gpus,
rendezvous=None,
net_type=('prof_dag' if args.profiling == 1 else 'dag'),
optimize_gradient_memory=True,
)
# Add test model, if specified
test_model = None
if args.test_data is not None:
log.info("----- Create test net ----")
test_model = cnn.CNNModelHelper(
order="NCHW",
name='{}_test'.format(args.model_name),
use_cudnn=(True if args.use_cudnn == 1 else False),
cudnn_exhaustive_search=True
)
test_reader, test_examples = reader_utils.create_data_reader(
test_model,
name="test_reader",
input_data=args.test_data,
)
log.info("Testing set has {} examples".format(test_examples))
def test_input_fn(model):
model_helper.AddVideoInput(
model,
test_reader,
batch_size=batch_per_device,
length_rgb=args.clip_length_rgb,
clip_per_video=1,
decode_type=0,
random_mirror=False,
random_crop=False,
sampling_rate_rgb=args.sampling_rate_rgb,
scale_h=args.scale_h,
scale_w=args.scale_w,
crop_size=args.crop_size,
video_res_type=args.video_res_type,
short_edge=min(args.scale_h, args.scale_w),
num_decode_threads=args.num_decode_threads,
do_multi_label=args.multi_label,
num_of_class=args.num_labels,
input_type=args.input_type,
length_of=args.clip_length_of,
sampling_rate_of=args.sampling_rate_of,
frame_gap_of=args.frame_gap_of,
do_flow_aggregation=args.do_flow_aggregation,
flow_data_type=args.flow_data_type,
get_rgb=(args.input_type == 0),
get_optical_flow=(args.input_type == 1),
get_video_id=args.get_video_id,
use_local_file=args.use_local_file,
)
data_parallel_model.Parallelize_GPU(
test_model,
input_builder_fun=test_input_fn,
forward_pass_builder_fun=create_model_ops,
param_update_builder_fun=None,
devices=gpus,
optimize_gradient_memory=True,
)
workspace.RunNetOnce(test_model.param_init_net)
workspace.CreateNet(test_model.net)
workspace.RunNetOnce(train_model.param_init_net)
workspace.CreateNet(train_model.net)
epoch = 0
# load the pre-trained model and reset epoch
if args.load_model_path is not None:
if args.db_type == 'pickle':
model_loader.LoadModelFromPickleFile(
train_model,
args.load_model_path,
use_gpu=True,
root_gpu_id=gpus[0]
)
else:
model_helper.LoadModel(
args.load_model_path, args.db_type
)
# Sync the model params
data_parallel_model.FinalizeAfterCheckpoint(
train_model,
GetCheckpointParams(train_model),
)
if args.is_checkpoint:
# reset epoch. load_model_path should end with *_X.mdl,
# where X is the epoch number
last_str = args.load_model_path.split('_')[-1]
if last_str.endswith('.mdl'):
epoch = int(last_str[:-4])
log.info("Reset epoch to {}".format(epoch))
else:
log.warning("The format of load_model_path doesn't match!")
expname = "%s_gpu%d_b%d_L%d_lr%.2f" % (
args.model_name,
args.num_gpus,
total_batch_size,
args.num_labels,
args.base_learning_rate,
)
explog = experiment_util.ModelTrainerLog(expname, args)
# Run the training one epoch a time
while epoch < args.num_epochs:
epoch = RunEpoch(
args,
epoch,
train_model,
test_model,
total_batch_size,
1,
expname,
explog
)
# Save the model for each epoch
SaveModel(args, train_model, epoch)