banyand/queue/sub/sub.go (164 lines of code) (raw):
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Package sub implements the queue server.
package sub
import (
"fmt"
"io"
"time"
"github.com/pkg/errors"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/apache/skywalking-banyandb/api/common"
"github.com/apache/skywalking-banyandb/api/data"
clusterv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/cluster/v1"
"github.com/apache/skywalking-banyandb/pkg/bus"
"github.com/apache/skywalking-banyandb/pkg/logger"
)
func (s *server) Send(stream clusterv1.Service_SendServer) error {
ctx := stream.Context()
var topic *bus.Topic
var m bus.Message
var dataCollection []any
start := time.Now()
defer func() {
if topic != nil {
s.metrics.totalFinished.Inc(1, topic.String())
s.metrics.totalLatency.Inc(time.Since(start).Seconds(), topic.String())
}
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
writeEntity, err := stream.Recv()
if errors.Is(err, io.EOF) {
s.handleEOF(stream, topic, dataCollection, writeEntity)
return nil
}
if err != nil {
return s.handleRecvError(err)
}
s.metrics.totalMsgReceived.Inc(1, writeEntity.Topic)
if writeEntity.Topic != "" && topic == nil {
t, ok := data.TopicMap[writeEntity.Topic]
if !ok {
s.reply(stream, writeEntity, err, "invalid topic")
continue
}
topic = &t
}
if topic == nil {
s.reply(stream, writeEntity, err, "topic is empty")
continue
}
if reqSupplier, ok := data.TopicRequestMap[*topic]; ok {
req := reqSupplier()
if errUnmarshal := writeEntity.Body.UnmarshalTo(req); errUnmarshal != nil {
s.reply(stream, writeEntity, errUnmarshal, "failed to unmarshal message")
continue
}
m = bus.NewMessage(bus.MessageID(writeEntity.MessageId), req)
} else {
s.reply(stream, writeEntity, err, "unknown topic")
continue
}
if writeEntity.BatchMod {
s.handleBatch(&dataCollection, writeEntity, &start)
continue
}
s.metrics.totalStarted.Inc(1, writeEntity.Topic)
listeners := s.getListeners(*topic)
if len(listeners) == 0 {
s.reply(stream, writeEntity, err, "no listener found")
continue
}
if len(listeners) > 1 {
logger.Panicf("multiple listeners found for topic %s", *topic)
}
listener := listeners[0]
m = listener.Rev(ctx, m)
if m.Data() == nil {
if errSend := stream.Send(&clusterv1.SendResponse{
MessageId: writeEntity.MessageId,
}); errSend != nil {
s.log.Error().Stringer("request", writeEntity).Err(errSend).Msg("failed to send empty response")
s.metrics.totalMsgSentErr.Inc(1, writeEntity.Topic)
continue
}
s.metrics.totalMsgSent.Inc(1, writeEntity.Topic)
continue
}
var message proto.Message
switch d := m.Data().(type) {
case proto.Message:
message = d
case *common.Error:
select {
case <-ctx.Done():
s.metrics.totalMsgReceivedErr.Inc(1, writeEntity.Topic)
return ctx.Err()
default:
}
s.reply(stream, writeEntity, nil, d.Error())
continue
default:
s.reply(stream, writeEntity, nil, fmt.Sprintf("invalid response: %T", d))
continue
}
anyMessage, err := anypb.New(message)
if err != nil {
s.reply(stream, writeEntity, err, "failed to marshal message")
continue
}
if err := stream.Send(&clusterv1.SendResponse{
MessageId: writeEntity.MessageId,
Body: anyMessage,
}); err != nil {
s.log.Error().Stringer("request", writeEntity).Dur("latency", time.Since(start)).Err(err).Msg("failed to send query response")
s.metrics.totalMsgSentErr.Inc(1, writeEntity.Topic)
continue
}
s.metrics.totalMsgSent.Inc(1, writeEntity.Topic)
}
}
func (s *server) Subscribe(topic bus.Topic, listener bus.MessageListener) error {
s.listenersLock.Lock()
defer s.listenersLock.Unlock()
listeners, ok := s.listeners[topic]
if ok {
listeners = append(listeners, listener)
s.listeners[topic] = listeners
return nil
}
listeners = make([]bus.MessageListener, 0)
listeners = append(listeners, listener)
s.listeners[topic] = listeners
s.topicMap[topic.String()] = topic
return nil
}
func (s *server) getListeners(topic bus.Topic) []bus.MessageListener {
s.listenersLock.RLock()
defer s.listenersLock.RUnlock()
return s.listeners[topic]
}
func (s *server) reply(stream clusterv1.Service_SendServer, writeEntity *clusterv1.SendRequest, err error, message string) {
s.log.Error().Stringer("request", writeEntity).Err(err).Msg(message)
s.metrics.totalMsgReceivedErr.Inc(1, writeEntity.Topic)
resp := &clusterv1.SendResponse{
MessageId: writeEntity.MessageId,
}
var ce *common.Error
if errors.As(err, &ce) {
resp.Error = ce.Error()
resp.Status = ce.Status()
} else {
resp.Error = message
}
if errResp := stream.Send(&clusterv1.SendResponse{
MessageId: writeEntity.MessageId,
Error: message,
}); errResp != nil {
s.log.Error().Err(errResp).AnErr("original", err).Stringer("request", writeEntity).Msg("failed to send error response")
s.metrics.totalMsgSentErr.Inc(1, writeEntity.Topic)
}
}