github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/reflection/serverreflection.go (about) 1 /* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 /* 20 Package reflection implements server reflection service. 21 22 The service implemented is defined in: 23 https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. 24 25 To register server reflection on a gRPC server: 26 import "github.com/hxx258456/ccgo/grpc/reflection" 27 28 s := grpc.NewServer() 29 pb.RegisterYourOwnServer(s, &server{}) 30 31 // Register reflection service on gRPC server. 32 reflection.Register(s) 33 34 s.Serve(lis) 35 36 */ 37 package reflection // import "github.com/hxx258456/ccgo/grpc/reflection" 38 39 import ( 40 "bytes" 41 "compress/gzip" 42 "fmt" 43 "io" 44 "io/ioutil" 45 "reflect" 46 "sort" 47 "sync" 48 49 "github.com/golang/protobuf/proto" 50 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 51 grpc "github.com/hxx258456/ccgo/grpc" 52 "github.com/hxx258456/ccgo/grpc/codes" 53 rpb "github.com/hxx258456/ccgo/grpc/reflection/grpc_reflection_v1alpha" 54 "github.com/hxx258456/ccgo/grpc/status" 55 ) 56 57 // GRPCServer is the interface provided by a gRPC server. It is implemented by 58 // *grpc.Server, but could also be implemented by other concrete types. It acts 59 // as a registry, for accumulating the services exposed by the server. 60 type GRPCServer interface { 61 grpc.ServiceRegistrar 62 GetServiceInfo() map[string]grpc.ServiceInfo 63 } 64 65 var _ GRPCServer = (*grpc.Server)(nil) 66 67 type serverReflectionServer struct { 68 rpb.UnimplementedServerReflectionServer 69 s GRPCServer 70 71 initSymbols sync.Once 72 serviceNames []string 73 symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files 74 } 75 76 // Register registers the server reflection service on the given gRPC server. 77 func Register(s GRPCServer) { 78 rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ 79 s: s, 80 }) 81 } 82 83 // protoMessage is used for type assertion on proto messages. 84 // Generated proto message implements function Descriptor(), but Descriptor() 85 // is not part of interface proto.Message. This interface is needed to 86 // call Descriptor(). 87 type protoMessage interface { 88 Descriptor() ([]byte, []int) 89 } 90 91 func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { 92 s.initSymbols.Do(func() { 93 serviceInfo := s.s.GetServiceInfo() 94 95 s.symbols = map[string]*dpb.FileDescriptorProto{} 96 s.serviceNames = make([]string, 0, len(serviceInfo)) 97 processed := map[string]struct{}{} 98 for svc, info := range serviceInfo { 99 s.serviceNames = append(s.serviceNames, svc) 100 fdenc, ok := parseMetadata(info.Metadata) 101 if !ok { 102 continue 103 } 104 fd, err := decodeFileDesc(fdenc) 105 if err != nil { 106 continue 107 } 108 s.processFile(fd, processed) 109 } 110 sort.Strings(s.serviceNames) 111 }) 112 113 return s.serviceNames, s.symbols 114 } 115 116 func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { 117 filename := fd.GetName() 118 if _, ok := processed[filename]; ok { 119 return 120 } 121 processed[filename] = struct{}{} 122 123 prefix := fd.GetPackage() 124 125 for _, msg := range fd.MessageType { 126 s.processMessage(fd, prefix, msg) 127 } 128 for _, en := range fd.EnumType { 129 s.processEnum(fd, prefix, en) 130 } 131 for _, ext := range fd.Extension { 132 s.processField(fd, prefix, ext) 133 } 134 for _, svc := range fd.Service { 135 svcName := fqn(prefix, svc.GetName()) 136 s.symbols[svcName] = fd 137 for _, meth := range svc.Method { 138 name := fqn(svcName, meth.GetName()) 139 s.symbols[name] = fd 140 } 141 } 142 143 for _, dep := range fd.Dependency { 144 fdenc := proto.FileDescriptor(dep) 145 fdDep, err := decodeFileDesc(fdenc) 146 if err != nil { 147 continue 148 } 149 s.processFile(fdDep, processed) 150 } 151 } 152 153 func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { 154 msgName := fqn(prefix, msg.GetName()) 155 s.symbols[msgName] = fd 156 157 for _, nested := range msg.NestedType { 158 s.processMessage(fd, msgName, nested) 159 } 160 for _, en := range msg.EnumType { 161 s.processEnum(fd, msgName, en) 162 } 163 for _, ext := range msg.Extension { 164 s.processField(fd, msgName, ext) 165 } 166 for _, fld := range msg.Field { 167 s.processField(fd, msgName, fld) 168 } 169 for _, oneof := range msg.OneofDecl { 170 oneofName := fqn(msgName, oneof.GetName()) 171 s.symbols[oneofName] = fd 172 } 173 } 174 175 func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { 176 enName := fqn(prefix, en.GetName()) 177 s.symbols[enName] = fd 178 179 for _, val := range en.Value { 180 valName := fqn(enName, val.GetName()) 181 s.symbols[valName] = fd 182 } 183 } 184 185 func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { 186 fldName := fqn(prefix, fld.GetName()) 187 s.symbols[fldName] = fd 188 } 189 190 func fqn(prefix, name string) string { 191 if prefix == "" { 192 return name 193 } 194 return prefix + "." + name 195 } 196 197 // fileDescForType gets the file descriptor for the given type. 198 // The given type should be a proto message. 199 func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { 200 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage) 201 if !ok { 202 return nil, fmt.Errorf("failed to create message from type: %v", st) 203 } 204 enc, _ := m.Descriptor() 205 206 return decodeFileDesc(enc) 207 } 208 209 // decodeFileDesc does decompression and unmarshalling on the given 210 // file descriptor byte slice. 211 func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { 212 raw, err := decompress(enc) 213 if err != nil { 214 return nil, fmt.Errorf("failed to decompress enc: %v", err) 215 } 216 217 fd := new(dpb.FileDescriptorProto) 218 if err := proto.Unmarshal(raw, fd); err != nil { 219 return nil, fmt.Errorf("bad descriptor: %v", err) 220 } 221 return fd, nil 222 } 223 224 // decompress does gzip decompression. 225 func decompress(b []byte) ([]byte, error) { 226 r, err := gzip.NewReader(bytes.NewReader(b)) 227 if err != nil { 228 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 229 } 230 out, err := ioutil.ReadAll(r) 231 if err != nil { 232 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 233 } 234 return out, nil 235 } 236 237 func typeForName(name string) (reflect.Type, error) { 238 pt := proto.MessageType(name) 239 if pt == nil { 240 return nil, fmt.Errorf("unknown type: %q", name) 241 } 242 st := pt.Elem() 243 244 return st, nil 245 } 246 247 func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { 248 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 249 if !ok { 250 return nil, fmt.Errorf("failed to create message from type: %v", st) 251 } 252 253 var extDesc *proto.ExtensionDesc 254 for id, desc := range proto.RegisteredExtensions(m) { 255 if id == ext { 256 extDesc = desc 257 break 258 } 259 } 260 261 if extDesc == nil { 262 return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) 263 } 264 265 return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) 266 } 267 268 func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { 269 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 270 if !ok { 271 return nil, fmt.Errorf("failed to create message from type: %v", st) 272 } 273 274 exts := proto.RegisteredExtensions(m) 275 out := make([]int32, 0, len(exts)) 276 for id := range exts { 277 out = append(out, id) 278 } 279 return out, nil 280 } 281 282 // fileDescWithDependencies returns a slice of serialized fileDescriptors in 283 // wire format ([]byte). The fileDescriptors will include fd and all the 284 // transitive dependencies of fd with names not in sentFileDescriptors. 285 func fileDescWithDependencies(fd *dpb.FileDescriptorProto, sentFileDescriptors map[string]bool) ([][]byte, error) { 286 r := [][]byte{} 287 queue := []*dpb.FileDescriptorProto{fd} 288 for len(queue) > 0 { 289 currentfd := queue[0] 290 queue = queue[1:] 291 if sent := sentFileDescriptors[currentfd.GetName()]; len(r) == 0 || !sent { 292 sentFileDescriptors[currentfd.GetName()] = true 293 currentfdEncoded, err := proto.Marshal(currentfd) 294 if err != nil { 295 return nil, err 296 } 297 r = append(r, currentfdEncoded) 298 } 299 for _, dep := range currentfd.Dependency { 300 fdenc := proto.FileDescriptor(dep) 301 fdDep, err := decodeFileDesc(fdenc) 302 if err != nil { 303 continue 304 } 305 queue = append(queue, fdDep) 306 } 307 } 308 return r, nil 309 } 310 311 // fileDescEncodingByFilename finds the file descriptor for given filename, 312 // finds all of its previously unsent transitive dependencies, does marshalling 313 // on them, and returns the marshalled result. 314 func (s *serverReflectionServer) fileDescEncodingByFilename(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { 315 enc := proto.FileDescriptor(name) 316 if enc == nil { 317 return nil, fmt.Errorf("unknown file: %v", name) 318 } 319 fd, err := decodeFileDesc(enc) 320 if err != nil { 321 return nil, err 322 } 323 return fileDescWithDependencies(fd, sentFileDescriptors) 324 } 325 326 // parseMetadata finds the file descriptor bytes specified meta. 327 // For SupportPackageIsVersion4, m is the name of the proto file, we 328 // call proto.FileDescriptor to get the byte slice. 329 // For SupportPackageIsVersion3, m is a byte slice itself. 330 func parseMetadata(meta interface{}) ([]byte, bool) { 331 // Check if meta is the file name. 332 if fileNameForMeta, ok := meta.(string); ok { 333 return proto.FileDescriptor(fileNameForMeta), true 334 } 335 336 // Check if meta is the byte slice. 337 if enc, ok := meta.([]byte); ok { 338 return enc, true 339 } 340 341 return nil, false 342 } 343 344 // fileDescEncodingContainingSymbol finds the file descriptor containing the 345 // given symbol, finds all of its previously unsent transitive dependencies, 346 // does marshalling on them, and returns the marshalled result. The given symbol 347 // can be a type, a service or a method. 348 func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string, sentFileDescriptors map[string]bool) ([][]byte, error) { 349 _, symbols := s.getSymbols() 350 fd := symbols[name] 351 if fd == nil { 352 // Check if it's a type name that was not present in the 353 // transitive dependencies of the registered services. 354 if st, err := typeForName(name); err == nil { 355 fd, err = s.fileDescForType(st) 356 if err != nil { 357 return nil, err 358 } 359 } 360 } 361 362 if fd == nil { 363 return nil, fmt.Errorf("unknown symbol: %v", name) 364 } 365 366 return fileDescWithDependencies(fd, sentFileDescriptors) 367 } 368 369 // fileDescEncodingContainingExtension finds the file descriptor containing 370 // given extension, finds all of its previously unsent transitive dependencies, 371 // does marshalling on them, and returns the marshalled result. 372 func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32, sentFileDescriptors map[string]bool) ([][]byte, error) { 373 st, err := typeForName(typeName) 374 if err != nil { 375 return nil, err 376 } 377 fd, err := fileDescContainingExtension(st, extNum) 378 if err != nil { 379 return nil, err 380 } 381 return fileDescWithDependencies(fd, sentFileDescriptors) 382 } 383 384 // allExtensionNumbersForTypeName returns all extension numbers for the given type. 385 func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { 386 st, err := typeForName(name) 387 if err != nil { 388 return nil, err 389 } 390 extNums, err := s.allExtensionNumbersForType(st) 391 if err != nil { 392 return nil, err 393 } 394 return extNums, nil 395 } 396 397 // ServerReflectionInfo is the reflection service handler. 398 func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { 399 sentFileDescriptors := make(map[string]bool) 400 for { 401 in, err := stream.Recv() 402 if err == io.EOF { 403 return nil 404 } 405 if err != nil { 406 return err 407 } 408 409 out := &rpb.ServerReflectionResponse{ 410 ValidHost: in.Host, 411 OriginalRequest: in, 412 } 413 switch req := in.MessageRequest.(type) { 414 case *rpb.ServerReflectionRequest_FileByFilename: 415 b, err := s.fileDescEncodingByFilename(req.FileByFilename, sentFileDescriptors) 416 if err != nil { 417 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 418 ErrorResponse: &rpb.ErrorResponse{ 419 ErrorCode: int32(codes.NotFound), 420 ErrorMessage: err.Error(), 421 }, 422 } 423 } else { 424 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 425 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 426 } 427 } 428 case *rpb.ServerReflectionRequest_FileContainingSymbol: 429 b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol, sentFileDescriptors) 430 if err != nil { 431 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 432 ErrorResponse: &rpb.ErrorResponse{ 433 ErrorCode: int32(codes.NotFound), 434 ErrorMessage: err.Error(), 435 }, 436 } 437 } else { 438 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 439 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 440 } 441 } 442 case *rpb.ServerReflectionRequest_FileContainingExtension: 443 typeName := req.FileContainingExtension.ContainingType 444 extNum := req.FileContainingExtension.ExtensionNumber 445 b, err := s.fileDescEncodingContainingExtension(typeName, extNum, sentFileDescriptors) 446 if err != nil { 447 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 448 ErrorResponse: &rpb.ErrorResponse{ 449 ErrorCode: int32(codes.NotFound), 450 ErrorMessage: err.Error(), 451 }, 452 } 453 } else { 454 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 455 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: b}, 456 } 457 } 458 case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: 459 extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType) 460 if err != nil { 461 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 462 ErrorResponse: &rpb.ErrorResponse{ 463 ErrorCode: int32(codes.NotFound), 464 ErrorMessage: err.Error(), 465 }, 466 } 467 } else { 468 out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{ 469 AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{ 470 BaseTypeName: req.AllExtensionNumbersOfType, 471 ExtensionNumber: extNums, 472 }, 473 } 474 } 475 case *rpb.ServerReflectionRequest_ListServices: 476 svcNames, _ := s.getSymbols() 477 serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) 478 for i, n := range svcNames { 479 serviceResponses[i] = &rpb.ServiceResponse{ 480 Name: n, 481 } 482 } 483 out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ 484 ListServicesResponse: &rpb.ListServiceResponse{ 485 Service: serviceResponses, 486 }, 487 } 488 default: 489 return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest) 490 } 491 492 if err := stream.Send(out); err != nil { 493 return err 494 } 495 } 496 }