def main()

in ai-judge/inference.py [0:0]


def main():
    gpu = False
    batch_size = 8
    num_workers = 2
    f_names = None
    
    # load bach pdf
    bach_pdf = pickle.load(open('bach_pdf.p', 'rb'))
    
    # load midi files
    args = parse_args()
    save_path = args.save_path
    if args.midi_file:
        f = args.midi_file
        midi, f_names = load_midi(f)
    elif args.midi_folder:
        path = args.midi_folder
        print(path)
        midi, f_names = load_midi_folder(path)
        print('Loaded files from folder.')
        print('Files: ', f_names)
    print(f'Midi file is loaded! There are {len(midi)} tracks')
    
    # If no track
    if len(midi) == 0:
        print('No midi files to calculate score')
        sys.exit()
    
    if torch.cuda.is_available():
        gpu = True
       
    # load model
    model = load_model(gpu)
    print('Model is loaded!')
    
    # transform and merge track
    midi_ready = np.array(midi, dtype=float)
    midi_ready = torch.tensor(np.expand_dims(midi_ready, axis=1))
    midi_loader = torch.utils.data.DataLoader(
        midi_ready,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False
    )
    
    final_scores = []
    
    for midi_set in midi_loader:
        # calculate classification score and KS distance
        np_midi_set = midi_set.numpy()
        ks_dist = get_ks_dist(np_midi_set, bach_pdf)
        if gpu:
            midi_set = midi_set.to('cuda')
        cls_score = get_class_score(midi_set, model)
        cls_score = cls_score.cpu().detach().numpy()
        
        # combine score together to output the final score
        combined_score = combine_score(cls_score, ks_dist)
        final_scores.extend(combined_score)
    
    # Write the final scores to output csv
    output = {f:s for f,s in zip(f_names, final_scores)}
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    pickle.dump(output, open(save_path + 'score_results.p', 'wb'))
    print(f'Calculation is completed! Scores are saved at {save_path}')