def main()

in conv_split_awa_hybrid.py [0:0]


def main():
    """
    Create the model and start the training
    """

    # Get the CL arguments
    args = get_arguments()
    
    # Initialize the random seed of numpy
    np.random.seed(args.random_seed)

    # Check if the network architecture is valid
    if args.arch not in VALID_ARCHS:
        raise ValueError("Network architecture %s is not supported!"%(args.arch))

    # Check if the method to compute importance is valid
    if args.imp_method not in MODELS:
        raise ValueError("Importance measure %s is undefined!"%(args.imp_method))
    
    # Check if the optimizer is valid
    if args.optim not in VALID_OPTIMS:
        raise ValueError("Optimizer %s is undefined!"%(args.optim))

    # Create log directories to store the results
    if not os.path.exists(args.log_dir):
        print('Log directory %s created!'%(args.log_dir))
        os.makedirs(args.log_dir)

    if args.online_cross_val:
        num_tasks = K_FOR_CROSS_VAL
    else:
        num_tasks = NUM_TASKS - K_FOR_CROSS_VAL

    # Load the split AWA dataset
    data_labs = [np.arange(TOTAL_CLASSES)]
    datasets, AWA_attr = construct_split_awa(data_labs, args.data_dir, AWA_TRAIN_LIST, AWA_VAL_LIST, AWA_TEST_LIST, IMG_HEIGHT, IMG_WIDTH, attr_file=AWA_ATTR_LIST)
    if args.online_cross_val:
        AWA_attr[K_FOR_CROSS_VAL*CLASSES_PER_TASK:] = 0
    else:
        AWA_attr[:K_FOR_CROSS_VAL*CLASSES_PER_TASK] = 0

    print('Attributes: {}'.format(np.sum(AWA_attr, axis=1)))

    if args.cross_validate_mode:
        models_list = MODELS
        learning_rate_list = [0.1, 0.03, 0.01, 0.001, 0.0003]
    else:
        models_list = [args.imp_method]
    for imp_method in models_list:
        if imp_method == 'VAN':
            synap_stgth_list = [0]
            if args.online_cross_val or args.cross_validate_mode:
                pass
            else:
                learning_rate_list = [0.003]
        elif imp_method == 'PI':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10]
            else:
                synap_stgth_list = [10]
                learning_rate_list = [0.003]
        elif imp_method == 'EWC' or imp_method == 'M-EWC':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10, 100]
            else:
                synap_stgth_list = [100]
                learning_rate_list = [0.003]
        elif imp_method == 'MAS':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10, 100]
            else:
                synap_stgth_list = [0.1]
                learning_rate_list = [0.001]
        elif imp_method == 'RWALK':
            if args.online_cross_val or args.cross_validate_mode:
                synap_stgth_list = [0.1, 1, 10, 100]
            else:
                synap_stgth_list = [10]     # Check again!
                learning_rate_list = [0.003]
        elif imp_method == 'S-GEM':
            synap_stgth_list = [0]
            if args.online_cross_val:
                pass
            else:
                learning_rate_list = [args.learning_rate]
        elif imp_method == 'A-GEM':
            synap_stgth_list = [0]
            if args.online_cross_val or args.cross_validate_mode:
                pass
            else:
                learning_rate_list = [0.003]

        for synap_stgth in synap_stgth_list:
            for lr in learning_rate_list:
                # Generate the experiment key and store the meta data in a file
                exper_meta_data = {'ARCH': args.arch,
                    'DATASET': 'SPLIT_AWA',
                    'HYBRID': args.set_hybrid,
                    'NUM_RUNS': args.num_runs,
                    'TRAIN_SINGLE_EPOCH': args.train_single_epoch, 
                    'IMP_METHOD': imp_method, 
                    'SYNAP_STGTH': synap_stgth,
                    'FISHER_EMA_DECAY': args.fisher_ema_decay,
                    'FISHER_UPDATE_AFTER': args.fisher_update_after,
                    'OPTIM': args.optim, 
                    'LR': lr, 
                    'BATCH_SIZE': args.batch_size, 
                    'EPS_MEMORY': args.do_sampling, 
                    'MEM_SIZE': args.mem_size, 
                    'IS_HERDING': args.is_herding}
                experiment_id = "SPLIT_AWA_HERDING_%r_HYB_%r_%s_%r_%s_%s_%s_%r_%s-"%(args.is_herding, args.set_hybrid, args.arch, args.train_single_epoch, imp_method, 
                        str(synap_stgth).replace('.', '_'), 
                        str(args.batch_size), args.do_sampling, str(args.mem_size)) + datetime.datetime.now().strftime("%y-%m-%d-%H-%M")
                snapshot_experiment_meta_data(args.log_dir, experiment_id, exper_meta_data)

                # Reset the default graph
                tf.reset_default_graph()
                graph  = tf.Graph()
                with graph.as_default():

                    # Set the random seed
                    tf.set_random_seed(args.random_seed)

                    # Define Input and Output of the model
                    x = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
                    y_ = tf.placeholder(tf.float32, shape=[None, num_tasks*TOTAL_CLASSES])
                    attr = tf.placeholder(tf.float32, shape=[num_tasks*TOTAL_CLASSES, ATTR_DIMS])

                    if not args.train_single_epoch:
                        # Define ops for data augmentation
                        x_aug = image_scaling(x)
                        x_aug = random_crop_and_pad_image(x_aug, IMG_HEIGHT, IMG_WIDTH)

                    # Define the optimizer
                    if args.optim == 'ADAM':
                        opt = tf.train.AdamOptimizer(learning_rate=lr)

                    elif args.optim == 'SGD':
                        opt = tf.train.GradientDescentOptimizer(learning_rate=lr)

                    elif args.optim == 'MOMENTUM':
                        base_lr = tf.constant(lr)
                        learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - train_step / training_iters), OPT_POWER))
                        opt = tf.train.MomentumOptimizer(lr, OPT_MOMENTUM)

                    # Create the Model/ contruct the graph
                    if args.train_single_epoch:
                        # When training using a single epoch then there is no need for data augmentation
                        model = Model(x, y_, num_tasks, opt, imp_method, synap_stgth, args.fisher_update_after,
                                args.fisher_ema_decay, network_arch=args.arch, is_ATT_DATASET=True, attr=attr)
                    else:
                        model = Model(x_aug, y_, num_tasks, opt, imp_method, synap_stgth, args.fisher_update_after, 
                                args.fisher_ema_decay, network_arch=args.arch, is_ATT_DATASET=True, x_test=x, attr=attr)

                    # Set up tf session and initialize variables.
                    config = tf.ConfigProto()
                    config.gpu_options.allow_growth = True

                    time_start = time.time()
                    with tf.Session(config=config, graph=graph) as sess:
                        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=100)
                        runs, task_labels_dataset = train_task_sequence(model, sess, saver, datasets, AWA_attr, CLASSES_PER_TASK, args.cross_validate_mode, 
                                args.train_single_epoch, args.do_sampling, args.is_herding, args.mem_size*CLASSES_PER_TASK*num_tasks, args.train_iters, 
                                args.batch_size, args.num_runs, args.init_checkpoint, args.online_cross_val, args.random_seed)
                        # Close the session
                        sess.close()
                    time_end = time.time()
                    time_spent = time_end - time_start
                    print('Time spent: {}'.format(time_spent))

                # Clean up
                del model

                if args.cross_validate_mode:
                    # If cross-validation flag is enabled, store the stuff in a text file
                    cross_validate_dump_file = args.log_dir + '/' + 'SPLIT_AWA_HYBRID_%s_%s'%(imp_method, args.optim) + '.txt'
                    with open(cross_validate_dump_file, 'a') as f:
                        f.write('HERDING: {} \t ARCH: {} \t LR:{} \t LAMBDA: {} \t ACC: {}\n'.format(args.is_herding, args.arch, lr, synap_stgth, runs))
                else:
                    # Store all the results in one dictionary to process later
                    exper_acc = dict(mean=runs)
                    exper_labels = dict(labels=task_labels_dataset)
                    # Store the experiment output to a file
                    snapshot_experiment_eval(args.log_dir, experiment_id, exper_acc)
                    snapshot_task_labels(args.log_dir, experiment_id, exper_labels)