in libraries/value/src/LLVMContext.cpp [1265:1419]
Value LLVMContext::BinaryOperationImpl(ValueBinaryOperation op, Value destination, Value source)
{
if (!source.IsDefined())
{
throw InputException(InputExceptionErrors::invalidArgument);
}
if (destination.IsDefined())
{
if (source.IsConstant() && destination.IsConstant())
{
return _computeContext.BinaryOperation(op, destination, source);
}
}
else
{
destination = Allocate(source.GetBaseType(), source.GetLayout());
}
if (!TypeCompatible(destination, source))
{
throw InputException(InputExceptionErrors::typeMismatch);
}
if (!((source.PointerLevel() == 0 || source.PointerLevel() == 1) &&
(destination.PointerLevel() == 0 || destination.PointerLevel() == 1)))
{
throw InputException(InputExceptionErrors::invalidArgument);
}
if (destination.GetLayout() != source.GetLayout())
{
throw InputException(InputExceptionErrors::sizeMismatch);
}
auto& fn = GetFunctionEmitter();
std::visit(
[this, destination = EnsureEmittable(destination), &source, op, &fn](auto&& sourceData) {
using SourceDataType = std::decay_t<decltype(sourceData)>;
if constexpr (std::is_same_v<Boolean*, SourceDataType>)
{
throw LogicException(LogicExceptionErrors::notImplemented);
}
else
{
auto isFp = destination.IsFloatingPoint();
std::function<LLVMValue(LLVMValue, LLVMValue)> opFn;
switch (op)
{
case ValueBinaryOperation::add:
opFn = [&fn, isFp](auto dst, auto src) {
return fn.Operator(isFp ? TypedOperator::addFloat : TypedOperator::add, dst, src);
};
break;
case ValueBinaryOperation::subtract:
opFn = [&fn, isFp](auto dst, auto src) {
return fn.Operator(isFp ? TypedOperator::subtractFloat : TypedOperator::subtract, dst, src);
};
break;
case ValueBinaryOperation::multiply:
opFn = [&fn, isFp](auto dst, auto src) {
return fn.Operator(isFp ? TypedOperator::multiplyFloat : TypedOperator::multiply, dst, src);
};
break;
case ValueBinaryOperation::divide:
opFn = [&fn, isFp](auto dst, auto src) {
return fn.Operator(isFp ? TypedOperator::divideFloat : TypedOperator::divideSigned,
dst,
src);
};
break;
case ValueBinaryOperation::modulus:
if (isFp)
{
throw InputException(InputExceptionErrors::invalidArgument);
}
opFn = [&fn](auto dst, auto src) { return fn.Operator(TypedOperator::moduloSigned, dst, src); };
break;
case ValueBinaryOperation::logicalAnd:
[[fallthrough]];
case ValueBinaryOperation::logicalOr:
if (destination.GetBaseType() != ValueType::Boolean)
{
throw InputException(InputExceptionErrors::invalidArgument);
}
opFn = [&fn, op](auto dst, auto src) {
return fn.Operator(op == ValueBinaryOperation::logicalAnd ? TypedOperator::logicalAnd : TypedOperator::logicalOr, dst, src);
};
break;
default:
throw LogicException(LogicExceptionErrors::illegalState);
}
auto& layout = destination.GetLayout();
auto maxCoordinate = layout.GetActiveSize().ToVector();
decltype(maxCoordinate) coordinate(maxCoordinate.size());
if constexpr (std::is_same_v<Emittable, SourceDataType>)
{
// If the pointer levels don't match, it means the source is not a pointer (logically)
// and we just need to do an assignment of the value to the value pointed to by
// destintion
bool scalarLLVMSource = source.PointerLevel() == 0;
bool scalarLLVMDestination = destination.PointerLevel() == 0;
ForImpl(
layout, [&](std::vector<Scalar> index) {
LLVMValue srcValue = nullptr;
if (scalarLLVMSource)
{
srcValue = ToLLVMValue(source);
}
else
{
auto offsetSource = source.Offset(detail::CalculateOffset(source.GetLayout(), index));
srcValue = fn.Load(ToLLVMValue(offsetSource));
}
LLVMValue destValue = nullptr;
LLVMValue destValueOffset = nullptr;
if (scalarLLVMDestination)
{
destValue = ToLLVMValue(destination);
}
else
{
destValueOffset = ToLLVMValue(destination.Offset(detail::CalculateOffset(destination.GetLayout(), index)));
destValue = fn.Load(destValueOffset);
}
auto result = opFn(destValue, srcValue);
if (!scalarLLVMDestination)
{
fn.Store(destValueOffset, result);
}
else
{
const_cast<Value&>(destination).SetData(Emittable{ result });
}
},
"");
}
else
{
auto destValue = ToLLVMValue(destination);
ConstantForLoop(layout, [&](int offset) {
auto offsetLiteral = fn.Literal(offset);
auto opResult = opFn(fn.ValueAt(destValue, offsetLiteral), fn.Literal(*(sourceData + offset)));
fn.SetValueAt(destValue, offsetLiteral, opResult);
});
}
}
},
source.GetUnderlyingData());
return destination;
}