recipes/joint_training_vox_populi/cpc/MTLLoss.h (26 lines of code) (raw):
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
#pragma once
#include <gflags/gflags.h>
#include <map>
#include <string>
#include <vector>
#include "flashlight/fl/contrib/modules/modules.h"
#include "flashlight/fl/nn/modules/Linear.h"
#include "flashlight/pkg/speech/common/Defines.h"
#include "flashlight/pkg/speech/criterion/SequenceCriterion.h"
typedef std::map<std::string, unsigned int> Mapping;
namespace asr4real {
Mapping loadMapping(const std::string& filename);
/**
* Perform a step of the MTL loss:
* - find the corresponding label from the dataset for the given batchIdx
* - apply the label classifier to the input features
* - Return the categoricalCrossEntropy loss obtained from the results above
*
* @param enc_out : feature vector of dimension H x Time X Batch
* @param crit : Linera classifier of dimension H x Nlabels
* @param trainset : input dataset
* @param i_map : Mapping from file id to file integer label
* @param batchIdx : batch number
*
* @return : The loss vector, of dimension Batch x 1 x 1
*/
fl::Variable mtl_step(
fl::Variable& enc_out,
std::shared_ptr<fl::Linear> crit,
std::shared_ptr<fl::Dataset> trainset,
const Mapping& i_map,
const int batchIdx);
/**
* Map a file id to its corresponding label number with
* fileID = "baseID#{label}"
*/
unsigned int getMapIndexFromFileID(
const std::string& fileid,
const Mapping& i_map);
/**
* Extract the ID labels from a given batch
* The labels are define as follow: file_id = "baseID#{label}"
* @param batch : batch extracted from a fl::Dataset.
* We should be able to read each sample Id from
* fl::app::asr::kSampleIdx
* @param i_map : mapping from label to label number
* @param batch_size
*
* @return An array X of shape (1, batch_size, 1) where X[0, a, 0] =
* label_number(batch.at(a))
*/
af::array buildIndexLabels(
const std::vector<af::array>& batch,
const Mapping& i_map,
const int batch_size);
} // namespace asr4real