include/fbgemm/FbgemmConvert.h (56 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <stdexcept>
#include "fbgemm/Types.h"
#include "fbgemm/Utils.h"
namespace fbgemm {
typedef uint16_t bfloat16;
/**
* @ Transform all entries in a matrix from fp32 to bfloat16: reference
* implementation.
*
*/
FBGEMM_API void
FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size);
/**
* @ Transform all entries in a matrix from bfloat16 to fp32: reference
* implementation.
*
*/
FBGEMM_API void
Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size);
/**
* @ Transform all entries in a matrix from fp32 to bfloat16: simd
* implementation.
*
*/
FBGEMM_API void
FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size);
/**
* @ Transform all entries in a matrix from bfloat16 to fp32: simd
* implementation.
*
*/
FBGEMM_API void
Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size);
/**
* @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
*
*/
FBGEMM_API void
FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size);
/**
* @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
*
*/
FBGEMM_API void
FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size);
/**
* @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
*
*/
FBGEMM_API void
Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size);
/**
* @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
*
*/
FBGEMM_API void
Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size);
/**
* @ Transform all entries in a matrix from fp32 to float16: reference
* implementation.
*
* @param do_clip if true we saturate to fp16 min and max instead of generating
* infinities.
*/
FBGEMM_API void FloatToFloat16_ref(
const float* src,
float16* dst,
size_t size,
bool do_clip = false);
/**
* @ Transform all entries in a matrix from float16 to fp32: reference
* implementation.
*
*/
FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size);
/**
* @ Transform all entries in a matrix from fp32 to float16: simd
* implementation.
*
* @param do_clip if true we saturate to fp16 min and max instead of generating
* infinities.
*/
FBGEMM_API void FloatToFloat16_simd(
const float* src,
float16* dst,
size_t size,
bool do_clip = false);
/**
* @ Transform all entries in a matrix from float16 to fp32: simd
* implementation.
*
*/
FBGEMM_API void
Float16ToFloat_simd(const float16* src, float* dst, size_t size);
/**
* @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
*
*/
FBGEMM_API void FloatToFloat16_avx2(
const float* src,
float16* dst,
size_t size,
bool do_clip = false);
/**
* @brief AVX512 implementation to convert fp32 numbers to fp16 numbers.
*
*/
FBGEMM_API void FloatToFloat16_avx512(
const float* src,
float16* dst,
size_t size,
bool do_clip = false);
/**
* @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
*
*/
FBGEMM_API void
Float16ToFloat_avx2(const float16* src, float* dst, size_t size);
/**
* @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
*
*/
FBGEMM_API void
Float16ToFloat_avx512(const float16* src, float* dst, size_t size);
/**
* @brief Transform all entries in a matrix from fp32 to float16 and back to
* fp32.
*/
FBGEMM_API void RoundToFloat16(
const float* input,
float* output,
size_t size,
bool clamp = false,
bool clamp_denorms = false);
} // namespace fbgemm