in tensorflow/tensorflow/compiler/xla/service/instruction_fusion.cc [59:185]
/*static*/ bool InstructionFusion::IsExpensive(
const HloInstruction& instruction) {
namespace m = match;
switch (instruction.opcode()) {
// Cheap instructions.
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kBroadcast:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kConcatenate:
case HloOpcode::kConstant:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCopyDone:
case HloOpcode::kCopyStart:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kFloor:
case HloOpcode::kGetTupleElement:
case HloOpcode::kImag:
case HloOpcode::kInfeed:
case HloOpcode::kIota:
case HloOpcode::kIsFinite:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kOutfeed:
case HloOpcode::kPad:
case HloOpcode::kPartitionId:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kReplicaId:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSlice:
case HloOpcode::kSubtract:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
return false;
// Cheap instructions for reals, but expensive for complex.
case HloOpcode::kAbs:
case HloOpcode::kCos:
case HloOpcode::kSign:
case HloOpcode::kSin:
return ShapeUtil::ElementIsComplex(instruction.shape());
// We say that integer div/mod by a constant is cheap because it gets
// compiled down to multiplies and shifts, and we consider those to be
// cheap.
case HloOpcode::kDivide:
case HloOpcode::kRemainder:
return !ShapeUtil::ElementIsIntegral(instruction.shape()) ||
!Match(instruction.operand(1),
m::AnyOf<const HloInstruction>(
m::ConstantEffectiveScalar(),
m::Broadcast(m::ConstantEffectiveScalar())));
// Expensive instructions or unusual instructions for which fusion is
// nonsensical.
case HloOpcode::kAddDependency:
case HloOpcode::kAfterAll:
case HloOpcode::kAtan2:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kCall:
case HloOpcode::kCholesky:
case HloOpcode::kConditional:
case HloOpcode::kConvolution:
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kGather:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kMap:
case HloOpcode::kParameter:
case HloOpcode::kPower:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kRsqrt:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kSort:
case HloOpcode::kSqrt:
case HloOpcode::kTanh:
case HloOpcode::kTrace:
case HloOpcode::kTriangularSolve:
case HloOpcode::kWhile:
case HloOpcode::kGetDimensionSize:
return true;
}
return false;
}