in tensorflow/tensorflow/compiler/xla/service/hlo_instruction.cc [1403:1588]
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
VLOG(3) << " %" << new_operand->name();
}
std::unique_ptr<HloInstruction> clone;
// Explicitly call the factory for the instruction type. This is more robust
// in the face of code changes than copying fields explicitly. This also
// properly sets the user fields of the operands.
switch (opcode_) {
// Ops migrated to subclasses.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kFft:
case HloOpcode::kCompare:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReverse:
case HloOpcode::kConcatenate:
case HloOpcode::kReduce:
case HloOpcode::kTranspose:
case HloOpcode::kBroadcast:
case HloOpcode::kReshape:
case HloOpcode::kMap:
case HloOpcode::kSlice:
case HloOpcode::kConstant:
case HloOpcode::kTrace:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
case HloOpcode::kReducePrecision:
case HloOpcode::kAllReduce:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kConvolution:
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
case HloOpcode::kDot:
case HloOpcode::kDomain:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kTanh:
CHECK_EQ(new_operands.size(), 1);
clone = CreateUnary(shape, opcode_, new_operands[0]);
break;
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
CHECK_EQ(new_operands.size(), 2);
clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
break;
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
case HloOpcode::kTupleSelect:
CHECK_EQ(new_operands.size(), 3);
clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
new_operands[2]);
break;
// Other supported ops.
case HloOpcode::kCall:
clone = CreateCall(shape, new_operands, to_apply());
break;
case HloOpcode::kConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateConvert(shape, new_operands[0]);
break;
case HloOpcode::kBitcastConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
case HloOpcode::kDynamicUpdateSlice:
clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
new_operands.subspan(2));
break;
case HloOpcode::kTuple:
clone = CreateTuple(new_operands);
*clone->mutable_shape() = shape;
break;
case HloOpcode::kWhile:
CHECK_EQ(new_operands.size(), 1);
clone =
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
break;
case HloOpcode::kConditional:
CHECK_EQ(new_operands.size(), branch_count() + 1);
clone = CreateConditional(shape, new_operands[0],
absl::MakeSpan(branch_computations()),
new_operands.subspan(1));
break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
clone = CreateToken();
} else {
clone = CreateAfterAll(new_operands);
}
break;
case HloOpcode::kAddDependency:
CHECK_EQ(new_operands.size(), 2);
clone = CreateAddDependency(new_operands[0], new_operands[1]);
break;
case HloOpcode::kReplicaId:
CHECK_EQ(new_operands.size(), 0);
clone = CreateReplicaId();
break;
case HloOpcode::kPartitionId:
CHECK_EQ(new_operands.size(), 0);
clone = CreatePartitionId();
break;
}
// SetupDerivedInstruction will setup the precision_config_ field.
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
clone->set_outer_dimension_partitions(outer_dimension_partitions_);
clone->set_raw_backend_config_string(backend_config_);
if (context != nullptr) {
context->MapInstruction(this, clone.get());
clone->ReplaceCalledComputations([&](HloComputation* callee) {
return callee->parent() != context->module()
? context->module()->DeepCloneComputation(callee, context)
: callee;
});
}
return clone;
}