recipes/local_prior_match/src/runtime/Eval.h (27 lines of code) (raw):
/**
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the MIT-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <flashlight/flashlight.h>
#include "criterion/criterion.h"
#include "data/W2lDataset.h"
#include "libraries/common/Dictionary.h"
#include "recipes/models/local_prior_match/src/runtime/Logging.h"
#include "runtime/runtime.h"
namespace w2l {
void evalOutput(
const af::array& op,
const af::array& target,
std::map<std::string, fl::EditDistanceMeter>& mtr,
const Dictionary& tgtDict,
std::shared_ptr<SequenceCriterion> criterion);
void evalDataset(
std::shared_ptr<fl::Module> ntwrk,
std::shared_ptr<SequenceCriterion> crit,
std::shared_ptr<W2lDataset> testds,
SSLDatasetMeters& mtrs,
const Dictionary& dict);
void runEval(
std::shared_ptr<fl::Module> network,
std::shared_ptr<SequenceCriterion> criterion,
const std::unordered_map<std::string, std::shared_ptr<W2lDataset>>& ds,
SSLTrainMeters& meters,
const Dictionary& dict);
} // namespace w2l