recipes/joint_training_vox_populi/cpc/MTLLoss.cpp (75 lines of code) (raw):

// (c) Facebook, Inc. and its affiliates. Confidential and proprietary. #include <fstream> #include <stdexcept> #include <vector> #include "flashlight/fl/autograd/Variable.h" #include "flashlight/fl/common/Logging.h" #include "flashlight/pkg/speech/runtime/runtime.h" // MTL Loss flags DEFINE_string( mtllossmapping, "", "Path to the MTL loss label mapping. Leave empty to not activate"); DEFINE_double(mtllossweight, 0.5, "Weight given to the MTL Loss"); typedef std::map<std::string, unsigned int> Mapping; namespace asr4real { Mapping loadMapping(const std::string& filename) { std::ifstream file(filename); unsigned int index = 0; Mapping output = Mapping(); if (!file.is_open()) { throw std::invalid_argument("Cannot open " + filename); } std::string line; while (std::getline(file, line)) { output[line] = index; index++; } file.close(); return output; } unsigned int getMapIndexFromFileID( const std::string& fileid, const Mapping& i_map) { size_t pos_delim = fileid.find('#'); if (pos_delim == std::string::npos) { throw std::invalid_argument("Cannot parse " + fileid); } pos_delim++; const std::string token = fileid.substr(pos_delim, fileid.size() - pos_delim); return i_map.at(token); } af::array buildIndexLabels( const std::vector<af::array>& batch, const Mapping& i_map, const int batch_size) { af::array targets_lid_(1, batch_size, 1); for (int bIdx = 0; bIdx < batch_size; bIdx++) { auto filename = fl::pkg::speech::readSampleIds(batch[fl::pkg::speech::kSampleIdx]) .at(bIdx); targets_lid_(0, bIdx, 0) = getMapIndexFromFileID(filename, i_map); } return targets_lid_; } fl::Variable mtl_step( fl::Variable& enc_out, std::shared_ptr<fl::Linear> crit, std::shared_ptr<fl::Dataset> trainset, const Mapping& i_map, const int batchIdx) { const int timedim = 1; const int batchdim = 2; const int featdim = 0; const int batchsz = enc_out.dims(batchdim); const int timesz = enc_out.dims(timedim); const std::vector<af::array>& batch = trainset->get(batchIdx); const fl::Variable target_ids_ = fl::Variable(buildIndexLabels(batch, i_map, batchsz), false); enc_out = fl::reorder(enc_out, featdim, timedim, batchdim); fl::Variable predictions = crit->forward(enc_out); predictions = fl::mean(predictions.as(f32), std::vector<int>{1}).as(f32); predictions = fl::logSoftmax(predictions, 0); fl::Variable loss = fl::categoricalCrossEntropy( predictions.as(f32), target_ids_, fl::ReduceMode::NONE); return fl::reorder(loss, 1, 0, 2); } }; // namespace asr4real