in fc_permute_mnist.py [0:0]
def main():
"""
Create the model and start the training
"""
# Get the CL arguments
args = get_arguments()
# Check if the network architecture is valid
if args.arch not in VALID_ARCHS:
raise ValueError("Network architecture %s is not supported!"%(args.arch))
# Check if the method to compute importance is valid
if args.imp_method not in MODELS:
raise ValueError("Importance measure %s is undefined!"%(args.imp_method))
# Check if the optimizer is valid
if args.optim not in VALID_OPTIMS:
raise ValueError("Optimizer %s is undefined!"%(args.optim))
# Create log directories to store the results
if not os.path.exists(args.log_dir):
print('Log directory %s created!'%(args.log_dir))
os.makedirs(args.log_dir)
# Generate the experiment key and store the meta data in a file
exper_meta_data = {'DATASET': 'PERMUTE_MNIST',
'NUM_RUNS': args.num_runs,
'TRAIN_SINGLE_EPOCH': args.train_single_epoch,
'IMP_METHOD': args.imp_method,
'SYNAP_STGTH': args.synap_stgth,
'FISHER_EMA_DECAY': args.fisher_ema_decay,
'FISHER_UPDATE_AFTER': args.fisher_update_after,
'OPTIM': args.optim,
'LR': args.learning_rate,
'BATCH_SIZE': args.batch_size,
'MEM_SIZE': args.mem_size}
experiment_id = "PERMUTE_MNIST_HERDING_%s_%s_%s_%s_%r_%s-"%(args.arch, args.train_single_epoch, args.imp_method, str(args.synap_stgth).replace('.', '_'),
str(args.batch_size), str(args.mem_size)) + datetime.datetime.now().strftime("%y-%m-%d-%H-%M")
snapshot_experiment_meta_data(args.log_dir, experiment_id, exper_meta_data)
# Get the subset of data depending on training or cross-validation mode
if args.online_cross_val:
num_tasks = K_FOR_CROSS_VAL
else:
num_tasks = NUM_TASKS - K_FOR_CROSS_VAL
# Variables to store the accuracies and standard deviations of the experiment
acc_mean = dict()
acc_std = dict()
# Reset the default graph
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
# Set the random seed
tf.set_random_seed(args.random_seed)
# Define Input and Output of the model
x = tf.placeholder(tf.float32, shape=[None, INPUT_FEATURE_SIZE])
#x = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
if args.imp_method == 'PNN':
y_ = []
for i in range(num_tasks):
y_.append(tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES]))
else:
y_ = tf.placeholder(tf.float32, shape=[None, TOTAL_CLASSES])
# Define the optimizer
if args.optim == 'ADAM':
opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
elif args.optim == 'SGD':
opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
elif args.optim == 'MOMENTUM':
base_lr = tf.constant(args.learning_rate)
learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - train_step / training_iters), OPT_POWER))
opt = tf.train.MomentumOptimizer(args.learning_rate, OPT_MOMENTUM)
# Create the Model/ contruct the graph
model = Model(x, y_, num_tasks, opt, args.imp_method, args.synap_stgth, args.fisher_update_after,
args.fisher_ema_decay, network_arch=args.arch)
# Set up tf session and initialize variables.
if USE_GPU:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
else:
config = tf.ConfigProto(
device_count = {'GPU': 0}
)
time_start = time.time()
with tf.Session(config=config, graph=graph) as sess:
runs = train_task_sequence(model, sess, args)
# Close the session
sess.close()
time_end = time.time()
time_spent = time_end - time_start
# Store all the results in one dictionary to process later
exper_acc = dict(mean=runs)
# If cross-validation flag is enabled, store the stuff in a text file
if args.cross_validate_mode:
acc_mean = runs.mean(0)
acc_std = runs.std(0)
cross_validate_dump_file = args.log_dir + '/' + 'PERMUTE_MNIST_%s_%s'%(args.imp_method, args.optim) + '.txt'
with open(cross_validate_dump_file, 'a') as f:
if MULTI_TASK:
f.write('GPU:{} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {}\n'.format(USE_GPU, args.arch, args.learning_rate,
args.synap_stgth, acc_mean[-1, :].mean()))
else:
f.write('GPU: {} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {} \t Fgt: {} \t Time: {}\n'.format(USE_GPU, args.arch, args.learning_rate,
args.synap_stgth, acc_mean[-1, :].mean(), compute_fgt(acc_mean), str(time_spent)))
# Store the experiment output to a file
snapshot_experiment_eval(args.log_dir, experiment_id, exper_acc)