bool SinkCode::run()

in lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp [648:1246]


bool SinkCode::run(Function *F, const CompilationContext &cctx) {
  LOG_SCOPE(F->getLogContext(), getName());
  bool changed = false;
  auto &nodes = F->getNodes();
  // For each node:
  for (auto &N : nodes) {
    auto *node = &N;

    // Sink Reshape/Transpose below BatchNormalization.
    if (auto *BN = dyn_cast<BatchNormalizationNode>(node)) {

      // Sink Reshape below BatchNormalization.
      if (auto *RS = dyn_cast<ReshapeNode>(BN->getInput())) {
        auto inDims = RS->getInput().dims();
        auto outDims = RS->getResult().dims();
        unsigned_t newChannelIdx;

        // Skip sinking if: 1) the input was less than 3 dimensions,
        // because we need spatial dimensions in addition to batch
        // and channel or 2) if it is 3D data because the reshapes
        // are deliberately introduced to phrase 3D BatchNormalization
        // as a 2D one.
        if (RS->getInput().dims().size() < 3 ||
            RS->getInput().dims().size() == 5) {
          continue;
        }

        // Reshape should not change the BatchNorm ChannelIdx dimensions.
        // Only NH[W]C and NCH[W] are allowed.
        if (BN->getChannelIdx() == outDims.size() - 1) {
          if (inDims[inDims.size() - 1] != outDims[outDims.size() - 1]) {
            continue;
          }
          newChannelIdx = inDims.size() - 1;
        } else if (BN->getChannelIdx() == 1) {
          // Note: index '1' maps to C in NCH[W] layout.
          if (inDims[1] != outDims[1]) {
            continue;
          }
          newChannelIdx = 1;
        } else {
          continue;
        }

        // Reshape should not change the batch dimension.
        if (inDims[0] != outDims[0]) {
          continue;
        }

        auto bnOutTy = F->getParent()->uniqueTypeWithNewShape(
            BN->getResult().getType(), RS->getInput().getType());
        auto rsInputType = RS->getInput().getType();
        glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape(
            bnOutTy, rsInputType->dims());
        auto *newBN = F->createBatchNormalization(
            BN->getName(), outTy, RS->getInput(), BN->getBias(), BN->getScale(),
            BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(),
            BN->getMomentum());
        auto *newRS = F->createReshape(RS->getName(), newBN,
                                       RS->getResult().dims(), RS->getLayout());
        BN->getResult().replaceAllUsesOfWith(newRS);
        changed = true;
        continue;
      }

      // Sink Transpose below batch normalization nodes:
      if (auto *TR = dyn_cast<TransposeNode>(BN->getInput())) {

        // Figure out where we transposed the channel index for batch
        // normalization.
        unsigned_t idx = BN->getChannelIdx();
        unsigned_t newChannelIdx = TR->getShuffle()[idx];

        auto bnOutTy = BN->getResult().getType();
        auto trInputType = TR->getInput().getType();
        glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape(
            bnOutTy, trInputType->dims());

        auto *NewBN = F->createBatchNormalization(
            BN->getName(), outTy, TR->getInput(), BN->getBias(), BN->getScale(),
            BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(),
            BN->getMomentum());
        NewBN->setPredicate(node->getPredicate());
        auto *newTR = F->createTranspose(TR->getName(), NewBN, TR->getShuffle(),
                                         TR->getLayout());
        newTR->setPredicate(node->getPredicate());

        BN->getResult().replaceAllUsesOfWith(newTR);
        changed = true;
        continue;
      }
    }

    if (auto *RL = dyn_cast<ReluNode>(node)) {
      // Sink Transpose below batch RELU nodes.
      if (auto *TR = dyn_cast<TransposeNode>(RL->getInput())) {
        // Keep the same quantization parameters for ReLU output, but
        // change the shape to appropriate value.
        auto reluOutTy = F->getParent()->uniqueTypeWithNewShape(
            RL->getResult().getType(), TR->getInput().getType());
        auto *NRL = F->createRELU(RL->getName(), TR->getInput(), reluOutTy);
        NRL->setPredicate(node->getPredicate());
        auto *newTR = F->createTranspose(TR->getName(), NRL, TR->getShuffle(),
                                         TR->getLayout());
        newTR->setPredicate(node->getPredicate());
        RL->getResult().replaceAllUsesOfWith(newTR);
        changed = true;
        continue;
      }

      // Sink Clip below RELU nodes.
      if (ClipNode *CN = dyn_cast<ClipNode>(RL->getInput())) {
        assert(!RL->getResult().getType()->isQuantizedType() &&
               "Relu(Clip) means Relu should not be quantized.");
        ReluNode *newRL = F->createRELU(RL->getName(), CN->getInput());
        ClipNode *newCN =
            F->createClip(CN->getName(), newRL->getResult(),
                          std::max(CN->getMin(), 0.0f), CN->getMax());
        RL->getResult().replaceAllUsesOfWith(newCN);
        changed = true;
        continue;
      }
    }

    // Sink Transpose below Clip nodes.
    if (auto *CL = dyn_cast<ClipNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(CL->getInput());

      if (!TR) {
        continue;
      }

      // Keep the same quantization parameters for Clip output, but
      // change the shape to appropriate value.
      auto clipOutTy = F->getParent()->uniqueTypeWithNewShape(
          CL->getResult().getType(), TR->getInput().getType());
      auto *NCL = F->createClip(CL->getName(), TR->getInput(), clipOutTy,
                                CL->getMin(), CL->getMax());
      NCL->setPredicate(node->getPredicate());
      auto *newTR = F->createTranspose(TR->getName(), NCL, TR->getShuffle());
      newTR->setPredicate(node->getPredicate());
      CL->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
      continue;
    }

    // Sink Transpose below LeakyRelu nodes.
    if (auto *LR = dyn_cast<LeakyReluNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(LR->getInput());
      if (!TR) {
        continue;
      }
      auto newLROutTy = F->getParent()->uniqueTypeWithNewShape(
          LR->getResult().getType(), TR->getInput().getType());
      auto *newLR = F->createLeakyRELU(LR->getName(), newLROutTy,
                                       TR->getInput(), LR->getAlpha());
      newLR->setPredicate(node->getPredicate());
      auto *newTR = F->createTranspose(TR->getName(), newLR, TR->getShuffle());
      newTR->setPredicate(node->getPredicate());
      LR->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
      continue;
    }

    // Sink Transpose below PRelu with Splat.
    if (auto *PN = dyn_cast<PReluNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(PN->getInput());
      if (!TR) {
        continue;
      }
      auto *SN = dyn_cast<SplatNode>(PN->getSlope());
      if (!SN) {
        continue;
      }
      auto newSNOutTy = F->getParent()->uniqueTypeWithNewShape(
          SN->getResult().getType(), TR->getInput().getType());
      auto newPNOutTy = F->getParent()->uniqueTypeWithNewShape(
          PN->getResult().getType(), TR->getInput().getType());
      auto *newSN = F->createSplat(SN->getName(), newSNOutTy, SN->getValue());
      auto *newPN =
          F->createPRELU(PN->getName(), TR->getInput(), newSN, newPNOutTy);
      auto *newTR = F->createTranspose(TR->getName(), newPN, TR->getShuffle());
      newPN->setPredicate(node->getPredicate());
      newTR->setPredicate(node->getPredicate());
      PN->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
      continue;
    }

    // Sink Transpose below Sigmoid nodes.
    if (auto *SI = dyn_cast<SigmoidNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(SI->getInput());

      if (!TR) {
        continue;
      }

      auto *NSI = F->createSigmoid(SI->getName(), TR->getInput());
      NSI->setPredicate(node->getPredicate());
      auto *newTR = F->createTranspose(TR->getName(), NSI, TR->getShuffle(),
                                       TR->getLayout());
      newTR->setPredicate(node->getPredicate());
      SI->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
      continue;
    }

    // Sink Transpose below Tile nodes.
    if (auto *TN = dyn_cast<TileNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(TN->getInput());

      if (!TR) {
        continue;
      }

      auto *newTN = F->createTile(TN->getName(), TR->getInput(), TN->getCount(),
                                  TR->getShuffle()[TN->getAxis()]);
      newTN->setPredicate(node->getPredicate());
      auto *newTR = F->createTranspose(TR->getName(), newTN, TR->getShuffle(),
                                       TR->getLayout());
      newTR->setPredicate(node->getPredicate());
      TN->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
      continue;
    }

    // Sink Transpose below Pad nodes.
    if (auto *padNode = dyn_cast<PadNode>(node)) {
      auto *transposeNode = dyn_cast<TransposeNode>(padNode->getInput());

      if (!transposeNode) {
        continue;
      }

      // The transpose shuffle specifies the source dimension.
      // When sinking Transpose below Pad, shuffle describes the target
      // dimension.
      auto shuffle = transposeNode->getShuffle();

      // Shuffle the Pad output type and the padding attribute.
      auto outPadType = padNode->getResult().getType();
      auto outPadShape = outPadType->dims();
      auto pads = padNode->getPads();
      size_t numDims = outPadShape.size();
      std::vector<dim_t> newOutPadShape(numDims);
      std::vector<int> newPads(2 * numDims);
      for (size_t i = 0; i < outPadShape.size(); i++) {
        newOutPadShape[shuffle[i]] = outPadShape[i];
        newPads[shuffle[i]] = pads[i];
        newPads[shuffle[i] + numDims] = pads[i + numDims];
      }

      // New pad
      auto newOutPadType =
          F->getParent()->uniqueTypeWithNewShape(outPadType, newOutPadShape);
      auto *NewPadNode = F->createPad(
          padNode->getName(), transposeNode->getInput(), newOutPadType,
          padNode->getMode(), newPads, padNode->getValue());
      NewPadNode->setPredicate(node->getPredicate());
      auto *newTransposeNode =
          F->createTranspose(transposeNode->getName(), NewPadNode, shuffle);
      newTransposeNode->setPredicate(node->getPredicate());
      padNode->getResult().replaceAllUsesOfWith(newTransposeNode);
      changed = true;
      continue;
    }

    // Sink Transpose below Tanh nodes.
    if (auto *TN = dyn_cast<TanhNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(TN->getInput());

      if (!TR) {
        continue;
      }

      auto *NTN = F->createTanh(TN->getName(), TR->getInput());
      NTN->setPredicate(node->getPredicate());
      auto *newTR = F->createTranspose(TR->getName(), NTN, TR->getShuffle(),
                                       TR->getLayout());
      newTR->setPredicate(node->getPredicate());
      TN->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
      continue;
    }

    // Remove 'identity' transpose operations.
    if (auto *TR = dyn_cast<TransposeNode>(node)) {
      auto mask = TR->getShuffle();

      if (isIdentityShuffle(mask)) {
        TR->getResult().replaceAllUsesOfWith(TR->getInput());
        changed = true;
        continue;
      }
    }

    // Merge consecutive Transpose operations.
    if (auto *TR1 = dyn_cast<TransposeNode>(node)) {
      auto *TR2 = dyn_cast<TransposeNode>(TR1->getInput());

      if (!TR2) {
        continue;
      }

      auto mask1 = TR1->getShuffle();
      auto mask2 = TR2->getShuffle();
      assert(mask1.size() == mask2.size() && "Invalid mask size");

      llvm::SmallVector<unsigned_t, max_tensor_dimensions> newMask;
      newMask.resize(mask2.size());

      for (size_t i = 0, end = mask2.size(); i < end; i++) {
        newMask[i] = mask2[mask1[i]];
      }

      auto *newTR = F->createTranspose("tranpose", TR2->getInput(), newMask);
      TR1->getResult().replaceAllUsesOfWith(newTR->getResult());
      changed = true;
      continue;
    }

    if (auto *CS = dyn_cast<ChannelShuffleNode>(node)) {
      // Sink Transpose below ChannelShuffle.
      if (sinkTranposeBelowChannelShuffle(F, CS)) {
        changed = true;
        continue;
      }
    }

    // Sink Transpose below Arithmetic nodes.
    if (node->isArithmetic()) {
      TransposeNode *LTR =
          dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::LHSIdx));
      TransposeNode *RTR =
          dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::RHSIdx));

      if (!LTR || !RTR) {
        // If one of the sides is a splat, it can be seen as
        // transpose (splat'). Similarly, if one of the sides is a Constant,
        // it can be seen as tranpose (Constant').
        if (isa<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx)) && RTR) {
          // Build splat' for LHS.
          auto *SN =
              dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx));
          auto *NS = F->createSplat("splat", RTR->getInput().getType(),
                                    SN->getValue());
          LTR = F->createTranspose("transpose", NS, RTR->getShuffle(),
                                   RTR->getLayout());
          changed = true;
        } else if (isa<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx)) &&
                   LTR) {
          // Build splat' for RHS.
          auto *SN =
              dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx));
          auto *NS = F->createSplat("splat", LTR->getInput().getType(),
                                    SN->getValue());
          RTR = F->createTranspose("transpose", NS, LTR->getShuffle(),
                                   LTR->getLayout());
          changed = true;
        } else if (isa<Constant>(node->getNthInput(ArithmeticNode::LHSIdx)) &&
                   RTR) {
          // Build Constant' for for LHS.
          auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::LHSIdx));
          LTR = insertMatchingTransposeAfterConstant(F, C, RTR);
          changed = true;
        } else if (isa<Constant>(node->getNthInput(ArithmeticNode::RHSIdx)) &&
                   LTR) {
          // Build Constant' for for RHS.
          auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::RHSIdx));
          RTR = insertMatchingTransposeAfterConstant(F, C, LTR);
          changed = true;
        } else {
          continue;
        }
      }
      // The masks of the transposes on both sizes must match.
      if (LTR->getShuffle() != RTR->getShuffle()) {
        continue;
      }

      Node *newAN = nullptr;

