in lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp [648:1246]
bool SinkCode::run(Function *F, const CompilationContext &cctx) {
LOG_SCOPE(F->getLogContext(), getName());
bool changed = false;
auto &nodes = F->getNodes();
// For each node:
for (auto &N : nodes) {
auto *node = &N;
// Sink Reshape/Transpose below BatchNormalization.
if (auto *BN = dyn_cast<BatchNormalizationNode>(node)) {
// Sink Reshape below BatchNormalization.
if (auto *RS = dyn_cast<ReshapeNode>(BN->getInput())) {
auto inDims = RS->getInput().dims();
auto outDims = RS->getResult().dims();
unsigned_t newChannelIdx;
// Skip sinking if: 1) the input was less than 3 dimensions,
// because we need spatial dimensions in addition to batch
// and channel or 2) if it is 3D data because the reshapes
// are deliberately introduced to phrase 3D BatchNormalization
// as a 2D one.
if (RS->getInput().dims().size() < 3 ||
RS->getInput().dims().size() == 5) {
continue;
}
// Reshape should not change the BatchNorm ChannelIdx dimensions.
// Only NH[W]C and NCH[W] are allowed.
if (BN->getChannelIdx() == outDims.size() - 1) {
if (inDims[inDims.size() - 1] != outDims[outDims.size() - 1]) {
continue;
}
newChannelIdx = inDims.size() - 1;
} else if (BN->getChannelIdx() == 1) {
// Note: index '1' maps to C in NCH[W] layout.
if (inDims[1] != outDims[1]) {
continue;
}
newChannelIdx = 1;
} else {
continue;
}
// Reshape should not change the batch dimension.
if (inDims[0] != outDims[0]) {
continue;
}
auto bnOutTy = F->getParent()->uniqueTypeWithNewShape(
BN->getResult().getType(), RS->getInput().getType());
auto rsInputType = RS->getInput().getType();
glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape(
bnOutTy, rsInputType->dims());
auto *newBN = F->createBatchNormalization(
BN->getName(), outTy, RS->getInput(), BN->getBias(), BN->getScale(),
BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(),
BN->getMomentum());
auto *newRS = F->createReshape(RS->getName(), newBN,
RS->getResult().dims(), RS->getLayout());
BN->getResult().replaceAllUsesOfWith(newRS);
changed = true;
continue;
}
// Sink Transpose below batch normalization nodes:
if (auto *TR = dyn_cast<TransposeNode>(BN->getInput())) {
// Figure out where we transposed the channel index for batch
// normalization.
unsigned_t idx = BN->getChannelIdx();
unsigned_t newChannelIdx = TR->getShuffle()[idx];
auto bnOutTy = BN->getResult().getType();
auto trInputType = TR->getInput().getType();
glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape(
bnOutTy, trInputType->dims());
auto *NewBN = F->createBatchNormalization(
BN->getName(), outTy, TR->getInput(), BN->getBias(), BN->getScale(),
BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(),
BN->getMomentum());
NewBN->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(TR->getName(), NewBN, TR->getShuffle(),
TR->getLayout());
newTR->setPredicate(node->getPredicate());
BN->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
}
if (auto *RL = dyn_cast<ReluNode>(node)) {
// Sink Transpose below batch RELU nodes.
if (auto *TR = dyn_cast<TransposeNode>(RL->getInput())) {
// Keep the same quantization parameters for ReLU output, but
// change the shape to appropriate value.
auto reluOutTy = F->getParent()->uniqueTypeWithNewShape(
RL->getResult().getType(), TR->getInput().getType());
auto *NRL = F->createRELU(RL->getName(), TR->getInput(), reluOutTy);
NRL->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(TR->getName(), NRL, TR->getShuffle(),
TR->getLayout());
newTR->setPredicate(node->getPredicate());
RL->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Sink Clip below RELU nodes.
if (ClipNode *CN = dyn_cast<ClipNode>(RL->getInput())) {
assert(!RL->getResult().getType()->isQuantizedType() &&
"Relu(Clip) means Relu should not be quantized.");
ReluNode *newRL = F->createRELU(RL->getName(), CN->getInput());
ClipNode *newCN =
F->createClip(CN->getName(), newRL->getResult(),
std::max(CN->getMin(), 0.0f), CN->getMax());
RL->getResult().replaceAllUsesOfWith(newCN);
changed = true;
continue;
}
}
// Sink Transpose below Clip nodes.
if (auto *CL = dyn_cast<ClipNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(CL->getInput());
if (!TR) {
continue;
}
// Keep the same quantization parameters for Clip output, but
// change the shape to appropriate value.
auto clipOutTy = F->getParent()->uniqueTypeWithNewShape(
CL->getResult().getType(), TR->getInput().getType());
auto *NCL = F->createClip(CL->getName(), TR->getInput(), clipOutTy,
CL->getMin(), CL->getMax());
NCL->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(TR->getName(), NCL, TR->getShuffle());
newTR->setPredicate(node->getPredicate());
CL->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Sink Transpose below LeakyRelu nodes.
if (auto *LR = dyn_cast<LeakyReluNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(LR->getInput());
if (!TR) {
continue;
}
auto newLROutTy = F->getParent()->uniqueTypeWithNewShape(
LR->getResult().getType(), TR->getInput().getType());
auto *newLR = F->createLeakyRELU(LR->getName(), newLROutTy,
TR->getInput(), LR->getAlpha());
newLR->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(TR->getName(), newLR, TR->getShuffle());
newTR->setPredicate(node->getPredicate());
LR->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Sink Transpose below PRelu with Splat.
if (auto *PN = dyn_cast<PReluNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(PN->getInput());
if (!TR) {
continue;
}
auto *SN = dyn_cast<SplatNode>(PN->getSlope());
if (!SN) {
continue;
}
auto newSNOutTy = F->getParent()->uniqueTypeWithNewShape(
SN->getResult().getType(), TR->getInput().getType());
auto newPNOutTy = F->getParent()->uniqueTypeWithNewShape(
PN->getResult().getType(), TR->getInput().getType());
auto *newSN = F->createSplat(SN->getName(), newSNOutTy, SN->getValue());
auto *newPN =
F->createPRELU(PN->getName(), TR->getInput(), newSN, newPNOutTy);
auto *newTR = F->createTranspose(TR->getName(), newPN, TR->getShuffle());
newPN->setPredicate(node->getPredicate());
newTR->setPredicate(node->getPredicate());
PN->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Sink Transpose below Sigmoid nodes.
if (auto *SI = dyn_cast<SigmoidNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(SI->getInput());
if (!TR) {
continue;
}
auto *NSI = F->createSigmoid(SI->getName(), TR->getInput());
NSI->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(TR->getName(), NSI, TR->getShuffle(),
TR->getLayout());
newTR->setPredicate(node->getPredicate());
SI->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Sink Transpose below Tile nodes.
if (auto *TN = dyn_cast<TileNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(TN->getInput());
if (!TR) {
continue;
}
auto *newTN = F->createTile(TN->getName(), TR->getInput(), TN->getCount(),
TR->getShuffle()[TN->getAxis()]);
newTN->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(TR->getName(), newTN, TR->getShuffle(),
TR->getLayout());
newTR->setPredicate(node->getPredicate());
TN->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Sink Transpose below Pad nodes.
if (auto *padNode = dyn_cast<PadNode>(node)) {
auto *transposeNode = dyn_cast<TransposeNode>(padNode->getInput());
if (!transposeNode) {
continue;
}
// The transpose shuffle specifies the source dimension.
// When sinking Transpose below Pad, shuffle describes the target
// dimension.
auto shuffle = transposeNode->getShuffle();
// Shuffle the Pad output type and the padding attribute.
auto outPadType = padNode->getResult().getType();
auto outPadShape = outPadType->dims();
auto pads = padNode->getPads();
size_t numDims = outPadShape.size();
std::vector<dim_t> newOutPadShape(numDims);
std::vector<int> newPads(2 * numDims);
for (size_t i = 0; i < outPadShape.size(); i++) {
newOutPadShape[shuffle[i]] = outPadShape[i];
newPads[shuffle[i]] = pads[i];
newPads[shuffle[i] + numDims] = pads[i + numDims];
}
// New pad
auto newOutPadType =
F->getParent()->uniqueTypeWithNewShape(outPadType, newOutPadShape);
auto *NewPadNode = F->createPad(
padNode->getName(), transposeNode->getInput(), newOutPadType,
padNode->getMode(), newPads, padNode->getValue());
NewPadNode->setPredicate(node->getPredicate());
auto *newTransposeNode =
F->createTranspose(transposeNode->getName(), NewPadNode, shuffle);
newTransposeNode->setPredicate(node->getPredicate());
padNode->getResult().replaceAllUsesOfWith(newTransposeNode);
changed = true;
continue;
}
// Sink Transpose below Tanh nodes.
if (auto *TN = dyn_cast<TanhNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(TN->getInput());
if (!TR) {
continue;
}
auto *NTN = F->createTanh(TN->getName(), TR->getInput());
NTN->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(TR->getName(), NTN, TR->getShuffle(),
TR->getLayout());
newTR->setPredicate(node->getPredicate());
TN->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Remove 'identity' transpose operations.
if (auto *TR = dyn_cast<TransposeNode>(node)) {
auto mask = TR->getShuffle();
if (isIdentityShuffle(mask)) {
TR->getResult().replaceAllUsesOfWith(TR->getInput());
changed = true;
continue;
}
}
// Merge consecutive Transpose operations.
if (auto *TR1 = dyn_cast<TransposeNode>(node)) {
auto *TR2 = dyn_cast<TransposeNode>(TR1->getInput());
if (!TR2) {
continue;
}
auto mask1 = TR1->getShuffle();
auto mask2 = TR2->getShuffle();
assert(mask1.size() == mask2.size() && "Invalid mask size");
llvm::SmallVector<unsigned_t, max_tensor_dimensions> newMask;
newMask.resize(mask2.size());
for (size_t i = 0, end = mask2.size(); i < end; i++) {
newMask[i] = mask2[mask1[i]];
}
auto *newTR = F->createTranspose("tranpose", TR2->getInput(), newMask);
TR1->getResult().replaceAllUsesOfWith(newTR->getResult());
changed = true;
continue;
}
if (auto *CS = dyn_cast<ChannelShuffleNode>(node)) {
// Sink Transpose below ChannelShuffle.
if (sinkTranposeBelowChannelShuffle(F, CS)) {
changed = true;
continue;
}
}
// Sink Transpose below Arithmetic nodes.
if (node->isArithmetic()) {
TransposeNode *LTR =
dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::LHSIdx));
TransposeNode *RTR =
dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::RHSIdx));
if (!LTR || !RTR) {
// If one of the sides is a splat, it can be seen as
// transpose (splat'). Similarly, if one of the sides is a Constant,
// it can be seen as tranpose (Constant').
if (isa<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx)) && RTR) {
// Build splat' for LHS.
auto *SN =
dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx));
auto *NS = F->createSplat("splat", RTR->getInput().getType(),
SN->getValue());
LTR = F->createTranspose("transpose", NS, RTR->getShuffle(),
RTR->getLayout());
changed = true;
} else if (isa<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx)) &&
LTR) {
// Build splat' for RHS.
auto *SN =
dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx));
auto *NS = F->createSplat("splat", LTR->getInput().getType(),
SN->getValue());
RTR = F->createTranspose("transpose", NS, LTR->getShuffle(),
LTR->getLayout());
changed = true;
} else if (isa<Constant>(node->getNthInput(ArithmeticNode::LHSIdx)) &&
RTR) {
// Build Constant' for for LHS.
auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::LHSIdx));
LTR = insertMatchingTransposeAfterConstant(F, C, RTR);
changed = true;
} else if (isa<Constant>(node->getNthInput(ArithmeticNode::RHSIdx)) &&
LTR) {
// Build Constant' for for RHS.
auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::RHSIdx));
RTR = insertMatchingTransposeAfterConstant(F, C, LTR);
changed = true;
} else {
continue;
}
}
// The masks of the transposes on both sizes must match.
if (LTR->getShuffle() != RTR->getShuffle()) {
continue;
}
Node *newAN = nullptr;
#define ARITHMETIC_CASE(NODE_NAME_) \
case glow::Kinded::Kind::NODE_NAME_##NodeKind: \
newAN = \
F->create##NODE_NAME_(node->getName(), \
F->getParent()->uniqueTypeWithNewShape( \
node->getType(ArithmeticNode::ResultIdx), \
LTR->getInput().getType()), \
LTR->getInput(), RTR->getInput()); \
break;
#define BOOLEAN_OP_CASE(NODE_NAME_) \
case glow::Kinded::Kind::NODE_NAME_##NodeKind: \
newAN = F->create##NODE_NAME_(node->getName(), LTR->getInput(), \
RTR->getInput()); \
break;
switch (node->getKind()) {
ARITHMETIC_CASE(Add);
ARITHMETIC_CASE(Mul);
ARITHMETIC_CASE(Sub);
ARITHMETIC_CASE(Div);
ARITHMETIC_CASE(Fmod);
ARITHMETIC_CASE(Max);
ARITHMETIC_CASE(Min);
ARITHMETIC_CASE(Pow);
BOOLEAN_OP_CASE(CmpLTE);
BOOLEAN_OP_CASE(CmpEQ);
default:
llvm_unreachable("Unhandled node");
}
#undef BOOLEAN_OP_CASE
#undef ARITHMETIC_CASE
newAN->setPredicate(node->getPredicate());
changed = true;
auto *newTR = F->createTranspose(LTR->getName(), newAN, LTR->getShuffle(),
LTR->getLayout());
newTR->setPredicate(node->getPredicate());
node->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(newTR);
}
if (auto *Q = dyn_cast<QuantizeNode>(node)) {
// Sink TransposeNode below QuantizedNode.
if (auto *TR = getTransposeNodeWithAllSameUserKind(Q->getInput())) {
auto newQType = F->getParent()->uniqueTypeWithNewShape(
Q->getResult().getType(), TR->getInput().dims());
auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType);
auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle());
Q->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
// Sink Reshape below Quantize.
if (auto *RN = dyn_cast<ReshapeNode>(Q->getInput())) {
auto newQType = F->getParent()->uniqueTypeWithNewShape(
Q->getResult().getType(), RN->getInput().dims());
auto *newQ = F->createQuantize(Q->getName(), RN->getInput(), newQType);
auto *newRN = F->createReshape(RN->getName(), newQ,
RN->getResult().dims(), RN->getLayout());
Q->getResult().replaceAllUsesOfWith(newRN->getResult());
changed = true;
continue;
}
}
// Sink Reshape below ConvertTo.
if (auto *CN = dyn_cast<ConvertToNode>(node)) {
auto *RN = dyn_cast<ReshapeNode>(CN->getInput());
if (!RN) {
continue;
}
auto *newCN = F->createConvertTo(CN->getName(), RN->getInput(),
CN->getResult().getElementType());
auto *newRN = F->createReshape(RN->getName(), newCN,
RN->getResult().dims(), RN->getLayout());
CN->getResult().replaceAllUsesOfWith(newRN->getResult());
changed = true;
continue;
}
// Sink TransposeNode below DequantizedNode.
// If it doesn't work out it will be re-sinked later.
if (auto *D = dyn_cast<DequantizeNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(D->getInput());
if (!TR) {
continue;
}
auto newDType = F->getParent()->uniqueTypeWithNewShape(
D->getResult().getType(), TR->getInput().dims());
auto *newD = F->createDequantize(D->getName(), TR->getInput(), newDType);
auto *newTR = F->createTranspose(TR->getName(), newD, TR->getShuffle());
D->getResult().replaceAllUsesOfWith(newTR);
changed = true;
}
// Sink Transpose below RescaleQuantized.
// Potentially exposes opportunity to be combined up with Convolution.
// If it doesn't work out it will be re-sinked later.
if (auto *RQ = dyn_cast<RescaleQuantizedNode>(node)) {
auto *TR = dyn_cast<TransposeNode>(RQ->getInput());
if (!TR) {
continue;
}
auto newRQType = F->getParent()->uniqueTypeWithNewShape(
RQ->getResult().getType(), TR->getInput().getType());
auto *newRQ =
F->createRescaleQuantized(RQ->getName(), TR->getInput(), newRQType);
auto *newTR = F->createTranspose(TR->getName(), newRQ, TR->getShuffle(),
TR->getLayout());
RQ->getResult().replaceAllUsesOfWith(newTR);
changed = true;
}
if (auto *CN = dyn_cast<ConcatNode>(node)) {
const Node *firstNode = CN->getInputs().front().getNode();
// Sink RELU below batch concat nodes.
if (firstNode->getKind() == Kinded::Kind::ReluNodeKind) {
llvm::SmallVector<NodeValue, 6> CNInputs;
for (auto &input : CN->getInputs()) {
auto *inputRL = dyn_cast<ReluNode>(input);
if (!inputRL) {
break;
}
CNInputs.push_back(inputRL->getInput());
}
if (CNInputs.size() == CN->getNumInputs()) {
auto *newCN = F->createConcat(CN->getName(), CNInputs, CN->getDim());
newCN->setPredicate(node->getPredicate());
auto name = CN->getNthInput(0).getNode()->getName();
auto *newRL = F->createRELU(name, newCN, CN->getResult().getType());
newRL->setPredicate(node->getPredicate());
CN->getResult().replaceAllUsesOfWith(newRL);
changed = true;
}
continue;
}
// Sink Transpose below concat nodes.
if (firstNode->getKind() == Kinded::Kind::TransposeNodeKind) {
llvm::SmallVector<NodeValue, 6> transVector;
auto inputIter = CN->getInputs().begin();
auto *firstInput = dyn_cast<TransposeNode>(*inputIter);
if (!firstInput) {
continue;
}
transVector.push_back(firstInput->getInput());
auto shuffle = firstInput->getShuffle();
// If the shuffle masks don't agree or not all inputs are Transpose then
// bail out.
for (++inputIter; inputIter != CN->getInputs().end(); ++inputIter) {
auto *tTR = dyn_cast<TransposeNode>(*inputIter);
if (!tTR || tTR->getShuffle() != shuffle) {
break;
}
transVector.push_back(tTR->getInput());
}
if (transVector.size() != CN->getNumInputs()) {
continue;
}
// Figure out where we transposed the channel index for batch
// normalization.
unsigned_t idx = CN->getDim();
unsigned_t newChannelIdx = shuffle[idx];
auto *newCN =
F->createConcat(CN->getName(), transVector, newChannelIdx);
newCN->setPredicate(node->getPredicate());
auto *newTR = F->createTranspose(firstInput->getName(), newCN,
firstInput->getShuffle(),
firstInput->getLayout());
newTR->setPredicate(node->getPredicate());
CN->getResult().replaceAllUsesOfWith(newTR);
changed = true;
continue;
}
}
} // For all nodes in the graph.
// Transformations to sink nodes below Slice. Outlined into a separate loop to
// prevent Transpose/Slice sinking to affect them.
for (auto &N : nodes) {
auto *node = &N;
// Sink BatchNorm below Slice.
if (auto *SN = dyn_cast<SliceNode>(node)) {
auto *BN = dyn_cast<BatchNormalizationNode>(SN->getInput());
if (!BN || !BN->hasOneUse()) {
continue;
}
// Don't support sinking below Slice which affects depth.
if (SN->getInput().dims()[BN->getChannelIdx()] !=
SN->getResult().dims()[BN->getChannelIdx()]) {
continue;
}
auto newSNType = F->getParent()->uniqueTypeWithNewShape(
BN->getInput().getType(), SN->getResult().dims());
auto *newSN = F->createSlice(SN->getName(), BN->getInput(),
SN->getStart(), newSNType);
auto *newBN = F->createBatchNormalization(
BN->getName(), SN->getResult().getType(), newSN, BN->getBias(),
BN->getScale(), BN->getMean(), BN->getVar(), BN->getChannelIdx(),
BN->getEpsilon(), BN->getMomentum());
SN->getResult().replaceAllUsesOfWith(newBN);
changed = true;
}
}
return changed;
}