in libraries/value/src/VectorOperations.cpp [43:186]
Scalar Dot(Vector v1, Vector v2)
{
if (v1.Size() != v2.Size())
{
throw InputException(InputExceptionErrors::sizeMismatch);
}
if (v1.GetType() != v2.GetType())
{
throw InputException(InputExceptionErrors::typeMismatch);
}
auto defaultImpl = [](Vector v1, Vector v2) {
Scalar result = Allocate(v1.GetType(), ScalarLayout);
For(v1, [&](auto index) {
result += v1[index] * v2[index];
});
return result;
};
if (v1.GetType() == ValueType::Float)
{
auto fn = DeclareFunction("cblas_sdot")
.Returns(Value({ ValueType::Float, 0 }, ScalarLayout))
.Parameters(
Value({ ValueType::Int32, 0 }, ScalarLayout), /*n*/
v1, /*x*/
Value({ ValueType::Int32, 0 }, ScalarLayout), /*incx*/
v2, /*y*/
Value({ ValueType::Int32, 0 }, ScalarLayout)); /*incy*/
auto result = InvokeForContext<ComputeContext>([&] {
#ifdef USE_BLAS
auto wrapper = fn.Define([](Scalar n, Vector x, Scalar incx, Vector y, Scalar incy) -> Scalar {
return math::Blas::Dot(n.Get<int>(), x.GetValue().Get<float*>(), incx.Get<int>(), y.GetValue().Get<float*>(), incy.Get<int>());
});
return wrapper(
static_cast<int>(v1.Size()),
v1,
static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)),
v2,
static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)));
#else
return defaultImpl(v1, v2);
#endif
});
if (result)
{
return *result;
}
result = InvokeForContext<LLVMContext>([&](LLVMContext& context) -> Scalar {
if (context.GetModuleEmitter().GetCompilerOptions().useBlas)
{
auto returnValue = fn.Decorated(false)
.Call(
Scalar{ static_cast<int>(v1.Size()) },
v1,
Scalar{ static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)) },
v2,
Scalar{ static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)) });
return *returnValue;
}
else
{
return defaultImpl(v1, v2);
}
});
if (result)
{
return *result;
}
return defaultImpl(v1, v2);
}
else if (v1.GetType() == ValueType::Double)
{
auto fn = DeclareFunction("cblas_ddot")
.Returns(Value({ ValueType::Double, 0 }, ScalarLayout))
.Parameters(
Value({ ValueType::Int32, 0 }, ScalarLayout), /*n*/
v1, /*x*/
Value({ ValueType::Int32, 0 }, ScalarLayout), /*incx*/
v2, /*y*/
Value({ ValueType::Int32, 0 }, ScalarLayout)); /*incy*/
auto result = InvokeForContext<ComputeContext>([&] {
#ifdef USE_BLAS
auto wrapper = fn.Define([](Scalar n, Vector x, Scalar incx, Vector y, Scalar incy) -> Scalar {
return math::Blas::Dot(n.Get<int>(), x.GetValue().Get<double*>(), incx.Get<int>(), y.GetValue().Get<double*>(), incy.Get<int>());
});
return wrapper(
static_cast<int>(v1.Size()),
v1,
static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)),
v2,
static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)));
#else
return defaultImpl(v1, v2);
#endif
});
if (result)
{
return *result;
}
result = InvokeForContext<LLVMContext>([&](LLVMContext& context) -> Scalar {
if (context.GetModuleEmitter().GetCompilerOptions().useBlas)
{
auto returnValue = fn.Decorated(false)
.Call(
Scalar{ static_cast<int>(v1.Size()) },
v1,
Scalar{ static_cast<int>(v1.GetValue().GetLayout().GetCumulativeIncrement(0)) },
v2,
Scalar{ static_cast<int>(v2.GetValue().GetLayout().GetCumulativeIncrement(0)) });
return *returnValue;
}
else
{
return defaultImpl(v1, v2);
}
});
if (result)
{
return *result;
}
return defaultImpl(v1, v2);
}
else
{
return defaultImpl(v1, v2);
}
}