gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/reflection/serverreflection.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 /* 20 Package reflection implements server reflection service. 21 22 The service implemented is defined in: 23 https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. 24 25 To register server reflection on a gRPC server: 26 import "gitee.com/zhaochuninhefei/gmgo/grpc/reflection" 27 28 s := grpc.NewServer() 29 pb.RegisterYourOwnServer(s, &server{}) 30 31 // Register reflection service on gRPC server. 32 reflection.Register(s) 33 34 s.Serve(lis) 35 36 */ 37 package reflection // import "gitee.com/zhaochuninhefei/gmgo/grpc/reflection" 38 39 import ( 40 "bytes" 41 "compress/gzip" 42 "fmt" 43 "io" 44 "io/ioutil" 45 "reflect" 46 "sort" 47 "sync" 48 49 "gitee.com/zhaochuninhefei/gmgo/grpc" 50 "gitee.com/zhaochuninhefei/gmgo/grpc/codes" 51 rpb "gitee.com/zhaochuninhefei/gmgo/grpc/reflection/grpc_reflection_v1alpha" 52 "gitee.com/zhaochuninhefei/gmgo/grpc/status" 53 "github.com/golang/protobuf/proto" 54 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 55 ) 56 57 // GRPCServer is the interface provided by a gRPC server. It is implemented by 58 // *grpc.Server, but could also be implemented by other concrete types. It acts 59 // as a registry, for accumulating the services exposed by the server. 60 type GRPCServer interface { 61 grpc.ServiceRegistrar 62 GetServiceInfo() map[string]grpc.ServiceInfo 63 } 64 65 var _ GRPCServer = (*grpc.Server)(nil) 66 67 type serverReflectionServer struct { 68 rpb.UnimplementedServerReflectionServer 69 s GRPCServer 70 71 initSymbols sync.Once 72 serviceNames []string 73 symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files 74 } 75 76 // Register registers the server reflection service on the given gRPC server. 77 func Register(s GRPCServer) { 78 rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ 79 s: s, 80 }) 81 } 82 83 // protoMessage is used for type assertion on proto messages. 84 // Generated proto message implements function Descriptor(), but Descriptor() 85 // is not part of interface proto.Message. This interface is needed to 86 // call Descriptor(). 87 type protoMessage interface { 88 Descriptor() ([]byte, []int) 89 } 90 91 func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { 92 s.initSymbols.Do(func() { 93 serviceInfo := s.s.GetServiceInfo() 94 95 s.symbols = map[string]*dpb.FileDescriptorProto{} 96 s.serviceNames = make([]string, 0, len(serviceInfo)) 97 processed := map[string]struct{}{} 98 for svc, info := range serviceInfo { 99 s.serviceNames = append(s.serviceNames, svc) 100 fdenc, ok := parseMetadata(info.Metadata) 101 if !ok { 102 continue 103 } 104 fd, err := decodeFileDesc(fdenc) 105 if err != nil { 106 continue 107 } 108 s.processFile(fd, processed) 109 } 110 sort.Strings(s.serviceNames) 111 }) 112 113 return s.serviceNames, s.symbols 114 } 115 116 func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { 117 filename := fd.GetName() 118 if _, ok := processed[filename]; ok { 119 return 120 } 121 processed[filename] = struct{}{} 122 123 prefix := fd.GetPackage() 124 125 for _, msg := range fd.MessageType { 126 s.processMessage(fd, prefix, msg) 127 } 128 for _, en := range fd.EnumType { 129 s.processEnum(fd, prefix, en) 130 } 131 for _, ext := range fd.Extension { 132 s.processField(fd, prefix, ext) 133 } 134 for _, svc := range fd.Service { 135 svcName := fqn(prefix, svc.GetName()) 136 s.symbols[svcName] = fd 137 for _, meth := range svc.Method { 138 name := fqn(svcName, meth.GetName()) 139 s.symbols[name] = fd 140 } 141 } 142 143 for _, dep := range fd.Dependency { 144 //goland:noinspection GoDeprecation 145 fdenc := proto.FileDescriptor(dep) 146 fdDep, err := decodeFileDesc(fdenc) 147 if err != nil { 148 continue 149 } 150 s.processFile(fdDep, processed) 151 } 152 } 153 154 func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { 155 msgName := fqn(prefix, msg.GetName()) 156 s.symbols[msgName] = fd 157 158 for _, nested := range msg.NestedType { 159 s.processMessage(fd, msgName, nested) 160 } 161 for _, en := range msg.EnumType { 162 s.processEnum(fd, msgName, en) 163 } 164 for _, ext := range msg.Extension { 165 s.processField(fd, msgName, ext) 166 } 167 for _, fld := range msg.Field { 168 s.processField(fd, msgName, fld) 169 } 170 for _, oneof := range msg.OneofDecl { 171 oneofName := fqn(msgName, oneof.GetName()) 172 s.symbols[oneofName] = fd 173 } 174 } 175 176 func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { 177 enName := fqn(prefix, en.GetName()) 178 s.symbols[enName] = fd 179 180 for _, val := range en.Value { 181 valName := fqn(enName, val.GetName()) 182 s.symbols[valName] = fd 183 } 184 } 185 186 func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { 187 fldName := fqn(prefix, fld.GetName()) 188 s.symbols[fldName] = fd 189 } 190 191 func fqn(prefix, name string) string { 192 if prefix == "" { 193 return name 194 } 195 return prefix + "." + name 196 } 197 198 // fileDescForType gets the file descriptor for the given type. 199 // The given type should be a proto message. 200 func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { 201 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage) 202 if !ok { 203 return nil, fmt.Errorf("failed to create message from type: %v", st) 204 } 205 enc, _ := m.Descriptor() 206 207 return decodeFileDesc(enc) 208 } 209 210 // decodeFileDesc does decompression and unmarshalling on the given 211 // file descriptor byte slice. 212 func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { 213 raw, err := decompress(enc) 214 if err != nil { 215 return nil, fmt.Errorf("failed to decompress enc: %v", err) 216 } 217 218 fd := new(dpb.FileDescriptorProto) 219 if err := proto.Unmarshal(raw, fd); err != nil { 220 return nil, fmt.Errorf("bad descriptor: %v", err) 221 } 222 return fd, nil 223 } 224 225 // decompress does gzip decompression. 226 func decompress(b []byte) ([]byte, error) { 227 r, err := gzip.NewReader(bytes.NewReader(b)) 228 if err != nil { 229 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 230 } 231 out, err := ioutil.ReadAll(r) 232 if err != nil { 233 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 234 } 235 return out, nil 236 } 237 238 func typeForName(name string) (reflect.Type, error) { 239 //goland:noinspection GoDeprecation 240 pt := proto.MessageType(name) 241 if pt == nil { 242 return nil, fmt.Errorf("unknown type: %q", name) 243 } 244 st := pt.Elem() 245 246 return st, nil 247 } 248 249 func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { 250 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 251 if !ok { 252 return nil, fmt.Errorf("failed to create message from type: %v", st) 253 } 254 255 var extDesc *proto.ExtensionDesc 256 //goland:noinspection GoDeprecation 257 for id, desc := range proto.RegisteredExtensions(m) { 258 if id == ext { 259 extDesc = desc 260 break 261 } 262 } 263 264 if extDesc == nil { 265 return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) 266 } 267 268 //goland:noinspection GoDeprecation 269 return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) 270 } 271 272 func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { 273 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 274 if !ok { 275 return nil, fmt.Errorf("failed to create message from type: %v", st) 276 } 277 278 //goland:noinspection GoDeprecation 279 exts := proto.RegisteredExtensions(m) 280 out := make([]int32, 0, len(exts)) 281 for id := range exts { 282 out = append(out, id) 283 } 284 return out, nil 285 } 286 287 // fileDescWithDependencies returns a slice of serialized fileDescriptors in 288 // wire format ([]byte). The fileDescriptors will include fd and all the 289 // transitive dependencies of fd with names not in sentFileDescriptors. 290 func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { 291 var r [][]byte 292 queue := []*dpb.FileDescriptorProto{fd} 293 for len(queue) > 0 { 294 currentfd := queue[0] 295 queue = queue[1:] 296 if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { 297 sentFileDescriptors[currentfd.GetName()] = true 298 currentfdEncoded, err := proto.Marshal(currentfd) 299 if err != nil { 300 return nil, err 301 } 302 r = append(r, currentfdEncoded) 303 } 304 for _, dep := range currentfd.Dependency { 305 //goland:noinspection GoDeprecation 306 fdenc := proto.FileDescriptor(dep) 307 fdDep, err := decodeFileDesc(fdenc) 308 if err != nil { 309 continue 310 } 311 queue = append(queue, fdDep) 312 } 313 } 314 return r, nil 315 } 316 317 // fileDescEncodingByFilename finds the file descriptor for given filename, 318 // finds all of its previously unsent transitive dependencies, does marshalling 319 // on them, and returns the marshalled result. 320 func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { 321 //goland:noinspection GoDeprecation 322 enc := proto.FileDescriptor(name) 323 if enc == nil { 324 return nil, fmt.Errorf("unknown file: %v", name) 325 } 326 fd, err := decodeFileDesc(enc) 327 if err != nil { 328 return nil, err 329 } 330 return fileDescWithDependencies(fd, sentFileDescriptors) 331 } 332 333 // parseMetadata finds the file descriptor bytes specified meta. 334 // For SupportPackageIsVersion4, m is the name of the proto file, we 335 // call proto.FileDescriptor to get the byte slice. 336 // For SupportPackageIsVersion3, m is a byte slice itself. 337 func parseMetadata(meta interface{}) ([]byte, bool) { 338 // Check if meta is the file name. 339 if fileNameForMeta, ok := meta.(string); ok { 340 //goland:noinspection GoDeprecation 341 return proto.FileDescriptor(fileNameForMeta), true 342 } 343 344 // Check if meta is the byte slice. 345 if enc, ok := meta.([]byte); ok { 346 return enc, true 347 } 348 349 return nil, false 350 } 351 352 // fileDescEncodingContainingSymbol finds the file descriptor containing the 353 // given symbol, finds all of its previously unsent transitive dependencies, 354 // does marshalling on them, and returns the marshalled result. The given symbol 355 // can be a type, a service or a method. 356 func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { 357 _, symbols := s.getSymbols() 358 fd := symbols[name] 359 if fd == nil { 360 // Check if it's a type name that was not present in the 361 // transitive dependencies of the registered services. 362 if st, err := typeForName(name); err == nil { 363 fd, err = s.fileDescForType(st) 364 if err != nil { 365 return nil, err 366 } 367 } 368 } 369 370 if fd == nil { 371 return nil, fmt.Errorf("unknown symbol: %v", name) 372 } 373 374 return fileDescWithDependencies(fd, sentFileDescriptors) 375 } 376 377 // fileDescEncodingContainingExtension finds the file descriptor containing 378 // given extension, finds all of its previously unsent transitive dependencies, 379 // does marshalling on them, and returns the marshalled result. 380 func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { 381 st, err := typeForName(typeName) 382 if err != nil { 383 return nil, err 384 } 385 fd, err := fileDescContainingExtension(st, extNum) 386 if err != nil { 387 return nil, err 388 } 389 return fileDescWithDependencies(fd, sentFileDescriptors) 390 } 391 392 // allExtensionNumbersForTypeName returns all extension numbers for the given type. 393 func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { 394 st, err := typeForName(name) 395 if err != nil { 396 return nil, err 397 } 398 extNums, err := s.allExtensionNumbersForType(st) 399 if err != nil { 400 return nil, err 401 } 402 return extNums, nil 403 } 404 405 // ServerReflectionInfo is the reflection service handler. 406 func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { 407 sentFileDescriptors := make(map[string]bool) 408 for { 409 in, err := stream.Recv() 410 if err == io.EOF { 411 return nil 412 } 413 if err != nil { 414 return err 415 } 416 417 out := &rpb.ServerReflectionResponse{ 418 ValidHost: in.Host, 419 OriginalRequest: in, 420 } 421 switch req := in.MessageRequest.(type) { 422 case *rpb.ServerReflectionRequest_FileByFilename: 423 b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) 424 if err != nil { 425 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 426 ErrorResponse: &rpb.ErrorResponse{ 427 ErrorCode: int32(codes.NotFound), 428 ErrorMessage: err.Error(), 429 }, 430 } 431 } else { 432 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 433 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 434 } 435 } 436 case *rpb.ServerReflectionRequest_FileContainingSymbol: 437 b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors) 438 if err != nil { 439 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 440 ErrorResponse: &rpb.ErrorResponse{ 441 ErrorCode: int32(codes.NotFound), 442 ErrorMessage: err.Error(), 443 }, 444 } 445 } else { 446 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 447 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 448 } 449 } 450 case *rpb.ServerReflectionRequest_FileContainingExtension: 451 typeName := req.FileContainingExtension.ContainingType 452 extNum := req.FileContainingExtension.ExtensionNumber 453 b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors) 454 if err != nil { 455 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 456 ErrorResponse: &rpb.ErrorResponse{ 457 ErrorCode: int32(codes.NotFound), 458 ErrorMessage: err.Error(), 459 }, 460 } 461 } else { 462 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 463 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 464 } 465 } 466 case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: 467 extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType) 468 if err != nil { 469 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 470 ErrorResponse: &rpb.ErrorResponse{ 471 ErrorCode: int32(codes.NotFound), 472 ErrorMessage: err.Error(), 473 }, 474 } 475 } else { 476 out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{ 477 AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{ 478 BaseTypeName: req.AllExtensionNumbersOfType, 479 ExtensionNumber: extNums, 480 }, 481 } 482 } 483 case *rpb.ServerReflectionRequest_ListServices: 484 svcNames, _ := s.getSymbols() 485 serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) 486 for i, n := range svcNames { 487 serviceResponses[i] = &rpb.ServiceResponse{ 488 Name: n, 489 } 490 } 491 out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ 492 ListServicesResponse: &rpb.ListServiceResponse{ 493 Service: serviceResponses, 494 }, 495 } 496 default: 497 return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest) 498 } 499 500 if err := stream.Send(out); err != nil { 501 return err 502 } 503 } 504 }