static void setupBasicParallelizationConfigs()

in lib/Backends/NNPI/NNPI.cpp [886:1207]


static void setupBasicParallelizationConfigs(
    Function *F, llvm::DenseMap<Node *, size_t> &numChunks,
    llvm::DenseMap<Node *, ParallelTransformKind> &parOpts,
    int32_t numParallelChunks) {
  // Process nodes PostOrder so we always process inputs before outputs of any
  // Node, so parallelization can be based on if a parent is parallelized.
  GraphPostOrderVisitor visitor(*F);
  for (auto *node : visitor.getPostOrder()) {
    // Find all FC layers to split
    if (auto *FC = llvm::dyn_cast<FullyConnectedNode>(node)) {
      size_t K = FC->getWeights().dims()[1];
      if (K >= 512) {
        parOpts[FC] = ParallelTransformKind::Model;
        numChunks[FC] =
            std::min((size_t)numParallelChunks, FC->getResult().dims()[1]);
        continue;
      }
      size_t M = FC->getInput().dims()[0];
      if (M >= 256) {
        parOpts[FC] = ParallelTransformKind::Data;
        numChunks[FC] =
            std::min((size_t)numParallelChunks, FC->getResult().dims()[0]);
        continue;
      }
    }

    // Relu parallelization.
    // If a Relu follows FC, mirror FC split so that they fuse.
    // Otherwise, use data parallelism.
    if (auto *R = llvm::dyn_cast<ReluNode>(node)) {
      // For Relus that arent preceded by FC, do data parallelism if the input
      // was parallelized.
      Node *inputNode = R->getInput().getNode();
      auto FC = llvm::dyn_cast<FullyConnectedNode>(inputNode);
      if (!FC) {
        if (numChunks.find(inputNode) != numChunks.end() &&
            parOpts.find(inputNode) != parOpts.end()) {
          parOpts[R] = ParallelTransformKind::Data;
          numChunks[R] =
              std::min((size_t)numParallelChunks, R->getResult().dims()[0]);
        }
        continue;
      }

      // Otherwise, mirror FC split.
      if (R->getInput().dims().size() < 2) {
        continue;
      }
      size_t K = R->getInput().dims()[1];
      if (K >= 512) {
        parOpts[R] = ParallelTransformKind::Model;
        numChunks[R] =
            std::min((size_t)numParallelChunks, R->getResult().dims()[1]);
        continue;
      }
      size_t M = R->getInput().dims()[0];
      if (M >= 256) {
        parOpts[R] = ParallelTransformKind::Data;
        numChunks[R] =
            std::min((size_t)numParallelChunks, R->getResult().dims()[0]);
        continue;
      }
    }

    if (auto *R = llvm::dyn_cast<RescaleQuantizedNode>(node)) {
      // For Rescales that are preceded by FC or Relu, mirror their
      // parallelization.
      Node *inputNode = R->getInput().getNode();
      if (!llvm::isa<FullyConnectedNode>(inputNode) &&
          !llvm::isa<ReluNode>(inputNode)) {
        continue;
      }
      auto numChunksIt = numChunks.find(inputNode);
      auto parOptsIt = parOpts.find(inputNode);
      if (numChunksIt == numChunks.end() || parOptsIt == parOpts.end()) {
        continue;
      }
      parOpts[R] = parOptsIt->second;
      numChunks[R] = numChunksIt->second;
      continue;
    }

    // Split Gelu layers in data parallel fashion
    if (auto *GL = llvm::dyn_cast<GeluNode>(node)) {
      size_t M = GL->getInput().dims()[0];
      if (M >= numParallelChunks) {
        parOpts[GL] = ParallelTransformKind::Data;
        numChunks[GL] = numParallelChunks;
        continue;
      }
    }

    // Split transpose layers in data parallel fashion
    if (auto *TP = llvm::dyn_cast<TransposeNode>(node)) {
      parOpts[TP] = ParallelTransformKind::Data;
      numChunks[TP] =
          std::min((size_t)numParallelChunks, TP->getResult().dims()[0]);
      continue;
    }

    // Split Quantize layers in data parallel fashion
    if (auto *QN = llvm::dyn_cast<QuantizeNode>(node)) {
      parOpts[QN] = ParallelTransformKind::Data;
      numChunks[QN] =
          std::min((size_t)numParallelChunks, QN->getResult().dims()[0]);
      continue;
    }

    // Split Dequantize layers in data parallel fashion
    if (auto *DQN = llvm::dyn_cast<DequantizeNode>(node)) {
      parOpts[DQN] = ParallelTransformKind::Data;
      numChunks[DQN] =
          std::min((size_t)numParallelChunks, DQN->getResult().dims()[0]);
      continue;
    }

    // Split Tile layers
    if (auto *TN = llvm::dyn_cast<TileNode>(node)) {
      if (TN->getAxis() == 0) {
        if (TN->getInput().dims().size() < 2) {
          continue;
        }
        size_t N = TN->getInput().dims()[1];
        if (N < 256) {
          continue;
        }
        parOpts[TN] = ParallelTransformKind::Model;
        numChunks[TN] =
            std::min((size_t)numParallelChunks, TN->getResult().dims()[1]);
      } else if (TN->getAxis() == 1) {
        if (TN->getInput().dims().size() < 2) {
          continue;
        }
        size_t M = TN->getInput().dims()[0];
        if (M < 256) {
          continue;
        }
        parOpts[TN] = ParallelTransformKind::Data;
        numChunks[TN] =
            std::min((size_t)numParallelChunks, TN->getResult().dims()[0]);
      }
      continue;
    }

    // Split BatchedReduceAdd layers
    if (auto *BR = llvm::dyn_cast<BatchedReduceAddNode>(node)) {
      size_t N = BR->getResult().dims()[0];
      if (N < 64) {
        continue;
      }
      parOpts[BR] = ParallelTransformKind::Data;
      numChunks[BR] =
          std::min((size_t)numParallelChunks, BR->getResult().dims()[0]);
      continue;
    }

    // Split LayerNorm layers in data parallel fashion
    if (auto *LN = llvm::dyn_cast<LayerNormalizationNode>(node)) {
      if (LN->getInput().dims().size() < 2) {
        continue;
      }
      size_t NIdx = getMaxDimOtherThanBatch(LN->getInput().dims());
      size_t N = LN->getInput().dims()[NIdx];
      if (N < 1024) {
        continue;
      }
      parOpts[LN] = ParallelTransformKind::Data;
      numChunks[LN] =
          std::min((size_t)numParallelChunks, LN->getResult().dims()[0]);
      continue;
    }

    // Split BMM layers in data parallel fashion
    if (auto *BMM = llvm::dyn_cast<BatchMatMulNode>(node)) {
      parOpts[BMM] = ParallelTransformKind::Data;
      numChunks[BMM] =
          std::min((size_t)numParallelChunks, BMM->getResult().dims()[0]);
      continue;
    }

    // Split MatMul layers in Model parallel fashion
    if (auto *MM = llvm::dyn_cast<MatMulNode>(node)) {
      parOpts[MM] = ParallelTransformKind::Model;
      numChunks[MM] =
          std::min((size_t)numParallelChunks, MM->getResult().dims()[1]);
      continue;
    }

    // Split Tanh layers in data parallel fashion
    if (auto *TH = llvm::dyn_cast<TanhNode>(node)) {
      if (TH->getInput().dims().size() < 2) {
        continue;
      }
      if (TH->getInput().dims().size() == 2) {
        size_t N = TH->getInput().dims()[1];
        if (N < 1792) {
          continue;
        }
        parOpts[TH] = ParallelTransformKind::Data;
        numChunks[TH] =
            std::min((size_t)numParallelChunks, TH->getResult().dims()[0]);
        continue;
      } else if (TH->getInput().dims().size() == 3) {
        size_t N = TH->getInput().dims()[1];
        size_t K = TH->getInput().dims()[2];
        if (N * K < 2048) {
          continue;
        }
        parOpts[TH] = ParallelTransformKind::Data;
        numChunks[TH] =
            std::min((size_t)numParallelChunks, TH->getResult().dims()[0]);
        continue;
      }
    }

    // Split Add layers in data parallel fashion
    if (auto *AD = llvm::dyn_cast<AddNode>(node)) {
      if (AD->getLHS().dims().size() < 2) {
        continue;
      }
      if (AD->getLHS().dims().size() == 2) {
        size_t N = AD->getLHS().dims()[1];
        if (N < 1792) {
          continue;
        }
        parOpts[AD] = ParallelTransformKind::Data;
        numChunks[AD] =
            std::min((size_t)numParallelChunks, AD->getResult().dims()[0]);
        continue;
      } else if (AD->getLHS().dims().size() == 3) {
        size_t N = AD->getLHS().dims()[1];
        size_t K = AD->getLHS().dims()[2];
        if (N * K < 2048) {
          continue;
        }
        parOpts[AD] = ParallelTransformKind::Data;
        numChunks[AD] =
            std::min((size_t)numParallelChunks, AD->getResult().dims()[0]);
        continue;
      }
    }

    // Split Swish layers in data parallel fashion
    if (auto *SW = llvm::dyn_cast<SwishNode>(node)) {
      if (SW->getInput().dims().size() < 2) {
        continue;
      }
      size_t N = SW->getInput().dims()[1];
      if (N < 512) {
        continue;
      }
      parOpts[SW] = ParallelTransformKind::Data;
      numChunks[SW] =
          std::min((size_t)numParallelChunks, SW->getResult().dims()[0]);
      continue;
    }

    // Split Mul layers in data parallel fashion
    if (auto *M = llvm::dyn_cast<MulNode>(node)) {
      if (M->getLHS().dims().size() < 2) {
        continue;
      }
      size_t N = M->getLHS().dims()[1];
      if (N < 512) {
        continue;
      }
      parOpts[M] = ParallelTransformKind::Data;
      numChunks[M] =
          std::min((size_t)numParallelChunks, M->getResult().dims()[0]);
      continue;
    }

    // Split Sigmoid layers in data parallel fashion
    if (auto *S = llvm::dyn_cast<SigmoidNode>(node)) {
      if (S->getInput().dims().size() < 2) {
        continue;
      }
      size_t N = S->getInput().dims()[1];
      if (N < 512) {
        continue;
      }
      parOpts[S] = ParallelTransformKind::Data;
      numChunks[S] =
          std::min((size_t)numParallelChunks, S->getResult().dims()[0]);
      continue;
    }

    // Split Softmax layers in data parallel fashion
    if (auto *SM = llvm::dyn_cast<SoftMaxNode>(node)) {
      if (SM->getInput().dims().size() < 2) {
        continue;
      }
      size_t M = SM->getInput().dims()[0];
      size_t N = SM->getInput().dims()[1];
      if (N < 32 || M < 128) {
        continue;
      }
      parOpts[SM] = ParallelTransformKind::Data;
      numChunks[SM] =
          std::min((size_t)numParallelChunks, SM->getResult().dims()[0]);
      continue;
    }

    // Clip parallelization.
    // If a Clip follows a parallel op, mirror that.
    if (auto *C = llvm::dyn_cast<ClipNode>(node)) {
      Node *inputNode = C->getInput().getNode();
      if (numChunks.find(inputNode) != numChunks.end() &&
          parOpts.find(inputNode) != parOpts.end()) {
        parOpts[C] = parOpts[inputNode];
        if (parOpts[C] == ParallelTransformKind::Data) {
          numChunks[C] =
              std::min((size_t)numChunks[inputNode], C->getResult().dims()[0]);
        } else {
          numChunks[C] =
              std::min((size_t)numChunks[inputNode], C->getResult().dims()[1]);
        }
      }
      continue;
    }
  }
}