in tensorflow/tensorflow/compiler/xla/service/elemental_ir_emitter.cc [2140:2390]
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
switch (hlo->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kConvert:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kTanh:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
return EmitUnaryOp(hlo, operand_value);
};
case HloOpcode::kAdd:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kCompare:
case HloOpcode::kComplex:
case HloOpcode::kDivide:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSubtract:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
const HloInstruction* lhs = hlo->operand(0);
const HloInstruction* rhs = hlo->operand(1);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
operand_to_generator.at(lhs)(index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
operand_to_generator.at(rhs)(index));
return EmitBinaryOp(hlo, lhs_value, rhs_value);
};
case HloOpcode::kSelect:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalSelect(hlo, operand_to_generator, index);
};
case HloOpcode::kClamp:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalClamp(hlo, operand_to_generator, index);
};
case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
return EmitReducePrecision(hlo, operand_value);
};
case HloOpcode::kConcatenate:
return [this, hlo, &operand_to_generator](
const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
return EmitElementalConcatenate(hlo, operand_to_generator,
target_index);
};
case HloOpcode::kReverse:
return [this, hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0);
std::vector<llvm::Value*> source_multi_index = target_index.multidim();
for (int64 dim : hlo->dimensions()) {
source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
hlo->shape().dimensions(dim) - 1),
target_index[dim]);
}
llvm_ir::IrArray::Index source_index(
source_multi_index, operand->shape(), target_index.GetType());
return operand_to_generator.at(operand)(source_index);
};
case HloOpcode::kBroadcast:
return [this, hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0);
// The `dimensions` member of the broadcast instruction maps from
// input dimensions to output dimensions.
return operand_to_generator.at(operand)(
target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
hlo->dimensions(), b_));
};
case HloOpcode::kIota:
return [this, hlo](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
auto* iota = Cast<HloIotaInstruction>(hlo);
PrimitiveType element_type = iota->shape().element_type();
IrArray::Index elem_index =
iota->shape().rank() > 1
? target_index.SourceIndexOfBroadcast(
iota->shape(),
ShapeUtil::MakeShapeWithDescendingLayout(
element_type,
{iota->shape().dimensions(iota->iota_dimension())}),
{iota->iota_dimension()}, b_)
: target_index;
llvm::Value* elem_index_linear = elem_index.linear();
if (elem_index_linear == nullptr) {
std::vector<int64> iota_bound = {
iota->shape().dimensions(iota->iota_dimension())};
elem_index_linear = elem_index.Linearize(iota_bound, b_);
}
Shape component_shape =
ShapeUtil::ElementIsComplex(iota->shape())
? ShapeUtil::ComplexComponentShape(iota->shape())
: iota->shape();
PrimitiveType component_element_type = component_shape.element_type();
llvm::Value* iota_result;
if (primitive_util::IsIntegralType(component_element_type)) {
iota_result = b_->CreateIntCast(
elem_index_linear,
llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
/*isSigned=*/false);
} else {
TF_RET_CHECK(
primitive_util::IsFloatingPointType(component_element_type))
<< component_element_type;
llvm::Type* float_ir_type;
if (component_element_type == BF16) {
float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
} else {
float_ir_type =
llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
}
llvm::Value* float_val =
b_->CreateUIToFP(elem_index_linear, float_ir_type);
if (component_element_type == BF16) {
TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
} else {
iota_result = float_val;
}
}
if (ShapeUtil::ElementIsComplex(iota->shape())) {
return EmitComposeComplex(iota, iota_result, nullptr);
} else {
return iota_result;
}
};
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index = index.SourceIndexOfSlice(
/*operand_shape=*/hlo->operand(0)->shape(),
/*starts=*/hlo->slice_starts(),
/*strides=*/hlo->slice_strides(), /*builder=*/b_);
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};
case HloOpcode::kDynamicSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
};
case HloOpcode::kGather:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalGather(hlo, operand_to_generator, index);
};
case HloOpcode::kDynamicUpdateSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
index);
};
case HloOpcode::kBitcast:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
const HloInstruction* operand = hlo->operand(0);
return operand_to_generator.at(operand)(
index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_));
};
case HloOpcode::kReshape:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
const HloInstruction* operand = hlo->operand(0);
return operand_to_generator.at(operand)(
index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
};
case HloOpcode::kCopy:
return [hlo, &operand_to_generator](
const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
IrArray::Index source_index(target_index.multidim(),
hlo->operand(0)->shape(),
target_index.GetType());
TF_ASSIGN_OR_RETURN(
llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(source_index));
return operand_value;
};
case HloOpcode::kTranspose:
return [this, hlo,
&operand_to_generator](const IrArray::Index& target_index) {
return operand_to_generator.at(hlo->operand(0))(
target_index.SourceIndexOfTranspose(
hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_));
};
case HloOpcode::kPad:
return [this, hlo, &operand_to_generator](
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
return EmitElementalPad(hlo, operand_to_generator, padded_index);
};
case HloOpcode::kDot:
return [this, hlo,
&operand_to_generator](const IrArray::Index& dot_result_index)
-> StatusOr<llvm::Value*> {
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
case HloOpcode::kReplicaId:
return [this, hlo](const IrArray::Index&) -> StatusOr<llvm::Value*> {
if (hlo_module_config_.replica_count() != 1) {
return Unimplemented("Replication is not implemented on CPU/GPU.");
}
llvm::Type* type = llvm_ir::PrimitiveTypeToIrType(
hlo->shape().element_type(), module_);
return llvm::ConstantInt::getNullValue(type);
};
default:
return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
HloOpcodeString(hlo->opcode()));
};
}
}