github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/reflection/serverreflection.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  /*
    20  Package reflection implements server reflection service.
    21  
    22  The service implemented is defined in:
    23  https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
    24  
    25  To register server reflection on a gRPC server:
    26  	import "github.com/hxx258456/ccgo/grpc/reflection"
    27  
    28  	s := grpc.NewServer()
    29  	pb.RegisterYourOwnServer(s, &server{})
    30  
    31  	// Register reflection service on gRPC server.
    32  	reflection.Register(s)
    33  
    34  	s.Serve(lis)
    35  
    36  */
    37  package reflection // import "github.com/hxx258456/ccgo/grpc/reflection"
    38  
    39  import (
    40  	"bytes"
    41  	"compress/gzip"
    42  	"fmt"
    43  	"io"
    44  	"io/ioutil"
    45  	"reflect"
    46  	"sort"
    47  	"sync"
    48  
    49  	"github.com/golang/protobuf/proto"
    50  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    51  	grpc "github.com/hxx258456/ccgo/grpc"
    52  	"github.com/hxx258456/ccgo/grpc/codes"
    53  	rpb "github.com/hxx258456/ccgo/grpc/reflection/grpc_reflection_v1alpha"
    54  	"github.com/hxx258456/ccgo/grpc/status"
    55  )
    56  
    57  // GRPCServer is the interface provided by a gRPC server. It is implemented by
    58  // *grpc.Server, but could also be implemented by other concrete types. It acts
    59  // as a registry, for accumulating the services exposed by the server.
    60  type GRPCServer interface {
    61  	grpc.ServiceRegistrar
    62  	GetServiceInfo() map[string]grpc.ServiceInfo
    63  }
    64  
    65  var _ GRPCServer = (*grpc.Server)(nil)
    66  
    67  type serverReflectionServer struct {
    68  	rpb.UnimplementedServerReflectionServer
    69  	s GRPCServer
    70  
    71  	initSymbols  sync.Once
    72  	serviceNames []string
    73  	symbols      map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files
    74  }
    75  
    76  // Register registers the server reflection service on the given gRPC server.
    77  func Register(s GRPCServer) {
    78  	rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
    79  		s: s,
    80  	})
    81  }
    82  
    83  // protoMessage is used for type assertion on proto messages.
    84  // Generated proto message implements function Descriptor(), but Descriptor()
    85  // is not part of interface proto.Message. This interface is needed to
    86  // call Descriptor().
    87  type protoMessage interface {
    88  	Descriptor() ([]byte, []int)
    89  }
    90  
    91  func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) {
    92  	s.initSymbols.Do(func() {
    93  		serviceInfo := s.s.GetServiceInfo()
    94  
    95  		s.symbols = map[string]*dpb.FileDescriptorProto{}
    96  		s.serviceNames = make([]string, 0, len(serviceInfo))
    97  		processed := map[string]struct{}{}
    98  		for svc, info := range serviceInfo {
    99  			s.serviceNames = append(s.serviceNames, svc)
   100  			fdenc, ok := parseMetadata(info.Metadata)
   101  			if !ok {
   102  				continue
   103  			}
   104  			fd, err := decodeFileDesc(fdenc)
   105  			if err != nil {
   106  				continue
   107  			}
   108  			s.processFile(fd, processed)
   109  		}
   110  		sort.Strings(s.serviceNames)
   111  	})
   112  
   113  	return s.serviceNames, s.symbols
   114  }
   115  
   116  func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) {
   117  	filename := fd.GetName()
   118  	if _, ok := processed[filename]; ok {
   119  		return
   120  	}
   121  	processed[filename] = struct{}{}
   122  
   123  	prefix := fd.GetPackage()
   124  
   125  	for _, msg := range fd.MessageType {
   126  		s.processMessage(fd, prefix, msg)
   127  	}
   128  	for _, en := range fd.EnumType {
   129  		s.processEnum(fd, prefix, en)
   130  	}
   131  	for _, ext := range fd.Extension {
   132  		s.processField(fd, prefix, ext)
   133  	}
   134  	for _, svc := range fd.Service {
   135  		svcName := fqn(prefix, svc.GetName())
   136  		s.symbols[svcName] = fd
   137  		for _, meth := range svc.Method {
   138  			name := fqn(svcName, meth.GetName())
   139  			s.symbols[name] = fd
   140  		}
   141  	}
   142  
   143  	for _, dep := range fd.Dependency {
   144  		fdenc := proto.FileDescriptor(dep)
   145  		fdDep, err := decodeFileDesc(fdenc)
   146  		if err != nil {
   147  			continue
   148  		}
   149  		s.processFile(fdDep, processed)
   150  	}
   151  }
   152  
   153  func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) {
   154  	msgName := fqn(prefix, msg.GetName())
   155  	s.symbols[msgName] = fd
   156  
   157  	for _, nested := range msg.NestedType {
   158  		s.processMessage(fd, msgName, nested)
   159  	}
   160  	for _, en := range msg.EnumType {
   161  		s.processEnum(fd, msgName, en)
   162  	}
   163  	for _, ext := range msg.Extension {
   164  		s.processField(fd, msgName, ext)
   165  	}
   166  	for _, fld := range msg.Field {
   167  		s.processField(fd, msgName, fld)
   168  	}
   169  	for _, oneof := range msg.OneofDecl {
   170  		oneofName := fqn(msgName, oneof.GetName())
   171  		s.symbols[oneofName] = fd
   172  	}
   173  }
   174  
   175  func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) {
   176  	enName := fqn(prefix, en.GetName())
   177  	s.symbols[enName] = fd
   178  
   179  	for _, val := range en.Value {
   180  		valName := fqn(enName, val.GetName())
   181  		s.symbols[valName] = fd
   182  	}
   183  }
   184  
   185  func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) {
   186  	fldName := fqn(prefix, fld.GetName())
   187  	s.symbols[fldName] = fd
   188  }
   189  
   190  func fqn(prefix, name string) string {
   191  	if prefix == "" {
   192  		return name
   193  	}
   194  	return prefix + "." + name
   195  }
   196  
   197  // fileDescForType gets the file descriptor for the given type.
   198  // The given type should be a proto message.
   199  func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
   200  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
   201  	if !ok {
   202  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   203  	}
   204  	enc, _ := m.Descriptor()
   205  
   206  	return decodeFileDesc(enc)
   207  }
   208  
   209  // decodeFileDesc does decompression and unmarshalling on the given
   210  // file descriptor byte slice.
   211  func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
   212  	raw, err := decompress(enc)
   213  	if err != nil {
   214  		return nil, fmt.Errorf("failed to decompress enc: %v", err)
   215  	}
   216  
   217  	fd := new(dpb.FileDescriptorProto)
   218  	if err := proto.Unmarshal(raw, fd); err != nil {
   219  		return nil, fmt.Errorf("bad descriptor: %v", err)
   220  	}
   221  	return fd, nil
   222  }
   223  
   224  // decompress does gzip decompression.
   225  func decompress(b []byte) ([]byte, error) {
   226  	r, err := gzip.NewReader(bytes.NewReader(b))
   227  	if err != nil {
   228  		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
   229  	}
   230  	out, err := ioutil.ReadAll(r)
   231  	if err != nil {
   232  		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
   233  	}
   234  	return out, nil
   235  }
   236  
   237  func typeForName(name string) (reflect.Type, error) {
   238  	pt := proto.MessageType(name)
   239  	if pt == nil {
   240  		return nil, fmt.Errorf("unknown type: %q", name)
   241  	}
   242  	st := pt.Elem()
   243  
   244  	return st, nil
   245  }
   246  
   247  func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
   248  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
   249  	if !ok {
   250  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   251  	}
   252  
   253  	var extDesc *proto.ExtensionDesc
   254  	for id, desc := range proto.RegisteredExtensions(m) {
   255  		if id == ext {
   256  			extDesc = desc
   257  			break
   258  		}
   259  	}
   260  
   261  	if extDesc == nil {
   262  		return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
   263  	}
   264  
   265  	return decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
   266  }
   267  
   268  func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
   269  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
   270  	if !ok {
   271  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   272  	}
   273  
   274  	exts := proto.RegisteredExtensions(m)
   275  	out := make([]int32, 0, len(exts))
   276  	for id := range exts {
   277  		out = append(out, id)
   278  	}
   279  	return out, nil
   280  }
   281  
   282  // fileDescWithDependencies returns a slice of serialized fileDescriptors in
   283  // wire format ([]byte). The fileDescriptors will include fd and all the
   284  // transitive dependencies of fd with names not in sentFileDescriptors.
   285  func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
   286  	r := [][]byte{}
   287  	queue := []*dpb.FileDescriptorProto{fd}
   288  	for len(queue) > 0 {
   289  		currentfd := queue[0]
   290  		queue = queue[1:]
   291  		if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
   292  			sentFileDescriptors[currentfd.GetName()] = true
   293  			currentfdEncoded, err := proto.Marshal(currentfd)
   294  			if err != nil {
   295  				return nil, err
   296  			}
   297  			r = append(r, currentfdEncoded)
   298  		}
   299  		for _, dep := range currentfd.Dependency {
   300  			fdenc := proto.FileDescriptor(dep)
   301  			fdDep, err := decodeFileDesc(fdenc)
   302  			if err != nil {
   303  				continue
   304  			}
   305  			queue = append(queue, fdDep)
   306  		}
   307  	}
   308  	return r, nil
   309  }
   310  
   311  // fileDescEncodingByFilename finds the file descriptor for given filename,
   312  // finds all of its previously unsent transitive dependencies, does marshalling
   313  // on them, and returns the marshalled result.
   314  func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
   315  	enc := proto.FileDescriptor(name)
   316  	if enc == nil {
   317  		return nil, fmt.Errorf("unknown file: %v", name)
   318  	}
   319  	fd, err := decodeFileDesc(enc)
   320  	if err != nil {
   321  		return nil, err
   322  	}
   323  	return fileDescWithDependencies(fd, sentFileDescriptors)
   324  }
   325  
   326  // parseMetadata finds the file descriptor bytes specified meta.
   327  // For SupportPackageIsVersion4, m is the name of the proto file, we
   328  // call proto.FileDescriptor to get the byte slice.
   329  // For SupportPackageIsVersion3, m is a byte slice itself.
   330  func parseMetadata(meta interface{}) ([]byte, bool) {
   331  	// Check if meta is the file name.
   332  	if fileNameForMeta, ok := meta.(string); ok {
   333  		return proto.FileDescriptor(fileNameForMeta), true
   334  	}
   335  
   336  	// Check if meta is the byte slice.
   337  	if enc, ok := meta.([]byte); ok {
   338  		return enc, true
   339  	}
   340  
   341  	return nil, false
   342  }
   343  
   344  // fileDescEncodingContainingSymbol finds the file descriptor containing the
   345  // given symbol, finds all of its previously unsent transitive dependencies,
   346  // does marshalling on them, and returns the marshalled result. The given symbol
   347  // can be a type, a service or a method.
   348  func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
   349  	_, symbols := s.getSymbols()
   350  	fd := symbols[name]
   351  	if fd == nil {
   352  		// Check if it's a type name that was not present in the
   353  		// transitive dependencies of the registered services.
   354  		if st, err := typeForName(name); err == nil {
   355  			fd, err = s.fileDescForType(st)
   356  			if err != nil {
   357  				return nil, err
   358  			}
   359  		}
   360  	}
   361  
   362  	if fd == nil {
   363  		return nil, fmt.Errorf("unknown symbol: %v", name)
   364  	}
   365  
   366  	return fileDescWithDependencies(fd, sentFileDescriptors)
   367  }
   368  
   369  // fileDescEncodingContainingExtension finds the file descriptor containing
   370  // given extension, finds all of its previously unsent transitive dependencies,
   371  // does marshalling on them, and returns the marshalled result.
   372  func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
   373  	st, err := typeForName(typeName)
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  	fd, err := fileDescContainingExtension(st, extNum)
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  	return fileDescWithDependencies(fd, sentFileDescriptors)
   382  }
   383  
   384  // allExtensionNumbersForTypeName returns all extension numbers for the given type.
   385  func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
   386  	st, err := typeForName(name)
   387  	if err != nil {
   388  		return nil, err
   389  	}
   390  	extNums, err := s.allExtensionNumbersForType(st)
   391  	if err != nil {
   392  		return nil, err
   393  	}
   394  	return extNums, nil
   395  }
   396  
   397  // ServerReflectionInfo is the reflection service handler.
   398  func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
   399  	sentFileDescriptors := make(map[string]bool)
   400  	for {
   401  		in, err := stream.Recv()
   402  		if err == io.EOF {
   403  			return nil
   404  		}
   405  		if err != nil {
   406  			return err
   407  		}
   408  
   409  		out := &rpb.ServerReflectionResponse{
   410  			ValidHost:       in.Host,
   411  			OriginalRequest: in,
   412  		}
   413  		switch req := in.MessageRequest.(type) {
   414  		case *rpb.ServerReflectionRequest_FileByFilename:
   415  			b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors)
   416  			if err != nil {
   417  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   418  					ErrorResponse: &rpb.ErrorResponse{
   419  						ErrorCode:    int32(codes.NotFound),
   420  						ErrorMessage: err.Error(),
   421  					},
   422  				}
   423  			} else {
   424  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   425  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   426  				}
   427  			}
   428  		case *rpb.ServerReflectionRequest_FileContainingSymbol:
   429  			b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
   430  			if err != nil {
   431  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   432  					ErrorResponse: &rpb.ErrorResponse{
   433  						ErrorCode:    int32(codes.NotFound),
   434  						ErrorMessage: err.Error(),
   435  					},
   436  				}
   437  			} else {
   438  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   439  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   440  				}
   441  			}
   442  		case *rpb.ServerReflectionRequest_FileContainingExtension:
   443  			typeName := req.FileContainingExtension.ContainingType
   444  			extNum := req.FileContainingExtension.ExtensionNumber
   445  			b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
   446  			if err != nil {
   447  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   448  					ErrorResponse: &rpb.ErrorResponse{
   449  						ErrorCode:    int32(codes.NotFound),
   450  						ErrorMessage: err.Error(),
   451  					},
   452  				}
   453  			} else {
   454  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   455  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   456  				}
   457  			}
   458  		case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
   459  			extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
   460  			if err != nil {
   461  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   462  					ErrorResponse: &rpb.ErrorResponse{
   463  						ErrorCode:    int32(codes.NotFound),
   464  						ErrorMessage: err.Error(),
   465  					},
   466  				}
   467  			} else {
   468  				out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
   469  					AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
   470  						BaseTypeName:    req.AllExtensionNumbersOfType,
   471  						ExtensionNumber: extNums,
   472  					},
   473  				}
   474  			}
   475  		case *rpb.ServerReflectionRequest_ListServices:
   476  			svcNames, _ := s.getSymbols()
   477  			serviceResponses := make([]*rpb.ServiceResponse, len(svcNames))
   478  			for i, n := range svcNames {
   479  				serviceResponses[i] = &rpb.ServiceResponse{
   480  					Name: n,
   481  				}
   482  			}
   483  			out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
   484  				ListServicesResponse: &rpb.ListServiceResponse{
   485  					Service: serviceResponses,
   486  				},
   487  			}
   488  		default:
   489  			return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
   490  		}
   491  
   492  		if err := stream.Send(out); err != nil {
   493  			return err
   494  		}
   495  	}
   496  }