StatusOr HloFunctionImporter::ImportInstructionImpl()

in tensorflow/compiler/mlir/xla/hlo_function_importer.cc [450:1426]


StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
    const HloInstruction* instruction,
    const llvm::SmallVectorImpl<mlir::Value>& operands,
    mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
  const Shape& instruction_shape = instruction->shape();
  const Shape& shape = mode == DynamicShapeHandlingMode::kConvertToStatic
                           ? xla::ShapeUtil::MakeStaticShape(instruction_shape)
                           : instruction_shape;
  TF_ASSIGN_OR_RETURN(auto result_type,
                      ConvertShapeToType<RankedTensorType>(shape, *builder_));
  mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);

  llvm::SmallVector<NamedAttribute, 10> attributes;
  if (instruction->has_sharding()) {
    attributes.push_back(builder_->getNamedAttr(
        kShardingAttr,
        builder_->getStringAttr(
            instruction->sharding().ToProto().SerializeAsString())));
  }

  switch (instruction->opcode()) {
    case HloOpcode::kParameter: {
      return nullptr;
    }
    case HloOpcode::kConstant: {
      const Literal& literal = instruction->literal();
      auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
      if (!attr.ok()) return attr.status();
      mlir::Operation* new_operation =
          func_builder->create<mlir::mhlo::ConstOp>(loc, attr.ValueOrDie());
      for (auto attr : attributes) {
        new_operation->setAttr(attr.getName(), attr.getValue());
      }
      return new_operation;
    }
    case HloOpcode::kIota: {
      return func_builder
          ->create<mlir::mhlo::IotaOp>(
              loc, result_type,
              func_builder->getI64IntegerAttr(
                  Cast<HloIotaInstruction>(instruction)->iota_dimension()))
          .getOperation();
    }
    case HloOpcode::kBroadcast: {
      // Note that the HLO broadcast is more powerful than the XLA broadcast
      // op. BroadcastInDim offers a superset of the HLO op's functionality.
      attributes.push_back(
          builder_->getNamedAttr("broadcast_dimensions",
                                 ConvertDimensions(instruction->dimensions())));
      return func_builder
          ->create<mlir::mhlo::BroadcastInDimOp>(loc, result_type, operands,
                                                 attributes)
          .getOperation();
    }

    case HloOpcode::kBatchNormGrad:
    case HloOpcode::kBatchNormInference:
    case HloOpcode::kBatchNormTraining:
      attributes.push_back(builder_->getNamedAttr(
          "epsilon", builder_->getF32FloatAttr(instruction->epsilon())));
      attributes.push_back(builder_->getNamedAttr(
          "feature_index",
          builder_->getI64IntegerAttr(instruction->feature_index())));
      if (instruction->opcode() == HloOpcode::kBatchNormGrad) {
        // Flatten the return type if they are tuple-typed.
        llvm::SmallVector<Type> flattened_ret_types;
        FlattenTupleType(result_type, flattened_ret_types);

        auto op = func_builder
                      ->create<mlir::mhlo::BatchNormGradOp>(
                          loc, flattened_ret_types, operands, attributes)
                      .getOperation();

        return CreateTupleFromOpResults(func_builder, loc, op, result_type);
      } else if (instruction->opcode() == HloOpcode::kBatchNormInference) {
        return func_builder
            ->create<mlir::mhlo::BatchNormInferenceOp>(loc, result_type,
                                                       operands, attributes)
            .getOperation();
      } else {
        assert(instruction->opcode() == HloOpcode::kBatchNormTraining);

        // Flatten the return type if they are tuple-typed.
        llvm::SmallVector<Type> flattened_ret_types;
        FlattenTupleType(result_type, flattened_ret_types);

        auto op = func_builder
                      ->create<mlir::mhlo::BatchNormTrainingOp>(
                          loc, flattened_ret_types, operands, attributes)
                      .getOperation();

        return CreateTupleFromOpResults(func_builder, loc, op, result_type);
      }

    case HloOpcode::kDot: {
      attributes.push_back(builder_->getNamedAttr(
          "precision_config",
          ConvertPrecisionConfig(&instruction->precision_config(), builder_)));

      // Consider consolidating DotOps together.
      if (DotIsDefault(instruction)) {
        return func_builder
            ->create<mlir::mhlo::DotOp>(loc, result_type, operands, attributes)
            .getOperation();
      }

      attributes.push_back(builder_->getNamedAttr(
          "dot_dimension_numbers",
          ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
                                     builder_)));
      return func_builder
          ->create<mlir::mhlo::DotGeneralOp>(loc, result_type, operands,
                                             attributes)
          .getOperation();
    }
    case HloOpcode::kCall: {
      TF_ASSIGN_OR_RETURN(FuncOp function,
                          ImportAsFunc(*instruction->to_apply()));
      mlir::Operation* new_operation =
          func_builder->create<mlir::CallOp>(loc, function, operands);
      return new_operation;
    }
    case HloOpcode::kCollectivePermute: {
      attributes.push_back(ConvertSourceTargetPairs(
          instruction->source_target_pairs(), builder_));
      return func_builder
          ->create<mlir::mhlo::CollectivePermuteOp>(loc, result_type, operands,
                                                    attributes)
          .getOperation();
    }
    case HloOpcode::kCustomCall: {
      auto custom_call = Cast<HloCustomCallInstruction>(instruction);
      const auto& called_computations = custom_call->called_computations();
      if (!called_computations.empty()) {
        llvm::SmallVector<mlir::Attribute> callees;
        callees.reserve(called_computations.size());
        for (HloComputation* callee : called_computations) {
          TF_ASSIGN_OR_RETURN(FuncOp function, ImportAsFunc(*callee));
          callees.push_back(mlir::FlatSymbolRefAttr::get(builder_->getContext(),
                                                         function.getName()));
        }
        attributes.push_back(builder_->getNamedAttr(
            "called_computations",
            mlir::ArrayAttr::get(builder_->getContext(), callees)));
      }
      if (custom_call->layout_constrained()) {
        TF_ASSIGN_OR_RETURN(
            mlir::ArrayAttr operand_layouts,
            ExtractLayoutsFromShapes(custom_call->operand_shapes_with_layout(),
                                     builder_));
        attributes.push_back(
            builder_->getNamedAttr("operand_layouts", operand_layouts));
        mlir::ArrayAttr result_layouts;
        if (custom_call->shape().IsTuple()) {
          TF_ASSIGN_OR_RETURN(
              result_layouts,
              ExtractLayoutsFromTuple(custom_call->shape(), builder_));
        } else {
          TF_ASSIGN_OR_RETURN(
              result_layouts,
              ExtractLayoutsFromShapes({custom_call->shape()}, builder_));
        }
        attributes.push_back(
            builder_->getNamedAttr("result_layouts", result_layouts));
      }

      TF_ASSIGN_OR_RETURN(
          auto mlir_api_version,
          ConvertCustomCallApiVersion(custom_call->api_version()));
      attributes.push_back(builder_->getNamedAttr(
          "call_target_name",
          builder_->getStringAttr(custom_call->custom_call_target())));
      attributes.push_back(builder_->getNamedAttr(
          "has_side_effect",
          builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
      attributes.push_back(builder_->getNamedAttr(
          "backend_config",
          builder_->getStringAttr(custom_call->raw_backend_config_string())));
      attributes.push_back(builder_->getNamedAttr(
          "api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
                             builder_->getContext(), mlir_api_version)));
      return func_builder
          ->create<mlir::mhlo::CustomCallOp>(loc, result_type, operands,
                                             attributes)
          .getOperation();
    }
    case HloOpcode::kCompare: {
      auto compare = Cast<HloCompareInstruction>(instruction);
      attributes.push_back(ConvertComparisonDirection(compare->direction()));
      auto default_type = Comparison::DefaultComparisonType(
          compare->operand(0)->shape().element_type());
      if (compare->type() != default_type)
        attributes.push_back(ConvertComparisonType(compare->type()));
      return func_builder
          ->create<mlir::mhlo::CompareOp>(loc, result_type, operands,
                                          attributes)
          .getOperation();
    }
    case HloOpcode::kCholesky: {
      attributes.push_back(builder_->getNamedAttr(
          "lower",
          builder_->getBoolAttr(instruction->cholesky_options().lower())));
      return func_builder
          ->create<mlir::mhlo::CholeskyOp>(loc, result_type, operands,
                                           attributes)
          .getOperation();
    }
    case HloOpcode::kGather: {
      auto gather_instruction = Cast<HloGatherInstruction>(instruction);
      attributes.push_back(builder_->getNamedAttr(
          "dimension_numbers",
          ConvertGatherDimensionNumbers(
              gather_instruction->gather_dimension_numbers(), builder_)));

      std::vector<int64_t> slice_sizes(
          gather_instruction->gather_slice_sizes().begin(),
          gather_instruction->gather_slice_sizes().end());
      attributes.push_back(
          builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
      attributes.push_back(builder_->getNamedAttr(
          "indices_are_sorted",
          builder_->getBoolAttr(gather_instruction->indices_are_sorted())));

      return func_builder
          ->create<mlir::mhlo::GatherOp>(loc, result_type, operands, attributes)
          .getOperation();
    }
    case HloOpcode::kDynamicSlice: {
      std::vector<int64_t> slice_sizes(
          instruction->dynamic_slice_sizes().begin(),
          instruction->dynamic_slice_sizes().end());
      return func_builder
          ->create<mlir::mhlo::DynamicSliceOp>(
              loc, result_type, operands[0],
              makeArrayRef(operands).drop_front(), Convert(slice_sizes))
          .getOperation();
    }
    case HloOpcode::kDynamicUpdateSlice: {
      return func_builder
          ->create<mlir::mhlo::DynamicUpdateSliceOp>(
              loc, result_type, operands[0], operands[1],
              llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
          .getOperation();
    }
    case HloOpcode::kInfeed: {
      if (IsNestedTupleInData(result_type)) {
        llvm_unreachable(
            "Importing xla::kInfeed with nested tuple shape not supported");
      }

      attributes.push_back(builder_->getNamedAttr(
          "infeed_config",
          mlir::StringAttr::get(builder_->getContext(),
                                instruction->infeed_config())));

      llvm::SmallVector<mlir::Attribute> flattened_attr;
      TF_RETURN_IF_ERROR(
          ConvertShapeToMlirLayout(instruction->shape(), flattened_attr));
      attributes.push_back(builder_->getNamedAttr(
          "layout", builder_->getArrayAttr(makeArrayRef(flattened_attr))));

      // Flatten the return-type if they are tuple-typed.
      llvm::SmallVector<Type> flattened_ret_types;
      FlattenTupleType(result_type, flattened_ret_types);

      auto op = func_builder->create<mlir::mhlo::InfeedOp>(
          loc, flattened_ret_types, operands, attributes);

      return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
                                      result_type);
    }
    case HloOpcode::kOutfeed: {
      attributes.push_back(builder_->getNamedAttr(
          "outfeed_config",
          mlir::StringAttr::get(builder_->getContext(),
                                instruction->outfeed_config())));

      assert(operands.size() == 2 && "Expected 2 operands for HLO Infeed");

      // In case operands[0] is a tuple, flatten it.
      llvm::SmallVector<Value> flattened_operands;
      FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
      flattened_operands.push_back(operands[1]);

      return func_builder
          ->create<mlir::mhlo::OutfeedOp>(loc, result_type, flattened_operands,
                                          attributes)
          .getOperation();
    }
    case HloOpcode::kPad: {
      const auto& padding_config = instruction->padding_config();
      llvm::SmallVector<int64_t, 4> edge_padding_low;
      llvm::SmallVector<int64_t, 4> edge_padding_high;
      llvm::SmallVector<int64_t, 4> interior_padding;
      edge_padding_low.reserve(padding_config.dimensions_size());
      edge_padding_high.reserve(padding_config.dimensions_size());
      interior_padding.reserve(padding_config.dimensions_size());

      for (const auto& dimension : padding_config.dimensions()) {
        edge_padding_low.push_back(dimension.edge_padding_low());
        edge_padding_high.push_back(dimension.edge_padding_high());
        interior_padding.push_back(dimension.interior_padding());
      }

      return func_builder
          ->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
                                      operands[1], Convert(edge_padding_low),
                                      Convert(edge_padding_high),
                                      Convert(interior_padding))
          .getOperation();
    }
    case HloOpcode::kScatter: {
      auto scatter = Cast<HloScatterInstruction>(instruction);
      attributes.push_back(builder_->getNamedAttr(
          "scatter_dimension_numbers",
          ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
                                         builder_)));
      attributes.push_back(builder_->getNamedAttr(
          "indices_are_sorted",
          builder_->getBoolAttr(scatter->indices_are_sorted())));
      attributes.push_back(builder_->getNamedAttr(
          "unique_indices", builder_->getBoolAttr(scatter->unique_indices())));

      auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
          loc, result_type, operands, attributes);
      TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
                                        &scatter_op.update_computation()));
      return scatter_op.getOperation();
    }
    case HloOpcode::kSelectAndScatter: {
      auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
      llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
      llvm::SmallVector<int64_t, 8> padding;
      for (const auto& dim : select_scatter->window().dimensions()) {
        window_strides.push_back(dim.stride());
        window_dimensions.push_back(dim.size());
        padding.push_back(dim.padding_low());
        padding.push_back(dim.padding_high());
      }
      attributes.push_back(
          builder_->getNamedAttr("window_strides", Convert(window_strides)));
      attributes.push_back(builder_->getNamedAttr("window_dimensions",
                                                  Convert(window_dimensions)));
      attributes.push_back(ConvertPadding(padding));
      auto select_scatter_op =
          func_builder->create<mlir::mhlo::SelectAndScatterOp>(
              loc, result_type, operands, attributes);
      TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
                                        &select_scatter_op.select()));
      TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
                                        &select_scatter_op.scatter()));
      return select_scatter_op.getOperation();
    }
    case HloOpcode::kSetDimensionSize: {
      attributes.push_back(builder_->getNamedAttr(
          "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
      return func_builder
          ->create<mlir::mhlo::SetDimensionSizeOp>(loc, result_type, operands,
                                                   attributes)
          .getOperation();
    }
    case HloOpcode::kSlice: {
      return func_builder
          ->create<mlir::mhlo::SliceOp>(
              loc, result_type, operands[0],
              ConvertDimensions(instruction->slice_starts()),
              ConvertDimensions(instruction->slice_limits()),
              ConvertDimensions(instruction->slice_strides()))
          .getOperation();
    }
    case HloOpcode::kSort: {
      auto sort_instruction = Cast<HloSortInstruction>(instruction);

      llvm::SmallVector<Type, 4> return_types = {result_type};
      if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
        return_types = llvm::to_vector<6>(tuple_ty.getTypes());
      }

      auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
          loc, return_types, operands,
          builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
          builder_->getBoolAttr(sort_instruction->is_stable()));
      TF_RETURN_IF_ERROR(
          ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));

      // Check if the output needs to be tupled.
      if (return_types.size() == 1 && return_types.front() == result_type) {
        return sort_op.getOperation();
      }

      return func_builder
          ->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
          .getOperation();
    }
    case HloOpcode::kConditional: {
      llvm::SmallVector<Type, 4> rets;

      // Flatten the tuple-typed operands.
      llvm::SmallVector<Value> flattened_operands;
      for (auto& operand : operands)
        FlattenTupleValue(func_builder, loc, operand, flattened_operands);

      // If/Case Op has a single operand; we collect the other operands to
      // replace the corresponding block arguments.
      llvm::ArrayRef<Value> implicit_operands(flattened_operands.begin() + 1,
                                              flattened_operands.end());

      mlir::Type pred_or_index_type =
          operands[0].getType().cast<mlir::TensorType>().getElementType();
      // It is a predicated conditional if first argument is a boolean and
      // should be mapped to If op.
      if (pred_or_index_type.isInteger(1)) {
        TF_RETURN_IF_ERROR(GetMlirTypes(
            {instruction->true_computation()->root_instruction()}, &rets));

        // Flatten the return-type.
        llvm::SmallVector<Type> flattened_ret_types;
        assert(rets.size() == 1);
        FlattenTupleType(rets[0], flattened_ret_types);

        auto op = func_builder->create<mlir::mhlo::IfOp>(
            loc, flattened_ret_types, flattened_operands[0], attributes);
        TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
                                          &op.true_branch(),
                                          /*flatten_region_arg_tuple=*/true));
        TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
                                          &op.false_branch(),
                                          /*flatten_region_arg_tuple=*/true));

        // Replace the uses of block-arguments of the IfOp with the
        // implicit_operands.
        ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
                                                  implicit_operands);

        return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
                                        rets[0]);
      }

      // Otherwise, it is a indexed conditional and should be mapped to Case
      // op.
      TF_RETURN_IF_ERROR(GetMlirTypes(
          {instruction->branch_computation(0)->root_instruction()}, &rets));

      // Flatten the return-type.
      llvm::SmallVector<Type> flattened_ret_types;
      assert(rets.size() == 1);
      FlattenTupleType(rets[0], flattened_ret_types);

      int num_branches = instruction->branch_count();
      auto op = func_builder->create<mlir::mhlo::CaseOp>(
          loc, flattened_ret_types, flattened_operands[0], attributes,
          num_branches);
      for (const auto& index_and_computation :
           llvm::enumerate(instruction->branch_computations())) {
        auto index = index_and_computation.index();
        HloComputation* computation = index_and_computation.value();
        TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index],
                                          /*flatten_region_arg_tuple=*/true));
      }

      // Replace the uses of block-arguments of the CaseOp with the
      // implicit_operands.
      ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
                                                implicit_operands);

      return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
                                      rets[0]);
    }
    case HloOpcode::kConcatenate: {
      // TODO(b/132057942): Support taking an uint64_t instead of an
      // IntegerAttr for concatenate dimension.
      return func_builder
          ->create<mlir::mhlo::ConcatenateOp>(
              loc, result_type, operands,
              builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
          .getOperation();
    }
    case HloOpcode::kAllGather: {
      auto all_gather = Cast<HloAllGatherInstruction>(instruction);
      attributes.push_back(builder_->getNamedAttr(
          "all_gather_dim",
          builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
      attributes.push_back(
          ConvertReplicaGroups(all_gather->replica_groups(), builder_));
      if (all_gather->channel_id().has_value())
        attributes.push_back(
            ConvertChannelHandle(all_gather->channel_id().value()));
      return func_builder
          ->create<mlir::mhlo::AllGatherOp>(loc, result_type, operands,
                                            attributes)
          .getOperation();
    }
    case HloOpcode::kAllReduce: {
      auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
      attributes.push_back(
          ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
      if (all_reduce->channel_id().has_value())
        attributes.push_back(
            ConvertChannelHandle(all_reduce->channel_id().value()));
      auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
          loc, result_type, operands, attributes);
      TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
                                        &all_reduce_op.computation()));
      return all_reduce_op.getOperation();
    }
    case HloOpcode::kAllToAll: {
      // TODO(b/207152612): all-to-all HLO can either have pre-split operands
      // (and returns a tuple) or a single operand that is split across
      // `split_dimension` into the number of replicas in a group. Only the
      // latter case (array all-to-all) is supported in importer right now and
      // the former (tuple all-to-all) is not supported yet.
      auto all_to_all = Cast<HloAllToAllInstruction>(instruction);
      if (all_to_all->shape().IsTuple())
        return tensorflow::errors::Unimplemented(
            "Importing tuple all-to-all HLO is not supported yet");

      // Check invariants of array all-to-all. This is a sanity check and is
      // verified by the HLO verifier.
      if (!all_to_all->split_dimension().has_value() || operands.size() != 1 ||
          all_to_all->replica_groups().empty())
        return tensorflow::errors::InvalidArgument(
            "Array all-to-all should have a split dimension, one operand and "
            "non-empty replica groups");

      auto replica_groups_attr =
          ConvertReplicaGroups(all_to_all->replica_groups(), builder_)
              .getValue()
              .cast<DenseIntElementsAttr>();
      uint64_t split_dim = all_to_all->split_dimension().value();
      uint64_t concat_dim = split_dim;
      uint64_t split_count = all_to_all->replica_groups()[0].replica_ids_size();

      return func_builder
          ->create<mlir::mhlo::AllToAllOp>(loc, result_type, operands[0],
                                           split_dim, concat_dim, split_count,
                                           replica_groups_attr)
          .getOperation();
    }
    case HloOpcode::kReduce: {
      // Operands in the first half are reduction inputs and the remaining
      // operands are corresponding initial values.
      size_t num_inputs = operands.size() / 2;
      llvm::SmallVector<Type, 4> return_types = {result_type};
      if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
        return_types = llvm::to_vector<6>(tuple_ty.getTypes());
      }

      auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
          loc, return_types,
          llvm::makeArrayRef(operands).take_front(num_inputs),
          llvm::makeArrayRef(operands).drop_front(num_inputs),
          ConvertDimensions(instruction->dimensions()));
      TF_RETURN_IF_ERROR(
          ImportAsRegion(*instruction->to_apply(), &reduce.body()));

      // Check if the output needs to be tupled.
      if (return_types.size() == 1 && return_types.front() == result_type) {
        return reduce.getOperation();
      }

      return func_builder
          ->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
          .getOperation();
    }
    case HloOpcode::kReverse: {
      return func_builder
          ->create<mlir::mhlo::ReverseOp>(
              loc, result_type, operands[0],
              ConvertDimensions(instruction->dimensions()))
          .getOperation();
    }
    case HloOpcode::kRng: {
      auto shape = func_builder->create<mlir::mhlo::ConstOp>(
          loc, Convert(result_type.cast<RankedTensorType>().getShape()));
      switch (instruction->random_distribution()) {
        case xla::RNG_UNIFORM:
          return func_builder
              ->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
                                                 operands[1], shape)
              .getOperation();

        case xla::RNG_NORMAL:
          return func_builder
              ->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
                                                operands[1], shape)
              .getOperation();

        default:
          return tensorflow::errors::InvalidArgument(absl::StrCat(
              "Unsupported distribution: ",
              RandomDistributionToString(instruction->random_distribution())));
      }
    }
    case HloOpcode::kRngBitGenerator: {
      auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);

      // Flatten the return type if they are tuple-typed.
      llvm::SmallVector<Type> flattened_ret_types;
      FlattenTupleType(result_type, flattened_ret_types);

      auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
          loc, flattened_ret_types,
          func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]);

      return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
                                      result_type);
    }
    case HloOpcode::kWhile: {
      llvm::SmallVector<Value> flattened_operands;
      llvm::SmallVector<Type> flattened_operand_types;
      FlattenTupleType(operands[0].getType(), flattened_operand_types);
      FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);

      auto op = func_builder->create<mlir::mhlo::WhileOp>(
          loc, flattened_operand_types, flattened_operands);

      TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_condition(),
                                        &op.cond(),
                                        /*flatten_region_arg_tuple=*/true));
      TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_body(), &op.body(),
                                        /*flatten_region_arg_tuple=*/true));
      return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
                                      operands[0].getType());
    }
    case HloOpcode::kGetTupleElement: {
      attributes.push_back(builder_->getNamedAttr(
          "index", builder_->getIntegerAttr(builder_->getIntegerType(32),
                                            instruction->tuple_index())));
      return func_builder
          ->create<mlir::mhlo::GetTupleElementOp>(loc, result_type, operands,
                                                  attributes)
          .getOperation();
    };
    case HloOpcode::kGetDimensionSize: {
      attributes.push_back(builder_->getNamedAttr(
          "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
      return func_builder
          ->create<mlir::mhlo::GetDimensionSizeOp>(loc, result_type, operands,
                                                   attributes)
          .getOperation();
    };
    case HloOpcode::kTranspose: {
      attributes.push_back(builder_->getNamedAttr(
          "permutation", ConvertDimensions(instruction->dimensions())));
      return func_builder
          ->create<mlir::mhlo::TransposeOp>(loc, result_type, operands,
                                            attributes)
          .getOperation();
    }
    case HloOpcode::kTriangularSolve: {
      attributes.push_back(builder_->getNamedAttr(
          "left_side",
          builder_->getBoolAttr(
              instruction->triangular_solve_options().left_side())));
      attributes.push_back(builder_->getNamedAttr(
          "lower", builder_->getBoolAttr(
                       instruction->triangular_solve_options().lower())));
      attributes.push_back(builder_->getNamedAttr(
          "unit_diagonal",
          builder_->getBoolAttr(
              instruction->triangular_solve_options().unit_diagonal())));
      auto transpose_a =
          builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
              instruction->triangular_solve_options().transpose_a()));
      attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
      return func_builder
          ->create<mlir::mhlo::TriangularSolveOp>(loc, result_type, operands,
                                                  attributes)
          .getOperation();
    }
    case HloOpcode::kReduceScatter: {
      auto reduce_scatter = Cast<HloReduceScatterInstruction>(instruction);
      attributes.push_back(builder_->getNamedAttr(
          "scatter_dimension",
          builder_->getI64IntegerAttr(reduce_scatter->scatter_dimension())));
      attributes.push_back(
          ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_));
      if (reduce_scatter->channel_id().has_value())
        attributes.push_back(
            ConvertChannelHandle(reduce_scatter->channel_id().value()));
      auto reduce_scatter_op =
          func_builder->create<mlir::mhlo::ReduceScatterOp>(
              loc, result_type, operands, attributes);
      TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(),
                                        &reduce_scatter_op.computation()));

      return reduce_scatter_op.getOperation();
    }
    case HloOpcode::kReduceWindow: {
      llvm::SmallVector<Type, 4> return_types = {result_type};
      if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
        return_types = llvm::to_vector<6>(tuple_ty.getTypes());
      }
      llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations,
          win_dilations;
      llvm::SmallVector<int64_t, 8> padding;
      for (const auto& dim : instruction->window().dimensions()) {
        sizes.push_back(dim.size());
        strides.push_back(dim.stride());
        base_dilations.push_back(dim.base_dilation());
        win_dilations.push_back(dim.window_dilation());
        padding.push_back(dim.padding_low());
        padding.push_back(dim.padding_high());
      }
      attributes.push_back(builder_->getNamedAttr("window_dimensions",
                                                  ConvertDimensions(sizes)));
      attributes.push_back(
          builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
      attributes.push_back(builder_->getNamedAttr(
          "base_dilations", ConvertDimensions(base_dilations)));
      attributes.push_back(builder_->getNamedAttr(
          "window_dilations", ConvertDimensions(win_dilations)));
      attributes.push_back(ConvertPadding(padding));
      auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
          loc, return_types, operands, attributes);
      TF_RETURN_IF_ERROR(
          ImportAsRegion(*instruction->to_apply(), &reduce.body()));

      // Check if the output needs to be tupled.
      if (return_types.size() == 1 && return_types.front() == result_type) {
        return reduce.getOperation();
      }

      return func_builder
          ->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
          .getOperation();
    }
    case HloOpcode::kMap: {
      auto op = func_builder->create<mlir::mhlo::MapOp>(
          loc, result_type, operands,
          ConvertDimensions(instruction->dimensions()));
      TF_RETURN_IF_ERROR(
          ImportAsRegion(*instruction->to_apply(), &op.computation()));
      return op.getOperation();
    }
    case HloOpcode::kConvolution: {
      llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
      llvm::SmallVector<int64_t, 8> paddings;
      for (const auto& dim : instruction->window().dimensions()) {
        strides.push_back(dim.stride());
        lhs_dilations.push_back(dim.base_dilation());
        rhs_dilations.push_back(dim.window_dilation());
        paddings.push_back(dim.padding_low());
        paddings.push_back(dim.padding_high());
      }

      attributes.push_back(
          builder_->getNamedAttr("window_strides", Convert(strides)));
      attributes.push_back(ConvertPadding(paddings));
      attributes.push_back(
          builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
      attributes.push_back(
          builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
      attributes.push_back(builder_->getNamedAttr(
          "dimension_numbers",
          ConvertConvDimensionNumbers(
              instruction->convolution_dimension_numbers(), builder_)));
      attributes.push_back(builder_->getNamedAttr(
          "feature_group_count",
          builder_->getI64IntegerAttr(instruction->feature_group_count())));
      attributes.push_back(builder_->getNamedAttr(
          "batch_group_count",
          builder_->getI64IntegerAttr(instruction->batch_group_count())));
      attributes.push_back(builder_->getNamedAttr(
          "precision_config",
          ConvertPrecisionConfig(&instruction->precision_config(), builder_)));

      return func_builder
          ->create<mlir::mhlo::ConvOp>(loc, result_type, operands, attributes)
          .getOperation();
    }

    case HloOpcode::kFft: {
      auto fft_type =
          builder_->getStringAttr(FftType_Name(instruction->fft_type()));

      std::vector<int64_t> fft_length(instruction->fft_length().begin(),
                                      instruction->fft_length().end());

      attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
      attributes.push_back(
          builder_->getNamedAttr("fft_length", Convert(fft_length)));
      return func_builder
          ->create<mlir::mhlo::FftOp>(loc, result_type, operands, attributes)
          .getOperation();
    }

    case HloOpcode::kAdd: {
      // HLO add ops on PRED elements are actually boolean or, but MHLO dialect
      // AddOps on i1 are just addition with overflow; so, we have to implement
      // the special behavior of HLO add ops on PRED here by creating an
      // arith::OrIOp instead.
      if (instruction->shape().element_type() == PRED) {
        return func_builder
            ->create<mlir::mhlo::OrOp>(loc, result_type, operands, attributes)
            .getOperation();
      } else {
        return func_builder
            ->create<mlir::mhlo::AddOp>(loc, result_type, operands, attributes)
            .getOperation();
      }
    }
    case HloOpcode::kAfterAll: {
      // HLO AfterAll ops without any token input are used to just create a
      // token. MHLO has a special op CreateToken for this case.
      if (instruction->operands().empty()) {
        return func_builder
            ->create<mlir::mhlo::CreateTokenOp>(loc, result_type, operands,
                                                attributes)
            .getOperation();
      } else {
        return func_builder
            ->create<mlir::mhlo::AfterAllOp>(loc, result_type, operands,
                                             attributes)
            .getOperation();
      }
    }

    case HloOpcode::kConvert: {
      // Convert to boolean is special, it requires a comparison to 0 instead of
      // a truncation to i1, otherwise it is a 1-1 translation.
      auto ranked_type = result_type.dyn_cast<mlir::RankedTensorType>();
      mlir::IntegerType integer_type =
          (ranked_type)
              ? ranked_type.getElementType().dyn_cast<mlir::IntegerType>()
              : nullptr;
      if (!integer_type || integer_type.getWidth() != 1) {
        // Simple case: 1-1 mapping.
        return {func_builder->create<mlir::mhlo::ConvertOp>(
            loc, result_type, operands, attributes)};
      }

      // Return type is boolean, let's use `operand != 0` instead of Convert.
      xla::Shape input_shape = instruction->operand(0)->shape();
      TF_ASSIGN_OR_RETURN(mlir::Type type,
                          ConvertTensorShapeToType<mlir::RankedTensorType>(
                              input_shape, *func_builder));
      auto zero = func_builder->create<mlir::mhlo::ConstOp>(
          loc, func_builder->getZeroAttr(type));
      return {func_builder->create<mlir::mhlo::CompareOp>(
          loc, operands[0], zero, func_builder->getStringAttr("NE"))};
    }
    case HloOpcode::kOptimizationBarrier: {
      llvm::SmallVector<Value> flattened_operands;
      llvm::SmallVector<Type> flattened_operand_types;
      FlattenTupleType(operands[0].getType(), flattened_operand_types);
      FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);

      auto op = func_builder->create<mlir::mhlo::OptimizationBarrierOp>(
          loc, flattened_operand_types, flattened_operands);

      return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
                                      operands[0].getType());
    }

