recipes/self_training/pseudo_labeling/AnalyzeDataset.cpp (40 lines of code) (raw):
/**
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <iostream>
#include <string>
#include <vector>
#include "recipes/self_training/pseudo_labeling/Dataset.h"
#include <flashlight/fl/meter/EditDistanceMeter.h>
#include <gflags/gflags.h>
DEFINE_string(infile, "", "Input path for pseudo-labeled lst file");
DEFINE_string(groundtruthfile, "", "Input path for ground truth lst file");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
auto predictionDict =
filter::dataset::createTranscriptDictFromFile(FLAGS_infile);
auto groundtruthDict =
filter::dataset::createTranscriptDictFromFile(FLAGS_groundtruthfile);
fl::EditDistanceMeter wer;
size_t predictionDuration{0};
for (auto& sample : predictionDict) {
auto prediction = sample.second;
auto groundtruth = groundtruthDict[sample.first];
predictionDuration += prediction->getDuration();
wer.add(prediction->transcriptWords, groundtruth->transcriptWords);
}
size_t groundtruthDuration{0};
for (auto& sample : groundtruthDict) {
groundtruthDuration += sample.second->getDuration();
}
// Num samples
std::cout << "Prediction samples / groundtruth samples = "
<< predictionDict.size() << " / " << groundtruthDict.size() << " = "
<< (float)predictionDict.size() / (float)groundtruthDict.size()
<< std::endl;
// Duration
std::cout << "Prediction duration / groundtruth duration = "
<< predictionDuration << " / " << groundtruthDuration
<< " (seconds) = " << predictionDuration / (60.0 * 60.0 * 1000.0)
<< " / " << groundtruthDuration / (60.0 * 60.0 * 1000.0)
<< " (hours) = "
<< (float)predictionDuration / (float)groundtruthDuration
<< std::endl;
// WER
std::cout << "WER is " << wer.value()[0] << std::endl;
return 0;
}