def main()

in DataScience/Experimentation.py [0:0]


def main(args):
    try:
        vw_version = check_output(['vw','--version'], stderr=DEVNULL, universal_newlines=True)
    except:
        print("Error: Vowpal Wabbit executable not found. Please install and add it to your path")
        sys.exit()

    # Additional processing of inputs not covered by above
    base_command = args.base_command + ('-d ' if args.base_command[-1] == ' ' else ' -d ') + args.file_path
    
    # Shared and Action Features
    shared_features = set(args.shared_namespaces)
    action_features = set(args.action_namespaces)
    marginal_features = set(args.marginal_namespaces)
    if not (shared_features and action_features and marginal_features):
        shared_features, action_features, marginal_features = identify_namespaces(args.file_path, args.auto_lines, shared_features, action_features, marginal_features)

    print("\n*********** SETTINGS ******************")
    print('Log file size: {:.3f} MB'.format(os.path.getsize(args.file_path)/(1024**2)))
    print()
    print('Using VW version {}'.format(vw_version.strip()))
    print('Base command: {}'.format(base_command))
    print()
    print('Shared feature namespaces: ' + str(shared_features))
    print('Action feature namespaces: ' + str(action_features))
    print('Marginal feature namespaces: ' + str(marginal_features))
    print()
    print('cb_types: ['+', '.join(args.cb_types)+']')
    print('learning rates: ['+', '.join(map(str,args.learning_rates))+']')
    print('l1 regularization rates: ['+', '.join(map(str,args.regularizations))+']')
    print('power_t rates: ['+', '.join(map(str,args.power_t_rates))+']')
    print()
    print('Hyper-parameters grid size: ',(len(marginal_features)+1)*len(args.cb_types)*len(args.learning_rates)*len(args.regularizations)*len(args.power_t_rates))
    print('Parallel processes: {}'.format(args.n_proc))
    print("***************************************")
    if __name__ == '__main__' and input('Press ENTER to start (any other key to exit)...' ) != '':
        sys.exit()

    # Use only first character of namespace for interactions
    shared_features = {x[0] for x in shared_features}
    action_features = {x[0] for x in action_features}
    marginal_features = {x[0] for x in marginal_features}

    t0 = datetime.now()
    best_commands = []

    print('\nRunning the base command...')
    if ' -c ' not in base_command and os.path.exists(args.file_path+'.cache'):
        input('Warning: Cache file found, but not used (-c not in CLI): this is unnesessarily slow. Press to continue anyway...')
    best_command = run_experiment(Command(base_command))
    best_commands.append(best_command)
    best_commands[-1].name = 'Base'

    # cb_types, marginal, regularization, learning rates, and power_t rates grid search
    command_list = get_hp_command_list(base_command, best_command, args.cb_types, marginal_features, args.learning_rates, args.regularizations, args.power_t_rates)

    print('\nTesting {} different hyperparameters...'.format(len(command_list)))
    results = run_experiment_set(command_list, args.n_proc)
    if results[0].loss < best_command.loss:
        best_command = results[0]
        best_commands.append(results[0])
        best_commands[-1].name = 'Hyper1'
    
    if not args.only_hp:
        # TODO: Which namespaces to ignore
        
        # Test all combinations up to q_bruteforce_terms
        possible_interactions = set()
        for shared_feature in shared_features:
            for action_feature in action_features:
                interaction = '{0}{1}'.format(shared_feature, action_feature)
                possible_interactions.add(interaction)
        
        command_list = []    
        for i in range(args.q_bruteforce_terms+1):
            for interaction_list in itertools.combinations(possible_interactions, i):
                command = Command(base_command, clone_from=best_command, interaction_list=interaction_list)
                command_list.append(command)
        
        print('\nTesting {} different interactions (brute-force phase)...'.format(len(command_list)))
        results = run_experiment_set(command_list, args.n_proc)
        if results[0].loss < best_command.loss:
            best_command = results[0]
            best_commands.append(results[0])
            best_commands[-1].name = 'Inter-len'+str(len(results[0].interaction_list))
            
        
        # Build greedily on top of the best parameters found above (stop when no improvements for q_greedy_stop consecutive rounds)
        print('\nTesting interactions (greedy phase)...')
        temp_interaction_list = set(best_command.interaction_list)
        rounds_without_improvements = 0
        while rounds_without_improvements < args.q_greedy_stop:
            command_list = []
            for shared_feature in shared_features:
                for action_feature in action_features:
                    interaction = '{0}{1}'.format(shared_feature, action_feature)
                    if interaction in temp_interaction_list:
                        continue
                    command = Command(base_command, clone_from=best_command, interaction_list=temp_interaction_list.union({interaction}))   # union() keeps temp_interaction_list unchanged
                    command_list.append(command)
            
            if len(command_list) == 0:
                break

            results = run_experiment_set(command_list, args.n_proc)
            if results[0].loss < best_command.loss:
                best_command = results[0]
                best_commands.append(results[0])
                best_commands[-1].name = 'Inter-len'+str(len(results[0].interaction_list))
                rounds_without_improvements = 0
            else:
                rounds_without_improvements += 1
            temp_interaction_list = set(results[0].interaction_list)

        # cb_types, marginal, regularization, learning rates, and power_t rates grid search
        command_list = get_hp_command_list(base_command, best_command, args.cb_types, marginal_features, args.learning_rates, args.regularizations, args.power_t_rates)

        print('\nTesting {} different hyperparameters...'.format(len(command_list)))
        results = run_experiment_set(command_list, args.n_proc)
        if results[0].loss < best_command.loss:
            best_command = results[0]
            best_commands.append(results[0])
            best_commands[-1].name = 'Hyper2'

        # TODO: Repeat above process of tuning parameters and interactions until convergence / no more improvements.

    t1 = datetime.now()
    print("\n\n*************************")
    print("Best parameters found after {}:".format((t1-t0)-timedelta(microseconds=(t1-t0).microseconds)))
    best_command.prints()
    print("*************************")

    if args.generate_predictions:
        _ = generate_predictions_files(args.file_path, best_commands, args.n_proc)
        t2 = datetime.now()
        print('Predictions Generation Time:',(t2-t1)-timedelta(microseconds=(t2-t1).microseconds))
        print('Total Elapsed Time:',(t2-t0)-timedelta(microseconds=(t2-t0).microseconds))