pkg/admin/util/reflection/reflection.go (244 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package reflection import ( "bytes" "context" "fmt" "io" "strings" "sync" "time" ) import ( "github.com/fullstorydev/grpcurl" "github.com/jhump/protoreflect/desc" "github.com/jhump/protoreflect/grpcreflect" "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) type RPCReflection interface { // SetUserAgent sets the User-Agent header to be sent in each request. SetUserAgent(userAgent string) // SetConnectTimeout sets the timeout for dialing a target. SetConnectTimeout(timeout time.Duration) // SetKeepaliveTime sets the keepalive time for grpc connection. SetKeepaliveTime(timeout time.Duration) // SetAdditionalHeaders sets the additional headers to be sent in both reflection request and rpc request. SetAdditionalHeaders(headers map[string]string) // SetReflectionHeaders sets the additional headers to be sent in only reflection request. SetReflectionHeaders(headers map[string]string) // SetRPCHeaders sets the additional headers to be sent in only rpc request. SetRPCHeaders(headers map[string]string) // Dail to the target, you should call this method before send reflection request and rpc request Dail(ctx context.Context) error // Close the connection. Close() // ListServices returns all services in the target. ListServices() ([]string, error) // ListMethods returns all methods in the service. ListMethods(service string) ([]string, error) // Invoke invokes the method with input. Invoke(ctx context.Context, methodName, input string) (response string, err error) // TemplateString returns the template string of the message. TemplateString(messageName string) (string, error) // DescribeString returns the description string of the message. DescribeString(symbol string) (string, error) // Descriptor returns the desc.Descriptor. Descriptor(symbol string) (desc.Descriptor, error) // InputAndOutputType returns the input and output type of the method. InputAndOutputType(methodName string) (string, string, error) } type rpcReflection struct { target string // remote address userAgent string connectTimeout time.Duration keepaliveTime time.Duration // additionalHeaders will be included in both reflection request and rpc request additionalHeaders map[string]string // reflectionHeaders will be included in only the reflection request reflectionHeaders map[string]string // rpcHeaders will be included in only the rpc request rpcHeaders map[string]string mu sync.Mutex // to protect descSource and clientConn refClient *grpcreflect.Client clientConn *grpc.ClientConn descSource grpcurl.DescriptorSource } func NewRPCReflection(target string) RPCReflection { r := &rpcReflection{ target: target, } return r } func (r *rpcReflection) SetUserAgent(userAgent string) { r.userAgent = userAgent } func (r *rpcReflection) SetConnectTimeout(timeout time.Duration) { r.connectTimeout = timeout } func (r *rpcReflection) SetKeepaliveTime(timeout time.Duration) { r.keepaliveTime = timeout } func (r *rpcReflection) SetAdditionalHeaders(headers map[string]string) { r.additionalHeaders = headers } func (r *rpcReflection) SetReflectionHeaders(headers map[string]string) { r.reflectionHeaders = headers } func (r *rpcReflection) SetRPCHeaders(headers map[string]string) { r.rpcHeaders = headers } // Dail to the target, you should call this method before send reflection request and rpc request func (r *rpcReflection) Dail(ctx context.Context) error { r.mu.Lock() defer r.mu.Unlock() if r.clientConn != nil && r.descSource != nil { return nil } md := grpcurl.MetadataFromHeaders(append(headerMapToStrings(r.additionalHeaders), headerMapToStrings(r.reflectionHeaders)...)) refCtx := metadata.NewOutgoingContext(ctx, md) dail := func(ctx context.Context) (*grpc.ClientConn, error) { // add time out and keep alive dialTime := 10 * time.Second if r.connectTimeout > 0 { dialTime = r.connectTimeout } ctx, cancel := context.WithTimeout(ctx, dialTime) defer cancel() var opts []grpc.DialOption if r.keepaliveTime > 0 { timeout := r.keepaliveTime opts = append(opts, grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: timeout, Timeout: timeout, })) } var creds credentials.TransportCredentials // add user agent opts = append(opts, grpc.WithUserAgent(r.userAgent)) network := "tcp" cc, err := grpcurl.BlockingDial(ctx, network, r.target, creds, opts...) if err != nil { return nil, errors.Wrapf(err, "Failed to dial target host %s", r.target) } return cc, nil } cc, err := dail(refCtx) if err != nil { return err } r.clientConn = cc reflectionClient := grpcreflect.NewClientAuto(refCtx, cc) r.refClient = reflectionClient reflectionSource := grpcurl.DescriptorSourceFromServer(ctx, reflectionClient) r.descSource = reflectionSource return nil } func (r *rpcReflection) Close() { r.mu.Lock() defer r.mu.Unlock() if r.refClient != nil { r.refClient.Reset() r.refClient = nil } if r.clientConn != nil { r.clientConn.Close() r.clientConn = nil } } func (r *rpcReflection) ListServices() ([]string, error) { svcs, err := grpcurl.ListServices(r.descSource) if err != nil { return nil, errors.Wrap(err, "Failed to list services") } return svcs, nil } func (r *rpcReflection) ListMethods(service string) ([]string, error) { methods, err := grpcurl.ListMethods(r.descSource, service) if err != nil { return nil, errors.Wrapf(err, "Failed to list methods for service %s", service) } return methods, nil } func headerMapToStrings(headerMap map[string]string) []string { var headers []string for k, v := range headerMap { headers = append(headers, fmt.Sprintf("%v: %v", k, v)) } return headers } func (r *rpcReflection) Invoke(ctx context.Context, methodName, input string) (response string, err error) { // Invoke an RPC cc := r.clientConn // input string, request message var in io.Reader in = strings.NewReader(input) options := grpcurl.FormatOptions{ EmitJSONDefaultFields: true, IncludeTextSeparator: true, AllowUnknownFields: true, } rf, formatter, err := grpcurl.RequestParserAndFormatter(grpcurl.FormatJSON, r.descSource, in, options) if err != nil { return "", errors.Wrapf(err, "Failed to construct request parser and formatter for %v", err) } // invoke output := bytes.NewBuffer(nil) h := &grpcurl.DefaultEventHandler{ Out: output, Formatter: formatter, VerbosityLevel: 0, // no verbose } addlHeaders := headerMapToStrings(r.additionalHeaders) rpcHeaders := headerMapToStrings(r.rpcHeaders) err = grpcurl.InvokeRPC(ctx, r.descSource, cc, methodName, append(addlHeaders, rpcHeaders...), h, rf.Next) if err != nil { if errStatus, ok := status.FromError(err); ok { h.Status = errStatus } return "", err } if h.Status.Code() != codes.OK { // failed to invoke formattedStatus, err := formatter(h.Status.Proto()) if err != nil { return "", nil } return "", errors.New(formattedStatus) } // success invoke return output.String(), nil } func (r *rpcReflection) TemplateString(messageName string) (string, error) { dsc, err := r.Descriptor(messageName) if err != nil { return "", err } msgDesc, ok := dsc.(*desc.MessageDescriptor) if !ok { return "", errors.New("not a message") } // for messages, also show a template in JSON, to make it easier to // create a request to invoke an RPC tmpl := grpcurl.MakeTemplate(msgDesc) options := grpcurl.FormatOptions{EmitJSONDefaultFields: true} _, formatter, err := grpcurl.RequestParserAndFormatter(grpcurl.FormatJSON, r.descSource, nil, options) if err != nil { return "", errors.Wrapf(err, "Failed to construct formatter, err=%v", err) } template, err := formatter(tmpl) if err != nil { return "", errors.Wrapf(err, "Failed to print template for message %s", messageName) } return template, nil } func (r *rpcReflection) DescribeString(symbol string) (string, error) { dsc, err := r.Descriptor(symbol) if err != nil { return "", err } txt, err := grpcurl.GetDescriptorText(dsc, r.descSource) if err != nil { return "", errors.Wrapf(err, "Failed to describe symbol %q", symbol) } return txt, nil } func (r *rpcReflection) Descriptor(symbol string) (desc.Descriptor, error) { if symbol[0] == '.' { symbol = symbol[1:] } dsc, err := r.descSource.FindSymbol(symbol) if err != nil { return nil, errors.Wrapf(err, "Failed to resolve symbol %q", symbol) } return dsc, nil } func (r *rpcReflection) InputAndOutputType(methodName string) (string, string, error) { // get descriptor of method descriptor, err := r.Descriptor(methodName) if err != nil { return "", "", err } methodDesc, ok := descriptor.(*desc.MethodDescriptor) if !ok { return "", "", fmt.Errorf("%s is not a method", methodName) } // get input and output descriptor inputDesc := methodDesc.GetInputType() outputDesc := methodDesc.GetOutputType() return inputDesc.GetFullyQualifiedName(), outputDesc.GetFullyQualifiedName(), nil }