#define NO_ATTRIBUTE_CASE(hlo_op_code, mlir_op)                               \
  case HloOpcode::hlo_op_code: {                                              \
    return func_builder                                                       \
        ->create<mlir::mhlo::mlir_op>(loc, result_type, operands, attributes) \
        .getOperation();                                                      \
  }

      // broadcast dimensions are never added here because they don't exist as
      // part of the HLO instruction. They are only a convenience in the XLA
      // builder API.
      NO_ATTRIBUTE_CASE(kAbs, AbsOp);
      NO_ATTRIBUTE_CASE(kAnd, AndOp);
      NO_ATTRIBUTE_CASE(kAtan2, Atan2Op);
      NO_ATTRIBUTE_CASE(kBitcastConvert, BitcastConvertOp);
      NO_ATTRIBUTE_CASE(kCbrt, CbrtOp);
      NO_ATTRIBUTE_CASE(kClz, ClzOp);
      NO_ATTRIBUTE_CASE(kCeil, CeilOp);
      NO_ATTRIBUTE_CASE(kClamp, ClampOp);
      NO_ATTRIBUTE_CASE(kComplex, ComplexOp);
      NO_ATTRIBUTE_CASE(kCos, CosOp);
      NO_ATTRIBUTE_CASE(kDivide, DivOp);
      NO_ATTRIBUTE_CASE(kExp, ExpOp);
      NO_ATTRIBUTE_CASE(kExpm1, Expm1Op);
      NO_ATTRIBUTE_CASE(kFloor, FloorOp);
      NO_ATTRIBUTE_CASE(kIsFinite, IsFiniteOp);
      NO_ATTRIBUTE_CASE(kImag, ImagOp);
      NO_ATTRIBUTE_CASE(kLog, LogOp);
      NO_ATTRIBUTE_CASE(kLog1p, Log1pOp);
      NO_ATTRIBUTE_CASE(kMaximum, MaxOp);
      NO_ATTRIBUTE_CASE(kMinimum, MinOp);
      NO_ATTRIBUTE_CASE(kMultiply, MulOp);
      NO_ATTRIBUTE_CASE(kNegate, NegOp);
      NO_ATTRIBUTE_CASE(kNot, NotOp);
      NO_ATTRIBUTE_CASE(kOr, OrOp);
      NO_ATTRIBUTE_CASE(kPopulationCount, PopulationCountOp);
      NO_ATTRIBUTE_CASE(kPower, PowOp);
      NO_ATTRIBUTE_CASE(kReal, RealOp);
      NO_ATTRIBUTE_CASE(kRemainder, RemOp);
      NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp);
      NO_ATTRIBUTE_CASE(kLogistic, LogisticOp);
      // The dimensions attribute is not present on the HLO Reshape
      // instruction. If dimensions are non-default, the XLA builder
      // implements it as a separate transpose.
      NO_ATTRIBUTE_CASE(kReshape, ReshapeOp);
      NO_ATTRIBUTE_CASE(kRoundNearestAfz, RoundOp);
      NO_ATTRIBUTE_CASE(kRsqrt, RsqrtOp);
      NO_ATTRIBUTE_CASE(kSelect, SelectOp);
      NO_ATTRIBUTE_CASE(kShiftLeft, ShiftLeftOp);
      NO_ATTRIBUTE_CASE(kShiftRightArithmetic, ShiftRightArithmeticOp);
      NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp);
      NO_ATTRIBUTE_CASE(kSign, SignOp);
      NO_ATTRIBUTE_CASE(kSin, SinOp);
      NO_ATTRIBUTE_CASE(kSqrt, SqrtOp);
      NO_ATTRIBUTE_CASE(kSubtract, SubOp);
      NO_ATTRIBUTE_CASE(kTanh, TanhOp);
      NO_ATTRIBUTE_CASE(kTuple, TupleOp);
      NO_ATTRIBUTE_CASE(kXor, XorOp);
      // TODO(b/129422361) Copy needs special handling because it is not
      // defined in tensorflow/compiler/xla/client/xla_builder.h. See
      // operation semantics in
      // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
      NO_ATTRIBUTE_CASE(kCopy, CopyOp);

