recipes/slimIPL/10h_supervised.cpp (65 lines of code) (raw):

/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #include <iostream> #include "flashlight/fl/contrib/modules/modules.h" #include "flashlight/fl/flashlight.h" #include "flashlight/fl/nn/modules/modules.h" class MyModel : public fl::Container { public: MyModel(int64_t nFeature, int64_t nLabel) { convFrontend_->add( std::make_shared<fl::View>(af::dim4(-1, 1, nFeature, 0))); // Time x 1 x nFeature x Batch std::vector<int> lnDims = {0, 1, 2}; convFrontend_->add(std::make_shared<fl::LayerNorm>(lnDims)); convFrontend_->add( std::make_shared<fl::Conv2D>(nFeature, 1536, 7, 1, 3, 1, -1, 0, 1, 1)); convFrontend_->add(std::make_shared<fl::GatedLinearUnit>(2)); convFrontend_->add(std::make_shared<fl::Dropout>(0.5)); convFrontend_->add(std::make_shared<fl::Reorder>(2, 0, 3, 1)); // nFeature x Time x Batch x 1 add(convFrontend_); for (int trIdx = 0; trIdx < 36; trIdx++) { auto layer = std::make_shared<fl::Transformer>( 768, 192, 3072, 4, 920, 0.5, 0.5, false, false); transformers_.push_back(layer); add(layer); } linear_ = std::make_shared<fl::Linear>(768, nLabel); add(linear_); } std::vector<fl::Variable> forward( const std::vector<fl::Variable>& input) override { auto out = input[0]; auto xSizes = input[1].array(); // expected input dims T x C x 1 x B int T = out.dims(0), B = out.dims(3); auto inputMaxSize = af::tile(af::max(xSizes), 1, B); af::array inputNotPaddedSize = af::ceil(xSizes * T / inputMaxSize); auto padMask = af::iota(af::dim4(T, 1), af::dim4(1, B)) < af::tile(inputNotPaddedSize, T, 1); out = convFrontend_->forward(out); for (int trIdx = 0; trIdx < transformers_.size(); trIdx++) { out = transformers_[trIdx]->forward({out, fl::noGrad(padMask)}).front(); } out = linear_->forward(out); return {out.as(input[0].type())}; } std::string prettyString() const override { std::ostringstream ss; ss << "Model: "; ss << convFrontend_->prettyString() << "\n"; for (int trIdx = 0; trIdx < 36; trIdx++) { ss << transformers_[trIdx]->prettyString() << "\n"; } ss << linear_->prettyString() << "\n"; return ss.str(); } private: MyModel() = default; std::shared_ptr<fl::Sequential> convFrontend_{ std::make_shared<fl::Sequential>()}; std::vector<std::shared_ptr<fl::Transformer>> transformers_; std::shared_ptr<fl::Linear> linear_; FL_SAVE_LOAD_WITH_BASE(fl::Container, convFrontend_, transformers_, linear_) }; extern "C" fl::Module* createModule(int64_t nFeature, int64_t nLabel) { auto m = std::make_unique<MyModel>(nFeature, nLabel); return m.release(); } CEREAL_REGISTER_TYPE(MyModel)