void dpAndTpSyncModelInputs()

in maga_transformer/cpp/models/GptModel.cc [1256:1406]


void dpAndTpSyncModelInputs(GptModelInputs &inputs, rtp_llm::DeviceBase* device) {
    if (device->getDeviceProperties().tp_size <= 1) {
        return;
    }
    const size_t shape_hints_size = GptModelInputIndex::gptModelInputLength;
    auto shape_hints = device->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {shape_hints_size}, rtp_llm::AllocationType::HOST});
    auto shape_hints_ptr = shape_hints->data<int32_t>();
    shape_hints_ptr[GptModelInputIndex::comboTokens] = inputs.combo_tokens.get() ? inputs.combo_tokens->size() : 0;
    shape_hints_ptr[GptModelInputIndex::inputLengths] = inputs.input_lengths.get() ? inputs.input_lengths->size() : 0;
    shape_hints_ptr[GptModelInputIndex::sequenceLengths] = inputs.sequence_lengths.get() ? inputs.sequence_lengths->size() : 0;
    shape_hints_ptr[GptModelInputIndex::prefixLengths] = inputs.prefix_lengths.get() ? inputs.prefix_lengths->size() : 0;
    shape_hints_ptr[GptModelInputIndex::maxBlocksPerBatch] = inputs.kv_cache_block_id.get() ? inputs.kv_cache_block_id->shape()[1] : 0;
    shape_hints_ptr[GptModelInputIndex::lmOutputIndexes] = inputs.lm_output_indexes.get() ? inputs.lm_output_indexes->size() : 0;
    shape_hints_ptr[GptModelInputIndex::comboPositionIds] = inputs.combo_position_ids.get() ? inputs.combo_position_ids->size() : 0;
    shape_hints_ptr[GptModelInputIndex::loraIds] = inputs.lora_ids.get() ? inputs.lora_ids->size() : 0;
    shape_hints_ptr[GptModelInputIndex::loraInputLengths] = inputs.lora_input_lengths.get() ? inputs.lora_input_lengths->size() : 0;
    shape_hints_ptr[GptModelInputIndex::textTokensMask] = inputs.text_tokens_mask.get() ? inputs.text_tokens_mask->size() : 0;
    shape_hints_ptr[GptModelInputIndex::mmFeaturesLocs] = inputs.mm_features_locs.get() ? inputs.mm_features_locs->size() : 0;
    shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] = inputs.multimodal_features.has_value() ? inputs.multimodal_features.value().size() : 0;
    shape_hints_ptr[GptModelInputIndex::mmFeaturesSize] = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] ? inputs.multimodal_features.value()[0]->shape()[1] : 0;
    shape_hints_ptr[GptModelInputIndex::mmFeaturesDtype] = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] ? (std::uint8_t)inputs.multimodal_features.value()[0]->type() : 0;
    shape_hints_ptr[GptModelInputIndex::needAllLogits] = inputs.need_all_logits;
    shape_hints_ptr[GptModelInputIndex::mtpHiddenStates] = inputs.last_hidden_states.get() ? inputs.last_hidden_states->size() : 0;
    shape_hints_ptr[GptModelInputIndex::mtpHiddenStatesDtype] = shape_hints_ptr[GptModelInputIndex::mtpHiddenStates] ? (std::uint8_t)inputs.last_hidden_states->type() : 0;
    device->broadcast({{shape_hints}, 0});
    device->syncCommunication(false);
    device->syncAndCheck();

    // multimodal features shape broadcast
    rtp_llm::BufferPtr mm_features_shape;
    int32_t* mm_features_shape_ptr = nullptr;
    inputs.need_all_logits = shape_hints_ptr[GptModelInputIndex::needAllLogits];
    const size_t mm_features_num = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum];
    if (mm_features_num) {
        mm_features_shape =
            device->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::mmFeaturesNum]}, rtp_llm::AllocationType::HOST});
        mm_features_shape_ptr = mm_features_shape->data<int32_t>();
        for (auto i = 0; i < mm_features_num; ++i) {
            mm_features_shape_ptr[i] = inputs.multimodal_features.has_value() ? inputs.multimodal_features.value()[i]->shape()[0] : 0;
        }
        device->broadcast({{mm_features_shape}, 0});
        device->syncCommunication(false);
        device->syncAndCheck();
    }

    auto max_blocks = (size_t)shape_hints_ptr[GptModelInputIndex::maxBlocksPerBatch];
    auto combo_position_ids_size = shape_hints_ptr[GptModelInputIndex::comboPositionIds];
    auto text_tokens_mask_size = shape_hints_ptr[GptModelInputIndex::textTokensMask];
    auto mm_features_locs_size = shape_hints_ptr[GptModelInputIndex::mmFeaturesLocs];
    auto hidden_states_size = shape_hints_ptr[GptModelInputIndex::mtpHiddenStates];

    if (device->getDeviceProperties().tp_rank) {
        auto context_batch_size = (size_t)shape_hints_ptr[GptModelInputIndex::prefixLengths];

        inputs.combo_tokens = device->allocateBuffer(
            {rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::comboTokens]}, rtp_llm::AllocationType::HOST});
        inputs.input_lengths = device->allocateBuffer(
            {rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::inputLengths]}, rtp_llm::AllocationType::HOST});
        inputs.sequence_lengths = device->allocateBuffer(
            {rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::sequenceLengths]}, rtp_llm::AllocationType::HOST});
        inputs.prefix_lengths = device->allocateBuffer(
             {rtp_llm::DataType::TYPE_INT32, {context_batch_size}, rtp_llm::AllocationType::HOST});
        if (max_blocks != 0) {
            inputs.kv_cache_block_id = device->allocateBuffer(
                    {rtp_llm::DataType::TYPE_INT32,
                    {(size_t)shape_hints_ptr[GptModelInputIndex::inputLengths], max_blocks}, rtp_llm::AllocationType::HOST});
            inputs.cache_keys = device->allocateBuffer(
                    {rtp_llm::DataType::TYPE_INT64, {context_batch_size, max_blocks}, rtp_llm::AllocationType::HOST});
        }
        inputs.request_id = device->allocateBuffer(
            {rtp_llm::DataType::TYPE_INT64, {context_batch_size}, rtp_llm::AllocationType::HOST});
        inputs.request_pd_separation = device->allocateBuffer(
            {rtp_llm::DataType::TYPE_BOOL, {context_batch_size}, rtp_llm::AllocationType::HOST});
        inputs.lm_output_indexes = device->allocateBuffer(
            {rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::lmOutputIndexes]}, rtp_llm::AllocationType::HOST});
        if (combo_position_ids_size) {
            inputs.combo_position_ids = device->allocateBuffer(
                {rtp_llm::DataType::TYPE_INT32, {(size_t)combo_position_ids_size}, rtp_llm::AllocationType::HOST});
        }
        if (shape_hints_ptr[GptModelInputIndex::loraIds]) {
            inputs.lora_ids = device->allocateBuffer(
                {rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::loraIds]}, rtp_llm::AllocationType::HOST});
        }
        if (shape_hints_ptr[GptModelInputIndex::loraInputLengths]) {
            inputs.lora_input_lengths = device->allocateBuffer(
                {rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::loraInputLengths]}, rtp_llm::AllocationType::HOST});
        }
        if (shape_hints_ptr[GptModelInputIndex::mtpHiddenStates]) {
            auto hidden_states_dim0 = (size_t)shape_hints_ptr[GptModelInputIndex::comboTokens];
            auto hidden_states_dim1 = (size_t)hidden_states_size / hidden_states_dim0;
            RTP_LLM_CHECK(hidden_states_size % hidden_states_dim0 == 0);
            inputs.last_hidden_states = device->allocateBuffer(
                {(rtp_llm::DataType)shape_hints_ptr[GptModelInputIndex::mtpHiddenStatesDtype],
                 {hidden_states_dim0, hidden_states_dim1},
                 rtp_llm::AllocationType::DEVICE});
        }
        if (text_tokens_mask_size) {
            inputs.text_tokens_mask = device->allocateBuffer(
                {rtp_llm::DataType::TYPE_INT32, {(size_t)text_tokens_mask_size}, rtp_llm::AllocationType::HOST});
        }
        if (mm_features_locs_size) {
            inputs.mm_features_locs = device->allocateBuffer(
                {rtp_llm::DataType::TYPE_INT32, {(size_t)mm_features_locs_size}, rtp_llm::AllocationType::HOST});
        }
        if (mm_features_num) {
            std::vector<rtp_llm::BufferPtr> mm_features;
            for (auto mm_index = 0; mm_index < mm_features_num; ++mm_index) {
                mm_features.emplace_back(
                    device->allocateBuffer(
                        {(rtp_llm::DataType)shape_hints_ptr[GptModelInputIndex::mmFeaturesDtype],
                         {(size_t)mm_features_shape_ptr[mm_index], (size_t)shape_hints_ptr[GptModelInputIndex::mmFeaturesSize]},
                         rtp_llm::AllocationType::DEVICE}));
            }
            inputs.multimodal_features = std::move(mm_features);
        }
    }

    std::vector<rtp_llm::BufferPtr> buffers;
    buffers.emplace_back(inputs.combo_tokens);
    buffers.emplace_back(inputs.input_lengths);
    buffers.emplace_back(inputs.sequence_lengths);
    buffers.emplace_back(inputs.prefix_lengths);
    if (max_blocks) {
        buffers.emplace_back(inputs.kv_cache_block_id);
        buffers.emplace_back(inputs.cache_keys);
    }
    buffers.emplace_back(inputs.request_id);
    buffers.emplace_back(inputs.request_pd_separation);
    buffers.emplace_back(inputs.lm_output_indexes);
    if (combo_position_ids_size) {
        buffers.emplace_back(inputs.combo_position_ids);
    }
    buffers.emplace_back(inputs.lora_ids);
    buffers.emplace_back(inputs.lora_input_lengths);
    if (text_tokens_mask_size) {
        buffers.emplace_back(inputs.text_tokens_mask);
    }
    if (mm_features_locs_size) {
        buffers.emplace_back(inputs.mm_features_locs);
    }
    if (mm_features_num) {
        for (auto& mm_feature: inputs.multimodal_features.value()) {
            buffers.emplace_back(mm_feature);
        }
    }
    if (hidden_states_size) {
        buffers.emplace_back(inputs.last_hidden_states);
    }
    device->broadcast({buffers, 0});
    device->syncAndCheck();
}