#undef NO_ATTRIBUTE_CASE

    case HloOpcode::kFusion: {
      // Flatten the tuple-typed operands.
      llvm::SmallVector<Value> flattened_operands;
      for (auto& operand : operands)
        FlattenTupleValue(func_builder, loc, operand, flattened_operands);

      // Flatten the return type if they are tuple-typed.
      llvm::SmallVector<Type> flattened_ret_types;
      FlattenTupleType(result_type, flattened_ret_types);

      auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
          loc, flattened_ret_types, flattened_operands,
          builder_->getStringAttr(xla::ToString(instruction->fusion_kind())));
      TF_RETURN_IF_ERROR(ImportAsRegion(
          *instruction->fused_instructions_computation(),
          &fusion.fused_computation(), /*flatten_region_arg_tuple=*/true));

      return CreateTupleFromOpResults(func_builder, loc, fusion.getOperation(),
                                      result_type);
    }
    case HloOpcode::kBitcast: {
      auto bitcast = func_builder->create<mlir::mhlo::BitcastOp>(
          loc, result_type, operands, attributes);
      // Store the source and result layout as attributes. Although the MHLO
      // Bitcast operates on tensors, these layouts are relevant as they define
      // the mapping between the elements of the source and result.
      SetLayoutForMlir(bitcast, instruction->shape(), "result_layout");
      SetLayoutForMlir(bitcast, instruction->operand(0)->shape(),
                       "source_layout");
      return bitcast.getOperation();
    }
    case HloOpcode::kReducePrecision: {
      auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
          loc, result_type, operands[0], attributes);
      op.exponent_bitsAttr(func_builder->getIntegerAttr(
          func_builder->getI32Type(), instruction->exponent_bits()));
      op.mantissa_bitsAttr(func_builder->getIntegerAttr(
          func_builder->getI32Type(), instruction->mantissa_bits()));
      return op.getOperation();
    }
    case HloOpcode::kAddDependency:
      // Arbitrary op code that I suspect we will not implement for quite a
      // while and allows testing handling of unknown ops. Selected because it
      // is not mentioned in xla client anywhere or in the hlo of our sample
      // models.
    default: {
      mlir::OperationState result(loc, "mhlo.unknown");
      result.addOperands(operands);
      result.addTypes(result_type);
      for (auto attr : attributes) {
        result.attributes.push_back(attr);
      }

      return func_builder->createOperation(result);
    }
  }
}