recipes/joint_training_vox_populi/cpc/SequentialBuilder.h (17 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include "flashlight/fl/contrib/contrib.h" #include "flashlight/fl/contrib/modules/modules.h" #include "flashlight/fl/flashlight.h" #include "TransformerCPC.h" namespace w2l { namespace cpc { /** * Build a sequential module by parsing a file that * defines the model architecture. */ std::shared_ptr<fl::Sequential> buildSequentialModule( const std::string& archfile, int64_t nFeatures, int64_t nClasses); /** * Utility function for to run forward with pad masking * casting of modules happens to use pad masking for trasnfromer layers * properly. It assumes that model is constructed with * buildSequentialModule. Caveat: it is not supporting resnet block * with a transformer block in it! * TODO remove with landing plugin arch instead of arch files */ fl::Variable forwardSequentialModuleWithPadMask( const fl::Variable& input, std::shared_ptr<fl::Module> ntwrk, const af::array& inputSizes); } // namespace cpc } // namespace w2l