in tensorflow/tensorflow/compiler/xla/service/hlo_parser.cc [667:1814]
bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
const string& name, LocTy name_loc) {
Shape shape;
HloOpcode opcode;
std::vector<HloInstruction*> operands;
if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
return false;
}
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
optional<ParameterReplication> parameter_replication;
attrs["parameter_replication"] = {/*required=*/false,
AttrTy::kParameterReplication,
¶meter_replication};
optional<std::vector<HloInstruction*>> predecessors;
attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
&predecessors};
optional<OpMetadata> metadata;
attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
optional<string> backend_config;
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
optional<std::vector<int64>> outer_dimension_partitions;
attrs["outer_dimension_partitions"] = {/*required=*/false,
AttrTy::kBracedInt64List,
&outer_dimension_partitions};
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
int64 parameter_number;
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(¶meter_number)) {
return false;
}
if (parameter_number < 0) {
Error(lexer_.GetLoc(), "parameter number must be >= 0");
return false;
}
if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateParameter(parameter_number, shape, name));
break;
}
case HloOpcode::kConstant: {
Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
!ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
break;
}
case HloOpcode::kIota: {
optional<int64> iota_dimension;
attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
&iota_dimension};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateIota(shape, *iota_dimension));
break;
}
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kTanh: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateUnary(shape, opcode, operands[0]));
break;
}
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBinary(
shape, opcode, operands[0], operands[1]));
break;
}
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
case HloOpcode::kTupleSelect: {
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateTernary(
shape, opcode, operands[0], operands[1], operands[2]));
break;
}
// Other supported ops.
case HloOpcode::kConvert: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateConvert(shape, operands[0]));
break;
}
case HloOpcode::kBitcastConvert: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateBitcastConvert(shape, operands[0]));
break;
}
case HloOpcode::kAllReduce: {
optional<std::vector<std::vector<int64>>> tmp_groups;
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<int64> channel_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
replica_groups = CreateReplicaGroups(*tmp_groups);
}
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
shape, operands, *to_apply, replica_groups, channel_id));
break;
}
case HloOpcode::kAllToAll: {
optional<std::vector<std::vector<int64>>> tmp_groups;
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
replica_groups = CreateReplicaGroups(*tmp_groups);
}
instruction = builder->AddInstruction(
HloInstruction::CreateAllToAll(shape, operands, replica_groups));
break;
}
case HloOpcode::kCollectivePermute: {
optional<std::vector<std::vector<int64>>> source_targets;
attrs["source_target_pairs"] = {
/*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
optional<int64> channel_id;
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
std::vector<std::pair<int64, int64>> pairs(source_targets->size());
for (int i = 0; i < pairs.size(); i++) {
if ((*source_targets)[i].size() != 2) {
return TokenError(
"expects 'source_target_pairs=' to be a list of pairs");
}
pairs[i].first = (*source_targets)[i][0];
pairs[i].second = (*source_targets)[i][1];
}
instruction =
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
shape, operands[0], pairs, channel_id));
break;
}
case HloOpcode::kReplicaId: {
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
break;
}
case HloOpcode::kPartitionId: {
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreatePartitionId());
break;
}
case HloOpcode::kReshape: {
optional<int64> inferred_dimension;
attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
&inferred_dimension};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReshape(
shape, operands[0], inferred_dimension.value_or(-1)));
break;
}
case HloOpcode::kAfterAll: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.empty()) {
instruction = builder->AddInstruction(HloInstruction::CreateToken());
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
}
break;
}
case HloOpcode::kAddDependency: {
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateAddDependency(operands[0], operands[1]));
break;
}
case HloOpcode::kSort: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
optional<bool> is_stable = false;
attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
dimensions->size() != 1) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateSort(shape, dimensions->at(0), operands,
to_apply.value(), is_stable.value()));
break;
}
case HloOpcode::kTuple: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateTuple(operands));
break;
}
case HloOpcode::kWhile: {
optional<HloComputation*> condition;
optional<HloComputation*> body;
attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
&condition};
attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateWhile(
shape, *condition, *body, /*init=*/operands[0]));
break;
}
case HloOpcode::kRecv: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
// If the is_host_transfer attribute is not present then default to false.
instruction = builder->AddInstruction(HloInstruction::CreateRecv(
shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kRecvDone: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kSend: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateSend(
operands[0], operands[1], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kSendDone: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kGetTupleElement: {
optional<int64> index;
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
break;
}
case HloOpcode::kCall: {
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCall(shape, operands, *to_apply));
break;
}
case HloOpcode::kReduceWindow: {
optional<HloComputation*> reduce_computation;
optional<Window> window;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
if (!window) {
window.emplace();
}
instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
*reduce_computation));
break;
}
case HloOpcode::kConvolution: {
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
optional<int64> batch_group_count;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/true,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
&batch_group_count};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
if (!window) {
window.emplace();
}
if (!feature_group_count) {
feature_group_count = 1;
}
if (!batch_group_count) {
batch_group_count = 1;
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
feature_group_count.value(), batch_group_count.value(), *window,
*dnums, precision_config));
break;
}
case HloOpcode::kFft: {
optional<FftType> fft_type;
optional<std::vector<int64>> fft_length;
attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
&fft_length};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateFft(
shape, operands[0], *fft_type, *fft_length));
break;
}
case HloOpcode::kTriangularSolve: {
TriangularSolveOptions options;
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributesAsProtoMessage(
/*required_attrs=*/std::unordered_set<string>(), &options)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateTriangularSolve(
shape, operands[0], operands[1], options));
break;
}
case HloOpcode::kCompare: {
optional<ComparisonDirection> direction;
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
&direction};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
shape, operands[0], operands[1], *direction));
break;
}
case HloOpcode::kCholesky: {
CholeskyOptions options;
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributesAsProtoMessage(
/*required_attrs=*/std::unordered_set<string>(), &options)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCholesky(shape, operands[0], options));
break;
}
case HloOpcode::kBroadcast: {
optional<std::vector<int64>> broadcast_dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&broadcast_dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
shape, operands[0], *broadcast_dimensions));
break;
}
case HloOpcode::kConcatenate: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
dimensions->size() != 1) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, dimensions->at(0)));
break;
}
case HloOpcode::kMap: {
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateMap(shape, operands, *to_apply));
break;
}
case HloOpcode::kReduce: {
auto loc = lexer_.GetLoc();
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
optional<std::vector<int64>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.size() % 2) {
return Error(loc, StrCat("expects an even number of operands, but has ",
operands.size(), " operands"));
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operands=*/
absl::Span<HloInstruction* const>(operands).subspan(
0, operands.size() / 2),
/*init_values=*/
absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
2),
*dimensions_to_reduce, *reduce_computation));
break;
}
case HloOpcode::kReverse: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateReverse(shape, operands[0], *dimensions));
break;
}
case HloOpcode::kSelectAndScatter: {
optional<HloComputation*> select;
attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
optional<HloComputation*> scatter;
attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
optional<Window> window;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
if (!window) {
window.emplace();
}
instruction =
builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
shape, /*operand=*/operands[0], *select, *window,
/*source=*/operands[1], /*init_value=*/operands[2], *scatter));
break;
}
case HloOpcode::kSlice: {
optional<SliceRanges> slice_ranges;
attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateSlice(
shape, operands[0], slice_ranges->starts, slice_ranges->limits,
slice_ranges->strides));
break;
}
case HloOpcode::kDynamicSlice: {
optional<std::vector<int64>> dynamic_slice_sizes;
attrs["dynamic_slice_sizes"] = {
/*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.empty()) {
return Error(loc, "Expected at least one operand.");
}
if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
operands.size() != 1 + operands[0]->shape().rank()) {
return Error(loc, "Wrong number of operands.");
}
instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
shape, /*operand=*/operands[0],
/*start_indices=*/absl::MakeSpan(operands).subspan(1),
*dynamic_slice_sizes));
break;
}
case HloOpcode::kDynamicUpdateSlice: {
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.size() < 2) {
return Error(loc, "Expected at least two operands.");
}
if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
operands.size() != 2 + operands[0]->shape().rank()) {
return Error(loc, "Wrong number of operands.");
}
instruction =
builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
shape, /*operand=*/operands[0], /*update=*/operands[1],
/*start_indices=*/absl::MakeSpan(operands).subspan(2)));
break;
}
case HloOpcode::kTranspose: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
break;
}
case HloOpcode::kBatchNormTraining: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*offset=*/operands[2], *epsilon, *feature_index));
break;
}
case HloOpcode::kBatchNormInference: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateBatchNormInference(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*offset=*/operands[2], /*mean=*/operands[3],
/*variance=*/operands[4], *epsilon, *feature_index));
break;
}
case HloOpcode::kBatchNormGrad: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*mean=*/operands[2], /*variance=*/operands[3],
/*grad_output=*/operands[4], *epsilon, *feature_index));
break;
}
case HloOpcode::kPad: {
optional<PaddingConfig> padding;
attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreatePad(
shape, operands[0], /*padding_value=*/operands[1], *padding));
break;
}
case HloOpcode::kFusion: {
optional<HloComputation*> fusion_computation;
attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
&fusion_computation};
optional<HloInstruction::FusionKind> fusion_kind;
attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateFusion(
shape, *fusion_kind, operands, *fusion_computation));
break;
}
case HloOpcode::kInfeed: {
optional<string> config;
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
// We need to know the infeed data shape to construct the infeed
// instruction. This is the zero-th element of the tuple-shaped output of
// the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
// if the shape is not a non-empty tuple, so add guard so an error message
// can be emitted instead of a check fail
if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
return Error(lexer_.GetLoc(),
"infeed must have a non-empty tuple shape");
}
instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
config ? *config : ""));
break;
}
case HloOpcode::kOutfeed: {
optional<string> config;
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
operands[1], config ? *config : ""));
break;
}
case HloOpcode::kRng: {
optional<RandomDistribution> distribution;
attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
&distribution};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRng(shape, *distribution, operands));
break;
}
case HloOpcode::kRngGetAndUpdateState: {
optional<int64> delta;
attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
if (!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
break;
}
case HloOpcode::kReducePrecision: {
optional<int64> exponent_bits;
optional<int64> mantissa_bits;
attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
&exponent_bits};
attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
&mantissa_bits};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateReducePrecision(
shape, operands[0], static_cast<int>(*exponent_bits),
static_cast<int>(*mantissa_bits)));
break;
}
case HloOpcode::kConditional: {
optional<HloComputation*> true_computation;
optional<HloComputation*> false_computation;
optional<std::vector<HloComputation*>> branch_computations;
if (!ParseOperands(&operands)) {
return false;
}
if (!ShapeUtil::IsScalar(operands[0]->shape())) {
return Error(lexer_.GetLoc(), "The first operand must be a scalar");
}
const bool branch_index_is_bool =
operands[0]->shape().element_type() == PRED;
if (branch_index_is_bool) {
attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
&true_computation};
attrs["false_computation"] = {
/*required=*/true, AttrTy::kHloComputation, &false_computation};
} else {
if (operands[0]->shape().element_type() != S32) {
return Error(lexer_.GetLoc(),
"The first operand must be a scalar of PRED or S32");
}
attrs["branch_computations"] = {/*required=*/true,
AttrTy::kBracedHloComputationList,
&branch_computations};
}
if (!ParseAttributes(attrs)) {
return false;
}
if (branch_index_is_bool) {
branch_computations.emplace({*true_computation, *false_computation});
}
if (branch_computations->empty() ||
operands.size() != branch_computations->size() + 1) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConditional(
shape, /*branch_index=*/operands[0],
absl::MakeSpan(*branch_computations),
absl::MakeSpan(operands).subspan(1)));
break;
}
case HloOpcode::kCustomCall: {
optional<string> custom_call_target;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
optional<int64> batch_group_count;
optional<std::vector<Shape>> operand_layout_constraints;
optional<bool> custom_call_has_side_effect;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
&batch_group_count};
attrs["operand_layout_constraints"] = {
/*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
&custom_call_has_side_effect};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operand_layout_constraints.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return Error(lexer_.GetLoc(),
"Layout must be set on layout-constrained custom call");
}
if (operands.size() != operand_layout_constraints->size()) {
return Error(lexer_.GetLoc(),
StrCat("Expected ", operands.size(),
" operand layout constraints, ",
operand_layout_constraints->size(), " given"));
}
for (int64 i = 0; i < operands.size(); ++i) {
const Shape& operand_shape_with_layout =
(*operand_layout_constraints)[i];
if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
return Error(lexer_.GetLoc(),
StrCat("Operand layout constraint shape ",
ShapeUtil::HumanStringWithLayout(
operand_shape_with_layout),
" for operand ", i, " does not have a layout"));
}
if (!ShapeUtil::Compatible(operand_shape_with_layout,
operands[i]->shape())) {
return Error(
lexer_.GetLoc(),
StrCat(
"Operand layout constraint shape ",
ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
" for operand ", i,
" is not compatible with operand shape ",
ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
}
}
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target, *operand_layout_constraints,
backend_config ? *backend_config : ""));
} else {
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target,
backend_config ? *backend_config : ""));
}
auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
if (window.has_value()) {
custom_call_instr->set_window(*window);
}
if (dnums.has_value()) {
custom_call_instr->set_convolution_dimension_numbers(*dnums);
}
if (feature_group_count.has_value()) {
custom_call_instr->set_feature_group_count(*feature_group_count);
}
if (batch_group_count.has_value()) {
custom_call_instr->set_batch_group_count(*batch_group_count);
}
if (custom_call_has_side_effect.has_value()) {
custom_call_instr->set_custom_call_has_side_effect(
*custom_call_has_side_effect);
}
break;
}
case HloOpcode::kDot: {
optional<std::vector<int64>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
optional<std::vector<int64>> rhs_contracting_dims;
attrs["rhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
optional<std::vector<int64>> lhs_batch_dims;
attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&lhs_batch_dims};
optional<std::vector<int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
DotDimensionNumbers dnum;
if (lhs_contracting_dims) {
*dnum.mutable_lhs_contracting_dimensions() = {
lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
}
if (rhs_contracting_dims) {
*dnum.mutable_rhs_contracting_dimensions() = {
rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
}
if (lhs_batch_dims) {
*dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
lhs_batch_dims->end()};
}
if (rhs_batch_dims) {
*dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
rhs_batch_dims->end()};
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
instruction = builder->AddInstruction(HloInstruction::CreateDot(
shape, operands[0], operands[1], dnum, precision_config));
break;
}
case HloOpcode::kGather: {
optional<std::vector<int64>> offset_dims;
attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
&offset_dims};
optional<std::vector<int64>> collapsed_slice_dims;
attrs["collapsed_slice_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
optional<std::vector<int64>> start_index_map;
attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
&start_index_map};
optional<int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
optional<std::vector<int64>> slice_sizes;
attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
&slice_sizes};
optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
GatherDimensionNumbers dim_numbers =
HloGatherInstruction::MakeGatherDimNumbers(
/*offset_dims=*/*offset_dims,
/*collapsed_slice_dims=*/*collapsed_slice_dims,
/*start_index_map=*/*start_index_map,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
dim_numbers, *slice_sizes, indices_are_sorted.value()));
break;
}
case HloOpcode::kScatter: {
optional<std::vector<int64>> update_window_dims;
attrs["update_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
optional<std::vector<int64>> inserted_window_dims;
attrs["inserted_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
optional<std::vector<int64>> scatter_dims_to_operand_dims;
attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
AttrTy::kBracedInt64List,
&scatter_dims_to_operand_dims};
optional<int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
optional<HloComputation*> update_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&update_computation};
optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
ScatterDimensionNumbers dim_numbers =
HloScatterInstruction::MakeScatterDimNumbers(
/*update_window_dims=*/*update_window_dims,
/*inserted_window_dims=*/*inserted_window_dims,
/*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateScatter(
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
/*updates=*/operands[2], *update_computation, dim_numbers,
indices_are_sorted.value()));
break;
}
case HloOpcode::kDomain: {
DomainData domain;
attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateDomain(
shape, operands[0], std::move(domain.exit_metadata),
std::move(domain.entry_metadata)));
break;
}
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
case HloOpcode::kGetDimensionSize:
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
shape, operands[0], (*dimensions)[0]));
break;
}
instruction->SetAndSanitizeName(name);
if (instruction->name() != name) {
return Error(name_loc,
StrCat("illegal instruction name: ", name,
"; suggest renaming to: ", instruction->name()));
}
// Add shared attributes like metadata to the instruction, if they were seen.
if (sharding) {
instruction->set_sharding(
HloSharding::FromProto(sharding.value()).ValueOrDie());
}
if (parameter_replication) {
int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
const auto& replicated =
parameter_replication->replicated_at_leaf_buffers();
if (leaf_count != replicated.size()) {
return Error(lexer_.GetLoc(),
StrCat("parameter has ", leaf_count,
" leaf buffers, but parameter_replication has ",
replicated.size(), " elements."));
}
instruction->set_parameter_replicated_at_leaf_buffers(replicated);
}
if (predecessors) {
for (auto* pre : *predecessors) {
Status status = pre->AddControlDependencyTo(instruction);
if (!status.ok()) {
return Error(name_loc, StrCat("error adding control dependency for: ",
name, " status: ", status.ToString()));
}
}
}
if (metadata) {
instruction->set_metadata(*metadata);
}
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
if (outer_dimension_partitions) {
instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
}
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)