#define ARITHMETIC_CASE(NODE_NAME_)                                            \
  case glow::Kinded::Kind::NODE_NAME_##NodeKind:                               \
    newAN =                                                                    \
        F->create##NODE_NAME_(node->getName(),                                 \
                              F->getParent()->uniqueTypeWithNewShape(          \
                                  node->getType(ArithmeticNode::ResultIdx),    \
                                  LTR->getInput().getType()),                  \
                              LTR->getInput(), RTR->getInput());               \
    break;

#define BOOLEAN_OP_CASE(NODE_NAME_)                                            \
  case glow::Kinded::Kind::NODE_NAME_##NodeKind:                               \
    newAN = F->create##NODE_NAME_(node->getName(), LTR->getInput(),            \
                                  RTR->getInput());                            \
    break;

      switch (node->getKind()) {
        ARITHMETIC_CASE(Add);
        ARITHMETIC_CASE(Mul);
        ARITHMETIC_CASE(Sub);
        ARITHMETIC_CASE(Div);
        ARITHMETIC_CASE(Fmod);
        ARITHMETIC_CASE(Max);
        ARITHMETIC_CASE(Min);
        ARITHMETIC_CASE(Pow);
        BOOLEAN_OP_CASE(CmpLTE);
        BOOLEAN_OP_CASE(CmpEQ);
      default:
        llvm_unreachable("Unhandled node");
      }
