in tensorflow/compiler/mlir/xla/hlo_function_importer.cc [450:1426]
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
const Shape& instruction_shape = instruction->shape();
const Shape& shape = mode == DynamicShapeHandlingMode::kConvertToStatic
? xla::ShapeUtil::MakeStaticShape(instruction_shape)
: instruction_shape;
TF_ASSIGN_OR_RETURN(auto result_type,
ConvertShapeToType<RankedTensorType>(shape, *builder_));
mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
llvm::SmallVector<NamedAttribute, 10> attributes;
if (instruction->has_sharding()) {
attributes.push_back(builder_->getNamedAttr(
kShardingAttr,
builder_->getStringAttr(
instruction->sharding().ToProto().SerializeAsString())));
}
switch (instruction->opcode()) {
case HloOpcode::kParameter: {
return nullptr;
}
case HloOpcode::kConstant: {
const Literal& literal = instruction->literal();
auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
if (!attr.ok()) return attr.status();
mlir::Operation* new_operation =
func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
for (auto attr : attributes) {
new_operation->setAttr(attr.getName(), attr.getValue());
}
return new_operation;
}
case HloOpcode::kIota: {
return func_builder
->create<mlir::mhlo::IotaOp>(
loc, result_type,
func_builder->getI64IntegerAttr(
Cast<HloIotaInstruction>(instruction)->iota_dimension()))
.getOperation();
}
case HloOpcode::kBroadcast: {
// Note that the HLO broadcast is more powerful than the XLA broadcast
// op. BroadcastInDim offers a superset of the HLO op's functionality.
attributes.push_back(
builder_->getNamedAttr("broadcast_dimensions",
ConvertDimensions(instruction->dimensions())));
return func_builder
->create<mlir::mhlo::BroadcastInDimOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
attributes.push_back(builder_->getNamedAttr(
"epsilon", builder_->getF32FloatAttr(instruction->epsilon())));
attributes.push_back(builder_->getNamedAttr(
"feature_index",
builder_->getI64IntegerAttr(instruction->feature_index())));
if (instruction->opcode() == HloOpcode::kBatchNormGrad) {
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder
->create<mlir::mhlo::BatchNormGradOp>(
loc, flattened_ret_types, operands, attributes)
.getOperation();
return CreateTupleFromOpResults(func_builder, loc, op, result_type);
} else if (instruction->opcode() == HloOpcode::kBatchNormInference) {
return func_builder
->create<mlir::mhlo::BatchNormInferenceOp>(loc, result_type,
operands, attributes)
.getOperation();
} else {
assert(instruction->opcode() == HloOpcode::kBatchNormTraining);
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder
->create<mlir::mhlo::BatchNormTrainingOp>(
loc, flattened_ret_types, operands, attributes)
.getOperation();
return CreateTupleFromOpResults(func_builder, loc, op, result_type);
}
case HloOpcode::kDot: {
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
// Consider consolidating DotOps together.
if (DotIsDefault(instruction)) {
return func_builder
->create<mlir::mhlo::DotOp>(loc, result_type, operands, attributes)
.getOperation();
}
attributes.push_back(builder_->getNamedAttr(
"dot_dimension_numbers",
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
builder_)));
return func_builder
->create<mlir::mhlo::DotGeneralOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCall: {
TF_ASSIGN_OR_RETURN(FuncOp function,
ImportAsFunc(*instruction->to_apply()));
mlir::Operation* new_operation =
func_builder->create<mlir::CallOp>(loc, function, operands);
return new_operation;
}
case HloOpcode::kCollectivePermute: {
attributes.push_back(ConvertSourceTargetPairs(
instruction->source_target_pairs(), builder_));
return func_builder
->create<mlir::mhlo::CollectivePermuteOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCustomCall: {
auto custom_call = Cast<HloCustomCallInstruction>(instruction);
const auto& called_computations = custom_call->called_computations();
if (!called_computations.empty()) {
llvm::SmallVector<mlir::Attribute> callees;
callees.reserve(called_computations.size());
for (HloComputation* callee : called_computations) {
TF_ASSIGN_OR_RETURN(FuncOp function, ImportAsFunc(*callee));
callees.push_back(mlir::FlatSymbolRefAttr::get(builder_->getContext(),
function.getName()));
}
attributes.push_back(builder_->getNamedAttr(
"called_computations",
mlir::ArrayAttr::get(builder_->getContext(), callees)));
}
if (custom_call->layout_constrained()) {
TF_ASSIGN_OR_RETURN(
mlir::ArrayAttr operand_layouts,
ExtractLayoutsFromShapes(custom_call->operand_shapes_with_layout(),
builder_));
attributes.push_back(
builder_->getNamedAttr("operand_layouts", operand_layouts));
mlir::ArrayAttr result_layouts;
if (custom_call->shape().IsTuple()) {
TF_ASSIGN_OR_RETURN(
result_layouts,
ExtractLayoutsFromTuple(custom_call->shape(), builder_));
} else {
TF_ASSIGN_OR_RETURN(
result_layouts,
ExtractLayoutsFromShapes({custom_call->shape()}, builder_));
}
attributes.push_back(
builder_->getNamedAttr("result_layouts", result_layouts));
}
TF_ASSIGN_OR_RETURN(
auto mlir_api_version,
ConvertCustomCallApiVersion(custom_call->api_version()));
attributes.push_back(builder_->getNamedAttr(
"call_target_name",
builder_->getStringAttr(custom_call->custom_call_target())));
attributes.push_back(builder_->getNamedAttr(
"has_side_effect",
builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
attributes.push_back(builder_->getNamedAttr(
"backend_config",
builder_->getStringAttr(custom_call->raw_backend_config_string())));
attributes.push_back(builder_->getNamedAttr(
"api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
builder_->getContext(), mlir_api_version)));
return func_builder
->create<mlir::mhlo::CustomCallOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCompare: {
auto compare = Cast<HloCompareInstruction>(instruction);
attributes.push_back(ConvertComparisonDirection(compare->direction()));
auto default_type = Comparison::DefaultComparisonType(
compare->operand(0)->shape().element_type());
if (compare->type() != default_type)
attributes.push_back(ConvertComparisonType(compare->type()));
return func_builder
->create<mlir::mhlo::CompareOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kCholesky: {
attributes.push_back(builder_->getNamedAttr(
"lower",
builder_->getBoolAttr(instruction->cholesky_options().lower())));
return func_builder
->create<mlir::mhlo::CholeskyOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kGather: {
auto gather_instruction = Cast<HloGatherInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertGatherDimensionNumbers(
gather_instruction->gather_dimension_numbers(), builder_)));
std::vector<int64_t> slice_sizes(
gather_instruction->gather_slice_sizes().begin(),
gather_instruction->gather_slice_sizes().end());
attributes.push_back(
builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
return func_builder
->create<mlir::mhlo::GatherOp>(loc, result_type, operands, attributes)
.getOperation();
}
case HloOpcode::kDynamicSlice: {
std::vector<int64_t> slice_sizes(
instruction->dynamic_slice_sizes().begin(),
instruction->dynamic_slice_sizes().end());
return func_builder
->create<mlir::mhlo::DynamicSliceOp>(
loc, result_type, operands[0],
makeArrayRef(operands).drop_front(), Convert(slice_sizes))
.getOperation();
}
case HloOpcode::kDynamicUpdateSlice: {
return func_builder
->create<mlir::mhlo::DynamicUpdateSliceOp>(
loc, result_type, operands[0], operands[1],
llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
.getOperation();
}
case HloOpcode::kInfeed: {
if (IsNestedTupleInData(result_type)) {
llvm_unreachable(
"Importing xla::kInfeed with nested tuple shape not supported");
}
attributes.push_back(builder_->getNamedAttr(
"infeed_config",
mlir::StringAttr::get(builder_->getContext(),
instruction->infeed_config())));
llvm::SmallVector<mlir::Attribute> flattened_attr;
TF_RETURN_IF_ERROR(
ConvertShapeToMlirLayout(instruction->shape(), flattened_attr));
attributes.push_back(builder_->getNamedAttr(
"layout", builder_->getArrayAttr(makeArrayRef(flattened_attr))));
// Flatten the return-type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder->create<mlir::mhlo::InfeedOp>(
loc, flattened_ret_types, operands, attributes);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
result_type);
}
case HloOpcode::kOutfeed: {
attributes.push_back(builder_->getNamedAttr(
"outfeed_config",
mlir::StringAttr::get(builder_->getContext(),
instruction->outfeed_config())));
assert(operands.size() == 2 && "Expected 2 operands for HLO Infeed");
// In case operands[0] is a tuple, flatten it.
llvm::SmallVector<Value> flattened_operands;
FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
flattened_operands.push_back(operands[1]);
return func_builder
->create<mlir::mhlo::OutfeedOp>(loc, result_type, flattened_operands,
attributes)
.getOperation();
}
case HloOpcode::kPad: {
const auto& padding_config = instruction->padding_config();
llvm::SmallVector<int64_t, 4> edge_padding_low;
llvm::SmallVector<int64_t, 4> edge_padding_high;
llvm::SmallVector<int64_t, 4> interior_padding;
edge_padding_low.reserve(padding_config.dimensions_size());
edge_padding_high.reserve(padding_config.dimensions_size());
interior_padding.reserve(padding_config.dimensions_size());
for (const auto& dimension : padding_config.dimensions()) {
edge_padding_low.push_back(dimension.edge_padding_low());
edge_padding_high.push_back(dimension.edge_padding_high());
interior_padding.push_back(dimension.interior_padding());
}
return func_builder
->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
operands[1], Convert(edge_padding_low),
Convert(edge_padding_high),
Convert(interior_padding))
.getOperation();
}
case HloOpcode::kScatter: {
auto scatter = Cast<HloScatterInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"scatter_dimension_numbers",
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
builder_)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(scatter->indices_are_sorted())));
attributes.push_back(builder_->getNamedAttr(
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
&scatter_op.update_computation()));
return scatter_op.getOperation();
}
case HloOpcode::kSelectAndScatter: {
auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : select_scatter->window().dimensions()) {
window_strides.push_back(dim.stride());
window_dimensions.push_back(dim.size());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
attributes.push_back(
builder_->getNamedAttr("window_strides", Convert(window_strides)));
attributes.push_back(builder_->getNamedAttr("window_dimensions",
Convert(window_dimensions)));
attributes.push_back(ConvertPadding(padding));
auto select_scatter_op =
func_builder->create<mlir::mhlo::SelectAndScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
&select_scatter_op.select()));
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
&select_scatter_op.scatter()));
return select_scatter_op.getOperation();
}
case HloOpcode::kSetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getI64IntegerAttr(instruction->dimension())));
return func_builder
->create<mlir::mhlo::SetDimensionSizeOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kSlice: {
return func_builder
->create<mlir::mhlo::SliceOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->slice_starts()),
ConvertDimensions(instruction->slice_limits()),
ConvertDimensions(instruction->slice_strides()))
.getOperation();
}
case HloOpcode::kSort: {
auto sort_instruction = Cast<HloSortInstruction>(instruction);
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
loc, return_types, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return sort_op.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
.getOperation();
}
case HloOpcode::kConditional: {
llvm::SmallVector<Type, 4> rets;
// Flatten the tuple-typed operands.
llvm::SmallVector<Value> flattened_operands;
for (auto& operand : operands)
FlattenTupleValue(func_builder, loc, operand, flattened_operands);
// If/Case Op has a single operand; we collect the other operands to
// replace the corresponding block arguments.
llvm::ArrayRef<Value> implicit_operands(flattened_operands.begin() + 1,
flattened_operands.end());
mlir::Type pred_or_index_type =
operands[0].getType().cast<mlir::TensorType>().getElementType();
// It is a predicated conditional if first argument is a boolean and
// should be mapped to If op.
if (pred_or_index_type.isInteger(1)) {
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->true_computation()->root_instruction()}, &rets));
// Flatten the return-type.
llvm::SmallVector<Type> flattened_ret_types;
assert(rets.size() == 1);
FlattenTupleType(rets[0], flattened_ret_types);
auto op = func_builder->create<mlir::mhlo::IfOp>(
loc, flattened_ret_types, flattened_operands[0], attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
&op.true_branch(),
/*flatten_region_arg_tuple=*/true));
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
&op.false_branch(),
/*flatten_region_arg_tuple=*/true));
// Replace the uses of block-arguments of the IfOp with the
// implicit_operands.
ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
implicit_operands);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
rets[0]);
}
// Otherwise, it is a indexed conditional and should be mapped to Case
// op.
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->branch_computation(0)->root_instruction()}, &rets));
// Flatten the return-type.
llvm::SmallVector<Type> flattened_ret_types;
assert(rets.size() == 1);
FlattenTupleType(rets[0], flattened_ret_types);
int num_branches = instruction->branch_count();
auto op = func_builder->create<mlir::mhlo::CaseOp>(
loc, flattened_ret_types, flattened_operands[0], attributes,
num_branches);
for (const auto& index_and_computation :
llvm::enumerate(instruction->branch_computations())) {
auto index = index_and_computation.index();
HloComputation* computation = index_and_computation.value();
TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index],
/*flatten_region_arg_tuple=*/true));
}
// Replace the uses of block-arguments of the CaseOp with the
// implicit_operands.
ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
implicit_operands);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
rets[0]);
}
case HloOpcode::kConcatenate: {
// TODO(b/132057942): Support taking an uint64_t instead of an
// IntegerAttr for concatenate dimension.
return func_builder
->create<mlir::mhlo::ConcatenateOp>(
loc, result_type, operands,
builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
.getOperation();
}
case HloOpcode::kAllGather: {
auto all_gather = Cast<HloAllGatherInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"all_gather_dim",
builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
attributes.push_back(
ConvertReplicaGroups(all_gather->replica_groups(), builder_));
if (all_gather->channel_id().has_value())
attributes.push_back(
ConvertChannelHandle(all_gather->channel_id().value()));
return func_builder
->create<mlir::mhlo::AllGatherOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kAllReduce: {
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
attributes.push_back(
ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
if (all_reduce->channel_id().has_value())
attributes.push_back(
ConvertChannelHandle(all_reduce->channel_id().value()));
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
&all_reduce_op.computation()));
return all_reduce_op.getOperation();
}
case HloOpcode::kAllToAll: {
// TODO(b/207152612): all-to-all HLO can either have pre-split operands
// (and returns a tuple) or a single operand that is split across
// `split_dimension` into the number of replicas in a group. Only the
// latter case (array all-to-all) is supported in importer right now and
// the former (tuple all-to-all) is not supported yet.
auto all_to_all = Cast<HloAllToAllInstruction>(instruction);
if (all_to_all->shape().IsTuple())
return tensorflow::errors::Unimplemented(
"Importing tuple all-to-all HLO is not supported yet");
// Check invariants of array all-to-all. This is a sanity check and is
// verified by the HLO verifier.
if (!all_to_all->split_dimension().has_value() || operands.size() != 1 ||
all_to_all->replica_groups().empty())
return tensorflow::errors::InvalidArgument(
"Array all-to-all should have a split dimension, one operand and "
"non-empty replica groups");
auto replica_groups_attr =
ConvertReplicaGroups(all_to_all->replica_groups(), builder_)
.getValue()
.cast<DenseIntElementsAttr>();
uint64_t split_dim = all_to_all->split_dimension().value();
uint64_t concat_dim = split_dim;
uint64_t split_count = all_to_all->replica_groups()[0].replica_ids_size();
return func_builder
->create<mlir::mhlo::AllToAllOp>(loc, result_type, operands[0],
split_dim, concat_dim, split_count,
replica_groups_attr)
.getOperation();
}
case HloOpcode::kReduce: {
// Operands in the first half are reduction inputs and the remaining
// operands are corresponding initial values.
size_t num_inputs = operands.size() / 2;
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
loc, return_types,
llvm::makeArrayRef(operands).take_front(num_inputs),
llvm::makeArrayRef(operands).drop_front(num_inputs),
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return reduce.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
.getOperation();
}
case HloOpcode::kReverse: {
return func_builder
->create<mlir::mhlo::ReverseOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->dimensions()))
.getOperation();
}
case HloOpcode::kRng: {
auto shape = func_builder->create<mlir::mhlo::ConstOp>(
loc, Convert(result_type.cast<RankedTensorType>().getShape()));
switch (instruction->random_distribution()) {
case xla::RNG_UNIFORM:
return func_builder
->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
case xla::RNG_NORMAL:
return func_builder
->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
default:
return tensorflow::errors::InvalidArgument(absl::StrCat(
"Unsupported distribution: ",
RandomDistributionToString(instruction->random_distribution())));
}
}
case HloOpcode::kRngBitGenerator: {
auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
loc, flattened_ret_types,
func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
result_type);
}
case HloOpcode::kWhile: {
llvm::SmallVector<Value> flattened_operands;
llvm::SmallVector<Type> flattened_operand_types;
FlattenTupleType(operands[0].getType(), flattened_operand_types);
FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
auto op = func_builder->create<mlir::mhlo::WhileOp>(
loc, flattened_operand_types, flattened_operands);
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_condition(),
&op.cond(),
/*flatten_region_arg_tuple=*/true));
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_body(), &op.body(),
/*flatten_region_arg_tuple=*/true));
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
operands[0].getType());
}
case HloOpcode::kGetTupleElement: {
attributes.push_back(builder_->getNamedAttr(
"index", builder_->getIntegerAttr(builder_->getIntegerType(32),
instruction->tuple_index())));
return func_builder
->create<mlir::mhlo::GetTupleElementOp>(loc, result_type, operands,
attributes)
.getOperation();
};
case HloOpcode::kGetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getI64IntegerAttr(instruction->dimension())));
return func_builder
->create<mlir::mhlo::GetDimensionSizeOp>(loc, result_type, operands,
attributes)
.getOperation();
};
case HloOpcode::kTranspose: {
attributes.push_back(builder_->getNamedAttr(
"permutation", ConvertDimensions(instruction->dimensions())));
return func_builder
->create<mlir::mhlo::TransposeOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kTriangularSolve: {
attributes.push_back(builder_->getNamedAttr(
"left_side",
builder_->getBoolAttr(
instruction->triangular_solve_options().left_side())));
attributes.push_back(builder_->getNamedAttr(
"lower", builder_->getBoolAttr(
instruction->triangular_solve_options().lower())));
attributes.push_back(builder_->getNamedAttr(
"unit_diagonal",
builder_->getBoolAttr(
instruction->triangular_solve_options().unit_diagonal())));
auto transpose_a =
builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
instruction->triangular_solve_options().transpose_a()));
attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
return func_builder
->create<mlir::mhlo::TriangularSolveOp>(loc, result_type, operands,
attributes)
.getOperation();
}
case HloOpcode::kReduceScatter: {
auto reduce_scatter = Cast<HloReduceScatterInstruction>(instruction);
attributes.push_back(builder_->getNamedAttr(
"scatter_dimension",
builder_->getI64IntegerAttr(reduce_scatter->scatter_dimension())));
attributes.push_back(
ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_));
if (reduce_scatter->channel_id().has_value())
attributes.push_back(
ConvertChannelHandle(reduce_scatter->channel_id().value()));
auto reduce_scatter_op =
func_builder->create<mlir::mhlo::ReduceScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(),
&reduce_scatter_op.computation()));
return reduce_scatter_op.getOperation();
}
case HloOpcode::kReduceWindow: {
llvm::SmallVector<Type, 4> return_types = {result_type};
if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
return_types = llvm::to_vector<6>(tuple_ty.getTypes());
}
llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations,
win_dilations;
llvm::SmallVector<int64_t, 8> padding;
for (const auto& dim : instruction->window().dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
attributes.push_back(builder_->getNamedAttr("window_dimensions",
ConvertDimensions(sizes)));
attributes.push_back(
builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
attributes.push_back(builder_->getNamedAttr(
"base_dilations", ConvertDimensions(base_dilations)));
attributes.push_back(builder_->getNamedAttr(
"window_dilations", ConvertDimensions(win_dilations)));
attributes.push_back(ConvertPadding(padding));
auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
loc, return_types, operands, attributes);
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
// Check if the output needs to be tupled.
if (return_types.size() == 1 && return_types.front() == result_type) {
return reduce.getOperation();
}
return func_builder
->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
.getOperation();
}
case HloOpcode::kMap: {
auto op = func_builder->create<mlir::mhlo::MapOp>(
loc, result_type, operands,
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->to_apply(), &op.computation()));
return op.getOperation();
}
case HloOpcode::kConvolution: {
llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
llvm::SmallVector<int64_t, 8> paddings;
for (const auto& dim : instruction->window().dimensions()) {
strides.push_back(dim.stride());
lhs_dilations.push_back(dim.base_dilation());
rhs_dilations.push_back(dim.window_dilation());
paddings.push_back(dim.padding_low());
paddings.push_back(dim.padding_high());
}
attributes.push_back(
builder_->getNamedAttr("window_strides", Convert(strides)));
attributes.push_back(ConvertPadding(paddings));
attributes.push_back(
builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
attributes.push_back(
builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertConvDimensionNumbers(
instruction->convolution_dimension_numbers(), builder_)));
attributes.push_back(builder_->getNamedAttr(
"feature_group_count",
builder_->getI64IntegerAttr(instruction->feature_group_count())));
attributes.push_back(builder_->getNamedAttr(
"batch_group_count",
builder_->getI64IntegerAttr(instruction->batch_group_count())));
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
return func_builder
->create<mlir::mhlo::ConvOp>(loc, result_type, operands, attributes)
.getOperation();
}
case HloOpcode::kFft: {
auto fft_type =
builder_->getStringAttr(FftType_Name(instruction->fft_type()));
std::vector<int64_t> fft_length(instruction->fft_length().begin(),
instruction->fft_length().end());
attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
attributes.push_back(
builder_->getNamedAttr("fft_length", Convert(fft_length)));
return func_builder
->create<mlir::mhlo::FftOp>(loc, result_type, operands, attributes)
.getOperation();
}
case HloOpcode::kAdd: {
// HLO add ops on PRED elements are actually boolean or, but MHLO dialect
// AddOps on i1 are just addition with overflow; so, we have to implement
// the special behavior of HLO add ops on PRED here by creating an
// arith::OrIOp instead.
if (instruction->shape().element_type() == PRED) {
return func_builder
->create<mlir::mhlo::OrOp>(loc, result_type, operands, attributes)
.getOperation();
} else {
return func_builder
->create<mlir::mhlo::AddOp>(loc, result_type, operands, attributes)
.getOperation();
}
}
case HloOpcode::kAfterAll: {
// HLO AfterAll ops without any token input are used to just create a
// token. MHLO has a special op CreateToken for this case.
if (instruction->operands().empty()) {
return func_builder
->create<mlir::mhlo::CreateTokenOp>(loc, result_type, operands,
attributes)
.getOperation();
} else {
return func_builder
->create<mlir::mhlo::AfterAllOp>(loc, result_type, operands,
attributes)
.getOperation();
}
}
case HloOpcode::kConvert: {
// Convert to boolean is special, it requires a comparison to 0 instead of
// a truncation to i1, otherwise it is a 1-1 translation.
auto ranked_type = result_type.dyn_cast<mlir::RankedTensorType>();
mlir::IntegerType integer_type =
(ranked_type)
? ranked_type.getElementType().dyn_cast<mlir::IntegerType>()
: nullptr;
if (!integer_type || integer_type.getWidth() != 1) {
// Simple case: 1-1 mapping.
return {func_builder->create<mlir::mhlo::ConvertOp>(
loc, result_type, operands, attributes)};
}
// Return type is boolean, let's use `operand != 0` instead of Convert.
xla::Shape input_shape = instruction->operand(0)->shape();
TF_ASSIGN_OR_RETURN(mlir::Type type,
ConvertTensorShapeToType<mlir::RankedTensorType>(
input_shape, *func_builder));
auto zero = func_builder->create<mlir::mhlo::ConstOp>(
loc, func_builder->getZeroAttr(type));
return {func_builder->create<mlir::mhlo::CompareOp>(
loc, operands[0], zero, func_builder->getStringAttr("NE"))};
}
case HloOpcode::kOptimizationBarrier: {
llvm::SmallVector<Value> flattened_operands;
llvm::SmallVector<Type> flattened_operand_types;
FlattenTupleType(operands[0].getType(), flattened_operand_types);
FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
auto op = func_builder->create<mlir::mhlo::OptimizationBarrierOp>(
loc, flattened_operand_types, flattened_operands);
return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
operands[0].getType());
}
#define NO_ATTRIBUTE_CASE(hlo_op_code, mlir_op) \
case HloOpcode::hlo_op_code: { \
return func_builder \
->create<mlir::mhlo::mlir_op>(loc, result_type, operands, attributes) \
.getOperation(); \
}
// broadcast dimensions are never added here because they don't exist as
// part of the HLO instruction. They are only a convenience in the XLA
// builder API.
NO_ATTRIBUTE_CASE(kAbs, AbsOp);
NO_ATTRIBUTE_CASE(kAnd, AndOp);
NO_ATTRIBUTE_CASE(kAtan2, Atan2Op);
NO_ATTRIBUTE_CASE(kBitcastConvert, BitcastConvertOp);
NO_ATTRIBUTE_CASE(kCbrt, CbrtOp);
NO_ATTRIBUTE_CASE(kClz, ClzOp);
NO_ATTRIBUTE_CASE(kCeil, CeilOp);
NO_ATTRIBUTE_CASE(kClamp, ClampOp);
NO_ATTRIBUTE_CASE(kComplex, ComplexOp);
NO_ATTRIBUTE_CASE(kCos, CosOp);
NO_ATTRIBUTE_CASE(kDivide, DivOp);
NO_ATTRIBUTE_CASE(kExp, ExpOp);
NO_ATTRIBUTE_CASE(kExpm1, Expm1Op);
NO_ATTRIBUTE_CASE(kFloor, FloorOp);
NO_ATTRIBUTE_CASE(kIsFinite, IsFiniteOp);
NO_ATTRIBUTE_CASE(kImag, ImagOp);
NO_ATTRIBUTE_CASE(kLog, LogOp);
NO_ATTRIBUTE_CASE(kLog1p, Log1pOp);
NO_ATTRIBUTE_CASE(kMaximum, MaxOp);
NO_ATTRIBUTE_CASE(kMinimum, MinOp);
NO_ATTRIBUTE_CASE(kMultiply, MulOp);
NO_ATTRIBUTE_CASE(kNegate, NegOp);
NO_ATTRIBUTE_CASE(kNot, NotOp);
NO_ATTRIBUTE_CASE(kOr, OrOp);
NO_ATTRIBUTE_CASE(kPopulationCount, PopulationCountOp);
NO_ATTRIBUTE_CASE(kPower, PowOp);
NO_ATTRIBUTE_CASE(kReal, RealOp);
NO_ATTRIBUTE_CASE(kRemainder, RemOp);
NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp);
NO_ATTRIBUTE_CASE(kLogistic, LogisticOp);
// The dimensions attribute is not present on the HLO Reshape
// instruction. If dimensions are non-default, the XLA builder
// implements it as a separate transpose.
NO_ATTRIBUTE_CASE(kReshape, ReshapeOp);
NO_ATTRIBUTE_CASE(kRoundNearestAfz, RoundOp);
NO_ATTRIBUTE_CASE(kRsqrt, RsqrtOp);
NO_ATTRIBUTE_CASE(kSelect, SelectOp);
NO_ATTRIBUTE_CASE(kShiftLeft, ShiftLeftOp);
NO_ATTRIBUTE_CASE(kShiftRightArithmetic, ShiftRightArithmeticOp);
NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp);
NO_ATTRIBUTE_CASE(kSign, SignOp);
NO_ATTRIBUTE_CASE(kSin, SinOp);
NO_ATTRIBUTE_CASE(kSqrt, SqrtOp);
NO_ATTRIBUTE_CASE(kSubtract, SubOp);
NO_ATTRIBUTE_CASE(kTanh, TanhOp);
NO_ATTRIBUTE_CASE(kTuple, TupleOp);
NO_ATTRIBUTE_CASE(kXor, XorOp);
// TODO(b/129422361) Copy needs special handling because it is not
// defined in tensorflow/compiler/xla/client/xla_builder.h. See
// operation semantics in
// g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
NO_ATTRIBUTE_CASE(kCopy, CopyOp);
#undef NO_ATTRIBUTE_CASE
case HloOpcode::kFusion: {
// Flatten the tuple-typed operands.
llvm::SmallVector<Value> flattened_operands;
for (auto& operand : operands)
FlattenTupleValue(func_builder, loc, operand, flattened_operands);
// Flatten the return type if they are tuple-typed.
llvm::SmallVector<Type> flattened_ret_types;
FlattenTupleType(result_type, flattened_ret_types);
auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
loc, flattened_ret_types, flattened_operands,
builder_->getStringAttr(xla::ToString(instruction->fusion_kind())));
TF_RETURN_IF_ERROR(ImportAsRegion(
*instruction->fused_instructions_computation(),
&fusion.fused_computation(), /*flatten_region_arg_tuple=*/true));
return CreateTupleFromOpResults(func_builder, loc, fusion.getOperation(),
result_type);
}
case HloOpcode::kBitcast: {
auto bitcast = func_builder->create<mlir::mhlo::BitcastOp>(
loc, result_type, operands, attributes);
// Store the source and result layout as attributes. Although the MHLO
// Bitcast operates on tensors, these layouts are relevant as they define
// the mapping between the elements of the source and result.
SetLayoutForMlir(bitcast, instruction->shape(), "result_layout");
SetLayoutForMlir(bitcast, instruction->operand(0)->shape(),
"source_layout");
return bitcast.getOperation();
}
case HloOpcode::kReducePrecision: {
auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
loc, result_type, operands[0], attributes);
op.exponent_bitsAttr(func_builder->getIntegerAttr(
func_builder->getI32Type(), instruction->exponent_bits()));
op.mantissa_bitsAttr(func_builder->getIntegerAttr(
func_builder->getI32Type(), instruction->mantissa_bits()));
return op.getOperation();
}
case HloOpcode::kAddDependency:
// Arbitrary op code that I suspect we will not implement for quite a
// while and allows testing handling of unknown ops. Selected because it
// is not mentioned in xla client anywhere or in the hlo of our sample
// models.
default: {
mlir::OperationState result(loc, "mhlo.unknown");
result.addOperands(operands);
result.addTypes(result_type);
for (auto attr : attributes) {
result.attributes.push_back(attr);
}
return func_builder->createOperation(result);
}
}
}