def ABX()

in eval/eval_ABX.py [0:0]


def ABX(feature_function,
        path_item_file,
        seq_list,
        distance_mode,
        step_feature,
        modes,
        cuda=False,
        max_x_across=5,
        max_size_group=30):

    # ABX dataset
    ABXDataset = abx_it.ABXFeatureLoader(path_item_file, seq_list,
                                         feature_function, step_feature, True)

    if cuda:
        ABXDataset.cuda()

    # Distance function
    distance_function = abx_g.get_distance_function_from_name(distance_mode)

    # Output
    scores = {}

    # ABX within
    if 'within' in modes:
        print("Computing ABX within speakers...")
        ABXIterator = ABXDataset.get_iterator('within', max_size_group)
        group_confusion = abx_g.get_abx_scores_dtw_on_group(ABXIterator,
                                                            distance_function,
                                                            ABXIterator.symmetric)
        n_data = group_confusion._values().size(0)
        index_ = torch.sparse.LongTensor(group_confusion._indices(),
                                         torch.ones((n_data),
                                                    dtype=torch.float),
                                         group_confusion.size())
        divisor_context = torch.sparse.sum(index_, dim=3).to_dense()
        group_confusion = torch.sparse.sum(group_confusion, dim=3).to_dense()
        group_confusion = reduce_sparse_data(group_confusion, divisor_context)
        S, p1, p2 = group_confusion.size()

        index_speaker = divisor_context > 0
        divisor_speaker = index_speaker.sum(dim=0)
        phone_confusion = reduce_sparse_data(group_confusion.sum(dim=0),
                                             divisor_speaker)

        scores['within'] = (phone_confusion.sum() /
                            (divisor_speaker > 0).sum()).item()
        print(f"...done. ABX within : {scores['within']}")

    # ABX across
    if 'across' in modes:
        print("Computing ABX across speakers...")
        ABXIterator = ABXDataset.get_iterator('across', max_size_group)
        ABXIterator.max_x = max_x_across
        group_confusion = abx_g.get_abx_scores_dtw_on_group(ABXIterator,
                                                            distance_function,
                                                            ABXIterator.symmetric)
        n_data = group_confusion._values().size(0)
        index_ = torch.sparse.LongTensor(group_confusion._indices(),
                                         torch.ones((n_data),
                                                    dtype=torch.float),
                                         group_confusion.size())
        divisor_context = torch.sparse.sum(index_, dim=[3, 4]).to_dense()
        group_confusion = torch.sparse.sum(
            group_confusion, dim=[3, 4]).to_dense()
        group_confusion = reduce_sparse_data(group_confusion, divisor_context)
        S, p1, p2 = group_confusion.size()

        index_speaker = divisor_context > 0
        divisor_speaker = index_speaker.sum(dim=0)
        phone_confusion = reduce_sparse_data(group_confusion.sum(dim=0),
                                             divisor_speaker)
        scores['across'] = (phone_confusion.sum() /
                            (divisor_speaker > 0).sum()).item()
        print(f"...done. ABX across : {scores['across']}")

    return scores