recipes/streaming_convnets/inference/inference/module/nn/TDSBlock.cpp (69 lines of code) (raw):

/* * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #include "inference/module/nn/TDSBlock.h" #include <cassert> #include <sstream> #include <stdexcept> #include "inference/module/nn/Relu.h" #include "inference/module/nn/Residual.h" namespace w2l { namespace streaming { TDSBlock::TDSBlock( std::shared_ptr<Conv1d> conv, std::shared_ptr<LayerNorm> layernorm1, std::shared_ptr<Linear> linear1, std::shared_ptr<Linear> linear2, std::shared_ptr<LayerNorm> layernorm2, DataType reluDataType, DataType residualDataType) : reluDataType_(reluDataType), residualDataType_(residualDataType) { if (!conv) { throw std::invalid_argument( "TDSBlock::TDSBlock() is called with null conv."); } if (!layernorm1) { throw std::invalid_argument( "TDSBlock::TDSBlock() is called with null layernorm1."); } if (!linear1) { throw std::invalid_argument( "TDSBlock::TDSBlock() is called with null linear1."); } if (!linear2) { throw std::invalid_argument( "TDSBlock::TDSBlock() is called with null linear2."); } if (!layernorm2) { throw std::invalid_argument( "TDSBlock::TDSBlock() is called with null layernorm2."); } if (reluDataType == DataType::UNINITIALIZED) { throw std::invalid_argument( "TDSBlock::TDSBlock() is called with UNINITIALIZED reluDataType."); } if (residualDataType == DataType::UNINITIALIZED) { throw std::invalid_argument( "TDSBlock::TDSBlock() is called with UNINITIALIZED residualDataType."); } auto convSeq = std::make_shared<Sequential>(); convSeq->add(conv); convSeq->add(std::make_shared<Relu>(reluDataType_)); add(std::make_shared<Residual>(convSeq, residualDataType_)); add(layernorm1); auto linearSeq = std::make_shared<Sequential>(); linearSeq->add(linear1); linearSeq->add(std::make_shared<Relu>(reluDataType_)); linearSeq->add(linear2); add(std::make_shared<Residual>(linearSeq, residualDataType_)); add(layernorm2); } TDSBlock::TDSBlock() : reluDataType_(DataType::UNINITIALIZED), residualDataType_(DataType::UNINITIALIZED) {} std::string TDSBlock::debugString() const { std::stringstream ss; ss << "TDSBlock: { \n"; ss << Sequential::debugString() << "\n"; ss << "}"; return ss.str(); } } // namespace streaming } // namespace w2l