banyand/liaison/grpc/stream.go (120 lines of code) (raw):
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package grpc
import (
"context"
"io"
"time"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/apache/skywalking-banyandb/api/common"
"github.com/apache/skywalking-banyandb/api/data"
streamv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/stream/v1"
"github.com/apache/skywalking-banyandb/banyand/tsdb"
"github.com/apache/skywalking-banyandb/pkg/accesslog"
"github.com/apache/skywalking-banyandb/pkg/bus"
"github.com/apache/skywalking-banyandb/pkg/logger"
"github.com/apache/skywalking-banyandb/pkg/timestamp"
)
type streamService struct {
streamv1.UnimplementedStreamServiceServer
*discoveryService
sampled *logger.Logger
ingestionAccessLog accesslog.Log
}
func (s *streamService) setLogger(log *logger.Logger) {
s.sampled = log.Sampled(10)
}
func (s *streamService) activeIngestionAccessLog(root string) (err error) {
if s.ingestionAccessLog, err = accesslog.
NewFileLog(root, "stream-ingest-%s", 10*time.Minute, s.log); err != nil {
return err
}
return nil
}
func (s *streamService) Write(stream streamv1.StreamService_WriteServer) error {
reply := func(stream streamv1.StreamService_WriteServer, logger *logger.Logger) {
if errResp := stream.Send(&streamv1.WriteResponse{}); errResp != nil {
logger.Err(errResp).Msg("failed to send response")
}
}
ctx := stream.Context()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
writeEntity, err := stream.Recv()
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
s.sampled.Error().Stringer("written", writeEntity).Err(err).Msg("failed to receive message")
reply(stream, s.sampled)
continue
}
if errTime := timestamp.CheckPb(writeEntity.GetElement().Timestamp); errTime != nil {
s.sampled.Error().Stringer("written", writeEntity).Err(errTime).Msg("the element time is invalid")
reply(stream, s.sampled)
continue
}
entity, tagValues, shardID, err := s.navigate(writeEntity.GetMetadata(), writeEntity.GetElement().GetTagFamilies())
if err != nil {
s.sampled.Error().Err(err).RawJSON("written", logger.Proto(writeEntity)).Msg("failed to navigate to the write target")
reply(stream, s.sampled)
continue
}
if s.ingestionAccessLog != nil {
if errAccessLog := s.ingestionAccessLog.Write(writeEntity); errAccessLog != nil {
s.sampled.Error().Err(errAccessLog).Msg("failed to write ingestion access log")
}
}
iwr := &streamv1.InternalWriteRequest{
Request: writeEntity,
ShardId: uint32(shardID),
SeriesHash: tsdb.HashEntity(entity),
}
if s.log.Debug().Enabled() {
iwr.EntityValues = tagValues.Encode()
}
message := bus.NewMessage(bus.MessageID(time.Now().UnixNano()), iwr)
_, errWritePub := s.pipeline.Publish(data.TopicStreamWrite, message)
if errWritePub != nil {
s.sampled.Error().Err(errWritePub).RawJSON("written", logger.Proto(writeEntity)).Msg("failed to send a message")
}
reply(stream, s.sampled)
}
}
var emptyStreamQueryResponse = &streamv1.QueryResponse{Elements: make([]*streamv1.Element, 0)}
func (s *streamService) Query(_ context.Context, req *streamv1.QueryRequest) (*streamv1.QueryResponse, error) {
timeRange := req.GetTimeRange()
if timeRange == nil {
req.TimeRange = timestamp.DefaultTimeRange
}
if err := timestamp.CheckTimeRange(req.GetTimeRange()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v is invalid :%s", req.GetTimeRange(), err)
}
message := bus.NewMessage(bus.MessageID(time.Now().UnixNano()), req)
feat, errQuery := s.pipeline.Publish(data.TopicStreamQuery, message)
if errQuery != nil {
if errors.Is(errQuery, io.EOF) {
return emptyStreamQueryResponse, nil
}
return nil, errQuery
}
msg, errFeat := feat.Get()
if errFeat != nil {
return nil, errFeat
}
data := msg.Data()
switch d := data.(type) {
case []*streamv1.Element:
return &streamv1.QueryResponse{Elements: d}, nil
case common.Error:
return nil, errors.WithMessage(errQueryMsg, d.Msg())
}
return nil, nil
}
func (s *streamService) Close() error {
return s.ingestionAccessLog.Close()
}