in lib/Backends/NNPI/NNPI.cpp [886:1207]
static void setupBasicParallelizationConfigs(
Function *F, llvm::DenseMap<Node *, size_t> &numChunks,
llvm::DenseMap<Node *, ParallelTransformKind> &parOpts,
int32_t numParallelChunks) {
// Process nodes PostOrder so we always process inputs before outputs of any
// Node, so parallelization can be based on if a parent is parallelized.
GraphPostOrderVisitor visitor(*F);
for (auto *node : visitor.getPostOrder()) {
// Find all FC layers to split
if (auto *FC = llvm::dyn_cast<FullyConnectedNode>(node)) {
size_t K = FC->getWeights().dims()[1];
if (K >= 512) {
parOpts[FC] = ParallelTransformKind::Model;
numChunks[FC] =
std::min((size_t)numParallelChunks, FC->getResult().dims()[1]);
continue;
}
size_t M = FC->getInput().dims()[0];
if (M >= 256) {
parOpts[FC] = ParallelTransformKind::Data;
numChunks[FC] =
std::min((size_t)numParallelChunks, FC->getResult().dims()[0]);
continue;
}
}
// Relu parallelization.
// If a Relu follows FC, mirror FC split so that they fuse.
// Otherwise, use data parallelism.
if (auto *R = llvm::dyn_cast<ReluNode>(node)) {
// For Relus that arent preceded by FC, do data parallelism if the input
// was parallelized.
Node *inputNode = R->getInput().getNode();
auto FC = llvm::dyn_cast<FullyConnectedNode>(inputNode);
if (!FC) {
if (numChunks.find(inputNode) != numChunks.end() &&
parOpts.find(inputNode) != parOpts.end()) {
parOpts[R] = ParallelTransformKind::Data;
numChunks[R] =
std::min((size_t)numParallelChunks, R->getResult().dims()[0]);
}
continue;
}
// Otherwise, mirror FC split.
if (R->getInput().dims().size() < 2) {
continue;
}
size_t K = R->getInput().dims()[1];
if (K >= 512) {
parOpts[R] = ParallelTransformKind::Model;
numChunks[R] =
std::min((size_t)numParallelChunks, R->getResult().dims()[1]);
continue;
}
size_t M = R->getInput().dims()[0];
if (M >= 256) {
parOpts[R] = ParallelTransformKind::Data;
numChunks[R] =
std::min((size_t)numParallelChunks, R->getResult().dims()[0]);
continue;
}
}
if (auto *R = llvm::dyn_cast<RescaleQuantizedNode>(node)) {
// For Rescales that are preceded by FC or Relu, mirror their
// parallelization.
Node *inputNode = R->getInput().getNode();
if (!llvm::isa<FullyConnectedNode>(inputNode) &&
!llvm::isa<ReluNode>(inputNode)) {
continue;
}
auto numChunksIt = numChunks.find(inputNode);
auto parOptsIt = parOpts.find(inputNode);
if (numChunksIt == numChunks.end() || parOptsIt == parOpts.end()) {
continue;
}
parOpts[R] = parOptsIt->second;
numChunks[R] = numChunksIt->second;
continue;
}
// Split Gelu layers in data parallel fashion
if (auto *GL = llvm::dyn_cast<GeluNode>(node)) {
size_t M = GL->getInput().dims()[0];
if (M >= numParallelChunks) {
parOpts[GL] = ParallelTransformKind::Data;
numChunks[GL] = numParallelChunks;
continue;
}
}
// Split transpose layers in data parallel fashion
if (auto *TP = llvm::dyn_cast<TransposeNode>(node)) {
parOpts[TP] = ParallelTransformKind::Data;
numChunks[TP] =
std::min((size_t)numParallelChunks, TP->getResult().dims()[0]);
continue;
}
// Split Quantize layers in data parallel fashion
if (auto *QN = llvm::dyn_cast<QuantizeNode>(node)) {
parOpts[QN] = ParallelTransformKind::Data;
numChunks[QN] =
std::min((size_t)numParallelChunks, QN->getResult().dims()[0]);
continue;
}
// Split Dequantize layers in data parallel fashion
if (auto *DQN = llvm::dyn_cast<DequantizeNode>(node)) {
parOpts[DQN] = ParallelTransformKind::Data;
numChunks[DQN] =
std::min((size_t)numParallelChunks, DQN->getResult().dims()[0]);
continue;
}
// Split Tile layers
if (auto *TN = llvm::dyn_cast<TileNode>(node)) {
if (TN->getAxis() == 0) {
if (TN->getInput().dims().size() < 2) {
continue;
}
size_t N = TN->getInput().dims()[1];
if (N < 256) {
continue;
}
parOpts[TN] = ParallelTransformKind::Model;
numChunks[TN] =
std::min((size_t)numParallelChunks, TN->getResult().dims()[1]);
} else if (TN->getAxis() == 1) {
if (TN->getInput().dims().size() < 2) {
continue;
}
size_t M = TN->getInput().dims()[0];
if (M < 256) {
continue;
}
parOpts[TN] = ParallelTransformKind::Data;
numChunks[TN] =
std::min((size_t)numParallelChunks, TN->getResult().dims()[0]);
}
continue;
}
// Split BatchedReduceAdd layers
if (auto *BR = llvm::dyn_cast<BatchedReduceAddNode>(node)) {
size_t N = BR->getResult().dims()[0];
if (N < 64) {
continue;
}
parOpts[BR] = ParallelTransformKind::Data;
numChunks[BR] =
std::min((size_t)numParallelChunks, BR->getResult().dims()[0]);
continue;
}
// Split LayerNorm layers in data parallel fashion
if (auto *LN = llvm::dyn_cast<LayerNormalizationNode>(node)) {
if (LN->getInput().dims().size() < 2) {
continue;
}
size_t NIdx = getMaxDimOtherThanBatch(LN->getInput().dims());
size_t N = LN->getInput().dims()[NIdx];
if (N < 1024) {
continue;
}
parOpts[LN] = ParallelTransformKind::Data;
numChunks[LN] =
std::min((size_t)numParallelChunks, LN->getResult().dims()[0]);
continue;
}
// Split BMM layers in data parallel fashion
if (auto *BMM = llvm::dyn_cast<BatchMatMulNode>(node)) {
parOpts[BMM] = ParallelTransformKind::Data;
numChunks[BMM] =
std::min((size_t)numParallelChunks, BMM->getResult().dims()[0]);
continue;
}
// Split MatMul layers in Model parallel fashion
if (auto *MM = llvm::dyn_cast<MatMulNode>(node)) {
parOpts[MM] = ParallelTransformKind::Model;
numChunks[MM] =
std::min((size_t)numParallelChunks, MM->getResult().dims()[1]);
continue;
}
// Split Tanh layers in data parallel fashion
if (auto *TH = llvm::dyn_cast<TanhNode>(node)) {
if (TH->getInput().dims().size() < 2) {
continue;
}
if (TH->getInput().dims().size() == 2) {
size_t N = TH->getInput().dims()[1];
if (N < 1792) {
continue;
}
parOpts[TH] = ParallelTransformKind::Data;
numChunks[TH] =
std::min((size_t)numParallelChunks, TH->getResult().dims()[0]);
continue;
} else if (TH->getInput().dims().size() == 3) {
size_t N = TH->getInput().dims()[1];
size_t K = TH->getInput().dims()[2];
if (N * K < 2048) {
continue;
}
parOpts[TH] = ParallelTransformKind::Data;
numChunks[TH] =
std::min((size_t)numParallelChunks, TH->getResult().dims()[0]);
continue;
}
}
// Split Add layers in data parallel fashion
if (auto *AD = llvm::dyn_cast<AddNode>(node)) {
if (AD->getLHS().dims().size() < 2) {
continue;
}
if (AD->getLHS().dims().size() == 2) {
size_t N = AD->getLHS().dims()[1];
if (N < 1792) {
continue;
}
parOpts[AD] = ParallelTransformKind::Data;
numChunks[AD] =
std::min((size_t)numParallelChunks, AD->getResult().dims()[0]);
continue;
} else if (AD->getLHS().dims().size() == 3) {
size_t N = AD->getLHS().dims()[1];
size_t K = AD->getLHS().dims()[2];
if (N * K < 2048) {
continue;
}
parOpts[AD] = ParallelTransformKind::Data;
numChunks[AD] =
std::min((size_t)numParallelChunks, AD->getResult().dims()[0]);
continue;
}
}
// Split Swish layers in data parallel fashion
if (auto *SW = llvm::dyn_cast<SwishNode>(node)) {
if (SW->getInput().dims().size() < 2) {
continue;
}
size_t N = SW->getInput().dims()[1];
if (N < 512) {
continue;
}
parOpts[SW] = ParallelTransformKind::Data;
numChunks[SW] =
std::min((size_t)numParallelChunks, SW->getResult().dims()[0]);
continue;
}
// Split Mul layers in data parallel fashion
if (auto *M = llvm::dyn_cast<MulNode>(node)) {
if (M->getLHS().dims().size() < 2) {
continue;
}
size_t N = M->getLHS().dims()[1];
if (N < 512) {
continue;
}
parOpts[M] = ParallelTransformKind::Data;
numChunks[M] =
std::min((size_t)numParallelChunks, M->getResult().dims()[0]);
continue;
}
// Split Sigmoid layers in data parallel fashion
if (auto *S = llvm::dyn_cast<SigmoidNode>(node)) {
if (S->getInput().dims().size() < 2) {
continue;
}
size_t N = S->getInput().dims()[1];
if (N < 512) {
continue;
}
parOpts[S] = ParallelTransformKind::Data;
numChunks[S] =
std::min((size_t)numParallelChunks, S->getResult().dims()[0]);
continue;
}
// Split Softmax layers in data parallel fashion
if (auto *SM = llvm::dyn_cast<SoftMaxNode>(node)) {
if (SM->getInput().dims().size() < 2) {
continue;
}
size_t M = SM->getInput().dims()[0];
size_t N = SM->getInput().dims()[1];
if (N < 32 || M < 128) {
continue;
}
parOpts[SM] = ParallelTransformKind::Data;
numChunks[SM] =
std::min((size_t)numParallelChunks, SM->getResult().dims()[0]);
continue;
}
// Clip parallelization.
// If a Clip follows a parallel op, mirror that.
if (auto *C = llvm::dyn_cast<ClipNode>(node)) {
Node *inputNode = C->getInput().getNode();
if (numChunks.find(inputNode) != numChunks.end() &&
parOpts.find(inputNode) != parOpts.end()) {
parOpts[C] = parOpts[inputNode];
if (parOpts[C] == ParallelTransformKind::Data) {
numChunks[C] =
std::min((size_t)numChunks[inputNode], C->getResult().dims()[0]);
} else {
numChunks[C] =
std::min((size_t)numChunks[inputNode], C->getResult().dims()[1]);
}
}
continue;
}
}
}