in DataScience/Experimentation.py [0:0]
def main(args):
try:
vw_version = check_output(['vw','--version'], stderr=DEVNULL, universal_newlines=True)
except:
print("Error: Vowpal Wabbit executable not found. Please install and add it to your path")
sys.exit()
# Additional processing of inputs not covered by above
base_command = args.base_command + ('-d ' if args.base_command[-1] == ' ' else ' -d ') + args.file_path
# Shared and Action Features
shared_features = set(args.shared_namespaces)
action_features = set(args.action_namespaces)
marginal_features = set(args.marginal_namespaces)
if not (shared_features and action_features and marginal_features):
shared_features, action_features, marginal_features = identify_namespaces(args.file_path, args.auto_lines, shared_features, action_features, marginal_features)
print("\n*********** SETTINGS ******************")
print('Log file size: {:.3f} MB'.format(os.path.getsize(args.file_path)/(1024**2)))
print()
print('Using VW version {}'.format(vw_version.strip()))
print('Base command: {}'.format(base_command))
print()
print('Shared feature namespaces: ' + str(shared_features))
print('Action feature namespaces: ' + str(action_features))
print('Marginal feature namespaces: ' + str(marginal_features))
print()
print('cb_types: ['+', '.join(args.cb_types)+']')
print('learning rates: ['+', '.join(map(str,args.learning_rates))+']')
print('l1 regularization rates: ['+', '.join(map(str,args.regularizations))+']')
print('power_t rates: ['+', '.join(map(str,args.power_t_rates))+']')
print()
print('Hyper-parameters grid size: ',(len(marginal_features)+1)*len(args.cb_types)*len(args.learning_rates)*len(args.regularizations)*len(args.power_t_rates))
print('Parallel processes: {}'.format(args.n_proc))
print("***************************************")
if __name__ == '__main__' and input('Press ENTER to start (any other key to exit)...' ) != '':
sys.exit()
# Use only first character of namespace for interactions
shared_features = {x[0] for x in shared_features}
action_features = {x[0] for x in action_features}
marginal_features = {x[0] for x in marginal_features}
t0 = datetime.now()
best_commands = []
print('\nRunning the base command...')
if ' -c ' not in base_command and os.path.exists(args.file_path+'.cache'):
input('Warning: Cache file found, but not used (-c not in CLI): this is unnesessarily slow. Press to continue anyway...')
best_command = run_experiment(Command(base_command))
best_commands.append(best_command)
best_commands[-1].name = 'Base'
# cb_types, marginal, regularization, learning rates, and power_t rates grid search
command_list = get_hp_command_list(base_command, best_command, args.cb_types, marginal_features, args.learning_rates, args.regularizations, args.power_t_rates)
print('\nTesting {} different hyperparameters...'.format(len(command_list)))
results = run_experiment_set(command_list, args.n_proc)
if results[0].loss < best_command.loss:
best_command = results[0]
best_commands.append(results[0])
best_commands[-1].name = 'Hyper1'
if not args.only_hp:
# TODO: Which namespaces to ignore
# Test all combinations up to q_bruteforce_terms
possible_interactions = set()
for shared_feature in shared_features:
for action_feature in action_features:
interaction = '{0}{1}'.format(shared_feature, action_feature)
possible_interactions.add(interaction)
command_list = []
for i in range(args.q_bruteforce_terms+1):
for interaction_list in itertools.combinations(possible_interactions, i):
command = Command(base_command, clone_from=best_command, interaction_list=interaction_list)
command_list.append(command)
print('\nTesting {} different interactions (brute-force phase)...'.format(len(command_list)))
results = run_experiment_set(command_list, args.n_proc)
if results[0].loss < best_command.loss:
best_command = results[0]
best_commands.append(results[0])
best_commands[-1].name = 'Inter-len'+str(len(results[0].interaction_list))
# Build greedily on top of the best parameters found above (stop when no improvements for q_greedy_stop consecutive rounds)
print('\nTesting interactions (greedy phase)...')
temp_interaction_list = set(best_command.interaction_list)
rounds_without_improvements = 0
while rounds_without_improvements < args.q_greedy_stop:
command_list = []
for shared_feature in shared_features:
for action_feature in action_features:
interaction = '{0}{1}'.format(shared_feature, action_feature)
if interaction in temp_interaction_list:
continue
command = Command(base_command, clone_from=best_command, interaction_list=temp_interaction_list.union({interaction})) # union() keeps temp_interaction_list unchanged
command_list.append(command)
if len(command_list) == 0:
break
results = run_experiment_set(command_list, args.n_proc)
if results[0].loss < best_command.loss:
best_command = results[0]
best_commands.append(results[0])
best_commands[-1].name = 'Inter-len'+str(len(results[0].interaction_list))
rounds_without_improvements = 0
else:
rounds_without_improvements += 1
temp_interaction_list = set(results[0].interaction_list)
# cb_types, marginal, regularization, learning rates, and power_t rates grid search
command_list = get_hp_command_list(base_command, best_command, args.cb_types, marginal_features, args.learning_rates, args.regularizations, args.power_t_rates)
print('\nTesting {} different hyperparameters...'.format(len(command_list)))
results = run_experiment_set(command_list, args.n_proc)
if results[0].loss < best_command.loss:
best_command = results[0]
best_commands.append(results[0])
best_commands[-1].name = 'Hyper2'
# TODO: Repeat above process of tuning parameters and interactions until convergence / no more improvements.
t1 = datetime.now()
print("\n\n*************************")
print("Best parameters found after {}:".format((t1-t0)-timedelta(microseconds=(t1-t0).microseconds)))
best_command.prints()
print("*************************")
if args.generate_predictions:
_ = generate_predictions_files(args.file_path, best_commands, args.n_proc)
t2 = datetime.now()
print('Predictions Generation Time:',(t2-t1)-timedelta(microseconds=(t2-t1).microseconds))
print('Total Elapsed Time:',(t2-t0)-timedelta(microseconds=(t2-t0).microseconds))