in tensorflow/tensorflow/compiler/xla/service/hlo_instruction.cc [1796:1927]
bool HloInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
// Perform opcode specific checks.
switch (opcode()) {
// The result of these instructions only depend upon their opcode and
// operands.
case HloOpcode::kAbs:
case HloOpcode::kAtan2:
case HloOpcode::kAdd:
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kClz:
case HloOpcode::kComplex:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kNegate:
case HloOpcode::kPartitionId:
case HloOpcode::kPopulationCount:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kReshape:
case HloOpcode::kReplicaId:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRsqrt:
case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
return true;
// This opcode has complex or special behavior so just return false.
case HloOpcode::kAfterAll:
case HloOpcode::kAddDependency:
return false;
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kConditional:
for (int j = 0; j < branch_count(); ++j) {
if (!eq_computations(branch_computation(j),
other.branch_computation(j))) {
return false;
}
}
return true;
case HloOpcode::kWhile:
return (eq_computations(while_body(), other.while_body()) &&
eq_computations(while_condition(), other.while_condition()));
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kCompare:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
case HloOpcode::kSort:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
case HloOpcode::kIota:
case HloOpcode::kTrace:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kDot:
case HloOpcode::kDomain:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
return false;
}