Value LLVMContext::BinaryOperationImpl()

in libraries/value/src/LLVMContext.cpp [1265:1419]


    Value LLVMContext::BinaryOperationImpl(ValueBinaryOperation op, Value destination, Value source)
    {
        if (!source.IsDefined())
        {
            throw InputException(InputExceptionErrors::invalidArgument);
        }

        if (destination.IsDefined())
        {
            if (source.IsConstant() && destination.IsConstant())
            {
                return _computeContext.BinaryOperation(op, destination, source);
            }
        }
        else
        {
            destination = Allocate(source.GetBaseType(), source.GetLayout());
        }

        if (!TypeCompatible(destination, source))
        {
            throw InputException(InputExceptionErrors::typeMismatch);
        }

        if (!((source.PointerLevel() == 0 || source.PointerLevel() == 1) &&
              (destination.PointerLevel() == 0 || destination.PointerLevel() == 1)))
        {
            throw InputException(InputExceptionErrors::invalidArgument);
        }

        if (destination.GetLayout() != source.GetLayout())
        {
            throw InputException(InputExceptionErrors::sizeMismatch);
        }

        auto& fn = GetFunctionEmitter();
        std::visit(
            [this, destination = EnsureEmittable(destination), &source, op, &fn](auto&& sourceData) {
                using SourceDataType = std::decay_t<decltype(sourceData)>;
                if constexpr (std::is_same_v<Boolean*, SourceDataType>)
                {
                    throw LogicException(LogicExceptionErrors::notImplemented);
                }
                else
                {
                    auto isFp = destination.IsFloatingPoint();
                    std::function<LLVMValue(LLVMValue, LLVMValue)> opFn;
                    switch (op)
                    {
                    case ValueBinaryOperation::add:
                        opFn = [&fn, isFp](auto dst, auto src) {
                            return fn.Operator(isFp ? TypedOperator::addFloat : TypedOperator::add, dst, src);
                        };
                        break;
                    case ValueBinaryOperation::subtract:
                        opFn = [&fn, isFp](auto dst, auto src) {
                            return fn.Operator(isFp ? TypedOperator::subtractFloat : TypedOperator::subtract, dst, src);
                        };
                        break;
                    case ValueBinaryOperation::multiply:
                        opFn = [&fn, isFp](auto dst, auto src) {
                            return fn.Operator(isFp ? TypedOperator::multiplyFloat : TypedOperator::multiply, dst, src);
                        };
                        break;
                    case ValueBinaryOperation::divide:
                        opFn = [&fn, isFp](auto dst, auto src) {
                            return fn.Operator(isFp ? TypedOperator::divideFloat : TypedOperator::divideSigned,
                                               dst,
                                               src);
                        };
                        break;
                    case ValueBinaryOperation::modulus:
                        if (isFp)
                        {
                            throw InputException(InputExceptionErrors::invalidArgument);
                        }
                        opFn = [&fn](auto dst, auto src) { return fn.Operator(TypedOperator::moduloSigned, dst, src); };
                        break;
                    case ValueBinaryOperation::logicalAnd:
                        [[fallthrough]];
                    case ValueBinaryOperation::logicalOr:
                        if (destination.GetBaseType() != ValueType::Boolean)
                        {
                            throw InputException(InputExceptionErrors::invalidArgument);
                        }
                        opFn = [&fn, op](auto dst, auto src) {
                            return fn.Operator(op == ValueBinaryOperation::logicalAnd ? TypedOperator::logicalAnd : TypedOperator::logicalOr, dst, src);
                        };
                        break;
                    default:
                        throw LogicException(LogicExceptionErrors::illegalState);
                    }

                    auto& layout = destination.GetLayout();
                    auto maxCoordinate = layout.GetActiveSize().ToVector();
                    decltype(maxCoordinate) coordinate(maxCoordinate.size());

                    if constexpr (std::is_same_v<Emittable, SourceDataType>)
                    {
                        // If the pointer levels don't match, it means the source is not a pointer (logically)
                        // and we just need to do an assignment of the value to the value pointed to by
                        // destintion
                        bool scalarLLVMSource = source.PointerLevel() == 0;
                        bool scalarLLVMDestination = destination.PointerLevel() == 0;
                        ForImpl(
                            layout, [&](std::vector<Scalar> index) {
                                LLVMValue srcValue = nullptr;
                                if (scalarLLVMSource)
                                {
                                    srcValue = ToLLVMValue(source);
                                }
                                else
                                {
                                    auto offsetSource = source.Offset(detail::CalculateOffset(source.GetLayout(), index));
                                    srcValue = fn.Load(ToLLVMValue(offsetSource));
                                }

                                LLVMValue destValue = nullptr;
                                LLVMValue destValueOffset = nullptr;
                                if (scalarLLVMDestination)
                                {
                                    destValue = ToLLVMValue(destination);
                                }
                                else
                                {
                                    destValueOffset = ToLLVMValue(destination.Offset(detail::CalculateOffset(destination.GetLayout(), index)));
                                    destValue = fn.Load(destValueOffset);
                                }
                                auto result = opFn(destValue, srcValue);
                                if (!scalarLLVMDestination)
                                {
                                    fn.Store(destValueOffset, result);
                                }
                                else
                                {
                                    const_cast<Value&>(destination).SetData(Emittable{ result });
                                }
                            },
                            "");
                    }
                    else
                    {
                        auto destValue = ToLLVMValue(destination);
                        ConstantForLoop(layout, [&](int offset) {
                            auto offsetLiteral = fn.Literal(offset);
                            auto opResult = opFn(fn.ValueAt(destValue, offsetLiteral), fn.Literal(*(sourceData + offset)));
                            fn.SetValueAt(destValue, offsetLiteral, opResult);
                        });
                    }
                }
            },
            source.GetUnderlyingData());

        return destination;
    }