#undef BOOLEAN_OP_CASE
#undef ARITHMETIC_CASE

      newAN->setPredicate(node->getPredicate());
      changed = true;
      auto *newTR = F->createTranspose(LTR->getName(), newAN, LTR->getShuffle(),
                                       LTR->getLayout());
      newTR->setPredicate(node->getPredicate());
      node->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(newTR);
    }

    if (auto *Q = dyn_cast<QuantizeNode>(node)) {
      // Sink TransposeNode below QuantizedNode.
      if (auto *TR = getTransposeNodeWithAllSameUserKind(Q->getInput())) {
        auto newQType = F->getParent()->uniqueTypeWithNewShape(
            Q->getResult().getType(), TR->getInput().dims());
        auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType);
        auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle());
        Q->getResult().replaceAllUsesOfWith(newTR);
        changed = true;
        continue;
      }

      // Sink Reshape below Quantize.
      if (auto *RN = dyn_cast<ReshapeNode>(Q->getInput())) {
        auto newQType = F->getParent()->uniqueTypeWithNewShape(
            Q->getResult().getType(), RN->getInput().dims());
        auto *newQ = F->createQuantize(Q->getName(), RN->getInput(), newQType);
        auto *newRN = F->createReshape(RN->getName(), newQ,
                                       RN->getResult().dims(), RN->getLayout());
        Q->getResult().replaceAllUsesOfWith(newRN->getResult());
        changed = true;
        continue;
      }
    }

    // Sink Reshape below ConvertTo.
    if (auto *CN = dyn_cast<ConvertToNode>(node)) {
      auto *RN = dyn_cast<ReshapeNode>(CN->getInput());
      if (!RN) {
        continue;
      }
      auto *newCN = F->createConvertTo(CN->getName(), RN->getInput(),
                                       CN->getResult().getElementType());
      auto *newRN = F->createReshape(RN->getName(), newCN,
                                     RN->getResult().dims(), RN->getLayout());
      CN->getResult().replaceAllUsesOfWith(newRN->getResult());
      changed = true;
      continue;
    }

    // Sink TransposeNode below DequantizedNode.
    // If it doesn't work out it will be re-sinked later.
    if (auto *D = dyn_cast<DequantizeNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(D->getInput());
      if (!TR) {
        continue;
      }

      auto newDType = F->getParent()->uniqueTypeWithNewShape(
          D->getResult().getType(), TR->getInput().dims());
      auto *newD = F->createDequantize(D->getName(), TR->getInput(), newDType);
      auto *newTR = F->createTranspose(TR->getName(), newD, TR->getShuffle());
      D->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
    }

    // Sink Transpose below RescaleQuantized.
    // Potentially exposes opportunity to be combined up with Convolution.
    // If it doesn't work out it will be re-sinked later.
    if (auto *RQ = dyn_cast<RescaleQuantizedNode>(node)) {
      auto *TR = dyn_cast<TransposeNode>(RQ->getInput());
      if (!TR) {
        continue;
      }

      auto newRQType = F->getParent()->uniqueTypeWithNewShape(
          RQ->getResult().getType(), TR->getInput().getType());
      auto *newRQ =
          F->createRescaleQuantized(RQ->getName(), TR->getInput(), newRQType);
      auto *newTR = F->createTranspose(TR->getName(), newRQ, TR->getShuffle(),
                                       TR->getLayout());
      RQ->getResult().replaceAllUsesOfWith(newTR);
      changed = true;
    }

    if (auto *CN = dyn_cast<ConcatNode>(node)) {
      const Node *firstNode = CN->getInputs().front().getNode();
      // Sink RELU below batch concat nodes.
      if (firstNode->getKind() == Kinded::Kind::ReluNodeKind) {
        llvm::SmallVector<NodeValue, 6> CNInputs;
        for (auto &input : CN->getInputs()) {
          auto *inputRL = dyn_cast<ReluNode>(input);
          if (!inputRL) {
            break;
          }
          CNInputs.push_back(inputRL->getInput());
        }

        if (CNInputs.size() == CN->getNumInputs()) {
          auto *newCN = F->createConcat(CN->getName(), CNInputs, CN->getDim());
          newCN->setPredicate(node->getPredicate());
          auto name = CN->getNthInput(0).getNode()->getName();
          auto *newRL = F->createRELU(name, newCN, CN->getResult().getType());
          newRL->setPredicate(node->getPredicate());
          CN->getResult().replaceAllUsesOfWith(newRL);
          changed = true;
        }
        continue;
      }

      // Sink Transpose below concat nodes.
      if (firstNode->getKind() == Kinded::Kind::TransposeNodeKind) {
        llvm::SmallVector<NodeValue, 6> transVector;
        auto inputIter = CN->getInputs().begin();
        auto *firstInput = dyn_cast<TransposeNode>(*inputIter);
        if (!firstInput) {
          continue;
        }

        transVector.push_back(firstInput->getInput());
        auto shuffle = firstInput->getShuffle();
        // If the shuffle masks don't agree or not all inputs are Transpose then
        // bail out.
        for (++inputIter; inputIter != CN->getInputs().end(); ++inputIter) {
          auto *tTR = dyn_cast<TransposeNode>(*inputIter);
          if (!tTR || tTR->getShuffle() != shuffle) {
            break;
          }
          transVector.push_back(tTR->getInput());
        }

        if (transVector.size() != CN->getNumInputs()) {
          continue;
        }

        // Figure out where we transposed the channel index for batch
        // normalization.
        unsigned_t idx = CN->getDim();
        unsigned_t newChannelIdx = shuffle[idx];

        auto *newCN =
            F->createConcat(CN->getName(), transVector, newChannelIdx);
        newCN->setPredicate(node->getPredicate());
        auto *newTR = F->createTranspose(firstInput->getName(), newCN,
                                         firstInput->getShuffle(),
                                         firstInput->getLayout());
        newTR->setPredicate(node->getPredicate());
        CN->getResult().replaceAllUsesOfWith(newTR);
        changed = true;
        continue;
      }
    }
  } // For all nodes in the graph.

  // Transformations to sink nodes below Slice. Outlined into a separate loop to
  // prevent Transpose/Slice sinking to affect them.
  for (auto &N : nodes) {
    auto *node = &N;
    // Sink BatchNorm below Slice.
    if (auto *SN = dyn_cast<SliceNode>(node)) {
      auto *BN = dyn_cast<BatchNormalizationNode>(SN->getInput());
      if (!BN || !BN->hasOneUse()) {
        continue;
      }

      // Don't support sinking below Slice which affects depth.
      if (SN->getInput().dims()[BN->getChannelIdx()] !=
          SN->getResult().dims()[BN->getChannelIdx()]) {
        continue;
      }

      auto newSNType = F->getParent()->uniqueTypeWithNewShape(
          BN->getInput().getType(), SN->getResult().dims());
      auto *newSN = F->createSlice(SN->getName(), BN->getInput(),
                                   SN->getStart(), newSNType);
      auto *newBN = F->createBatchNormalization(
          BN->getName(), SN->getResult().getType(), newSN, BN->getBias(),
          BN->getScale(), BN->getMean(), BN->getVar(), BN->getChannelIdx(),
          BN->getEpsilon(), BN->getMomentum());
      SN->getResult().replaceAllUsesOfWith(newBN);
      changed = true;
    }
  }

  return changed;
}