protocol/triple/client.go (208 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package triple import ( "context" "crypto/tls" "fmt" "net" "net/http" "strings" ) import ( "github.com/dubbogo/gost/log/logger" "github.com/dustin/go-humanize" "golang.org/x/net/http2" ) import ( "dubbo.apache.org/dubbo-go/v3/common" "dubbo.apache.org/dubbo-go/v3/common/constant" "dubbo.apache.org/dubbo-go/v3/config" tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" ) const ( httpPrefix string = "http://" httpsPrefix string = "https://" ) // clientManager wraps triple clients and is responsible for find concrete triple client to invoke // callUnary, callClientStream, callServerStream, callBidiStream. // A Reference has a clientManager. type clientManager struct { isIDL bool // triple_protocol clients, key is method name triClients map[string]*tri.Client } func (cm *clientManager) getClient(method string) (*tri.Client, error) { triClient, ok := cm.triClients[method] if !ok { return nil, fmt.Errorf("missing triple client for method: %s", method) } return triClient, nil } func (cm *clientManager) callUnary(ctx context.Context, method string, req, resp any) error { triClient, err := cm.getClient(method) if err != nil { return err } triReq := tri.NewRequest(req) triResp := tri.NewResponse(resp) if err := triClient.CallUnary(ctx, triReq, triResp); err != nil { return err } serverAttachments, ok := ctx.Value(constant.AttachmentServerKey).(map[string]interface{}) if !ok { return nil } for k, v := range triResp.Trailer() { if ok := isFilterHeader(k); ok { continue } if len(v) > 0 { serverAttachments[k] = v[0] } } return nil } func (cm *clientManager) callClientStream(ctx context.Context, method string) (any, error) { triClient, err := cm.getClient(method) if err != nil { return nil, err } stream, err := triClient.CallClientStream(ctx) if err != nil { return nil, err } return stream, nil } func (cm *clientManager) callServerStream(ctx context.Context, method string, req any) (any, error) { triClient, err := cm.getClient(method) if err != nil { return nil, err } triReq := tri.NewRequest(req) stream, err := triClient.CallServerStream(ctx, triReq) if err != nil { return nil, err } return stream, nil } func (cm *clientManager) callBidiStream(ctx context.Context, method string) (any, error) { triClient, err := cm.getClient(method) if err != nil { return nil, err } stream, err := triClient.CallBidiStream(ctx) if err != nil { return nil, err } return stream, nil } func (cm *clientManager) close() error { // There is no need to release resources right now. // But we leave this function here for future use. return nil } // newClientManager extracts configurations from url and builds clientManager func newClientManager(url *common.URL) (*clientManager, error) { var cliOpts []tri.ClientOption // set max send and recv msg size maxCallRecvMsgSize := constant.DefaultMaxCallRecvMsgSize if recvMsgSize, err := humanize.ParseBytes(url.GetParam(constant.MaxCallRecvMsgSize, "")); err == nil && recvMsgSize > 0 { maxCallRecvMsgSize = int(recvMsgSize) } cliOpts = append(cliOpts, tri.WithReadMaxBytes(maxCallRecvMsgSize)) maxCallSendMsgSize := constant.DefaultMaxCallSendMsgSize if sendMsgSize, err := humanize.ParseBytes(url.GetParam(constant.MaxCallSendMsgSize, "")); err == nil && sendMsgSize > 0 { maxCallSendMsgSize = int(sendMsgSize) } cliOpts = append(cliOpts, tri.WithSendMaxBytes(maxCallSendMsgSize)) // set keepalive interval and keepalive timeout keepAliveInterval := url.GetParamDuration(constant.KeepAliveInterval, constant.DefaultKeepAliveInterval) keepAliveTimeout := url.GetParamDuration(constant.KeepAliveTimeout, constant.DefaultKeepAliveTimeout) var isIDL bool // set serialization serialization := url.GetParam(constant.SerializationKey, constant.ProtobufSerialization) switch serialization { case constant.ProtobufSerialization: isIDL = true case constant.JSONSerialization: isIDL = true cliOpts = append(cliOpts, tri.WithProtoJSON()) case constant.Hessian2Serialization: cliOpts = append(cliOpts, tri.WithHessian2()) case constant.MsgpackSerialization: cliOpts = append(cliOpts, tri.WithMsgPack()) default: panic(fmt.Sprintf("Unsupported serialization: %s", serialization)) } // set timeout timeout := url.GetParamDuration(constant.TimeoutKey, "") cliOpts = append(cliOpts, tri.WithTimeout(timeout)) // set service group and version group := url.GetParam(constant.GroupKey, "") version := url.GetParam(constant.VersionKey, "") cliOpts = append(cliOpts, tri.WithGroup(group), tri.WithVersion(version)) // todo(DMwangnima): support opentracing // todo(DMwangnima): support TLS in an ideal way var cfg *tls.Config var tlsFlag bool var err error // handle tls config // TODO: think about a more elegant way to configure tls, // Maybe we can try to create a ClientOptions for unified settings, // after this function becomes bloated. // TODO: Once the global replacement of the config is completed, // replace config with global. if tlsConfig := config.GetRootConfig().TLSConfig; tlsConfig != nil { cfg, err = config.GetClientTlsConfig(&config.TLSConfig{ CACertFile: tlsConfig.CACertFile, TLSCertFile: tlsConfig.TLSCertFile, TLSKeyFile: tlsConfig.TLSKeyFile, TLSServerName: tlsConfig.TLSServerName, }) if err != nil { return nil, err } logger.Infof("TRIPLE clientManager initialized the TLSConfig configuration") tlsFlag = true } var transport http.RoundTripper callType := url.GetParam(constant.CallHTTPTypeKey, constant.CallHTTP2) switch callType { case constant.CallHTTP: transport = &http.Transport{ TLSClientConfig: cfg, } cliOpts = append(cliOpts, tri.WithTriple()) case constant.CallHTTP2: if tlsFlag { transport = &http2.Transport{ TLSClientConfig: cfg, } } else { transport = &http2.Transport{ DialTLSContext: func(_ context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { return net.Dial(network, addr) }, AllowHTTP: true, ReadIdleTimeout: keepAliveInterval, PingTimeout: keepAliveTimeout, } } default: panic(fmt.Sprintf("Unsupported callType: %s", callType)) } httpClient := &http.Client{ Transport: transport, } var baseTriURL string baseTriURL = strings.TrimPrefix(url.Location, httpPrefix) baseTriURL = strings.TrimPrefix(baseTriURL, httpsPrefix) if tlsFlag { baseTriURL = httpsPrefix + baseTriURL } else { baseTriURL = httpPrefix + baseTriURL } triClients := make(map[string]*tri.Client) for _, method := range url.Methods { triURL, err := joinPath(baseTriURL, url.Interface(), method) if err != nil { return nil, fmt.Errorf("JoinPath failed for base %s, interface %s, method %s", baseTriURL, url.Interface(), method) } triClient := tri.NewClient(httpClient, triURL, cliOpts...) triClients[method] = triClient } return &clientManager{ isIDL: isIDL, triClients: triClients, }, nil } func isFilterHeader(key string) bool { if key != "" && key[0] == ':' { return true } switch key { case constant.GrpcHeaderMessage, constant.GrpcHeaderStatus: return true default: return false } }