in tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h [38:133]
inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
const RuntimeShape& shape1,
tflite::ArithmeticParams* params) {
const int dims_count =
std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
RuntimeShape scalar_shape(dims_count, 1);
auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
// Check for "exact" match, implicitly accepting any scalar shapes.
if (extended_shape0 == extended_shape1) {
params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
return false;
}
for (int i = dims_count - 1; i >= 0; --i) {
if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
continue;
} else if (extended_shape0.Dims(i) == 1) {
params->broadcast_category =
BroadcastableOpCategory::kFirstInputBroadcastsFast;
break;
} else if (extended_shape1.Dims(i) == 1) {
params->broadcast_category =
BroadcastableOpCategory::kSecondInputBroadcastsFast;
break;
} else {
// This case is erroneous: there is a dimension that does not match and
// is not a broadcast from one shape to the other.
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
return true;
}
}
if (params->broadcast_category !=
BroadcastableOpCategory::kFirstInputBroadcastsFast &&
params->broadcast_category !=
BroadcastableOpCategory::kSecondInputBroadcastsFast) {
// This is unreachable because at least one else clause in the above loop
// must be reached.
TFLITE_DCHECK(false);
params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
return false;
}
// From this point it is assumed contractually that corresponding dimensions
// in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
const bool swap_inputs = params->broadcast_category ==
BroadcastableOpCategory::kSecondInputBroadcastsFast;
const RuntimeShape* shape_a =
swap_inputs ? &extended_shape1 : &extended_shape0;
const RuntimeShape* shape_b =
swap_inputs ? &extended_shape0 : &extended_shape1;
int i = dims_count - 1;
params->broadcast_shape[0] = 1;
params->broadcast_shape[1] = 1;
params->broadcast_shape[2] = 1;
params->broadcast_shape[3] = 1;
params->broadcast_shape[4] = 1;
// y_0 is greedy: include dims if both or neither equal 1: in other words,
// test for equality rather than (shape_a->Dims(i) != 1).
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
params->broadcast_shape[4] *= shape_b->Dims(i);
--i;
}
// Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b
// that has the unit dimension, the next two loops are not entered.
while (i >= 0 && shape_a->Dims(i) == 1) {
params->broadcast_shape[3] *= shape_b->Dims(i);
--i;
}
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
params->broadcast_shape[2] *= shape_a->Dims(i);
--i;
}
// Here either input_a or input_b has dim of 1 (if i >= 0).
while (i >= 0 && shape_b->Dims(i) == 1) {
params->broadcast_shape[1] *= shape_a->Dims(i);
--i;
}
while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
params->broadcast_shape[0] *= shape_b->Dims(i);
--i;
}
// Rarer case is when the broadcast dimensions cannot be handled by a fivefold
// loop.
if (i >= 0) {
params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
}
return true;
}