Scalar Dot()

in libraries/value/src/VectorOperations.cpp [43:186]


    Scalar Dot(Vector v1, Vector v2)
    {
        if (v1.Size() != v2.Size())
        {
            throw InputException(InputExceptionErrors::sizeMismatch);
        }
        if (v1.GetType() != v2.GetType())
        {
            throw InputException(InputExceptionErrors::typeMismatch);
        }

        auto defaultImpl = [](Vector v1, Vector v2) {
            Scalar result = Allocate(v1.GetType(), ScalarLayout);
            For(v1, [&](auto index) {
                result += v1[index] * v2[index];
            });

            return result;
        };

        if (v1.GetType() == ValueType::Float)
        {
            auto fn = DeclareFunction("cblas_sdot")
                          .Returns(Value({ ValueType::Float, 0 }, ScalarLayout))
                          .Parameters(
                              Value({ ValueType::Int32, 0 }, ScalarLayout), /*n*/
                              v1, /*x*/
                              Value({ ValueType::Int32, 0 }, ScalarLayout), /*incx*/
                              v2, /*y*/
                              Value({ ValueType::Int32, 0 }, ScalarLayout)); /*incy*/

            auto result = InvokeForContext<ComputeContext>([&] {
#ifdef USE_BLAS
                auto wrapper = fn.Define([](Scalar n, Vector x, Scalar incx, Vector y, Scalar incy) -> Scalar {
                    return math::Blas::Dot(n.Get<int>(), x.GetValue().Get<float*>(), incx.Get<int>(), y.GetValue().Get<float*>(), incy.Get<int>());
                });

                return wrapper(
                    static_cast<int>(v1.Size()),
                    v1,
                    static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)),
                    v2,
                    static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)));
#else
                return defaultImpl(v1, v2);
#endif

            });

            if (result)
            {
                return *result;
            }

            result = InvokeForContext<LLVMContext>([&](LLVMContext& context) -> Scalar {
                if (context.GetModuleEmitter().GetCompilerOptions().useBlas)
                {
                    auto returnValue = fn.Decorated(false)
                                           .Call(
                                               Scalar{ static_cast<int>(v1.Size()) },
                                               v1,
                                               Scalar{ static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)) },
                                               v2,
                                               Scalar{ static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)) });

                    return *returnValue;
                }
                else
                {
                    return defaultImpl(v1, v2);
                }
            });

            if (result)
            {
                return *result;
            }

            return defaultImpl(v1, v2);
        }
        else if (v1.GetType() == ValueType::Double)
        {
            auto fn = DeclareFunction("cblas_ddot")
                          .Returns(Value({ ValueType::Double, 0 }, ScalarLayout))
                          .Parameters(
                              Value({ ValueType::Int32, 0 }, ScalarLayout), /*n*/
                              v1, /*x*/
                              Value({ ValueType::Int32, 0 }, ScalarLayout), /*incx*/
                              v2, /*y*/
                              Value({ ValueType::Int32, 0 }, ScalarLayout)); /*incy*/

            auto result = InvokeForContext<ComputeContext>([&] {
#ifdef USE_BLAS
                auto wrapper = fn.Define([](Scalar n, Vector x, Scalar incx, Vector y, Scalar incy) -> Scalar {
                    return math::Blas::Dot(n.Get<int>(), x.GetValue().Get<double*>(), incx.Get<int>(), y.GetValue().Get<double*>(), incy.Get<int>());
                });

                return wrapper(
                    static_cast<int>(v1.Size()),
                    v1,
                    static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)),
                    v2,
                    static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)));
#else
                return defaultImpl(v1, v2);
#endif
            });

            if (result)
            {
                return *result;
            }

            result = InvokeForContext<LLVMContext>([&](LLVMContext& context) -> Scalar {
                if (context.GetModuleEmitter().GetCompilerOptions().useBlas)
                {
                    auto returnValue = fn.Decorated(false)
                                           .Call(
                                               Scalar{ static_cast<int>(v1.Size()) },
                                               v1,
                                               Scalar{ static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)) },
                                               v2,
                                               Scalar{ static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)) });

                    return *returnValue;
                }
                else
                {
                    return defaultImpl(v1, v2);
                }
            });

            if (result)
            {
                return *result;
            }

            return defaultImpl(v1, v2);
        }
        else
        {
            return defaultImpl(v1, v2);
        }
    }