gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/reflection/serverreflection.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  /*
    20  Package reflection implements server reflection service.
    21  
    22  The service implemented is defined in:
    23  https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
    24  
    25  To register server reflection on a gRPC server:
    26  	import "gitee.com/zhaochuninhefei/gmgo/grpc/reflection"
    27  
    28  	s := grpc.NewServer()
    29  	pb.RegisterYourOwnServer(s, &server{})
    30  
    31  	// Register reflection service on gRPC server.
    32  	reflection.Register(s)
    33  
    34  	s.Serve(lis)
    35  
    36  */
    37  package reflection // import "gitee.com/zhaochuninhefei/gmgo/grpc/reflection"
    38  
    39  import (
    40  	"bytes"
    41  	"compress/gzip"
    42  	"fmt"
    43  	"io"
    44  	"io/ioutil"
    45  	"reflect"
    46  	"sort"
    47  	"sync"
    48  
    49  	"gitee.com/zhaochuninhefei/gmgo/grpc"
    50  	"gitee.com/zhaochuninhefei/gmgo/grpc/codes"
    51  	rpb "gitee.com/zhaochuninhefei/gmgo/grpc/reflection/grpc_reflection_v1alpha"
    52  	"gitee.com/zhaochuninhefei/gmgo/grpc/status"
    53  	"github.com/golang/protobuf/proto"
    54  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    55  )
    56  
    57  // GRPCServer is the interface provided by a gRPC server. It is implemented by
    58  // *grpc.Server, but could also be implemented by other concrete types. It acts
    59  // as a registry, for accumulating the services exposed by the server.
    60  type GRPCServer interface {
    61  	grpc.ServiceRegistrar
    62  	GetServiceInfo() map[string]grpc.ServiceInfo
    63  }
    64  
    65  var _ GRPCServer = (*grpc.Server)(nil)
    66  
    67  type serverReflectionServer struct {
    68  	rpb.UnimplementedServerReflectionServer
    69  	s GRPCServer
    70  
    71  	initSymbols  sync.Once
    72  	serviceNames []string
    73  	symbols      map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files
    74  }
    75  
    76  // Register registers the server reflection service on the given gRPC server.
    77  func Register(s GRPCServer) {
    78  	rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
    79  		s: s,
    80  	})
    81  }
    82  
    83  // protoMessage is used for type assertion on proto messages.
    84  // Generated proto message implements function Descriptor(), but Descriptor()
    85  // is not part of interface proto.Message. This interface is needed to
    86  // call Descriptor().
    87  type protoMessage interface {
    88  	Descriptor() ([]byte, []int)
    89  }
    90  
    91  func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) {
    92  	s.initSymbols.Do(func() {
    93  		serviceInfo := s.s.GetServiceInfo()
    94  
    95  		s.symbols = map[string]*dpb.FileDescriptorProto{}
    96  		s.serviceNames = make([]string, 0, len(serviceInfo))
    97  		processed := map[string]struct{}{}
    98  		for svc, info := range serviceInfo {
    99  			s.serviceNames = append(s.serviceNames, svc)
   100  			fdenc, ok := parseMetadata(info.Metadata)
   101  			if !ok {
   102  				continue
   103  			}
   104  			fd, err := decodeFileDesc(fdenc)
   105  			if err != nil {
   106  				continue
   107  			}
   108  			s.processFile(fd, processed)
   109  		}
   110  		sort.Strings(s.serviceNames)
   111  	})
   112  
   113  	return s.serviceNames, s.symbols
   114  }
   115  
   116  func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) {
   117  	filename := fd.GetName()
   118  	if _, ok := processed[filename]; ok {
   119  		return
   120  	}
   121  	processed[filename] = struct{}{}
   122  
   123  	prefix := fd.GetPackage()
   124  
   125  	for _, msg := range fd.MessageType {
   126  		s.processMessage(fd, prefix, msg)
   127  	}
   128  	for _, en := range fd.EnumType {
   129  		s.processEnum(fd, prefix, en)
   130  	}
   131  	for _, ext := range fd.Extension {
   132  		s.processField(fd, prefix, ext)
   133  	}
   134  	for _, svc := range fd.Service {
   135  		svcName := fqn(prefix, svc.GetName())
   136  		s.symbols[svcName] = fd
   137  		for _, meth := range svc.Method {
   138  			name := fqn(svcName, meth.GetName())
   139  			s.symbols[name] = fd
   140  		}
   141  	}
   142  
   143  	for _, dep := range fd.Dependency {
   144  		//goland:noinspection GoDeprecation
   145  		fdenc := proto.FileDescriptor(dep)
   146  		fdDep, err := decodeFileDesc(fdenc)
   147  		if err != nil {
   148  			continue
   149  		}
   150  		s.processFile(fdDep, processed)
   151  	}
   152  }
   153  
   154  func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) {
   155  	msgName := fqn(prefix, msg.GetName())
   156  	s.symbols[msgName] = fd
   157  
   158  	for _, nested := range msg.NestedType {
   159  		s.processMessage(fd, msgName, nested)
   160  	}
   161  	for _, en := range msg.EnumType {
   162  		s.processEnum(fd, msgName, en)
   163  	}
   164  	for _, ext := range msg.Extension {
   165  		s.processField(fd, msgName, ext)
   166  	}
   167  	for _, fld := range msg.Field {
   168  		s.processField(fd, msgName, fld)
   169  	}
   170  	for _, oneof := range msg.OneofDecl {
   171  		oneofName := fqn(msgName, oneof.GetName())
   172  		s.symbols[oneofName] = fd
   173  	}
   174  }
   175  
   176  func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) {
   177  	enName := fqn(prefix, en.GetName())
   178  	s.symbols[enName] = fd
   179  
   180  	for _, val := range en.Value {
   181  		valName := fqn(enName, val.GetName())
   182  		s.symbols[valName] = fd
   183  	}
   184  }
   185  
   186  func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) {
   187  	fldName := fqn(prefix, fld.GetName())
   188  	s.symbols[fldName] = fd
   189  }
   190  
   191  func fqn(prefix, name string) string {
   192  	if prefix == "" {
   193  		return name
   194  	}
   195  	return prefix + "." + name
   196  }
   197  
   198  // fileDescForType gets the file descriptor for the given type.
   199  // The given type should be a proto message.
   200  func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
   201  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
   202  	if !ok {
   203  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   204  	}
   205  	enc, _ := m.Descriptor()
   206  
   207  	return decodeFileDesc(enc)
   208  }
   209  
   210  // decodeFileDesc does decompression and unmarshalling on the given
   211  // file descriptor byte slice.
   212  func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
   213  	raw, err := decompress(enc)
   214  	if err != nil {
   215  		return nil, fmt.Errorf("failed to decompress enc: %v", err)
   216  	}
   217  
   218  	fd := new(dpb.FileDescriptorProto)
   219  	if err := proto.Unmarshal(raw, fd); err != nil {
   220  		return nil, fmt.Errorf("bad descriptor: %v", err)
   221  	}
   222  	return fd, nil
   223  }
   224  
   225  // decompress does gzip decompression.
   226  func decompress(b []byte) ([]byte, error) {
   227  	r, err := gzip.NewReader(bytes.NewReader(b))
   228  	if err != nil {
   229  		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
   230  	}
   231  	out, err := ioutil.ReadAll(r)
   232  	if err != nil {
   233  		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
   234  	}
   235  	return out, nil
   236  }
   237  
   238  func typeForName(name string) (reflect.Type, error) {
   239  	//goland:noinspection GoDeprecation
   240  	pt := proto.MessageType(name)
   241  	if pt == nil {
   242  		return nil, fmt.Errorf("unknown type: %q", name)
   243  	}
   244  	st := pt.Elem()
   245  
   246  	return st, nil
   247  }
   248  
   249  func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
   250  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
   251  	if !ok {
   252  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   253  	}
   254  
   255  	var extDesc *proto.ExtensionDesc
   256  	//goland:noinspection GoDeprecation
   257  	for id, desc := range proto.RegisteredExtensions(m) {
   258  		if id == ext {
   259  			extDesc = desc
   260  			break
   261  		}
   262  	}
   263  
   264  	if extDesc == nil {
   265  		return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
   266  	}
   267  
   268  	//goland:noinspection GoDeprecation
   269  	return decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
   270  }
   271  
   272  func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
   273  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
   274  	if !ok {
   275  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   276  	}
   277  
   278  	//goland:noinspection GoDeprecation
   279  	exts := proto.RegisteredExtensions(m)
   280  	out := make([]int32, 0, len(exts))
   281  	for id := range exts {
   282  		out = append(out, id)
   283  	}
   284  	return out, nil
   285  }
   286  
   287  // fileDescWithDependencies returns a slice of serialized fileDescriptors in
   288  // wire format ([]byte). The fileDescriptors will include fd and all the
   289  // transitive dependencies of fd with names not in sentFileDescriptors.
   290  func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
   291  	var r [][]byte
   292  	queue := []*dpb.FileDescriptorProto{fd}
   293  	for len(queue) > 0 {
   294  		currentfd := queue[0]
   295  		queue = queue[1:]
   296  		if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
   297  			sentFileDescriptors[currentfd.GetName()] = true
   298  			currentfdEncoded, err := proto.Marshal(currentfd)
   299  			if err != nil {
   300  				return nil, err
   301  			}
   302  			r = append(r, currentfdEncoded)
   303  		}
   304  		for _, dep := range currentfd.Dependency {
   305  			//goland:noinspection GoDeprecation
   306  			fdenc := proto.FileDescriptor(dep)
   307  			fdDep, err := decodeFileDesc(fdenc)
   308  			if err != nil {
   309  				continue
   310  			}
   311  			queue = append(queue, fdDep)
   312  		}
   313  	}
   314  	return r, nil
   315  }
   316  
   317  // fileDescEncodingByFilename finds the file descriptor for given filename,
   318  // finds all of its previously unsent transitive dependencies, does marshalling
   319  // on them, and returns the marshalled result.
   320  func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
   321  	//goland:noinspection GoDeprecation
   322  	enc := proto.FileDescriptor(name)
   323  	if enc == nil {
   324  		return nil, fmt.Errorf("unknown file: %v", name)
   325  	}
   326  	fd, err := decodeFileDesc(enc)
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	return fileDescWithDependencies(fd, sentFileDescriptors)
   331  }
   332  
   333  // parseMetadata finds the file descriptor bytes specified meta.
   334  // For SupportPackageIsVersion4, m is the name of the proto file, we
   335  // call proto.FileDescriptor to get the byte slice.
   336  // For SupportPackageIsVersion3, m is a byte slice itself.
   337  func parseMetadata(meta interface{}) ([]byte, bool) {
   338  	// Check if meta is the file name.
   339  	if fileNameForMeta, ok := meta.(string); ok {
   340  		//goland:noinspection GoDeprecation
   341  		return proto.FileDescriptor(fileNameForMeta), true
   342  	}
   343  
   344  	// Check if meta is the byte slice.
   345  	if enc, ok := meta.([]byte); ok {
   346  		return enc, true
   347  	}
   348  
   349  	return nil, false
   350  }
   351  
   352  // fileDescEncodingContainingSymbol finds the file descriptor containing the
   353  // given symbol, finds all of its previously unsent transitive dependencies,
   354  // does marshalling on them, and returns the marshalled result. The given symbol
   355  // can be a type, a service or a method.
   356  func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
   357  	_, symbols := s.getSymbols()
   358  	fd := symbols[name]
   359  	if fd == nil {
   360  		// Check if it's a type name that was not present in the
   361  		// transitive dependencies of the registered services.
   362  		if st, err := typeForName(name); err == nil {
   363  			fd, err = s.fileDescForType(st)
   364  			if err != nil {
   365  				return nil, err
   366  			}
   367  		}
   368  	}
   369  
   370  	if fd == nil {
   371  		return nil, fmt.Errorf("unknown symbol: %v", name)
   372  	}
   373  
   374  	return fileDescWithDependencies(fd, sentFileDescriptors)
   375  }
   376  
   377  // fileDescEncodingContainingExtension finds the file descriptor containing
   378  // given extension, finds all of its previously unsent transitive dependencies,
   379  // does marshalling on them, and returns the marshalled result.
   380  func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
   381  	st, err := typeForName(typeName)
   382  	if err != nil {
   383  		return nil, err
   384  	}
   385  	fd, err := fileDescContainingExtension(st, extNum)
   386  	if err != nil {
   387  		return nil, err
   388  	}
   389  	return fileDescWithDependencies(fd, sentFileDescriptors)
   390  }
   391  
   392  // allExtensionNumbersForTypeName returns all extension numbers for the given type.
   393  func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
   394  	st, err := typeForName(name)
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  	extNums, err := s.allExtensionNumbersForType(st)
   399  	if err != nil {
   400  		return nil, err
   401  	}
   402  	return extNums, nil
   403  }
   404  
   405  // ServerReflectionInfo is the reflection service handler.
   406  func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
   407  	sentFileDescriptors := make(map[string]bool)
   408  	for {
   409  		in, err := stream.Recv()
   410  		if err == io.EOF {
   411  			return nil
   412  		}
   413  		if err != nil {
   414  			return err
   415  		}
   416  
   417  		out := &rpb.ServerReflectionResponse{
   418  			ValidHost:       in.Host,
   419  			OriginalRequest: in,
   420  		}
   421  		switch req := in.MessageRequest.(type) {
   422  		case *rpb.ServerReflectionRequest_FileByFilename:
   423  			b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors)
   424  			if err != nil {
   425  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   426  					ErrorResponse: &rpb.ErrorResponse{
   427  						ErrorCode:    int32(codes.NotFound),
   428  						ErrorMessage: err.Error(),
   429  					},
   430  				}
   431  			} else {
   432  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   433  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   434  				}
   435  			}
   436  		case *rpb.ServerReflectionRequest_FileContainingSymbol:
   437  			b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
   438  			if err != nil {
   439  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   440  					ErrorResponse: &rpb.ErrorResponse{
   441  						ErrorCode:    int32(codes.NotFound),
   442  						ErrorMessage: err.Error(),
   443  					},
   444  				}
   445  			} else {
   446  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   447  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   448  				}
   449  			}
   450  		case *rpb.ServerReflectionRequest_FileContainingExtension:
   451  			typeName := req.FileContainingExtension.ContainingType
   452  			extNum := req.FileContainingExtension.ExtensionNumber
   453  			b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
   454  			if err != nil {
   455  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   456  					ErrorResponse: &rpb.ErrorResponse{
   457  						ErrorCode:    int32(codes.NotFound),
   458  						ErrorMessage: err.Error(),
   459  					},
   460  				}
   461  			} else {
   462  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   463  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   464  				}
   465  			}
   466  		case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
   467  			extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
   468  			if err != nil {
   469  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   470  					ErrorResponse: &rpb.ErrorResponse{
   471  						ErrorCode:    int32(codes.NotFound),
   472  						ErrorMessage: err.Error(),
   473  					},
   474  				}
   475  			} else {
   476  				out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
   477  					AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
   478  						BaseTypeName:    req.AllExtensionNumbersOfType,
   479  						ExtensionNumber: extNums,
   480  					},
   481  				}
   482  			}
   483  		case *rpb.ServerReflectionRequest_ListServices:
   484  			svcNames, _ := s.getSymbols()
   485  			serviceResponses := make([]*rpb.ServiceResponse, len(svcNames))
   486  			for i, n := range svcNames {
   487  				serviceResponses[i] = &rpb.ServiceResponse{
   488  					Name: n,
   489  				}
   490  			}
   491  			out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
   492  				ListServicesResponse: &rpb.ListServiceResponse{
   493  					Service: serviceResponses,
   494  				},
   495  			}
   496  		default:
   497  			return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
   498  		}
   499  
   500  		if err := stream.Send(out); err != nil {
   501  			return err
   502  		}
   503  	}
   504  }