Value ComputeContext::BinaryOperationImpl()

in libraries/value/src/ComputeContext.cpp [1126:1238]


    Value ComputeContext::BinaryOperationImpl(ValueBinaryOperation op, Value destination, Value source)
    {
        if (!ValidateValue(source))
        {
            throw InputException(InputExceptionErrors::invalidArgument);
        }
        if (!ValidateValue(destination))
        {
            destination = Allocate(source.GetBaseType(), source.GetLayout());
        }

        if (!TypeCompatible(destination, source))
        {
            throw InputException(InputExceptionErrors::typeMismatch);
        }

        if (destination.GetLayout() != source.GetLayout())
        {
            throw InputException(InputExceptionErrors::sizeMismatch);
        }

        std::visit(
            VariantVisitor{
                [](Emittable) {},
                [&destination, &source, op](auto&& destinationData) {
                    using DestinationDataType =
                        std::remove_pointer_t<std::decay_t<decltype(destinationData)>>;

                    std::function<DestinationDataType(DestinationDataType, DestinationDataType)>
                        opFn;
                    if constexpr (!std::is_same_v<DestinationDataType, Boolean>)
                    {
                        switch (op)
                        {
                        case ValueBinaryOperation::add:
                            opFn = [](auto dst, auto src) { return dst + src; };
                            break;
                        case ValueBinaryOperation::subtract:
                            opFn = [](auto dst, auto src) { return dst - src; };
                            break;
                        case ValueBinaryOperation::multiply:
                            opFn = [](auto dst, auto src) { return dst * src; };
                            break;
                        case ValueBinaryOperation::divide:
                            opFn = [](auto dst, auto src) { return dst / src; };
                            break;

                        default:
                            if constexpr (std::is_integral_v<DestinationDataType>)
                            {
                                switch (op)
                                {
                                case ValueBinaryOperation::modulus:
                                    opFn = [](auto dst, auto src) { return dst % src; };
                                    break;
                                default:
                                    throw LogicException(LogicExceptionErrors::illegalState);
                                }
                            }
                            else
                            {
                                throw LogicException(LogicExceptionErrors::illegalState);
                            }
                        }
                    }
                    else
                    {
                        switch (op)
                        {
                        case ValueBinaryOperation::logicalAnd:
                            opFn = [](auto dst, auto src) { return dst && src; };
                            break;
                        case ValueBinaryOperation::logicalOr:
                            opFn = [](auto dst, auto src) { return dst || src; };
                            break;
                        default:
                            throw LogicException(LogicExceptionErrors::illegalState);
                        }
                    }

                    auto& sourceData = std::get<DestinationDataType*>(source.GetUnderlyingData());
                    if (source.GetLayout().IsContiguous() && destination.GetLayout().IsContiguous())
                    {
                        auto numElements = destination.GetLayout().NumElements();
                        std::transform(destinationData,
                                       destinationData + numElements,
                                       sourceData,
                                       destinationData,
                                       opFn);
                    }
                    else
                    {
                        auto& sourceLayout = source.GetLayout();
                        auto maxCoordinate = sourceLayout.GetActiveSize().ToVector();
                        decltype(maxCoordinate) coordinate(maxCoordinate.size());

                        do
                        {
                            auto logicalCoordinates = sourceLayout.GetLogicalCoordinates(coordinate);
                            auto sourceOffset =
                                sourceLayout.GetLogicalEntryOffset(logicalCoordinates);
                            auto destinationOffset =
                                destination.GetLayout().GetLogicalEntryOffset(logicalCoordinates);
                            *(destinationData + destinationOffset) =
                                opFn(*(destinationData + destinationOffset),
                                     *(sourceData + sourceOffset));
                        } while (IncrementMemoryCoordinate(coordinate, maxCoordinate));
                    }
                } },
            destination.GetUnderlyingData());

        return destination;
    }