in conv_split_awa_hybrid.py [0:0]
def main():
"""
Create the model and start the training
"""
# Get the CL arguments
args = get_arguments()
# Initialize the random seed of numpy
np.random.seed(args.random_seed)
# Check if the network architecture is valid
if args.arch not in VALID_ARCHS:
raise ValueError("Network architecture %s is not supported!"%(args.arch))
# Check if the method to compute importance is valid
if args.imp_method not in MODELS:
raise ValueError("Importance measure %s is undefined!"%(args.imp_method))
# Check if the optimizer is valid
if args.optim not in VALID_OPTIMS:
raise ValueError("Optimizer %s is undefined!"%(args.optim))
# Create log directories to store the results
if not os.path.exists(args.log_dir):
print('Log directory %s created!'%(args.log_dir))
os.makedirs(args.log_dir)
if args.online_cross_val:
num_tasks = K_FOR_CROSS_VAL
else:
num_tasks = NUM_TASKS - K_FOR_CROSS_VAL
# Load the split AWA dataset
data_labs = [np.arange(TOTAL_CLASSES)]
datasets, AWA_attr = construct_split_awa(data_labs, args.data_dir, AWA_TRAIN_LIST, AWA_VAL_LIST, AWA_TEST_LIST, IMG_HEIGHT, IMG_WIDTH, attr_file=AWA_ATTR_LIST)
if args.online_cross_val:
AWA_attr[K_FOR_CROSS_VAL*CLASSES_PER_TASK:] = 0
else:
AWA_attr[:K_FOR_CROSS_VAL*CLASSES_PER_TASK] = 0
print('Attributes: {}'.format(np.sum(AWA_attr, axis=1)))
if args.cross_validate_mode:
models_list = MODELS
learning_rate_list = [0.1, 0.03, 0.01, 0.001, 0.0003]
else:
models_list = [args.imp_method]
for imp_method in models_list:
if imp_method == 'VAN':
synap_stgth_list = [0]
if args.online_cross_val or args.cross_validate_mode:
pass
else:
learning_rate_list = [0.003]
elif imp_method == 'PI':
if args.online_cross_val or args.cross_validate_mode:
synap_stgth_list = [0.1, 1, 10]
else:
synap_stgth_list = [10]
learning_rate_list = [0.003]
elif imp_method == 'EWC' or imp_method == 'M-EWC':
if args.online_cross_val or args.cross_validate_mode:
synap_stgth_list = [0.1, 1, 10, 100]
else:
synap_stgth_list = [100]
learning_rate_list = [0.003]
elif imp_method == 'MAS':
if args.online_cross_val or args.cross_validate_mode:
synap_stgth_list = [0.1, 1, 10, 100]
else:
synap_stgth_list = [0.1]
learning_rate_list = [0.001]
elif imp_method == 'RWALK':
if args.online_cross_val or args.cross_validate_mode:
synap_stgth_list = [0.1, 1, 10, 100]
else:
synap_stgth_list = [10] # Check again!
learning_rate_list = [0.003]
elif imp_method == 'S-GEM':
synap_stgth_list = [0]
if args.online_cross_val:
pass
else:
learning_rate_list = [args.learning_rate]
elif imp_method == 'A-GEM':
synap_stgth_list = [0]
if args.online_cross_val or args.cross_validate_mode:
pass
else:
learning_rate_list = [0.003]
for synap_stgth in synap_stgth_list:
for lr in learning_rate_list:
# Generate the experiment key and store the meta data in a file
exper_meta_data = {'ARCH': args.arch,
'DATASET': 'SPLIT_AWA',
'HYBRID': args.set_hybrid,
'NUM_RUNS': args.num_runs,
'TRAIN_SINGLE_EPOCH': args.train_single_epoch,
'IMP_METHOD': imp_method,
'SYNAP_STGTH': synap_stgth,
'FISHER_EMA_DECAY': args.fisher_ema_decay,
'FISHER_UPDATE_AFTER': args.fisher_update_after,
'OPTIM': args.optim,
'LR': lr,
'BATCH_SIZE': args.batch_size,
'EPS_MEMORY': args.do_sampling,
'MEM_SIZE': args.mem_size,
'IS_HERDING': args.is_herding}
experiment_id = "SPLIT_AWA_HERDING_%r_HYB_%r_%s_%r_%s_%s_%s_%r_%s-"%(args.is_herding, args.set_hybrid, args.arch, args.train_single_epoch, imp_method,
str(synap_stgth).replace('.', '_'),
str(args.batch_size), args.do_sampling, str(args.mem_size)) + datetime.datetime.now().strftime("%y-%m-%d-%H-%M")
snapshot_experiment_meta_data(args.log_dir, experiment_id, exper_meta_data)
# Reset the default graph
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
# Set the random seed
tf.set_random_seed(args.random_seed)
# Define Input and Output of the model
x = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
y_ = tf.placeholder(tf.float32, shape=[None, num_tasks*TOTAL_CLASSES])
attr = tf.placeholder(tf.float32, shape=[num_tasks*TOTAL_CLASSES, ATTR_DIMS])
if not args.train_single_epoch:
# Define ops for data augmentation
x_aug = image_scaling(x)
x_aug = random_crop_and_pad_image(x_aug, IMG_HEIGHT, IMG_WIDTH)
# Define the optimizer
if args.optim == 'ADAM':
opt = tf.train.AdamOptimizer(learning_rate=lr)
elif args.optim == 'SGD':
opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
elif args.optim == 'MOMENTUM':
base_lr = tf.constant(lr)
learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - train_step / training_iters), OPT_POWER))
opt = tf.train.MomentumOptimizer(lr, OPT_MOMENTUM)
# Create the Model/ contruct the graph
if args.train_single_epoch:
# When training using a single epoch then there is no need for data augmentation
model = Model(x, y_, num_tasks, opt, imp_method, synap_stgth, args.fisher_update_after,
args.fisher_ema_decay, network_arch=args.arch, is_ATT_DATASET=True, attr=attr)
else:
model = Model(x_aug, y_, num_tasks, opt, imp_method, synap_stgth, args.fisher_update_after,
args.fisher_ema_decay, network_arch=args.arch, is_ATT_DATASET=True, x_test=x, attr=attr)
# Set up tf session and initialize variables.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
time_start = time.time()
with tf.Session(config=config, graph=graph) as sess:
saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=100)
runs, task_labels_dataset = train_task_sequence(model, sess, saver, datasets, AWA_attr, CLASSES_PER_TASK, args.cross_validate_mode,
args.train_single_epoch, args.do_sampling, args.is_herding, args.mem_size*CLASSES_PER_TASK*num_tasks, args.train_iters,
args.batch_size, args.num_runs, args.init_checkpoint, args.online_cross_val, args.random_seed)
# Close the session
sess.close()
time_end = time.time()
time_spent = time_end - time_start
print('Time spent: {}'.format(time_spent))
# Clean up
del model
if args.cross_validate_mode:
# If cross-validation flag is enabled, store the stuff in a text file
cross_validate_dump_file = args.log_dir + '/' + 'SPLIT_AWA_HYBRID_%s_%s'%(imp_method, args.optim) + '.txt'
with open(cross_validate_dump_file, 'a') as f:
f.write('HERDING: {} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {}\n'.format(args.is_herding, args.arch, lr, synap_stgth, runs))
else:
# Store all the results in one dictionary to process later
exper_acc = dict(mean=runs)
exper_labels = dict(labels=task_labels_dataset)
# Store the experiment output to a file
snapshot_experiment_eval(args.log_dir, experiment_id, exper_acc)
snapshot_task_labels(args.log_dir, experiment_id, exper_labels)