in maga_transformer/cpp/models/GptModel.cc [1256:1406]
void dpAndTpSyncModelInputs(GptModelInputs &inputs, rtp_llm::DeviceBase* device) {
if (device->getDeviceProperties().tp_size <= 1) {
return;
}
const size_t shape_hints_size = GptModelInputIndex::gptModelInputLength;
auto shape_hints = device->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {shape_hints_size}, rtp_llm::AllocationType::HOST});
auto shape_hints_ptr = shape_hints->data<int32_t>();
shape_hints_ptr[GptModelInputIndex::comboTokens] = inputs.combo_tokens.get() ? inputs.combo_tokens->size() : 0;
shape_hints_ptr[GptModelInputIndex::inputLengths] = inputs.input_lengths.get() ? inputs.input_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::sequenceLengths] = inputs.sequence_lengths.get() ? inputs.sequence_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::prefixLengths] = inputs.prefix_lengths.get() ? inputs.prefix_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::maxBlocksPerBatch] = inputs.kv_cache_block_id.get() ? inputs.kv_cache_block_id->shape()[1] : 0;
shape_hints_ptr[GptModelInputIndex::lmOutputIndexes] = inputs.lm_output_indexes.get() ? inputs.lm_output_indexes->size() : 0;
shape_hints_ptr[GptModelInputIndex::comboPositionIds] = inputs.combo_position_ids.get() ? inputs.combo_position_ids->size() : 0;
shape_hints_ptr[GptModelInputIndex::loraIds] = inputs.lora_ids.get() ? inputs.lora_ids->size() : 0;
shape_hints_ptr[GptModelInputIndex::loraInputLengths] = inputs.lora_input_lengths.get() ? inputs.lora_input_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::textTokensMask] = inputs.text_tokens_mask.get() ? inputs.text_tokens_mask->size() : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesLocs] = inputs.mm_features_locs.get() ? inputs.mm_features_locs->size() : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] = inputs.multimodal_features.has_value() ? inputs.multimodal_features.value().size() : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesSize] = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] ? inputs.multimodal_features.value()[0]->shape()[1] : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesDtype] = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] ? (std::uint8_t)inputs.multimodal_features.value()[0]->type() : 0;
shape_hints_ptr[GptModelInputIndex::needAllLogits] = inputs.need_all_logits;
shape_hints_ptr[GptModelInputIndex::mtpHiddenStates] = inputs.last_hidden_states.get() ? inputs.last_hidden_states->size() : 0;
shape_hints_ptr[GptModelInputIndex::mtpHiddenStatesDtype] = shape_hints_ptr[GptModelInputIndex::mtpHiddenStates] ? (std::uint8_t)inputs.last_hidden_states->type() : 0;
device->broadcast({{shape_hints}, 0});
device->syncCommunication(false);
device->syncAndCheck();
// multimodal features shape broadcast
rtp_llm::BufferPtr mm_features_shape;
int32_t* mm_features_shape_ptr = nullptr;
inputs.need_all_logits = shape_hints_ptr[GptModelInputIndex::needAllLogits];
const size_t mm_features_num = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum];
if (mm_features_num) {
mm_features_shape =
device->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::mmFeaturesNum]}, rtp_llm::AllocationType::HOST});
mm_features_shape_ptr = mm_features_shape->data<int32_t>();
for (auto i = 0; i < mm_features_num; ++i) {
mm_features_shape_ptr[i] = inputs.multimodal_features.has_value() ? inputs.multimodal_features.value()[i]->shape()[0] : 0;
}
device->broadcast({{mm_features_shape}, 0});
device->syncCommunication(false);
device->syncAndCheck();
}
auto max_blocks = (size_t)shape_hints_ptr[GptModelInputIndex::maxBlocksPerBatch];
auto combo_position_ids_size = shape_hints_ptr[GptModelInputIndex::comboPositionIds];
auto text_tokens_mask_size = shape_hints_ptr[GptModelInputIndex::textTokensMask];
auto mm_features_locs_size = shape_hints_ptr[GptModelInputIndex::mmFeaturesLocs];
auto hidden_states_size = shape_hints_ptr[GptModelInputIndex::mtpHiddenStates];
if (device->getDeviceProperties().tp_rank) {
auto context_batch_size = (size_t)shape_hints_ptr[GptModelInputIndex::prefixLengths];
inputs.combo_tokens = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::comboTokens]}, rtp_llm::AllocationType::HOST});
inputs.input_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::inputLengths]}, rtp_llm::AllocationType::HOST});
inputs.sequence_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::sequenceLengths]}, rtp_llm::AllocationType::HOST});
inputs.prefix_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {context_batch_size}, rtp_llm::AllocationType::HOST});
if (max_blocks != 0) {
inputs.kv_cache_block_id = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32,
{(size_t)shape_hints_ptr[GptModelInputIndex::inputLengths], max_blocks}, rtp_llm::AllocationType::HOST});
inputs.cache_keys = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {context_batch_size, max_blocks}, rtp_llm::AllocationType::HOST});
}
inputs.request_id = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {context_batch_size}, rtp_llm::AllocationType::HOST});
inputs.request_pd_separation = device->allocateBuffer(
{rtp_llm::DataType::TYPE_BOOL, {context_batch_size}, rtp_llm::AllocationType::HOST});
inputs.lm_output_indexes = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::lmOutputIndexes]}, rtp_llm::AllocationType::HOST});
if (combo_position_ids_size) {
inputs.combo_position_ids = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)combo_position_ids_size}, rtp_llm::AllocationType::HOST});
}
if (shape_hints_ptr[GptModelInputIndex::loraIds]) {
inputs.lora_ids = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::loraIds]}, rtp_llm::AllocationType::HOST});
}
if (shape_hints_ptr[GptModelInputIndex::loraInputLengths]) {
inputs.lora_input_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::loraInputLengths]}, rtp_llm::AllocationType::HOST});
}
if (shape_hints_ptr[GptModelInputIndex::mtpHiddenStates]) {
auto hidden_states_dim0 = (size_t)shape_hints_ptr[GptModelInputIndex::comboTokens];
auto hidden_states_dim1 = (size_t)hidden_states_size / hidden_states_dim0;
RTP_LLM_CHECK(hidden_states_size % hidden_states_dim0 == 0);
inputs.last_hidden_states = device->allocateBuffer(
{(rtp_llm::DataType)shape_hints_ptr[GptModelInputIndex::mtpHiddenStatesDtype],
{hidden_states_dim0, hidden_states_dim1},
rtp_llm::AllocationType::DEVICE});
}
if (text_tokens_mask_size) {
inputs.text_tokens_mask = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)text_tokens_mask_size}, rtp_llm::AllocationType::HOST});
}
if (mm_features_locs_size) {
inputs.mm_features_locs = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)mm_features_locs_size}, rtp_llm::AllocationType::HOST});
}
if (mm_features_num) {
std::vector<rtp_llm::BufferPtr> mm_features;
for (auto mm_index = 0; mm_index < mm_features_num; ++mm_index) {
mm_features.emplace_back(
device->allocateBuffer(
{(rtp_llm::DataType)shape_hints_ptr[GptModelInputIndex::mmFeaturesDtype],
{(size_t)mm_features_shape_ptr[mm_index], (size_t)shape_hints_ptr[GptModelInputIndex::mmFeaturesSize]},
rtp_llm::AllocationType::DEVICE}));
}
inputs.multimodal_features = std::move(mm_features);
}
}
std::vector<rtp_llm::BufferPtr> buffers;
buffers.emplace_back(inputs.combo_tokens);
buffers.emplace_back(inputs.input_lengths);
buffers.emplace_back(inputs.sequence_lengths);
buffers.emplace_back(inputs.prefix_lengths);
if (max_blocks) {
buffers.emplace_back(inputs.kv_cache_block_id);
buffers.emplace_back(inputs.cache_keys);
}
buffers.emplace_back(inputs.request_id);
buffers.emplace_back(inputs.request_pd_separation);
buffers.emplace_back(inputs.lm_output_indexes);
if (combo_position_ids_size) {
buffers.emplace_back(inputs.combo_position_ids);
}
buffers.emplace_back(inputs.lora_ids);
buffers.emplace_back(inputs.lora_input_lengths);
if (text_tokens_mask_size) {
buffers.emplace_back(inputs.text_tokens_mask);
}
if (mm_features_locs_size) {
buffers.emplace_back(inputs.mm_features_locs);
}
if (mm_features_num) {
for (auto& mm_feature: inputs.multimodal_features.value()) {
buffers.emplace_back(mm_feature);
}
}
if (hidden_states_size) {
buffers.emplace_back(inputs.last_hidden_states);
}
device->broadcast({buffers, 0});
device->syncAndCheck();
}