kernels/fmha_utils.h (57 lines of code) (raw):

/****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #include <assert.h> #include <stdio.h> #include <stdlib.h> #include <cuda_runtime_api.h> #include <cuda_fp16.h> #include <cuda_bf16.h> //////////////////////////////////////////////////////////////////////////////////////////////////// #define FMHA_CHECK_CUDA( call ) \ do { \ cudaError_t status_ = call; \ if( status_ != cudaSuccess ) { \ fprintf( stderr, \ "CUDA error (%s:%d): %s\n", \ __FILE__, \ __LINE__, \ cudaGetErrorString( status_ ) ); \ exit( 1 ); \ } \ } while( 0 ) //////////////////////////////////////////////////////////////////////////////////////////////////// enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 }; //////////////////////////////////////////////////////////////////////////////////////////////////// static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) { if( dtype == DATA_TYPE_FP16 ) { half x = __float2half_rn( norm ); uint16_t h = reinterpret_cast<const uint16_t &>( x ); ushort2 h2 = { h, h }; alpha = reinterpret_cast<const uint32_t &>( h2 ); } else if( dtype == DATA_TYPE_BF16 ) { __nv_bfloat16 x = __float2bfloat16( norm ); uint16_t h = reinterpret_cast<const uint16_t &>( x ); ushort2 h2 = { h, h }; alpha = reinterpret_cast<const uint32_t &>( h2 ); } else if( dtype == DATA_TYPE_FP32 ) { alpha = reinterpret_cast<const uint32_t &>( norm ); } else if( dtype == DATA_TYPE_INT32 ) { int32_t inorm = static_cast<int32_t>( norm ); alpha = reinterpret_cast<const uint32_t &>( inorm ); } else { assert( false ); } } //////////////////////////////////////////////////////////////////////////////////////////////////// static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) { switch( dtype ) { case DATA_TYPE_FP32: return n * 4; case DATA_TYPE_FP16: return n * 2; case DATA_TYPE_BF16: return n * 2; case DATA_TYPE_INT32: return n * 4; case DATA_TYPE_INT8: return n; default: assert( false ); return 0; } } ////////////////////////////////////////////////////////////////////////////////////////////////////