bool HloParser::ParseInstructionRhs()

in tensorflow/tensorflow/compiler/xla/service/hlo_parser.cc [667:1814]


bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
                                    const string& name, LocTy name_loc) {
  Shape shape;
  HloOpcode opcode;
  std::vector<HloInstruction*> operands;

  if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
    return false;
  }

  // Add optional attributes.
  std::unordered_map<string, AttrConfig> attrs;
  optional<OpSharding> sharding;
  attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
  optional<ParameterReplication> parameter_replication;
  attrs["parameter_replication"] = {/*required=*/false,
                                    AttrTy::kParameterReplication,
                                    &parameter_replication};
  optional<std::vector<HloInstruction*>> predecessors;
  attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
                                   &predecessors};
  optional<OpMetadata> metadata;
  attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};

  optional<string> backend_config;
  attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
                             &backend_config};
  optional<std::vector<int64>> outer_dimension_partitions;
  attrs["outer_dimension_partitions"] = {/*required=*/false,
                                         AttrTy::kBracedInt64List,
                                         &outer_dimension_partitions};

  HloInstruction* instruction;
  switch (opcode) {
    case HloOpcode::kParameter: {
      int64 parameter_number;
      if (!ParseToken(TokKind::kLparen,
                      "expects '(' before parameter number") ||
          !ParseInt64(&parameter_number)) {
        return false;
      }
      if (parameter_number < 0) {
        Error(lexer_.GetLoc(), "parameter number must be >= 0");
        return false;
      }
      if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateParameter(parameter_number, shape, name));
      break;
    }
    case HloOpcode::kConstant: {
      Literal literal;
      if (!ParseToken(TokKind::kLparen,
                      "expects '(' before constant literal") ||
          !ParseLiteral(&literal, shape) ||
          !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateConstant(std::move(literal)));
      break;
    }
    case HloOpcode::kIota: {
      optional<int64> iota_dimension;
      attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
                                 &iota_dimension};
      if (!ParseOperands(&operands, /*expected_size=*/0) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateIota(shape, *iota_dimension));
      break;
    }
    // Unary ops.
    case HloOpcode::kAbs:
    case HloOpcode::kRoundNearestAfz:
    case HloOpcode::kBitcast:
    case HloOpcode::kCeil:
    case HloOpcode::kClz:
    case HloOpcode::kCopy:
    case HloOpcode::kCopyStart:
    case HloOpcode::kCopyDone:
    case HloOpcode::kCos:
    case HloOpcode::kExp:
    case HloOpcode::kExpm1:
    case HloOpcode::kImag:
    case HloOpcode::kIsFinite:
    case HloOpcode::kFloor:
    case HloOpcode::kLog:
    case HloOpcode::kLog1p:
    case HloOpcode::kNot:
    case HloOpcode::kNegate:
    case HloOpcode::kPopulationCount:
    case HloOpcode::kReal:
    case HloOpcode::kRsqrt:
    case HloOpcode::kSign:
    case HloOpcode::kSin:
    case HloOpcode::kSqrt:
    case HloOpcode::kTanh: {
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateUnary(shape, opcode, operands[0]));
      break;
    }
    // Binary ops.
    case HloOpcode::kAdd:
    case HloOpcode::kDivide:
    case HloOpcode::kMultiply:
    case HloOpcode::kSubtract:
    case HloOpcode::kAtan2:
    case HloOpcode::kComplex:
    case HloOpcode::kMaximum:
    case HloOpcode::kMinimum:
    case HloOpcode::kPower:
    case HloOpcode::kRemainder:
    case HloOpcode::kAnd:
    case HloOpcode::kOr:
    case HloOpcode::kXor:
    case HloOpcode::kShiftLeft:
    case HloOpcode::kShiftRightArithmetic:
    case HloOpcode::kShiftRightLogical: {
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateBinary(
          shape, opcode, operands[0], operands[1]));
      break;
    }
    // Ternary ops.
    case HloOpcode::kClamp:
    case HloOpcode::kSelect:
    case HloOpcode::kTupleSelect: {
      if (!ParseOperands(&operands, /*expected_size=*/3) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateTernary(
          shape, opcode, operands[0], operands[1], operands[2]));
      break;
    }
    // Other supported ops.
    case HloOpcode::kConvert: {
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateConvert(shape, operands[0]));
      break;
    }
    case HloOpcode::kBitcastConvert: {
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateBitcastConvert(shape, operands[0]));
      break;
    }
    case HloOpcode::kAllReduce: {
      optional<std::vector<std::vector<int64>>> tmp_groups;
      optional<HloComputation*> to_apply;
      optional<std::vector<int64>> replica_group_ids;
      optional<int64> channel_id;
      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                           &to_apply};
      attrs["replica_groups"] = {/*required=*/false,
                                 AttrTy::kBracedInt64ListList, &tmp_groups};
      attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      std::vector<ReplicaGroup> replica_groups;
      if (tmp_groups) {
        replica_groups = CreateReplicaGroups(*tmp_groups);
      }
      instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
          shape, operands, *to_apply, replica_groups, channel_id));
      break;
    }
    case HloOpcode::kAllToAll: {
      optional<std::vector<std::vector<int64>>> tmp_groups;
      attrs["replica_groups"] = {/*required=*/false,
                                 AttrTy::kBracedInt64ListList, &tmp_groups};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      std::vector<ReplicaGroup> replica_groups;
      if (tmp_groups) {
        replica_groups = CreateReplicaGroups(*tmp_groups);
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateAllToAll(shape, operands, replica_groups));
      break;
    }
    case HloOpcode::kCollectivePermute: {
      optional<std::vector<std::vector<int64>>> source_targets;
      attrs["source_target_pairs"] = {
          /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
      optional<int64> channel_id;
      attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      std::vector<std::pair<int64, int64>> pairs(source_targets->size());
      for (int i = 0; i < pairs.size(); i++) {
        if ((*source_targets)[i].size() != 2) {
          return TokenError(
              "expects 'source_target_pairs=' to be a list of pairs");
        }
        pairs[i].first = (*source_targets)[i][0];
        pairs[i].second = (*source_targets)[i][1];
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateCollectivePermute(
              shape, operands[0], pairs, channel_id));
      break;
    }
    case HloOpcode::kReplicaId: {
      if (!ParseOperands(&operands, /*expected_size=*/0) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
      break;
    }
    case HloOpcode::kPartitionId: {
      if (!ParseOperands(&operands, /*expected_size=*/0) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreatePartitionId());
      break;
    }
    case HloOpcode::kReshape: {
      optional<int64> inferred_dimension;
      attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
                                     &inferred_dimension};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateReshape(
          shape, operands[0], inferred_dimension.value_or(-1)));
      break;
    }
    case HloOpcode::kAfterAll: {
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      if (operands.empty()) {
        instruction = builder->AddInstruction(HloInstruction::CreateToken());
      } else {
        instruction =
            builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
      }
      break;
    }
    case HloOpcode::kAddDependency: {
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateAddDependency(operands[0], operands[1]));
      break;
    }
    case HloOpcode::kSort: {
      optional<std::vector<int64>> dimensions;
      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &dimensions};
      optional<bool> is_stable = false;
      attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
      optional<HloComputation*> to_apply;
      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                           &to_apply};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
          dimensions->size() != 1) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateSort(shape, dimensions->at(0), operands,
                                     to_apply.value(), is_stable.value()));
      break;
    }
    case HloOpcode::kTuple: {
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateTuple(operands));
      break;
    }
    case HloOpcode::kWhile: {
      optional<HloComputation*> condition;
      optional<HloComputation*> body;
      attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
                            &condition};
      attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateWhile(
          shape, *condition, *body, /*init=*/operands[0]));
      break;
    }
    case HloOpcode::kRecv: {
      optional<int64> channel_id;
      // If the is_host_transfer attribute is not present then default to false.
      optional<bool> is_host_transfer = false;
      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
      attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
                                   &is_host_transfer};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      // If the is_host_transfer attribute is not present then default to false.
      instruction = builder->AddInstruction(HloInstruction::CreateRecv(
          shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
      break;
    }
    case HloOpcode::kRecvDone: {
      optional<int64> channel_id;
      // If the is_host_transfer attribute is not present then default to false.
      optional<bool> is_host_transfer = false;
      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
      attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
                                   &is_host_transfer};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      if (channel_id != operands[0]->channel_id()) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
      break;
    }
    case HloOpcode::kSend: {
      optional<int64> channel_id;
      // If the is_host_transfer attribute is not present then default to false.
      optional<bool> is_host_transfer = false;
      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
      attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
                                   &is_host_transfer};
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateSend(
          operands[0], operands[1], *channel_id, *is_host_transfer));
      break;
    }
    case HloOpcode::kSendDone: {
      optional<int64> channel_id;
      // If the is_host_transfer attribute is not present then default to false.
      optional<bool> is_host_transfer = false;
      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
      attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
                                   &is_host_transfer};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      if (channel_id != operands[0]->channel_id()) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
      break;
    }
    case HloOpcode::kGetTupleElement: {
      optional<int64> index;
      attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
      break;
    }
    case HloOpcode::kCall: {
      optional<HloComputation*> to_apply;
      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                           &to_apply};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateCall(shape, operands, *to_apply));
      break;
    }
    case HloOpcode::kReduceWindow: {
      optional<HloComputation*> reduce_computation;
      optional<Window> window;
      attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                           &reduce_computation};
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      if (!window) {
        window.emplace();
      }
      instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
          shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
          *reduce_computation));
      break;
    }
    case HloOpcode::kConvolution: {
      optional<Window> window;
      optional<ConvolutionDimensionNumbers> dnums;
      optional<int64> feature_group_count;
      optional<int64> batch_group_count;
      attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
      attrs["dim_labels"] = {/*required=*/true,
                             AttrTy::kConvolutionDimensionNumbers, &dnums};
      attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
                                      &feature_group_count};
      attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
                                    &batch_group_count};
      optional<std::vector<PrecisionConfig::Precision>> operand_precision;
      attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
                                    &operand_precision};
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      if (!window) {
        window.emplace();
      }
      if (!feature_group_count) {
        feature_group_count = 1;
      }
      if (!batch_group_count) {
        batch_group_count = 1;
      }
      PrecisionConfig precision_config;
      if (operand_precision) {
        *precision_config.mutable_operand_precision() = {
            operand_precision->begin(), operand_precision->end()};
      } else {
        precision_config.mutable_operand_precision()->Resize(
            operands.size(), PrecisionConfig::DEFAULT);
      }
      instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
          shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
          feature_group_count.value(), batch_group_count.value(), *window,
          *dnums, precision_config));
      break;
    }
    case HloOpcode::kFft: {
      optional<FftType> fft_type;
      optional<std::vector<int64>> fft_length;
      attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
      attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &fft_length};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateFft(
          shape, operands[0], *fft_type, *fft_length));
      break;
    }
    case HloOpcode::kTriangularSolve: {
      TriangularSolveOptions options;
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributesAsProtoMessage(
              /*required_attrs=*/std::unordered_set<string>(), &options)) {
        return false;
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateTriangularSolve(
              shape, operands[0], operands[1], options));
      break;
    }
    case HloOpcode::kCompare: {
      optional<ComparisonDirection> direction;
      attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
                            &direction};
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateCompare(
          shape, operands[0], operands[1], *direction));
      break;
    }
    case HloOpcode::kCholesky: {
      CholeskyOptions options;
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributesAsProtoMessage(
              /*required_attrs=*/std::unordered_set<string>(), &options)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateCholesky(shape, operands[0], options));
      break;
    }
    case HloOpcode::kBroadcast: {
      optional<std::vector<int64>> broadcast_dimensions;
      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &broadcast_dimensions};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
          shape, operands[0], *broadcast_dimensions));
      break;
    }
    case HloOpcode::kConcatenate: {
      optional<std::vector<int64>> dimensions;
      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &dimensions};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
          dimensions->size() != 1) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
          shape, operands, dimensions->at(0)));
      break;
    }
    case HloOpcode::kMap: {
      optional<HloComputation*> to_apply;
      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                           &to_apply};
      optional<std::vector<int64>> dimensions;
      attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
                             &dimensions};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateMap(shape, operands, *to_apply));
      break;
    }
    case HloOpcode::kReduce: {
      auto loc = lexer_.GetLoc();

      optional<HloComputation*> reduce_computation;
      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                           &reduce_computation};
      optional<std::vector<int64>> dimensions_to_reduce;
      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &dimensions_to_reduce};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      if (operands.size() % 2) {
        return Error(loc, StrCat("expects an even number of operands, but has ",
                                 operands.size(), " operands"));
      }
      instruction = builder->AddInstruction(HloInstruction::CreateReduce(
          shape, /*operands=*/
          absl::Span<HloInstruction* const>(operands).subspan(
              0, operands.size() / 2),
          /*init_values=*/
          absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
                                                              2),
          *dimensions_to_reduce, *reduce_computation));
      break;
    }
    case HloOpcode::kReverse: {
      optional<std::vector<int64>> dimensions;
      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &dimensions};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateReverse(shape, operands[0], *dimensions));
      break;
    }
    case HloOpcode::kSelectAndScatter: {
      optional<HloComputation*> select;
      attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
      optional<HloComputation*> scatter;
      attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
      optional<Window> window;
      attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
      if (!ParseOperands(&operands, /*expected_size=*/3) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      if (!window) {
        window.emplace();
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
              shape, /*operand=*/operands[0], *select, *window,
              /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
      break;
    }
    case HloOpcode::kSlice: {
      optional<SliceRanges> slice_ranges;
      attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateSlice(
          shape, operands[0], slice_ranges->starts, slice_ranges->limits,
          slice_ranges->strides));
      break;
    }
    case HloOpcode::kDynamicSlice: {
      optional<std::vector<int64>> dynamic_slice_sizes;
      attrs["dynamic_slice_sizes"] = {
          /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
      LocTy loc = lexer_.GetLoc();
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      if (operands.empty()) {
        return Error(loc, "Expected at least one operand.");
      }
      if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
          operands.size() != 1 + operands[0]->shape().rank()) {
        return Error(loc, "Wrong number of operands.");
      }
      instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
          shape, /*operand=*/operands[0],
          /*start_indices=*/absl::MakeSpan(operands).subspan(1),
          *dynamic_slice_sizes));
      break;
    }
    case HloOpcode::kDynamicUpdateSlice: {
      LocTy loc = lexer_.GetLoc();
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      if (operands.size() < 2) {
        return Error(loc, "Expected at least two operands.");
      }
      if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
          operands.size() != 2 + operands[0]->shape().rank()) {
        return Error(loc, "Wrong number of operands.");
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
              shape, /*operand=*/operands[0], /*update=*/operands[1],
              /*start_indices=*/absl::MakeSpan(operands).subspan(2)));
      break;
    }
    case HloOpcode::kTranspose: {
      optional<std::vector<int64>> dimensions;
      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &dimensions};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
      break;
    }
    case HloOpcode::kBatchNormTraining: {
      optional<float> epsilon;
      attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
      optional<int64> feature_index;
      attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
                                &feature_index};
      if (!ParseOperands(&operands, /*expected_size=*/3) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
              shape, /*operand=*/operands[0], /*scale=*/operands[1],
              /*offset=*/operands[2], *epsilon, *feature_index));
      break;
    }
    case HloOpcode::kBatchNormInference: {
      optional<float> epsilon;
      attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
      optional<int64> feature_index;
      attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
                                &feature_index};
      if (!ParseOperands(&operands, /*expected_size=*/5) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateBatchNormInference(
              shape, /*operand=*/operands[0], /*scale=*/operands[1],
              /*offset=*/operands[2], /*mean=*/operands[3],
              /*variance=*/operands[4], *epsilon, *feature_index));
      break;
    }
    case HloOpcode::kBatchNormGrad: {
      optional<float> epsilon;
      attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
      optional<int64> feature_index;
      attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
                                &feature_index};
      if (!ParseOperands(&operands, /*expected_size=*/5) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
          shape, /*operand=*/operands[0], /*scale=*/operands[1],
          /*mean=*/operands[2], /*variance=*/operands[3],
          /*grad_output=*/operands[4], *epsilon, *feature_index));
      break;
    }
    case HloOpcode::kPad: {
      optional<PaddingConfig> padding;
      attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreatePad(
          shape, operands[0], /*padding_value=*/operands[1], *padding));
      break;
    }
    case HloOpcode::kFusion: {
      optional<HloComputation*> fusion_computation;
      attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
                        &fusion_computation};
      optional<HloInstruction::FusionKind> fusion_kind;
      attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateFusion(
          shape, *fusion_kind, operands, *fusion_computation));
      break;
    }
    case HloOpcode::kInfeed: {
      optional<string> config;
      attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      // We need to know the infeed data shape to construct the infeed
      // instruction. This is the zero-th element of the tuple-shaped output of
      // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
      // if the shape is not a non-empty tuple, so add guard so an error message
      // can be emitted instead of a check fail
      if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
        return Error(lexer_.GetLoc(),
                     "infeed must have a non-empty tuple shape");
      }
      instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
          ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
          config ? *config : ""));
      break;
    }
    case HloOpcode::kOutfeed: {
      optional<string> config;
      attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
                                        operands[1], config ? *config : ""));
      break;
    }
    case HloOpcode::kRng: {
      optional<RandomDistribution> distribution;
      attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
                               &distribution};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateRng(shape, *distribution, operands));
      break;
    }
    case HloOpcode::kRngGetAndUpdateState: {
      optional<int64> delta;
      attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
      if (!ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(
          HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
      break;
    }
    case HloOpcode::kReducePrecision: {
      optional<int64> exponent_bits;
      optional<int64> mantissa_bits;
      attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
                                &exponent_bits};
      attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
                                &mantissa_bits};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateReducePrecision(
              shape, operands[0], static_cast<int>(*exponent_bits),
              static_cast<int>(*mantissa_bits)));
      break;
    }
    case HloOpcode::kConditional: {
      optional<HloComputation*> true_computation;
      optional<HloComputation*> false_computation;
      optional<std::vector<HloComputation*>> branch_computations;
      if (!ParseOperands(&operands)) {
        return false;
      }
      if (!ShapeUtil::IsScalar(operands[0]->shape())) {
        return Error(lexer_.GetLoc(), "The first operand must be a scalar");
      }
      const bool branch_index_is_bool =
          operands[0]->shape().element_type() == PRED;
      if (branch_index_is_bool) {
        attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
                                     &true_computation};
        attrs["false_computation"] = {
            /*required=*/true, AttrTy::kHloComputation, &false_computation};
      } else {
        if (operands[0]->shape().element_type() != S32) {
          return Error(lexer_.GetLoc(),
                       "The first operand must be a scalar of PRED or S32");
        }
        attrs["branch_computations"] = {/*required=*/true,
                                        AttrTy::kBracedHloComputationList,
                                        &branch_computations};
      }
      if (!ParseAttributes(attrs)) {
        return false;
      }
      if (branch_index_is_bool) {
        branch_computations.emplace({*true_computation, *false_computation});
      }
      if (branch_computations->empty() ||
          operands.size() != branch_computations->size() + 1) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateConditional(
          shape, /*branch_index=*/operands[0],
          absl::MakeSpan(*branch_computations),
          absl::MakeSpan(operands).subspan(1)));
      break;
    }
    case HloOpcode::kCustomCall: {
      optional<string> custom_call_target;
      optional<Window> window;
      optional<ConvolutionDimensionNumbers> dnums;
      optional<int64> feature_group_count;
      optional<int64> batch_group_count;
      optional<std::vector<Shape>> operand_layout_constraints;
      optional<bool> custom_call_has_side_effect;
      attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
                                     &custom_call_target};
      attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
      attrs["dim_labels"] = {/*required=*/false,
                             AttrTy::kConvolutionDimensionNumbers, &dnums};
      attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
                                      &feature_group_count};
      attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
                                    &batch_group_count};
      attrs["operand_layout_constraints"] = {
          /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
      attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
                                              &custom_call_has_side_effect};
      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
        return false;
      }
      if (operand_layout_constraints.has_value()) {
        if (!LayoutUtil::HasLayout(shape)) {
          return Error(lexer_.GetLoc(),
                       "Layout must be set on layout-constrained custom call");
        }
        if (operands.size() != operand_layout_constraints->size()) {
          return Error(lexer_.GetLoc(),
                       StrCat("Expected ", operands.size(),
                              " operand layout constraints, ",
                              operand_layout_constraints->size(), " given"));
        }
        for (int64 i = 0; i < operands.size(); ++i) {
          const Shape& operand_shape_with_layout =
              (*operand_layout_constraints)[i];
          if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
            return Error(lexer_.GetLoc(),
                         StrCat("Operand layout constraint shape ",
                                ShapeUtil::HumanStringWithLayout(
                                    operand_shape_with_layout),
                                " for operand ", i, " does not have a layout"));
          }
          if (!ShapeUtil::Compatible(operand_shape_with_layout,
                                     operands[i]->shape())) {
            return Error(
                lexer_.GetLoc(),
                StrCat(
                    "Operand layout constraint shape ",
                    ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
                    " for operand ", i,
                    " is not compatible with operand shape ",
                    ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
          }
        }
        instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
            shape, operands, *custom_call_target, *operand_layout_constraints,
            backend_config ? *backend_config : ""));
      } else {
        instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
            shape, operands, *custom_call_target,
            backend_config ? *backend_config : ""));
      }
      auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
      if (window.has_value()) {
        custom_call_instr->set_window(*window);
      }
      if (dnums.has_value()) {
        custom_call_instr->set_convolution_dimension_numbers(*dnums);
      }
      if (feature_group_count.has_value()) {
        custom_call_instr->set_feature_group_count(*feature_group_count);
      }
      if (batch_group_count.has_value()) {
        custom_call_instr->set_batch_group_count(*batch_group_count);
      }
      if (custom_call_has_side_effect.has_value()) {
        custom_call_instr->set_custom_call_has_side_effect(
            *custom_call_has_side_effect);
      }
      break;
    }
    case HloOpcode::kDot: {
      optional<std::vector<int64>> lhs_contracting_dims;
      attrs["lhs_contracting_dims"] = {
          /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
      optional<std::vector<int64>> rhs_contracting_dims;
      attrs["rhs_contracting_dims"] = {
          /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
      optional<std::vector<int64>> lhs_batch_dims;
      attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
                                 &lhs_batch_dims};
      optional<std::vector<int64>> rhs_batch_dims;
      attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
                                 &rhs_batch_dims};
      optional<std::vector<PrecisionConfig::Precision>> operand_precision;
      attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
                                    &operand_precision};

      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }

      DotDimensionNumbers dnum;
      if (lhs_contracting_dims) {
        *dnum.mutable_lhs_contracting_dimensions() = {
            lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
      }
      if (rhs_contracting_dims) {
        *dnum.mutable_rhs_contracting_dimensions() = {
            rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
      }
      if (lhs_batch_dims) {
        *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
                                                lhs_batch_dims->end()};
      }
      if (rhs_batch_dims) {
        *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
                                                rhs_batch_dims->end()};
      }

      PrecisionConfig precision_config;
      if (operand_precision) {
        *precision_config.mutable_operand_precision() = {
            operand_precision->begin(), operand_precision->end()};
      } else {
        precision_config.mutable_operand_precision()->Resize(
            operands.size(), PrecisionConfig::DEFAULT);
      }

      instruction = builder->AddInstruction(HloInstruction::CreateDot(
          shape, operands[0], operands[1], dnum, precision_config));
      break;
    }
    case HloOpcode::kGather: {
      optional<std::vector<int64>> offset_dims;
      attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
                              &offset_dims};
      optional<std::vector<int64>> collapsed_slice_dims;
      attrs["collapsed_slice_dims"] = {
          /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
      optional<std::vector<int64>> start_index_map;
      attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
                                  &start_index_map};
      optional<int64> index_vector_dim;
      attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
                                   &index_vector_dim};
      optional<std::vector<int64>> slice_sizes;
      attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
                              &slice_sizes};
      optional<bool> indices_are_sorted = false;
      attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
                                     &indices_are_sorted};

      if (!ParseOperands(&operands, /*expected_size=*/2) ||
          !ParseAttributes(attrs)) {
        return false;
      }

      GatherDimensionNumbers dim_numbers =
          HloGatherInstruction::MakeGatherDimNumbers(
              /*offset_dims=*/*offset_dims,
              /*collapsed_slice_dims=*/*collapsed_slice_dims,
              /*start_index_map=*/*start_index_map,
              /*index_vector_dim=*/*index_vector_dim);

      instruction = builder->AddInstruction(HloInstruction::CreateGather(
          shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
          dim_numbers, *slice_sizes, indices_are_sorted.value()));
      break;
    }
    case HloOpcode::kScatter: {
      optional<std::vector<int64>> update_window_dims;
      attrs["update_window_dims"] = {
          /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
      optional<std::vector<int64>> inserted_window_dims;
      attrs["inserted_window_dims"] = {
          /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
      optional<std::vector<int64>> scatter_dims_to_operand_dims;
      attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
                                               AttrTy::kBracedInt64List,
                                               &scatter_dims_to_operand_dims};
      optional<int64> index_vector_dim;
      attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
                                   &index_vector_dim};

      optional<HloComputation*> update_computation;
      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
                           &update_computation};
      optional<bool> indices_are_sorted = false;
      attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
                                     &indices_are_sorted};

      if (!ParseOperands(&operands, /*expected_size=*/3) ||
          !ParseAttributes(attrs)) {
        return false;
      }

      ScatterDimensionNumbers dim_numbers =
          HloScatterInstruction::MakeScatterDimNumbers(
              /*update_window_dims=*/*update_window_dims,
              /*inserted_window_dims=*/*inserted_window_dims,
              /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
              /*index_vector_dim=*/*index_vector_dim);

      instruction = builder->AddInstruction(HloInstruction::CreateScatter(
          shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
          /*updates=*/operands[2], *update_computation, dim_numbers,
          indices_are_sorted.value()));
      break;
    }
    case HloOpcode::kDomain: {
      DomainData domain;
      attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction = builder->AddInstruction(HloInstruction::CreateDomain(
          shape, operands[0], std::move(domain.exit_metadata),
          std::move(domain.entry_metadata)));
      break;
    }
    case HloOpcode::kTrace:
      return TokenError(StrCat("parsing not yet implemented for op: ",
                               HloOpcodeString(opcode)));
    case HloOpcode::kGetDimensionSize:
      optional<std::vector<int64>> dimensions;
      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
                             &dimensions};
      if (!ParseOperands(&operands, /*expected_size=*/1) ||
          !ParseAttributes(attrs)) {
        return false;
      }
      instruction =
          builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
              shape, operands[0], (*dimensions)[0]));
      break;
  }

  instruction->SetAndSanitizeName(name);
  if (instruction->name() != name) {
    return Error(name_loc,
                 StrCat("illegal instruction name: ", name,
                        "; suggest renaming to: ", instruction->name()));
  }

  // Add shared attributes like metadata to the instruction, if they were seen.
  if (sharding) {
    instruction->set_sharding(
        HloSharding::FromProto(sharding.value()).ValueOrDie());
  }
  if (parameter_replication) {
    int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
    const auto& replicated =
        parameter_replication->replicated_at_leaf_buffers();
    if (leaf_count != replicated.size()) {
      return Error(lexer_.GetLoc(),
                   StrCat("parameter has ", leaf_count,
                          " leaf buffers, but parameter_replication has ",
                          replicated.size(), " elements."));
    }
    instruction->set_parameter_replicated_at_leaf_buffers(replicated);
  }
  if (predecessors) {
    for (auto* pre : *predecessors) {
      Status status = pre->AddControlDependencyTo(instruction);
      if (!status.ok()) {
        return Error(name_loc, StrCat("error adding control dependency for: ",
                                      name, " status: ", status.ToString()));
      }
    }
  }
  if (metadata) {
    instruction->set_metadata(*metadata);
  }
  if (backend_config) {
    instruction->set_raw_backend_config_string(std::move(*backend_config));
  }
  if (outer_dimension_partitions) {
    instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
  }
  return AddInstruction(name, instruction, name_loc);
}  // NOLINT(readability/fn_size)