github.com/Finschia/finschia-sdk@v0.48.1/server/grpc/gogoreflection/serverreflection.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  /*
    20  Package reflection implements server reflection service.
    21  
    22  The service implemented is defined in:
    23  https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
    24  
    25  To register server reflection on a gRPC server:
    26  
    27  	import "google.golang.org/grpc/reflection"
    28  
    29  	s := grpc.NewServer()
    30  	pb.RegisterYourOwnServer(s, &server{})
    31  
    32  	// Register reflection service on gRPC server.
    33  	reflection.Register(s)
    34  
    35  	s.Serve(lis)
    36  */
    37  package gogoreflection // import "google.golang.org/grpc/reflection"
    38  
    39  import (
    40  	"bytes"
    41  	"compress/gzip"
    42  	"fmt"
    43  	"io"
    44  	"log"
    45  	"reflect"
    46  	"sort"
    47  	"sync"
    48  
    49  	//nolint: staticcheck
    50  	"github.com/golang/protobuf/proto"
    51  	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
    52  	"google.golang.org/grpc"
    53  	"google.golang.org/grpc/codes"
    54  	rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    55  	"google.golang.org/grpc/status"
    56  )
    57  
    58  type serverReflectionServer struct {
    59  	rpb.UnimplementedServerReflectionServer
    60  	s *grpc.Server
    61  
    62  	initSymbols  sync.Once
    63  	serviceNames []string
    64  	symbols      map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files
    65  }
    66  
    67  // Register registers the server reflection service on the given gRPC server.
    68  func Register(s *grpc.Server) {
    69  	rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
    70  		s: s,
    71  	})
    72  }
    73  
    74  // protoMessage is used for type assertion on proto messages.
    75  // Generated proto message implements function Descriptor(), but Descriptor()
    76  // is not part of interface proto.Message. This interface is needed to
    77  // call Descriptor().
    78  type protoMessage interface {
    79  	Descriptor() ([]byte, []int)
    80  }
    81  
    82  func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) {
    83  	s.initSymbols.Do(func() {
    84  		serviceInfo := s.s.GetServiceInfo()
    85  
    86  		s.symbols = map[string]*dpb.FileDescriptorProto{}
    87  		s.serviceNames = make([]string, 0, len(serviceInfo))
    88  		processed := map[string]struct{}{}
    89  		for svc, info := range serviceInfo {
    90  			s.serviceNames = append(s.serviceNames, svc)
    91  			fdenc, ok := parseMetadata(info.Metadata)
    92  			if !ok {
    93  				continue
    94  			}
    95  			fd, err := decodeFileDesc(fdenc)
    96  			if err != nil {
    97  				continue
    98  			}
    99  			s.processFile(fd, processed)
   100  		}
   101  		sort.Strings(s.serviceNames)
   102  	})
   103  
   104  	return s.serviceNames, s.symbols
   105  }
   106  
   107  func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) {
   108  	filename := fd.GetName()
   109  	if _, ok := processed[filename]; ok {
   110  		return
   111  	}
   112  	processed[filename] = struct{}{}
   113  
   114  	prefix := fd.GetPackage()
   115  
   116  	for _, msg := range fd.MessageType {
   117  		s.processMessage(fd, prefix, msg)
   118  	}
   119  	for _, en := range fd.EnumType {
   120  		s.processEnum(fd, prefix, en)
   121  	}
   122  	for _, ext := range fd.Extension {
   123  		s.processField(fd, prefix, ext)
   124  	}
   125  	for _, svc := range fd.Service {
   126  		svcName := fqn(prefix, svc.GetName())
   127  		s.symbols[svcName] = fd
   128  		for _, meth := range svc.Method {
   129  			name := fqn(svcName, meth.GetName())
   130  			s.symbols[name] = fd
   131  		}
   132  	}
   133  
   134  	for _, dep := range fd.Dependency {
   135  		fdenc := getFileDescriptor(dep)
   136  		fdDep, err := decodeFileDesc(fdenc)
   137  		if err != nil {
   138  			continue
   139  		}
   140  		s.processFile(fdDep, processed)
   141  	}
   142  }
   143  
   144  func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) {
   145  	msgName := fqn(prefix, msg.GetName())
   146  	s.symbols[msgName] = fd
   147  
   148  	for _, nested := range msg.NestedType {
   149  		s.processMessage(fd, msgName, nested)
   150  	}
   151  	for _, en := range msg.EnumType {
   152  		s.processEnum(fd, msgName, en)
   153  	}
   154  	for _, ext := range msg.Extension {
   155  		s.processField(fd, msgName, ext)
   156  	}
   157  	for _, fld := range msg.Field {
   158  		s.processField(fd, msgName, fld)
   159  	}
   160  	for _, oneof := range msg.OneofDecl {
   161  		oneofName := fqn(msgName, oneof.GetName())
   162  		s.symbols[oneofName] = fd
   163  	}
   164  }
   165  
   166  func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) {
   167  	enName := fqn(prefix, en.GetName())
   168  	s.symbols[enName] = fd
   169  
   170  	for _, val := range en.Value {
   171  		valName := fqn(enName, val.GetName())
   172  		s.symbols[valName] = fd
   173  	}
   174  }
   175  
   176  func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) {
   177  	fldName := fqn(prefix, fld.GetName())
   178  	s.symbols[fldName] = fd
   179  }
   180  
   181  func fqn(prefix, name string) string {
   182  	if prefix == "" {
   183  		return name
   184  	}
   185  	return prefix + "." + name
   186  }
   187  
   188  // fileDescForType gets the file descriptor for the given type.
   189  // The given type should be a proto message.
   190  func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
   191  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
   192  	if !ok {
   193  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   194  	}
   195  	enc, _ := m.Descriptor()
   196  
   197  	return decodeFileDesc(enc)
   198  }
   199  
   200  // decodeFileDesc does decompression and unmarshalling on the given
   201  // file descriptor byte slice.
   202  func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
   203  	raw, err := decompress(enc)
   204  	if err != nil {
   205  		return nil, fmt.Errorf("failed to decompress enc: %v", err)
   206  	}
   207  
   208  	fd := new(dpb.FileDescriptorProto)
   209  	if err := proto.Unmarshal(raw, fd); err != nil {
   210  		return nil, fmt.Errorf("bad descriptor: %v", err)
   211  	}
   212  	return fd, nil
   213  }
   214  
   215  // decompress does gzip decompression.
   216  func decompress(b []byte) ([]byte, error) {
   217  	r, err := gzip.NewReader(bytes.NewReader(b))
   218  	if err != nil {
   219  		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
   220  	}
   221  	out, err := io.ReadAll(r)
   222  	if err != nil {
   223  		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
   224  	}
   225  	return out, nil
   226  }
   227  
   228  func typeForName(name string) (reflect.Type, error) {
   229  	pt := getMessageType(name)
   230  	if pt == nil {
   231  		return nil, fmt.Errorf("unknown type: %q", name)
   232  	}
   233  	st := pt.Elem()
   234  
   235  	return st, nil
   236  }
   237  
   238  func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
   239  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
   240  	if !ok {
   241  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   242  	}
   243  
   244  	extDesc := getExtension(ext, m)
   245  
   246  	if extDesc == nil {
   247  		return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
   248  	}
   249  
   250  	return decodeFileDesc(getFileDescriptor(extDesc.Filename))
   251  }
   252  
   253  func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
   254  	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
   255  	if !ok {
   256  		return nil, fmt.Errorf("failed to create message from type: %v", st)
   257  	}
   258  
   259  	out := getExtensionsNumbers(m)
   260  	return out, nil
   261  }
   262  
   263  // fileDescWithDependencies returns a slice of serialized fileDescriptors in
   264  // wire format ([]byte). The fileDescriptors will include fd and all the
   265  // transitive dependencies of fd with names not in sentFileDescriptors.
   266  func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) {
   267  	r := [][]byte{}
   268  	queue := []*dpb.FileDescriptorProto{fd}
   269  	for len(queue) > 0 {
   270  		currentfd := queue[0]
   271  		queue = queue[1:]
   272  		if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent {
   273  			sentFileDescriptors[currentfd.GetName()] = true
   274  			currentfdEncoded, err := proto.Marshal(currentfd)
   275  			if err != nil {
   276  				return nil, err
   277  			}
   278  			r = append(r, currentfdEncoded)
   279  		}
   280  		for _, dep := range currentfd.Dependency {
   281  			fdenc := getFileDescriptor(dep)
   282  			fdDep, err := decodeFileDesc(fdenc)
   283  			if err != nil {
   284  				continue
   285  			}
   286  			queue = append(queue, fdDep)
   287  		}
   288  	}
   289  	return r, nil
   290  }
   291  
   292  // fileDescEncodingByFilename finds the file descriptor for given filename,
   293  // finds all of its previously unsent transitive dependencies, does marshalling
   294  // on them, and returns the marshalled result.
   295  func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
   296  	enc := getFileDescriptor(name)
   297  	if enc == nil {
   298  		return nil, fmt.Errorf("unknown file: %v", name)
   299  	}
   300  	fd, err := decodeFileDesc(enc)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  	return fileDescWithDependencies(fd, sentFileDescriptors)
   305  }
   306  
   307  // parseMetadata finds the file descriptor bytes specified meta.
   308  // For SupportPackageIsVersion4, m is the name of the proto file, we
   309  // call proto.FileDescriptor to get the byte slice.
   310  // For SupportPackageIsVersion3, m is a byte slice itself.
   311  func parseMetadata(meta interface{}) ([]byte, bool) {
   312  	// Check if meta is the file name.
   313  	if fileNameForMeta, ok := meta.(string); ok {
   314  		return getFileDescriptor(fileNameForMeta), true
   315  	}
   316  
   317  	// Check if meta is the byte slice.
   318  	if enc, ok := meta.([]byte); ok {
   319  		return enc, true
   320  	}
   321  
   322  	return nil, false
   323  }
   324  
   325  // fileDescEncodingContainingSymbol finds the file descriptor containing the
   326  // given symbol, finds all of its previously unsent transitive dependencies,
   327  // does marshalling on them, and returns the marshalled result. The given symbol
   328  // can be a type, a service or a method.
   329  func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) {
   330  	_, symbols := s.getSymbols()
   331  	fd := symbols[name]
   332  	if fd == nil {
   333  		// Check if it's a type name that was not present in the
   334  		// transitive dependencies of the registered services.
   335  		if st, err := typeForName(name); err == nil {
   336  			fd, err = s.fileDescForType(st)
   337  			if err != nil {
   338  				return nil, err
   339  			}
   340  		}
   341  	}
   342  
   343  	if fd == nil {
   344  		return nil, fmt.Errorf("unknown symbol: %v", name)
   345  	}
   346  
   347  	return fileDescWithDependencies(fd, sentFileDescriptors)
   348  }
   349  
   350  // fileDescEncodingContainingExtension finds the file descriptor containing
   351  // given extension, finds all of its previously unsent transitive dependencies,
   352  // does marshalling on them, and returns the marshalled result.
   353  func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) {
   354  	st, err := typeForName(typeName)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  	fd, err := fileDescContainingExtension(st, extNum)
   359  	if err != nil {
   360  		return nil, err
   361  	}
   362  	return fileDescWithDependencies(fd, sentFileDescriptors)
   363  }
   364  
   365  // allExtensionNumbersForTypeName returns all extension numbers for the given type.
   366  func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
   367  	st, err := typeForName(name)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  	extNums, err := s.allExtensionNumbersForType(st)
   372  	if err != nil {
   373  		return nil, err
   374  	}
   375  	return extNums, nil
   376  }
   377  
   378  // ServerReflectionInfo is the reflection service handler.
   379  func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
   380  	sentFileDescriptors := make(map[string]bool)
   381  	for {
   382  		in, err := stream.Recv()
   383  		if err == io.EOF {
   384  			return nil
   385  		}
   386  		if err != nil {
   387  			return err
   388  		}
   389  
   390  		out := &rpb.ServerReflectionResponse{
   391  			ValidHost:       in.Host,
   392  			OriginalRequest: in,
   393  		}
   394  		switch req := in.MessageRequest.(type) {
   395  		case *rpb.ServerReflectionRequest_FileByFilename:
   396  			b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors)
   397  			if err != nil {
   398  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   399  					ErrorResponse: &rpb.ErrorResponse{
   400  						ErrorCode:    int32(codes.NotFound),
   401  						ErrorMessage: err.Error(),
   402  					},
   403  				}
   404  			} else {
   405  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   406  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   407  				}
   408  			}
   409  		case *rpb.ServerReflectionRequest_FileContainingSymbol:
   410  			b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors)
   411  			if err != nil {
   412  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   413  					ErrorResponse: &rpb.ErrorResponse{
   414  						ErrorCode:    int32(codes.NotFound),
   415  						ErrorMessage: err.Error(),
   416  					},
   417  				}
   418  			} else {
   419  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   420  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   421  				}
   422  			}
   423  		case *rpb.ServerReflectionRequest_FileContainingExtension:
   424  			typeName := req.FileContainingExtension.ContainingType
   425  			extNum := req.FileContainingExtension.ExtensionNumber
   426  			b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors)
   427  			if err != nil {
   428  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   429  					ErrorResponse: &rpb.ErrorResponse{
   430  						ErrorCode:    int32(codes.NotFound),
   431  						ErrorMessage: err.Error(),
   432  					},
   433  				}
   434  			} else {
   435  				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
   436  					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b},
   437  				}
   438  			}
   439  		case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
   440  			extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
   441  			if err != nil {
   442  				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
   443  					ErrorResponse: &rpb.ErrorResponse{
   444  						ErrorCode:    int32(codes.NotFound),
   445  						ErrorMessage: err.Error(),
   446  					},
   447  				}
   448  				log.Printf("OH NO: %s", err)
   449  			} else {
   450  				out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
   451  					AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
   452  						BaseTypeName:    req.AllExtensionNumbersOfType,
   453  						ExtensionNumber: extNums,
   454  					},
   455  				}
   456  			}
   457  		case *rpb.ServerReflectionRequest_ListServices:
   458  			svcNames, _ := s.getSymbols()
   459  			serviceResponses := make([]*rpb.ServiceResponse, len(svcNames))
   460  			for i, n := range svcNames {
   461  				serviceResponses[i] = &rpb.ServiceResponse{
   462  					Name: n,
   463  				}
   464  			}
   465  			out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
   466  				ListServicesResponse: &rpb.ListServiceResponse{
   467  					Service: serviceResponses,
   468  				},
   469  			}
   470  		default:
   471  			return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
   472  		}
   473  		if err := stream.Send(out); err != nil {
   474  			return err
   475  		}
   476  	}
   477  }