github.com/Finschia/finschia-sdk@v0.48.1/server/grpc/gogoreflection/serverreflection.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 /* 20 Package reflection implements server reflection service. 21 22 The service implemented is defined in: 23 https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. 24 25 To register server reflection on a gRPC server: 26 27 import "google.golang.org/grpc/reflection" 28 29 s := grpc.NewServer() 30 pb.RegisterYourOwnServer(s, &server{}) 31 32 // Register reflection service on gRPC server. 33 reflection.Register(s) 34 35 s.Serve(lis) 36 */ 37 package gogoreflection // import "google.golang.org/grpc/reflection" 38 39 import ( 40 "bytes" 41 "compress/gzip" 42 "fmt" 43 "io" 44 "log" 45 "reflect" 46 "sort" 47 "sync" 48 49 //nolint: staticcheck 50 "github.com/golang/protobuf/proto" 51 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 52 "google.golang.org/grpc" 53 "google.golang.org/grpc/codes" 54 rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 55 "google.golang.org/grpc/status" 56 ) 57 58 type serverReflectionServer struct { 59 rpb.UnimplementedServerReflectionServer 60 s *grpc.Server 61 62 initSymbols sync.Once 63 serviceNames []string 64 symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files 65 } 66 67 // Register registers the server reflection service on the given gRPC server. 68 func Register(s *grpc.Server) { 69 rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ 70 s: s, 71 }) 72 } 73 74 // protoMessage is used for type assertion on proto messages. 75 // Generated proto message implements function Descriptor(), but Descriptor() 76 // is not part of interface proto.Message. This interface is needed to 77 // call Descriptor(). 78 type protoMessage interface { 79 Descriptor() ([]byte, []int) 80 } 81 82 func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { 83 s.initSymbols.Do(func() { 84 serviceInfo := s.s.GetServiceInfo() 85 86 s.symbols = map[string]*dpb.FileDescriptorProto{} 87 s.serviceNames = make([]string, 0, len(serviceInfo)) 88 processed := map[string]struct{}{} 89 for svc, info := range serviceInfo { 90 s.serviceNames = append(s.serviceNames, svc) 91 fdenc, ok := parseMetadata(info.Metadata) 92 if !ok { 93 continue 94 } 95 fd, err := decodeFileDesc(fdenc) 96 if err != nil { 97 continue 98 } 99 s.processFile(fd, processed) 100 } 101 sort.Strings(s.serviceNames) 102 }) 103 104 return s.serviceNames, s.symbols 105 } 106 107 func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { 108 filename := fd.GetName() 109 if _, ok := processed[filename]; ok { 110 return 111 } 112 processed[filename] = struct{}{} 113 114 prefix := fd.GetPackage() 115 116 for _, msg := range fd.MessageType { 117 s.processMessage(fd, prefix, msg) 118 } 119 for _, en := range fd.EnumType { 120 s.processEnum(fd, prefix, en) 121 } 122 for _, ext := range fd.Extension { 123 s.processField(fd, prefix, ext) 124 } 125 for _, svc := range fd.Service { 126 svcName := fqn(prefix, svc.GetName()) 127 s.symbols[svcName] = fd 128 for _, meth := range svc.Method { 129 name := fqn(svcName, meth.GetName()) 130 s.symbols[name] = fd 131 } 132 } 133 134 for _, dep := range fd.Dependency { 135 fdenc := getFileDescriptor(dep) 136 fdDep, err := decodeFileDesc(fdenc) 137 if err != nil { 138 continue 139 } 140 s.processFile(fdDep, processed) 141 } 142 } 143 144 func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { 145 msgName := fqn(prefix, msg.GetName()) 146 s.symbols[msgName] = fd 147 148 for _, nested := range msg.NestedType { 149 s.processMessage(fd, msgName, nested) 150 } 151 for _, en := range msg.EnumType { 152 s.processEnum(fd, msgName, en) 153 } 154 for _, ext := range msg.Extension { 155 s.processField(fd, msgName, ext) 156 } 157 for _, fld := range msg.Field { 158 s.processField(fd, msgName, fld) 159 } 160 for _, oneof := range msg.OneofDecl { 161 oneofName := fqn(msgName, oneof.GetName()) 162 s.symbols[oneofName] = fd 163 } 164 } 165 166 func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { 167 enName := fqn(prefix, en.GetName()) 168 s.symbols[enName] = fd 169 170 for _, val := range en.Value { 171 valName := fqn(enName, val.GetName()) 172 s.symbols[valName] = fd 173 } 174 } 175 176 func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { 177 fldName := fqn(prefix, fld.GetName()) 178 s.symbols[fldName] = fd 179 } 180 181 func fqn(prefix, name string) string { 182 if prefix == "" { 183 return name 184 } 185 return prefix + "." + name 186 } 187 188 // fileDescForType gets the file descriptor for the given type. 189 // The given type should be a proto message. 190 func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { 191 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage) 192 if !ok { 193 return nil, fmt.Errorf("failed to create message from type: %v", st) 194 } 195 enc, _ := m.Descriptor() 196 197 return decodeFileDesc(enc) 198 } 199 200 // decodeFileDesc does decompression and unmarshalling on the given 201 // file descriptor byte slice. 202 func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { 203 raw, err := decompress(enc) 204 if err != nil { 205 return nil, fmt.Errorf("failed to decompress enc: %v", err) 206 } 207 208 fd := new(dpb.FileDescriptorProto) 209 if err := proto.Unmarshal(raw, fd); err != nil { 210 return nil, fmt.Errorf("bad descriptor: %v", err) 211 } 212 return fd, nil 213 } 214 215 // decompress does gzip decompression. 216 func decompress(b []byte) ([]byte, error) { 217 r, err := gzip.NewReader(bytes.NewReader(b)) 218 if err != nil { 219 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 220 } 221 out, err := io.ReadAll(r) 222 if err != nil { 223 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 224 } 225 return out, nil 226 } 227 228 func typeForName(name string) (reflect.Type, error) { 229 pt := getMessageType(name) 230 if pt == nil { 231 return nil, fmt.Errorf("unknown type: %q", name) 232 } 233 st := pt.Elem() 234 235 return st, nil 236 } 237 238 func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { 239 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 240 if !ok { 241 return nil, fmt.Errorf("failed to create message from type: %v", st) 242 } 243 244 extDesc := getExtension(ext, m) 245 246 if extDesc == nil { 247 return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) 248 } 249 250 return decodeFileDesc(getFileDescriptor(extDesc.Filename)) 251 } 252 253 func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { 254 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 255 if !ok { 256 return nil, fmt.Errorf("failed to create message from type: %v", st) 257 } 258 259 out := getExtensionsNumbers(m) 260 return out, nil 261 } 262 263 // fileDescWithDependencies returns a slice of serialized fileDescriptors in 264 // wire format ([]byte). The fileDescriptors will include fd and all the 265 // transitive dependencies of fd with names not in sentFileDescriptors. 266 func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { 267 r := [][]byte{} 268 queue := []*dpb.FileDescriptorProto{fd} 269 for len(queue) > 0 { 270 currentfd := queue[0] 271 queue = queue[1:] 272 if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { 273 sentFileDescriptors[currentfd.GetName()] = true 274 currentfdEncoded, err := proto.Marshal(currentfd) 275 if err != nil { 276 return nil, err 277 } 278 r = append(r, currentfdEncoded) 279 } 280 for _, dep := range currentfd.Dependency { 281 fdenc := getFileDescriptor(dep) 282 fdDep, err := decodeFileDesc(fdenc) 283 if err != nil { 284 continue 285 } 286 queue = append(queue, fdDep) 287 } 288 } 289 return r, nil 290 } 291 292 // fileDescEncodingByFilename finds the file descriptor for given filename, 293 // finds all of its previously unsent transitive dependencies, does marshalling 294 // on them, and returns the marshalled result. 295 func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { 296 enc := getFileDescriptor(name) 297 if enc == nil { 298 return nil, fmt.Errorf("unknown file: %v", name) 299 } 300 fd, err := decodeFileDesc(enc) 301 if err != nil { 302 return nil, err 303 } 304 return fileDescWithDependencies(fd, sentFileDescriptors) 305 } 306 307 // parseMetadata finds the file descriptor bytes specified meta. 308 // For SupportPackageIsVersion4, m is the name of the proto file, we 309 // call proto.FileDescriptor to get the byte slice. 310 // For SupportPackageIsVersion3, m is a byte slice itself. 311 func parseMetadata(meta interface{}) ([]byte, bool) { 312 // Check if meta is the file name. 313 if fileNameForMeta, ok := meta.(string); ok { 314 return getFileDescriptor(fileNameForMeta), true 315 } 316 317 // Check if meta is the byte slice. 318 if enc, ok := meta.([]byte); ok { 319 return enc, true 320 } 321 322 return nil, false 323 } 324 325 // fileDescEncodingContainingSymbol finds the file descriptor containing the 326 // given symbol, finds all of its previously unsent transitive dependencies, 327 // does marshalling on them, and returns the marshalled result. The given symbol 328 // can be a type, a service or a method. 329 func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { 330 _, symbols := s.getSymbols() 331 fd := symbols[name] 332 if fd == nil { 333 // Check if it's a type name that was not present in the 334 // transitive dependencies of the registered services. 335 if st, err := typeForName(name); err == nil { 336 fd, err = s.fileDescForType(st) 337 if err != nil { 338 return nil, err 339 } 340 } 341 } 342 343 if fd == nil { 344 return nil, fmt.Errorf("unknown symbol: %v", name) 345 } 346 347 return fileDescWithDependencies(fd, sentFileDescriptors) 348 } 349 350 // fileDescEncodingContainingExtension finds the file descriptor containing 351 // given extension, finds all of its previously unsent transitive dependencies, 352 // does marshalling on them, and returns the marshalled result. 353 func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { 354 st, err := typeForName(typeName) 355 if err != nil { 356 return nil, err 357 } 358 fd, err := fileDescContainingExtension(st, extNum) 359 if err != nil { 360 return nil, err 361 } 362 return fileDescWithDependencies(fd, sentFileDescriptors) 363 } 364 365 // allExtensionNumbersForTypeName returns all extension numbers for the given type. 366 func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { 367 st, err := typeForName(name) 368 if err != nil { 369 return nil, err 370 } 371 extNums, err := s.allExtensionNumbersForType(st) 372 if err != nil { 373 return nil, err 374 } 375 return extNums, nil 376 } 377 378 // ServerReflectionInfo is the reflection service handler. 379 func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { 380 sentFileDescriptors := make(map[string]bool) 381 for { 382 in, err := stream.Recv() 383 if err == io.EOF { 384 return nil 385 } 386 if err != nil { 387 return err 388 } 389 390 out := &rpb.ServerReflectionResponse{ 391 ValidHost: in.Host, 392 OriginalRequest: in, 393 } 394 switch req := in.MessageRequest.(type) { 395 case *rpb.ServerReflectionRequest_FileByFilename: 396 b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) 397 if err != nil { 398 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 399 ErrorResponse: &rpb.ErrorResponse{ 400 ErrorCode: int32(codes.NotFound), 401 ErrorMessage: err.Error(), 402 }, 403 } 404 } else { 405 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 406 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 407 } 408 } 409 case *rpb.ServerReflectionRequest_FileContainingSymbol: 410 b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors) 411 if err != nil { 412 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 413 ErrorResponse: &rpb.ErrorResponse{ 414 ErrorCode: int32(codes.NotFound), 415 ErrorMessage: err.Error(), 416 }, 417 } 418 } else { 419 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 420 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 421 } 422 } 423 case *rpb.ServerReflectionRequest_FileContainingExtension: 424 typeName := req.FileContainingExtension.ContainingType 425 extNum := req.FileContainingExtension.ExtensionNumber 426 b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors) 427 if err != nil { 428 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 429 ErrorResponse: &rpb.ErrorResponse{ 430 ErrorCode: int32(codes.NotFound), 431 ErrorMessage: err.Error(), 432 }, 433 } 434 } else { 435 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 436 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 437 } 438 } 439 case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: 440 extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType) 441 if err != nil { 442 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 443 ErrorResponse: &rpb.ErrorResponse{ 444 ErrorCode: int32(codes.NotFound), 445 ErrorMessage: err.Error(), 446 }, 447 } 448 log.Printf("OH NO: %s", err) 449 } else { 450 out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{ 451 AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{ 452 BaseTypeName: req.AllExtensionNumbersOfType, 453 ExtensionNumber: extNums, 454 }, 455 } 456 } 457 case *rpb.ServerReflectionRequest_ListServices: 458 svcNames, _ := s.getSymbols() 459 serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) 460 for i, n := range svcNames { 461 serviceResponses[i] = &rpb.ServiceResponse{ 462 Name: n, 463 } 464 } 465 out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ 466 ListServicesResponse: &rpb.ListServiceResponse{ 467 Service: serviceResponses, 468 }, 469 } 470 default: 471 return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest) 472 } 473 if err := stream.Send(out); err != nil { 474 return err 475 } 476 } 477 }