banyand/queue/pub/pub.go (315 lines of code) (raw):
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Package pub implements the queue client.
package pub
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/pkg/errors"
"go.uber.org/multierr"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/apache/skywalking-banyandb/api/common"
"github.com/apache/skywalking-banyandb/api/data"
clusterv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/cluster/v1"
databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
modelv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
"github.com/apache/skywalking-banyandb/banyand/metadata"
"github.com/apache/skywalking-banyandb/banyand/metadata/schema"
"github.com/apache/skywalking-banyandb/banyand/queue"
"github.com/apache/skywalking-banyandb/pkg/bus"
"github.com/apache/skywalking-banyandb/pkg/grpchelper"
"github.com/apache/skywalking-banyandb/pkg/logger"
"github.com/apache/skywalking-banyandb/pkg/run"
)
var (
_ run.PreRunner = (*pub)(nil)
_ run.Service = (*pub)(nil)
_ run.Config = (*pub)(nil)
)
type pub struct {
schema.UnimplementedOnInitHandler
metadata metadata.Repo
handlers map[bus.Topic]schema.EventHandler
log *logger.Logger
registered map[string]*databasev1.Node
active map[string]*client
evictable map[string]evictNode
closer *run.Closer
caCertPath string
mu sync.RWMutex
tlsEnabled bool
}
func (p *pub) FlagSet() *run.FlagSet {
fs := run.NewFlagSet("queue-client")
fs.BoolVar(&p.tlsEnabled, "internal-tls", false, "enable internal TLS")
fs.StringVar(&p.caCertPath, "internal-ca-cert", "", "CA certificate file to verify the internal data server")
return fs
}
func (p *pub) Validate() error {
// simple sanity‑check: if TLS is on, a CA bundle must be provided
if p.tlsEnabled && p.caCertPath == "" {
return fmt.Errorf("TLS is enabled (--internal-tls), but no CA certificate file was provided (--internal-ca-cert is required)")
}
return nil
}
func (p *pub) Register(topic bus.Topic, handler schema.EventHandler) {
p.handlers[topic] = handler
}
func (p *pub) GracefulStop() {
p.mu.Lock()
defer p.mu.Unlock()
for i := range p.evictable {
close(p.evictable[i].c)
}
p.evictable = nil
p.closer.Done()
p.closer.CloseThenWait()
for _, c := range p.active {
_ = c.conn.Close()
}
p.active = nil
}
// Serve implements run.Service.
func (p *pub) Serve() run.StopNotify {
return p.closer.CloseNotify()
}
var bypassMatches = []MatchFunc{bypassMatch}
func bypassMatch(_ map[string]string) bool { return true }
func (p *pub) Broadcast(timeout time.Duration, topic bus.Topic, messages bus.Message) ([]bus.Future, error) {
var nodes []*databasev1.Node
p.mu.RLock()
for k := range p.active {
if n := p.registered[k]; n != nil {
nodes = append(nodes, n)
}
}
p.mu.RUnlock()
if len(nodes) == 0 {
return nil, errors.New("no active nodes")
}
names := make(map[string]struct{})
if len(messages.NodeSelectors()) == 0 {
for _, n := range nodes {
names[n.Metadata.GetName()] = struct{}{}
}
} else {
for _, sel := range messages.NodeSelectors() {
var matches []MatchFunc
if sel == nil {
matches = bypassMatches
} else {
for _, s := range sel {
selector, err := ParseLabelSelector(s)
if err != nil {
return nil, fmt.Errorf("failed to parse node selector: %w", err)
}
matches = append(matches, selector.Matches)
}
}
for _, n := range nodes {
for _, m := range matches {
if m(n.Labels) {
names[n.Metadata.Name] = struct{}{}
break
}
}
}
}
}
if l := p.log.Debug(); l.Enabled() {
l.Msgf("broadcasting message to %s nodes", names)
}
if len(names) == 0 {
return nil, fmt.Errorf("no nodes match the selector %v", messages.NodeSelectors())
}
futureCh := make(chan publishResult, len(names))
var wg sync.WaitGroup
for n := range names {
wg.Add(1)
go func(n string) {
defer wg.Done()
f, err := p.publish(timeout, topic, bus.NewMessageWithNode(messages.ID(), n, messages.Data()))
futureCh <- publishResult{n: n, f: f, e: err}
}(n)
}
go func() {
wg.Wait()
close(futureCh)
}()
var futures []bus.Future
var errs error
for f := range futureCh {
if f.e != nil {
errs = multierr.Append(errs, errors.Wrapf(f.e, "failed to publish message to %s", f.n))
if isFailoverError(f.e) {
if p.closer.AddRunning() {
go func() {
defer p.closer.Done()
p.failover(f.n, common.NewErrorWithStatus(modelv1.Status_STATUS_INTERNAL_ERROR, f.e.Error()), topic)
}()
}
}
continue
}
futures = append(futures, f.f)
}
if errs != nil {
return futures, fmt.Errorf("broadcast errors: %w", errs)
}
return futures, nil
}
type publishResult struct {
f bus.Future
e error
n string
}
func (p *pub) publish(timeout time.Duration, topic bus.Topic, messages ...bus.Message) (bus.Future, error) {
var err error
f := &future{}
handleMessage := func(m bus.Message, err error) error {
r, errSend := messageToRequest(topic, m)
if errSend != nil {
return multierr.Append(err, fmt.Errorf("failed to marshal message[%d]: %w", m.ID(), errSend))
}
node := m.Node()
p.mu.RLock()
client, ok := p.active[node]
p.mu.RUnlock()
if !ok {
return multierr.Append(err, fmt.Errorf("failed to get client for node %s", node))
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
f.cancelFn = append(f.cancelFn, cancel)
stream, errCreateStream := client.client.Send(ctx)
if errCreateStream != nil {
return multierr.Append(err, fmt.Errorf("failed to get stream for node %s: %w", node, errCreateStream))
}
errSend = stream.Send(r)
if errSend != nil {
return multierr.Append(err, fmt.Errorf("failed to send message to node %s: %w", node, errSend))
}
f.clients = append(f.clients, stream)
f.topics = append(f.topics, topic)
return err
}
for _, m := range messages {
err = handleMessage(m, err)
}
return f, err
}
func (p *pub) Publish(_ context.Context, topic bus.Topic, messages ...bus.Message) (bus.Future, error) {
// nolint: contextcheck
return p.publish(15*time.Second, topic, messages...)
}
// New returns a new queue client.
func New(metadata metadata.Repo) queue.Client {
return &pub{
metadata: metadata,
active: make(map[string]*client),
evictable: make(map[string]evictNode),
registered: make(map[string]*databasev1.Node),
handlers: make(map[bus.Topic]schema.EventHandler),
closer: run.NewCloser(1),
}
}
// NewWithoutMetadata returns a new queue client without metadata.
func NewWithoutMetadata() queue.Client {
p := New(nil)
p.(*pub).log = logger.GetLogger("queue-client")
return p
}
func (*pub) Name() string {
return "queue-client"
}
func (p *pub) PreRun(context.Context) error {
if p.metadata != nil {
p.metadata.RegisterHandler("queue-client", schema.KindNode, p)
}
p.log = logger.GetLogger("server-queue-pub")
return nil
}
func messageToRequest(topic bus.Topic, m bus.Message) (*clusterv1.SendRequest, error) {
r := &clusterv1.SendRequest{
Topic: topic.String(),
MessageId: uint64(m.ID()),
BatchMod: m.BatchModeEnabled(),
}
message, ok := m.Data().(proto.Message)
if !ok {
return nil, fmt.Errorf("invalid message type %T", m.Data())
}
anyMessage, err := anypb.New(message)
if err != nil {
return nil, fmt.Errorf("failed to marshal message %T: %w", m, err)
}
r.Body = anyMessage
return r, nil
}
type future struct {
clients []clusterv1.Service_SendClient
cancelFn []func()
topics []bus.Topic
}
func (l *future) Get() (bus.Message, error) {
if len(l.clients) < 1 {
return bus.Message{}, io.EOF
}
c := l.clients[0]
t := l.topics[0]
defer func() {
l.clients = l.clients[1:]
l.topics = l.topics[1:]
l.cancelFn[0]()
l.cancelFn = l.cancelFn[1:]
}()
resp, err := c.Recv()
if err != nil {
return bus.Message{}, err
}
if resp.Error != "" {
return bus.Message{}, errors.New(resp.Error)
}
if resp.Body == nil {
return bus.NewMessage(bus.MessageID(resp.MessageId), nil), nil
}
if messageSupplier, ok := data.TopicResponseMap[t]; ok {
m := messageSupplier()
err = resp.Body.UnmarshalTo(m)
if err != nil {
return bus.Message{}, err
}
return bus.NewMessage(
bus.MessageID(resp.MessageId),
m,
), nil
}
return bus.Message{}, fmt.Errorf("invalid topic %s", t)
}
func (l *future) GetAll() ([]bus.Message, error) {
var globalErr error
ret := make([]bus.Message, 0, len(l.clients))
for {
m, err := l.Get()
if errors.Is(err, io.EOF) {
return ret, globalErr
}
if err != nil {
globalErr = multierr.Append(globalErr, err)
continue
}
ret = append(ret, m)
}
}
func isFailoverError(err error) bool {
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.Unavailable || s.Code() == codes.DeadlineExceeded
}
func (p *pub) getClientTransportCredentials() ([]grpc.DialOption, error) {
opts, err := grpchelper.SecureOptions(nil, p.tlsEnabled, false, p.caCertPath)
if err != nil {
return nil, fmt.Errorf("failed to load TLS config: %w", err)
}
return opts, nil
}