maga_transformer/cpp/devices/rocm_impl/ROCmActOp.cc (86 lines of code) (raw):
#include "maga_transformer/cpp/devices/rocm_impl/ROCmDevice.h"
#include "maga_transformer/cpp/kernels/activation_kernels.h"
#include "maga_transformer/cpp/utils/compiler_config.h"
#include "maga_transformer/cpp/kernels/activation_kernels.h"
#include "maga_transformer/cpp/cuda/Dispatch.h"
using namespace std;
namespace rtp_llm {
#define ARGS_DISPATCH(Atype, \
Dtype, \
out, \
bias, \
gate, \
gate_bias, \
m, \
n, \
act_scale, \
stream) do { \
invokeGenericActivation<Atype>( \
(Dtype*) out, \
(const Dtype*) bias, \
(const Dtype*) gate, \
(const Dtype*) gate_bias, \
(const int*) nullptr, \
(const Dtype*) nullptr, \
(int)m, \
(int)n, \
0, \
(const float*) nullptr, \
(const float*) nullptr, \
(const Dtype*) act_scale, \
stream); \
} while (0)
#define ATYPE_DISPATCH(Dtype, cpp_type, Atype, ...) \
case Dtype: \
if (Atype == ActivationType::Silu) { \
ARGS_DISPATCH(SiluActivation, cpp_type, __VA_ARGS__); \
} else if (Atype == ActivationType::Gelu) { \
ARGS_DISPATCH(GeluActivation, cpp_type, __VA_ARGS__); \
} else if (Atype == ActivationType::Geglu) { \
ARGS_DISPATCH(GeluActivation, cpp_type, __VA_ARGS__); \
} else if (Atype == ActivationType::Swiglu) { \
ARGS_DISPATCH(SiluActivation, cpp_type, __VA_ARGS__); \
} else if (Atype == ActivationType::Identity) { \
ARGS_DISPATCH(IdentityActivation, cpp_type, __VA_ARGS__); \
} else { \
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED); \
} \
continue;
#define DTYPE_DISPATCH(Dtype, ...) do { \
switch (Dtype) { \
DISPATCH_FOR_EACH_COMPUTE_TYPE(ATYPE_DISPATCH, __VA_ARGS__); \
} \
} while(0);
BufferPtr ROCmDevice::activation(const ActivationParams& params) {
auto states = params.states;
ROCM_CHECK_VALUE(states != nullptr, "state should not be nullptr in activation");
const auto data_type = params.states->type();
if (params.atype == ActivationType::Sigmoid) {
RUNTIME_ASSERT_OP_ARG(!params.bias, "Sigmoid does not support bias");
RUNTIME_ASSERT_OP_ARG(!params.gate, "Sigmoid does not support gate");
RUNTIME_ASSERT_OP_ARG(!params.gate_bias, "Sigmoid does not support gate_bias");
RUNTIME_ASSERT_OP_ARG(!params.act_scale, "Sigmoid does not support act_scale");
DISPATCH_CUDA_FUNCTION_DATA_TYPE(
data_type, invokeSigmoid,
states->data(), states->size(), 1.0f, stream_
);
return states;
}
RUNTIME_ASSERT_OP_ARG(states->shape().size() == 2,
"activation states must be 2D, but got %zu", states->shape().size());
size_t m = states->shape()[0];
size_t n = states->shape()[1];
auto bias = params.bias ? params.bias.value().get().data() : nullptr;
auto gate = params.gate ? params.gate.value().get().data() : nullptr;
auto gate_bias = params.gate_bias ? params.gate_bias.value().get().data() : nullptr;
auto act_scale = params.act_scale ? params.act_scale.value().get().data() : nullptr;
DTYPE_DISPATCH(
states->type(), params.atype,
states->data(), bias,
gate, gate_bias, m, n,
act_scale,
stream_
);
return states;
}
} // namespace rtp_llm