include/fbgemm/FbgemmFP16.h (40 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once // WARNING: this is a legacy fp16 fbgemm implementation and will soon be // upgraded to match with new fbgemm interface. #include <cpuinfo.h> #include <cassert> #include <cstdlib> #include <memory> #include <stdexcept> #include <vector> #include "./FbgemmPackMatrixB.h" #include "./Types.h" #include "./Utils.h" namespace fbgemm { template <> struct TypeConverter<float16> { float16 operator()(float src) const { constexpr float FP16_MAX = 65504.f; const float fp16 = std::max(-FP16_MAX, std::min(src, FP16_MAX)); return cpu_float2half_rn(fp16); } }; using PackedGemmMatrixFP16 = PackedGemmMatrixB<float16>; template <typename T> FBGEMM_API void cblas_gemm_compute( const matrix_op_t transa, const int m, const float* A, const PackedGemmMatrixB<T>& Bp, const float beta, float* C, int thread_id = 0, int num_threads = 1); extern template void cblas_gemm_compute<float16>( const matrix_op_t transa, const int m, const float* A, const PackedGemmMatrixFP16& Bp, const float beta, float* C, int thread_id, int num_threads); }; // namespace fbgemm