TensorBase FullyConnectedLayer::Invoke()

in tensorflow_fold/llgtm/layers.cc [132:178]


TensorBase FullyConnectedLayer::Invoke(Graph* g, InputList inputs,
                                       DeviceID /*device*/) {
  DCHECK_EQ(inputs.size(), 1);
  auto& x = inputs[0].as<float>();
  DCHECK_EQ(x.rank(), 2);

  if (!initialized()) {
    int input_size = x.dimension(1);

    // Scale weights according to the number of inputs to maintain constant
    // variance.  Also scale by a factor which is emperically derived from:
    // https://arxiv.org/pdf/1412.6558v3.pdf.
    float stddev = 1.0f/sqrtf(input_size);
    switch (activation_) {
      case kLinear:
        break;
      case kRelu:
        stddev *= sqrtf(2.0f);
        break;
      case kSigmoid:   // TODO(delesley): what should this be?
        break;
      case kTanh:
        stddev *= 1.15f;
        break;
    }

    Dimensions wdims = Dimensions(input_size, num_hidden());
    weights_ = name_space()->NewVariable<float>("weights", wdims,
        NormalRandomInitializer<float>(/*mean=*/ 0.0f, stddev));

    Dimensions bdims = Dimensions(1, num_hidden());
    bias_ = name_space()->NewVariable<float>("bias", bdims,
                                             ZerosInitializer<float>());
  }

  auto weights = g->Variable(weights_);
  auto bias = g->Variable(bias_);
  auto xm = g->Matmul(x, weights);
  auto xmb = g->Add(xm, g->Broadcast(bias, xm.dimensions()));

  switch (activation_) {
    case kLinear:  return xmb;
    case kRelu:    return g->Relu(xmb);
    case kSigmoid: return g->Sigmoid(xmb);
    case kTanh:    return g->Tanh(xmb);
  }
}