in libraries/value/src/ComputeContext.cpp [1126:1238]
Value ComputeContext::BinaryOperationImpl(ValueBinaryOperation op, Value destination, Value source)
{
if (!ValidateValue(source))
{
throw InputException(InputExceptionErrors::invalidArgument);
}
if (!ValidateValue(destination))
{
destination = Allocate(source.GetBaseType(), source.GetLayout());
}
if (!TypeCompatible(destination, source))
{
throw InputException(InputExceptionErrors::typeMismatch);
}
if (destination.GetLayout() != source.GetLayout())
{
throw InputException(InputExceptionErrors::sizeMismatch);
}
std::visit(
VariantVisitor{
[](Emittable) {},
[&destination, &source, op](auto&& destinationData) {
using DestinationDataType =
std::remove_pointer_t<std::decay_t<decltype(destinationData)>>;
std::function<DestinationDataType(DestinationDataType, DestinationDataType)>
opFn;
if constexpr (!std::is_same_v<DestinationDataType, Boolean>)
{
switch (op)
{
case ValueBinaryOperation::add:
opFn = [](auto dst, auto src) { return dst + src; };
break;
case ValueBinaryOperation::subtract:
opFn = [](auto dst, auto src) { return dst - src; };
break;
case ValueBinaryOperation::multiply:
opFn = [](auto dst, auto src) { return dst * src; };
break;
case ValueBinaryOperation::divide:
opFn = [](auto dst, auto src) { return dst / src; };
break;
default:
if constexpr (std::is_integral_v<DestinationDataType>)
{
switch (op)
{
case ValueBinaryOperation::modulus:
opFn = [](auto dst, auto src) { return dst % src; };
break;
default:
throw LogicException(LogicExceptionErrors::illegalState);
}
}
else
{
throw LogicException(LogicExceptionErrors::illegalState);
}
}
}
else
{
switch (op)
{
case ValueBinaryOperation::logicalAnd:
opFn = [](auto dst, auto src) { return dst && src; };
break;
case ValueBinaryOperation::logicalOr:
opFn = [](auto dst, auto src) { return dst || src; };
break;
default:
throw LogicException(LogicExceptionErrors::illegalState);
}
}
auto& sourceData = std::get<DestinationDataType*>(source.GetUnderlyingData());
if (source.GetLayout().IsContiguous() && destination.GetLayout().IsContiguous())
{
auto numElements = destination.GetLayout().NumElements();
std::transform(destinationData,
destinationData + numElements,
sourceData,
destinationData,
opFn);
}
else
{
auto& sourceLayout = source.GetLayout();
auto maxCoordinate = sourceLayout.GetActiveSize().ToVector();
decltype(maxCoordinate) coordinate(maxCoordinate.size());
do
{
auto logicalCoordinates = sourceLayout.GetLogicalCoordinates(coordinate);
auto sourceOffset =
sourceLayout.GetLogicalEntryOffset(logicalCoordinates);
auto destinationOffset =
destination.GetLayout().GetLogicalEntryOffset(logicalCoordinates);
*(destinationData + destinationOffset) =
opFn(*(destinationData + destinationOffset),
*(sourceData + sourceOffset));
} while (IncrementMemoryCoordinate(coordinate, maxCoordinate));
}
} },
destination.GetUnderlyingData());
return destination;
}