void RNN::Setup()

in src/model/layer/rnn.cc [29:83]


void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
  Layer::Setup(in_sample, conf);

  RNNConf rnn_conf = conf.rnn_conf();
  hidden_size_ = rnn_conf.hidden_size();
  CHECK_GT(hidden_size_, 0u);
  num_stacks_ = rnn_conf.num_stacks();
  CHECK_GT(num_stacks_, 0u);
  input_size_ = Product(in_sample);
  CHECK_GT(input_size_, 0u);
  dropout_ = rnn_conf.dropout();  // drop probability
  CHECK_GE(dropout_, 0);

  input_mode_ = ToLowerCase(rnn_conf.input_mode());
  CHECK(input_mode_ == "linear" || input_mode_ == "skip")
      << "Input mode of " << input_mode_ << " is not supported; Please use "
      << "'linear' and 'skip'";

  direction_ = ToLowerCase(rnn_conf.direction());
  if (direction_ == "unidirectional")
    num_directions_ = 1;
  else if (direction_ == "bidirectional")
    num_directions_ = 2;
  else
    LOG(FATAL) << "Direction of " << direction_
      << " is not supported; Please use unidirectional or bidirectional";

  rnn_mode_ = ToLowerCase(rnn_conf.rnn_mode());
  if (rnn_mode_ == "lstm") {
    has_cell_ = true;
  } else if (rnn_mode_ !="relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") {
    LOG(FATAL) << "RNN memory unit (mode) of " << rnn_mode_
      << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'";
  }
  // the first constant (4) is the size of float
  // the second constant (2, 8, 6) is the number of sets of params
  int mult = 1;
  if (rnn_mode_ == "relu" || rnn_mode_ == "tanh")
    mult *= 1;
  else if (rnn_mode_ == "lstm")
    mult *= 4;
  else if (rnn_mode_ == "gru")
    mult *= 3;
  if (direction_ == "bidirectional")
    mult *= 2;

  size_t weight_size = 0;
  for (size_t i = 0; i < num_stacks_; i++) {
    size_t dim = hidden_size_ * (in_sample[0] +  hidden_size_ + 2);
    if (i > 0)
      dim = hidden_size_ * (hidden_size_ +  hidden_size_ + 2);
    weight_size += mult * dim;
  }
  weight_.Resize(Shape{weight_size});
}