maga_transformer/cpp/devices/DeviceBase.cc (406 lines of code) (raw):
#include "maga_transformer/cpp/devices/DeviceBase.h"
#include "ATen/ops/cross_entropy_loss.h"
#include "c10/util/Optional.h"
#include "maga_transformer/cpp/core/TrackerAllocator.h"
#include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h"
#include "maga_transformer/cpp/utils/KVCacheUtils.h"
#include "maga_transformer/cpp/utils/ErrorCode.h"
#include "maga_transformer/cpp/utils/RpcErrorCode.h"
#include "maga_transformer/cpp/devices/OpData.h"
#include "torch/extension.h"
#include "torch/types.h"
#include <numeric>
using namespace std;
using namespace rtp_llm;
namespace rtp_llm {
DeviceBase::DeviceBase(const DeviceInitParams& params)
: device_id_(params.device_id)
, init_params_(params)
{
// 默认stdout输出到文件的逻辑是全缓冲,导致ft_log和autil_log日志刷不出来,手动设置为行缓冲
setlinebuf(stdout);
}
void DeviceBase::init() {
buffer_manager_.reset(new BufferManager(getAllocator(), getHostAllocator()));
static char* enable_device_perf_env_char = std::getenv("ENABLE_DEVICE_PERF");
enable_device_perf_ = (enable_device_perf_env_char != nullptr && std::string(enable_device_perf_env_char) == "ON") ? true : false;
}
void DeviceBase::setTraceMemory(bool trace_memory) {
buffer_manager_->setTraceMemory(trace_memory);
}
std::shared_ptr<rtp_llm::CacheStore> DeviceBase::cacheStore() {
return cache_store_;
}
MemoryStatus DeviceBase::getDeviceMemoryStatus() {
return MemoryStatus();
}
DeviceStatus DeviceBase::getDeviceStatus() {
DeviceStatus status;
status.device_memory_status = getDeviceMemoryStatus();
const auto buffer_status = queryBufferStatus();
status.device_memory_status.allocated_bytes = buffer_status.device_allocated_bytes;
status.device_memory_status.preserved_bytes = buffer_status.device_preserved_bytes;
status.device_memory_status.available_bytes = status.device_memory_status.free_bytes + status.device_memory_status.preserved_bytes;
status.device_memory_status.max_consumed_bytes = buffer_status.device_max_consumed_bytes;
status.host_memory_status.allocated_bytes = buffer_status.host_allocated_bytes;
return status;
}
void DeviceBase::traceMemoryUsage() {
RTP_LLM_LOG_INFO("Device Memory: %s", buffer_manager_->printAllocationRecords(getAllocator()).c_str());
RTP_LLM_LOG_INFO("Host Memory: %s", buffer_manager_->printAllocationRecords(getHostAllocator()).c_str());
return;
}
AllocationType DeviceBase::getMemAllocationType(const MemoryType type) {
return (type == getAllocator()->memoryType()) ? AllocationType::DEVICE : AllocationType::HOST;
}
BufferStatus DeviceBase::queryBufferStatus() {
return buffer_manager_->queryStatus();
}
BufferPtr DeviceBase::allocateBuffer(const BufferParams& params, const BufferHints& hints) {
return buffer_manager_->allocate(params, hints);
}
BufferPtr DeviceBase::allocateBufferLike(const Buffer& buffer,
const AllocationType atype,
const BufferHints& hints) {
if (buffer.isQBuffer()) {
auto kernel = allocateBufferLike((reinterpret_cast<const QBuffer*>(&buffer)->kernel()),
atype,
hints);
auto scales = allocateBufferLike((reinterpret_cast<const QBuffer*>(&buffer)->scales()),
atype,
hints);
auto zeros = allocateBufferLike((reinterpret_cast<const QBuffer*>(&buffer)->zeros()),
atype,
hints);
return BufferPtr(new QBuffer(std::move(kernel),
std::move(scales),
std::move(zeros)));
}
return allocateBuffer({buffer.type(), buffer.shape(), atype}, hints);
}
void DeviceBase::syncAndCheck() {
return;
}
void DeviceBase::syncDeviceStream(DeviceStream stream) {
return;
}
DevicePrepOutput DeviceBase::prepareModelRun(const DevicePrepParams& params) {
return DevicePrepOutput();
}
void DeviceBase::syncCommunication(bool timeout) {
return;
}
void DeviceBase::syncCommunication(ParallelMode mode, bool timeout) {
return;
}
void DeviceBase::overlappedCommBarrier() {
syncCommunication();
}
DeviceHookPtr DeviceBase::createCommHook() {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
void DeviceBase::overlappedComputeBarrier() {
syncCommunication();
}
DeviceEventPtr DeviceBase::createEvent() {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
void DeviceBase::setCacheStore(std::shared_ptr<rtp_llm::CacheStore> cache_store) {
cache_store_ = cache_store;
}
void DeviceBase::writeHiddenStatesStore(const WriteMTPHiddenStatesParams& params) {
if (params.warmup) {
RTP_LLM_LOG_DEBUG("is warmup, so ignore writeCacheStore");
return;
}
if (!params.pd_separation || params.context_batch_size == 0) {
RTP_LLM_LOG_DEBUG("pd_separation = %d, context_batch_size = %d, so ignore writeCacheStore",
params.pd_separation, params.context_batch_size);
return;
}
RTP_LLM_LOG_DEBUG("start writeHiddenStatesStore");
auto context_lm_output_indexes = params.lm_output_indexes->slice(params.decoder_batch_size,
params.context_batch_size);
size_t last_token_index = 0;
for (size_t batch_id = 0; batch_id < params.context_batch_size; batch_id++) {
if (*(params.request_pd_separation->dataWithOffset<bool>(batch_id)) == false) {
continue;
}
auto request_id = *(params.request_id->dataWithOffset<int64_t>(batch_id));
auto request_blocks = std::make_shared<RequestBlockBuffer>(std::to_string(request_id), createEvent());
auto lm_index = *(context_lm_output_indexes->dataWithOffset<int32_t>(batch_id));
auto token_num = lm_index + 1 - params.decoder_batch_size - last_token_index;
auto hidden_states = params.hidden_states->slice(last_token_index, token_num);
RTP_LLM_LOG_DEBUG("request_id is %d", request_id);
RTP_LLM_LOG_DEBUG("token_num is %d", token_num);
last_token_index = token_num;
std::shared_ptr<void> hidden_states_addr(hidden_states->data(), [](void* p) { });
request_blocks->addBlock("hidden_states",
hidden_states_addr,
hidden_states->sizeBytes(), true, true);
auto storeCallback = [request_id, params](bool success, CacheStoreErrorCode ec) {
auto keep_lift_cycle = params.hidden_states;
if (!success) {
RTP_LLM_LOG_WARNING("query [%ld],"
"call store kv cache failed, ec is %d, error msg is [%s]",
request_id, ec, ErrorCodeToString(transCacheStoreErrorCode(ec)).c_str());
}
};
cache_store_->store(request_blocks, storeCallback);
}
}
void DeviceBase::writeCacheStore(const WriteCacheParams& params) {
auto& param = params.common;
if (param.warmup) {
RTP_LLM_LOG_DEBUG("is warmup, so ignore writeCacheStore");
return;
}
if (!param.pd_separation || param.context_batch_size == 0) {
RTP_LLM_LOG_DEBUG("pd_separation = %d, context_batch_size = %d, so ignore writeCacheStore",
param.pd_separation, param.context_batch_size);
return;
}
RTP_LLM_CHECK_WITH_INFO(param.cache_store_inputs.has_value()
&& param.cache_store_inputs->host_kv_cache_offset, "failed to get host_kv_cache_offset");
const auto max_blocks_per_batch = param.cache_store_inputs->host_kv_cache_offset->shape()[1];
const auto seq_size_per_block = params.configs.tokens_per_block;
auto offset_addr = param.cache_store_inputs->host_kv_cache_offset->data<int32_t>();
auto k_cache_data = (uint64_t*)param.kv_cache->k_cache_buffer->data();
auto v_cache_data = (uint64_t*)param.kv_cache->v_cache_buffer->data();
auto k_scale_data = (uint64_t*)(param.kv_cache->k_scale_buffer ? param.kv_cache->k_scale_buffer->data() : nullptr);
auto v_scale_data = (uint64_t*)(param.kv_cache->v_scale_buffer ? param.kv_cache->v_scale_buffer->data() : nullptr);
RTP_LLM_CHECK_WITH_INFO(param.context_batch_size == param.request_pd_separation->size(), "size not same");
RTP_LLM_CHECK_WITH_INFO(param.context_batch_size == param.request_id->size(),
"context batch size and request id size is not same");
RTP_LLM_LOG_DEBUG("write cache store, context_batch_size is %ld", param.context_batch_size);
for (size_t batch_id = 0; batch_id < param.context_batch_size; batch_id++) {
if (*(param.request_pd_separation->dataWithOffset<bool>(batch_id)) == false) {
continue;
}
RTP_LLM_CHECK_WITH_INFO(param.cache_store_inputs.has_value()
&& param.cache_store_inputs->prefix_lengths_host
&& param.cache_store_inputs->input_lengths_host,
"failed to get prefix_length_host and input_length_host for cache store");
RTP_LLM_CHECK_WITH_INFO(param.cache_store_inputs->prefix_lengths_host->data<int>()[batch_id] % seq_size_per_block == 0,
"prefix_length \% seq_size_per_block != 0");
int reuse_block_num = param.cache_store_inputs->prefix_lengths_host->data<int>()[batch_id] / seq_size_per_block;
int block_num = (param.cache_store_inputs->input_lengths_host->data<int>()[param.decoder_batch_size + batch_id]
+ seq_size_per_block - 1) / seq_size_per_block;
auto request_id = *(param.request_id->dataWithOffset<int64_t>(batch_id));
auto request_blocks = std::make_shared<RequestBlockBuffer>(std::to_string(request_id), createEvent());
RTP_LLM_LOG_DEBUG("write cache store, blocks num is %ld", block_num + reuse_block_num);
for (size_t index = 0; index < block_num + reuse_block_num; index++) {
auto cache_key = makeCacheKey(param.cache_keys[batch_id * max_blocks_per_batch + index], param.layer_id);
auto block_id = *(offset_addr + (param.decoder_batch_size + batch_id) * max_blocks_per_batch + index);
void* k_addr = (void*)((int8_t*)k_cache_data + block_id * param.k_block_size);
std::shared_ptr<void> k_block_addr(k_addr, [](void* p) { });
request_blocks->addBlock("k_" + cache_key, k_block_addr, param.k_block_size, true, true);
if (k_scale_data) {
void* k_scale_addr = (void*)((int8_t*)k_scale_data + block_id * param.scale_block_size);
std::shared_ptr<void> k_scale_block_addr(k_scale_addr, [](void* p) { });
request_blocks->addBlock("k_scale" + cache_key, k_scale_block_addr, param.scale_block_size, true, true);
}
// mla kvcache 不存储 v_cache
if (params.mla_kvcache) {
continue;
}
void* v_addr = (void*)((int8_t*)v_cache_data + block_id * param.v_block_size);
std::shared_ptr<void> v_block_addr(v_addr, [](void* p) { });
request_blocks->addBlock("v_" + cache_key, v_block_addr, param.v_block_size, true, true);
if (v_scale_data) {
void* v_scale_addr = (void*)((int8_t*)v_scale_data + block_id * param.scale_block_size);
std::shared_ptr<void> v_scale_block_addr(v_scale_addr, [](void* p) { });
request_blocks->addBlock("v_scale" + cache_key, v_scale_block_addr, param.scale_block_size, true, true);
}
}
auto storeCallback = [layer_id = param.layer_id, request_id](bool success, CacheStoreErrorCode ec) {
if (!success) {
RTP_LLM_LOG_WARNING("query [%ld], layer id [%d], "
"call store kv cache failed, ec is %d, error msg is [%s]",
request_id, layer_id, ec, ErrorCodeToString(transCacheStoreErrorCode(ec)).c_str());
}
};
cache_store_->store(request_blocks, storeCallback);
}
}
CloneOutput DeviceBase::clone(const CloneParams& params) {
const auto& src = params.input;
auto dst = allocateBufferLike(src, params.alloc_type, params.hints);
copy({*dst, src, params.overlapped});
return dst;
}
SelectOutput DeviceBase::select(const SelectParams& params) {
RUNTIME_ASSERT_OP_ARG(params.dim < params.input.shape().size(),
"Select dim %ld out of range with input shape %s.",
params.dim, params.input.debugString().c_str());
RUNTIME_ASSERT_OP_ARG(params.index.shape().size() == 1, "Select index must be 1D.");
RUNTIME_ASSERT_OP_ARG(params.index.type() == DataType::TYPE_INT32, "Select index must be int32.");
RUNTIME_ASSERT_OP_ARG(params.index.where() != MemoryType::MEMORY_GPU, "Select index must on CPU.");
const auto& src = params.input;
const auto& idx_buf = params.index;
const auto dim = params.dim;
auto selected_shape = src.shape();
selected_shape[dim] = idx_buf.shape()[0];
auto selected = allocateBuffer({src.type(), selected_shape, getMemAllocationType(src.where())});
const int pre_select_size = std::accumulate(
selected_shape.begin(), selected_shape.begin() + dim, 1UL, std::multiplies<size_t>());
const auto post_select_stride = (int32_t)std::accumulate(
selected_shape.begin() + dim + 1, selected_shape.end(), 1UL, std::multiplies<size_t>());
// both src and dst needs to be viewed into 1-d buffer.
auto src_view = src.reshape({src.size()});
auto dst_view = selected->reshape({selected->size()});
for (auto i = 0; i < int(idx_buf.shape()[0]); i++) {
const auto idx = idx_buf.data<int32_t>()[i];
for (auto j = 0; j < pre_select_size; j++) {
const auto src_offset = j * src.shape()[dim] * post_select_stride + idx * post_select_stride;
const auto dst_offset = j * idx_buf.size() * post_select_stride + i * post_select_stride;
copy({dst_view.view(dst_offset, post_select_stride), src_view.view(src_offset, post_select_stride)});
}
}
return selected;
}
ConcatOutput DeviceBase::concat(const ConcatParams& params) {
RUNTIME_ASSERT_OP_ARG(params.dim == 0, "Concat only support dim 0, but got %lu.", params.dim);
RUNTIME_ASSERT_OP_ARG(params.inputs.size() > 0, "Concat requires at least 1 input.");
if (params.inputs.size() == 1) {
return params.inputs[0];
}
const auto concated_length = std::accumulate(
params.inputs.begin(), params.inputs.end(), 0UL,
[](size_t sum, const BufferPtr& buffer) {
return sum + buffer->shape()[0];
});
auto concated_shape = params.inputs[0]->shape();
concated_shape[0] = concated_length;
const auto type = params.inputs[0]->type();
auto concated = allocateBuffer({
type, concated_shape, getMemAllocationType(params.inputs[0]->where())});
size_t offset = 0;
for (int i = 0; i < int(params.inputs.size()); i++) {
const auto& input = params.inputs[i];
const auto& shape = input->shape();
RUNTIME_ASSERT_OP_ARG(
shape.size() == concated_shape.size(),
"Concat input [%d] shape size %ld does not match concated shape size %lu.",
i, shape.size(), concated_shape.size());
for (int j = 1; j < int(concated_shape.size()); j++) {
RUNTIME_ASSERT_OP_ARG(
shape[j] == concated_shape[j],
"Concat input [%d] shape[%d] %ld does not match concated shape[%d] %ld.",
i, j, shape[j], j, concated_shape[j]);
}
RUNTIME_ASSERT_OP_ARG(
input->type() == type,
"Concat input [%d] type %d does not match concated type %d.",
i, input->type(), type);
copy({concated->view(offset, (int64_t)shape[0]), *input});
offset += shape[0];
}
return concated;
}
SplitOutput DeviceBase::split(const SplitParams& params) {
RUNTIME_ASSERT_OP_ARG(params.dim < params.input.dim()
&& std::accumulate(params.split_sizes.begin(), params.split_sizes.end(), 0)
== params.input.shape()[params.dim],
"split params args error, dim [%ld] split_size_sum [%d] input[%s]",
params.dim,
std::accumulate(params.split_sizes.begin(), params.split_sizes.end(), 0),
params.input.debugString().c_str());
RUNTIME_ASSERT_OP_ARG(!params.overlapped, "split base impl not support overlap");
torch::Tensor input_t = Buffer2torchTensor(params.input, false);
at::IntArrayRef split_sizes((int64_t*)params.split_sizes.data(), params.split_sizes.size());
std::vector<torch::Tensor> outputs_t = input_t.split_with_sizes(split_sizes, params.dim);
assert(params.split_sizes.size() == outputs_t.size());
std::vector<BufferPtr> outputs;
for (int i = 0; i < params.split_sizes.size(); ++i) {
outputs.emplace_back(clone({*torchTensor2Buffer(outputs_t[i].contiguous())}));
}
return {outputs};
}
LossOutput DeviceBase::loss(const LossParams& params) {
RUNTIME_ASSERT_OP_ARG(params.logits.where() == params.labels.where(), "logits and labels must be same device, but got %d and %d.", (int)params.logits.where(), (int)params.labels.where());
RUNTIME_ASSERT_OP_ARG(params.logits.shape()[0] == params.labels.shape()[0], "logits and labels must be same dim0, but got %d and %d.", (int)params.logits.shape()[0], (int)params.labels.shape()[0]);
torch::Tensor logits = Buffer2torchTensor(params.logits, false);
torch::Tensor labels = Buffer2torchTensor(params.labels, false).toType(torch::kInt64);
torch::Tensor output;
output = torch::cross_entropy_loss(logits, labels, torch::nullopt, at::Reduction::None).to(torch::TensorOptions(torch::kFloat32));
return clone({*torchTensor2Buffer(output)});
}
MaskOutput DeviceBase::attentionMask(const MaskParams& params) {
const int *input_lengths = params.input_lengths.data<int32_t>();
const int batch_size = params.input_lengths.size();
const int max_input_seq_len = *std::max_element(input_lengths, input_lengths + batch_size);
const auto torch_type = dataTypeToTorchType(params.dtype);
auto tensor_options = torch::TensorOptions(torch::kBool).device(torch::Device(torch::kCPU));
auto attention_mask = torch::ones({(int)max_input_seq_len, (int)max_input_seq_len}, tensor_options);
if (params.is_causal) {
attention_mask = attention_mask.tril();
}
attention_mask = attention_mask.unsqueeze_(0).tile({(int)batch_size, 1, 1}).to(torch_type);
for (int i = 0; i < batch_size; ++i) {
attention_mask[i].slice(0, input_lengths[i], max_input_seq_len) = 0;
if (!params.is_causal) {
attention_mask[i].slice(1, input_lengths[i], max_input_seq_len) = 0;
}
}
if (params.prefix_lengths.size()) {
RTP_LLM_CHECK(int(params.prefix_lengths.size()) == batch_size);
const int *prefix_lengths = params.prefix_lengths.data<int32_t>();
auto max_reuse_length = *std::max_element(prefix_lengths, prefix_lengths + batch_size);
attention_mask = torch::cat({attention_mask, torch::zeros({(int)batch_size, max_input_seq_len, max_reuse_length}).to(torch_type)}, -1);
if (max_reuse_length) {
for (int i = 0; i < batch_size; ++i) {
attention_mask[i] = attention_mask[i].roll({prefix_lengths[i]}, {-1});
attention_mask[i].slice(0, 0, input_lengths[i]).slice(1, 0, prefix_lengths[i]) = 1;
}
}
}
return clone({*torchTensor2Buffer(attention_mask)});
}
MultimodalEmbeddingOutput DeviceBase::multimodalEmbedding(const MultimodalEmbeddingParams& params) {
RUNTIME_ASSERT_OP_ARG(params.multimodal_locs, "no multimodal input location found");
const auto& embeddings = params.word_embeddings;
const auto& features = params.multimodal_features.value().get();
const auto& multimodal_locs = params.multimodal_locs.value().get();
const auto mm_num = features.size();
RUNTIME_ASSERT_OP_ARG(
embeddings->typeSize() == features[0]->typeSize(),
"type size of embeddings and multimodal features should be equal.");
for (int i = 0; i < mm_num; ++i) {
auto& feature = features[i];
auto loc = multimodal_locs.dataWithOffset<int32_t>(i);
copy({embeddings->view(*loc, feature->shape()[0]), *feature});
}
return move(embeddings);
}
AllReduceOutput DeviceBase::allReduce(const AllReduceParams& params) {
if (getDeviceProperties().tp_size == 1) {
return AllReduceOutput({params.buffer});
};
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
void DeviceBase::prepareCommBuffer(const PrepareCommBufferParams& params) {}
OverallExpertStats DeviceBase::createMoeExpertStates(const ExpertStatsParams& params)
{
OverallExpertStats states;
auto const layer_num = params.layer_num;
auto const logic_expert_num = params.log_exp_num;
auto const physic_expert_num = params.phy_exp_num;
auto const ep_size = params.ep_size;
states.layer_num = layer_num;
states.ep_size = ep_size;
states.log_exp_num = logic_expert_num;
states.phy_exp_num = physic_expert_num;
auto logic_buff = allocateBuffer({DataType::TYPE_INT32, {layer_num, logic_expert_num}, AllocationType::DEVICE}, {"exp_log_cnt"});
auto gpu_load_buff = allocateBuffer({DataType::TYPE_INT32, {layer_num, ep_size}, AllocationType::DEVICE}, {"phy_gpu_load"});
states.stats_buf.log_stats_buf = logic_buff;
states.stats_buf.gpu_loads_buf = gpu_load_buff;
cleanMoeExpertStates(states);
return states;
}
void DeviceBase::cleanMoeExpertStates(const OverallExpertStats& stats)
{
bufMemset(*stats.stats_buf.log_stats_buf, 0);
bufMemset(*stats.stats_buf.gpu_loads_buf, 0);
}
void DeviceBase::updateExpertGpuLoads(const MoeConfigs& moe_conf, const OptionalExpertStats& expert_stats, BufferPtr expert_ids) {
throw OpException(OpErrorType::ERROR_UNIMPLEMENTED);
}
} // namespace rtp_llm