github.com/xiaoshude/protoreflect@v1.16.1-0.20220310024924-8c94d7247598/grpcreflect/client.go (about) 1 package grpcreflect 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "reflect" 8 "runtime" 9 "sync" 10 11 "github.com/golang/protobuf/proto" 12 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 13 "golang.org/x/net/context" 14 "google.golang.org/grpc/codes" 15 rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 16 "google.golang.org/grpc/status" 17 18 "github.com/xiaoshude/protoreflect/desc" 19 "github.com/xiaoshude/protoreflect/internal" 20 ) 21 22 // elementNotFoundError is the error returned by reflective operations where the 23 // server does not recognize a given file name, symbol name, or extension. 24 type elementNotFoundError struct { 25 name string 26 kind elementKind 27 symType symbolType // only used when kind == elementKindSymbol 28 tag int32 // only used when kind == elementKindExtension 29 30 // only errors with a kind of elementKindFile will have a cause, which means 31 // the named file count not be resolved because of a dependency that could 32 // not be found where cause describes the missing dependency 33 cause *elementNotFoundError 34 } 35 36 type elementKind int 37 38 const ( 39 elementKindSymbol elementKind = iota 40 elementKindFile 41 elementKindExtension 42 ) 43 44 type symbolType string 45 46 const ( 47 symbolTypeService = "Service" 48 symbolTypeMessage = "Message" 49 symbolTypeEnum = "Enum" 50 symbolTypeUnknown = "Symbol" 51 ) 52 53 func symbolNotFound(symbol string, symType symbolType, cause *elementNotFoundError) error { 54 return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol, cause: cause} 55 } 56 57 func extensionNotFound(extendee string, tag int32, cause *elementNotFoundError) error { 58 return &elementNotFoundError{name: extendee, tag: tag, kind: elementKindExtension, cause: cause} 59 } 60 61 func fileNotFound(file string, cause *elementNotFoundError) error { 62 return &elementNotFoundError{name: file, kind: elementKindFile, cause: cause} 63 } 64 65 func (e *elementNotFoundError) Error() string { 66 first := true 67 var b bytes.Buffer 68 for ; e != nil; e = e.cause { 69 if first { 70 first = false 71 } else { 72 fmt.Fprint(&b, "\ncaused by: ") 73 } 74 switch e.kind { 75 case elementKindSymbol: 76 fmt.Fprintf(&b, "%s not found: %s", e.symType, e.name) 77 case elementKindExtension: 78 fmt.Fprintf(&b, "Extension not found: tag %d for %s", e.tag, e.name) 79 default: 80 fmt.Fprintf(&b, "File not found: %s", e.name) 81 } 82 } 83 return b.String() 84 } 85 86 // IsElementNotFoundError determines if the given error indicates that a file 87 // name, symbol name, or extension field was could not be found by the server. 88 func IsElementNotFoundError(err error) bool { 89 _, ok := err.(*elementNotFoundError) 90 return ok 91 } 92 93 // ProtocolError is an error returned when the server sends a response of the 94 // wrong type. 95 type ProtocolError struct { 96 missingType reflect.Type 97 } 98 99 func (p ProtocolError) Error() string { 100 return fmt.Sprintf("Protocol error: response was missing %v", p.missingType) 101 } 102 103 type extDesc struct { 104 extendedMessageName string 105 extensionNumber int32 106 } 107 108 // Client is a client connection to a server for performing reflection calls 109 // and resolving remote symbols. 110 type Client struct { 111 ctx context.Context 112 stub rpb.ServerReflectionClient 113 114 connMu sync.Mutex 115 cancel context.CancelFunc 116 stream rpb.ServerReflection_ServerReflectionInfoClient 117 118 cacheMu sync.RWMutex 119 protosByName map[string]*dpb.FileDescriptorProto 120 filesByName map[string]*desc.FileDescriptor 121 filesBySymbol map[string]*desc.FileDescriptor 122 filesByExtension map[extDesc]*desc.FileDescriptor 123 } 124 125 // NewClient creates a new Client with the given root context and using the 126 // given RPC stub for talking to the server. 127 func NewClient(ctx context.Context, stub rpb.ServerReflectionClient) *Client { 128 cr := &Client{ 129 ctx: ctx, 130 stub: stub, 131 protosByName: map[string]*dpb.FileDescriptorProto{}, 132 filesByName: map[string]*desc.FileDescriptor{}, 133 filesBySymbol: map[string]*desc.FileDescriptor{}, 134 filesByExtension: map[extDesc]*desc.FileDescriptor{}, 135 } 136 // don't leak a grpc stream 137 runtime.SetFinalizer(cr, (*Client).Reset) 138 return cr 139 } 140 141 // FileByFilename asks the server for a file descriptor for the proto file with 142 // the given name. 143 func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) { 144 // hit the cache first 145 cr.cacheMu.RLock() 146 if fd, ok := cr.filesByName[filename]; ok { 147 cr.cacheMu.RUnlock() 148 return fd, nil 149 } 150 fdp, ok := cr.protosByName[filename] 151 cr.cacheMu.RUnlock() 152 // not there? see if we've downloaded the proto 153 if ok { 154 return cr.descriptorFromProto(fdp) 155 } 156 157 req := &rpb.ServerReflectionRequest{ 158 MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ 159 FileByFilename: filename, 160 }, 161 } 162 fd, err := cr.getAndCacheFileDescriptors(req, filename, "") 163 if isNotFound(err) { 164 // file not found? see if we can look up via alternate name 165 if alternate, ok := internal.StdFileAliases[filename]; ok { 166 req := &rpb.ServerReflectionRequest{ 167 MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ 168 FileByFilename: alternate, 169 }, 170 } 171 fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename) 172 if isNotFound(err) { 173 err = fileNotFound(filename, nil) 174 } 175 } else { 176 err = fileNotFound(filename, nil) 177 } 178 } else if e, ok := err.(*elementNotFoundError); ok { 179 err = fileNotFound(filename, e) 180 } 181 return fd, err 182 } 183 184 // FileContainingSymbol asks the server for a file descriptor for the proto file 185 // that declares the given fully-qualified symbol. 186 func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) { 187 // hit the cache first 188 cr.cacheMu.RLock() 189 fd, ok := cr.filesBySymbol[symbol] 190 cr.cacheMu.RUnlock() 191 if ok { 192 return fd, nil 193 } 194 195 req := &rpb.ServerReflectionRequest{ 196 MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{ 197 FileContainingSymbol: symbol, 198 }, 199 } 200 fd, err := cr.getAndCacheFileDescriptors(req, "", "") 201 if isNotFound(err) { 202 err = symbolNotFound(symbol, symbolTypeUnknown, nil) 203 } else if e, ok := err.(*elementNotFoundError); ok { 204 err = symbolNotFound(symbol, symbolTypeUnknown, e) 205 } 206 return fd, err 207 } 208 209 // FileContainingExtension asks the server for a file descriptor for the proto 210 // file that declares an extension with the given number for the given 211 // fully-qualified message name. 212 func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) { 213 // hit the cache first 214 cr.cacheMu.RLock() 215 fd, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}] 216 cr.cacheMu.RUnlock() 217 if ok { 218 return fd, nil 219 } 220 221 req := &rpb.ServerReflectionRequest{ 222 MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{ 223 FileContainingExtension: &rpb.ExtensionRequest{ 224 ContainingType: extendedMessageName, 225 ExtensionNumber: extensionNumber, 226 }, 227 }, 228 } 229 fd, err := cr.getAndCacheFileDescriptors(req, "", "") 230 if isNotFound(err) { 231 err = extensionNotFound(extendedMessageName, extensionNumber, nil) 232 } else if e, ok := err.(*elementNotFoundError); ok { 233 err = extensionNotFound(extendedMessageName, extensionNumber, e) 234 } 235 return fd, err 236 } 237 238 func (cr *Client) getAndCacheFileDescriptors(req *rpb.ServerReflectionRequest, expectedName, alias string) (*desc.FileDescriptor, error) { 239 resp, err := cr.send(req) 240 if err != nil { 241 return nil, err 242 } 243 244 fdResp := resp.GetFileDescriptorResponse() 245 if fdResp == nil { 246 return nil, &ProtocolError{reflect.TypeOf(fdResp).Elem()} 247 } 248 249 // Response can contain the result file descriptor, but also its transitive 250 // deps. Furthermore, protocol states that subsequent requests do not need 251 // to send transitive deps that have been sent in prior responses. So we 252 // need to cache all file descriptors and then return the first one (which 253 // should be the answer). If we're looking for a file by name, we can be 254 // smarter and make sure to grab one by name instead of just grabbing the 255 // first one. 256 var firstFd *dpb.FileDescriptorProto 257 for _, fdBytes := range fdResp.FileDescriptorProto { 258 fd := &dpb.FileDescriptorProto{} 259 if err = proto.Unmarshal(fdBytes, fd); err != nil { 260 return nil, err 261 } 262 263 if expectedName != "" && alias != "" && expectedName != alias && fd.GetName() == expectedName { 264 // we found a file was aliased, so we need to update the proto to reflect that 265 fd.Name = proto.String(alias) 266 } 267 268 cr.cacheMu.Lock() 269 // see if this file was created and cached concurrently 270 if firstFd == nil { 271 if d, ok := cr.filesByName[fd.GetName()]; ok { 272 cr.cacheMu.Unlock() 273 return d, nil 274 } 275 } 276 // store in cache of raw descriptor protos, but don't overwrite existing protos 277 if existingFd, ok := cr.protosByName[fd.GetName()]; ok { 278 fd = existingFd 279 } else { 280 cr.protosByName[fd.GetName()] = fd 281 } 282 cr.cacheMu.Unlock() 283 if firstFd == nil { 284 firstFd = fd 285 } 286 } 287 if firstFd == nil { 288 return nil, &ProtocolError{reflect.TypeOf(firstFd).Elem()} 289 } 290 291 return cr.descriptorFromProto(firstFd) 292 } 293 294 func (cr *Client) descriptorFromProto(fd *dpb.FileDescriptorProto) (*desc.FileDescriptor, error) { 295 deps := make([]*desc.FileDescriptor, len(fd.GetDependency())) 296 for i, depName := range fd.GetDependency() { 297 if dep, err := cr.FileByFilename(depName); err != nil { 298 return nil, err 299 } else { 300 deps[i] = dep 301 } 302 } 303 d, err := desc.CreateFileDescriptor(fd, deps...) 304 if err != nil { 305 return nil, err 306 } 307 d = cr.cacheFile(d) 308 return d, nil 309 } 310 311 func (cr *Client) cacheFile(fd *desc.FileDescriptor) *desc.FileDescriptor { 312 cr.cacheMu.Lock() 313 defer cr.cacheMu.Unlock() 314 315 // cache file descriptor by name, but don't overwrite existing entry 316 // (existing entry could come from concurrent caller) 317 if existingFd, ok := cr.filesByName[fd.GetName()]; ok { 318 return existingFd 319 } 320 cr.filesByName[fd.GetName()] = fd 321 322 // also cache by symbols and extensions 323 for _, m := range fd.GetMessageTypes() { 324 cr.cacheMessageLocked(fd, m) 325 } 326 for _, e := range fd.GetEnumTypes() { 327 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd 328 for _, v := range e.GetValues() { 329 cr.filesBySymbol[v.GetFullyQualifiedName()] = fd 330 } 331 } 332 for _, e := range fd.GetExtensions() { 333 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd 334 cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd 335 } 336 for _, s := range fd.GetServices() { 337 cr.filesBySymbol[s.GetFullyQualifiedName()] = fd 338 for _, m := range s.GetMethods() { 339 cr.filesBySymbol[m.GetFullyQualifiedName()] = fd 340 } 341 } 342 343 return fd 344 } 345 346 func (cr *Client) cacheMessageLocked(fd *desc.FileDescriptor, md *desc.MessageDescriptor) { 347 cr.filesBySymbol[md.GetFullyQualifiedName()] = fd 348 for _, f := range md.GetFields() { 349 cr.filesBySymbol[f.GetFullyQualifiedName()] = fd 350 } 351 for _, o := range md.GetOneOfs() { 352 cr.filesBySymbol[o.GetFullyQualifiedName()] = fd 353 } 354 for _, e := range md.GetNestedEnumTypes() { 355 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd 356 for _, v := range e.GetValues() { 357 cr.filesBySymbol[v.GetFullyQualifiedName()] = fd 358 } 359 } 360 for _, e := range md.GetNestedExtensions() { 361 cr.filesBySymbol[e.GetFullyQualifiedName()] = fd 362 cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd 363 } 364 for _, m := range md.GetNestedMessageTypes() { 365 cr.cacheMessageLocked(fd, m) // recurse 366 } 367 } 368 369 // AllExtensionNumbersForType asks the server for all known extension numbers 370 // for the given fully-qualified message name. 371 func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) { 372 req := &rpb.ServerReflectionRequest{ 373 MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{ 374 AllExtensionNumbersOfType: extendedMessageName, 375 }, 376 } 377 resp, err := cr.send(req) 378 if err != nil { 379 if isNotFound(err) { 380 return nil, symbolNotFound(extendedMessageName, symbolTypeMessage, nil) 381 } 382 return nil, err 383 } 384 385 extResp := resp.GetAllExtensionNumbersResponse() 386 if extResp == nil { 387 return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()} 388 } 389 return extResp.ExtensionNumber, nil 390 } 391 392 // ListServices asks the server for the fully-qualified names of all exposed 393 // services. 394 func (cr *Client) ListServices() ([]string, error) { 395 req := &rpb.ServerReflectionRequest{ 396 MessageRequest: &rpb.ServerReflectionRequest_ListServices{ 397 // proto doesn't indicate any purpose for this value and server impl 398 // doesn't actually use it... 399 ListServices: "*", 400 }, 401 } 402 resp, err := cr.send(req) 403 if err != nil { 404 return nil, err 405 } 406 407 listResp := resp.GetListServicesResponse() 408 if listResp == nil { 409 return nil, &ProtocolError{reflect.TypeOf(listResp).Elem()} 410 } 411 serviceNames := make([]string, len(listResp.Service)) 412 for i, s := range listResp.Service { 413 serviceNames[i] = s.Name 414 } 415 return serviceNames, nil 416 } 417 418 func (cr *Client) send(req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) { 419 // we allow one immediate retry, in case we have a stale stream 420 // (e.g. closed by server) 421 resp, err := cr.doSend(true, req) 422 if err != nil { 423 return nil, err 424 } 425 426 // convert error response messages into errors 427 errResp := resp.GetErrorResponse() 428 if errResp != nil { 429 return nil, status.Errorf(codes.Code(errResp.ErrorCode), "%s", errResp.ErrorMessage) 430 } 431 432 return resp, nil 433 } 434 435 func isNotFound(err error) bool { 436 if err == nil { 437 return false 438 } 439 s, ok := status.FromError(err) 440 return ok && s.Code() == codes.NotFound 441 } 442 443 func (cr *Client) doSend(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) { 444 // TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery 445 // (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus 446 // delivered in correct oder. 447 cr.connMu.Lock() 448 defer cr.connMu.Unlock() 449 return cr.doSendLocked(retry, req) 450 } 451 452 func (cr *Client) doSendLocked(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) { 453 if err := cr.initStreamLocked(); err != nil { 454 return nil, err 455 } 456 457 if err := cr.stream.Send(req); err != nil { 458 if err == io.EOF { 459 // if send returns EOF, must call Recv to get real underlying error 460 _, err = cr.stream.Recv() 461 } 462 cr.resetLocked() 463 if retry { 464 return cr.doSendLocked(false, req) 465 } 466 return nil, err 467 } 468 469 if resp, err := cr.stream.Recv(); err != nil { 470 cr.resetLocked() 471 if retry { 472 return cr.doSendLocked(false, req) 473 } 474 return nil, err 475 } else { 476 return resp, nil 477 } 478 } 479 480 func (cr *Client) initStreamLocked() error { 481 if cr.stream != nil { 482 return nil 483 } 484 var newCtx context.Context 485 newCtx, cr.cancel = context.WithCancel(cr.ctx) 486 var err error 487 cr.stream, err = cr.stub.ServerReflectionInfo(newCtx) 488 return err 489 } 490 491 // Reset ensures that any active stream with the server is closed, releasing any 492 // resources. 493 func (cr *Client) Reset() { 494 cr.connMu.Lock() 495 defer cr.connMu.Unlock() 496 cr.resetLocked() 497 } 498 499 func (cr *Client) resetLocked() { 500 if cr.stream != nil { 501 cr.stream.CloseSend() 502 for { 503 // drain the stream, this covers io.EOF too 504 if _, err := cr.stream.Recv(); err != nil { 505 break 506 } 507 } 508 cr.stream = nil 509 } 510 if cr.cancel != nil { 511 cr.cancel() 512 cr.cancel = nil 513 } 514 } 515 516 // ResolveService asks the server to resolve the given fully-qualified service 517 // name into a service descriptor. 518 func (cr *Client) ResolveService(serviceName string) (*desc.ServiceDescriptor, error) { 519 file, err := cr.FileContainingSymbol(serviceName) 520 if err != nil { 521 return nil, setSymbolType(err, serviceName, symbolTypeService) 522 } 523 d := file.FindSymbol(serviceName) 524 if d == nil { 525 return nil, symbolNotFound(serviceName, symbolTypeService, nil) 526 } 527 if s, ok := d.(*desc.ServiceDescriptor); ok { 528 return s, nil 529 } else { 530 return nil, symbolNotFound(serviceName, symbolTypeService, nil) 531 } 532 } 533 534 // ResolveMessage asks the server to resolve the given fully-qualified message 535 // name into a message descriptor. 536 func (cr *Client) ResolveMessage(messageName string) (*desc.MessageDescriptor, error) { 537 file, err := cr.FileContainingSymbol(messageName) 538 if err != nil { 539 return nil, setSymbolType(err, messageName, symbolTypeMessage) 540 } 541 d := file.FindSymbol(messageName) 542 if d == nil { 543 return nil, symbolNotFound(messageName, symbolTypeMessage, nil) 544 } 545 if s, ok := d.(*desc.MessageDescriptor); ok { 546 return s, nil 547 } else { 548 return nil, symbolNotFound(messageName, symbolTypeMessage, nil) 549 } 550 } 551 552 // ResolveEnum asks the server to resolve the given fully-qualified enum name 553 // into an enum descriptor. 554 func (cr *Client) ResolveEnum(enumName string) (*desc.EnumDescriptor, error) { 555 file, err := cr.FileContainingSymbol(enumName) 556 if err != nil { 557 return nil, setSymbolType(err, enumName, symbolTypeEnum) 558 } 559 d := file.FindSymbol(enumName) 560 if d == nil { 561 return nil, symbolNotFound(enumName, symbolTypeEnum, nil) 562 } 563 if s, ok := d.(*desc.EnumDescriptor); ok { 564 return s, nil 565 } else { 566 return nil, symbolNotFound(enumName, symbolTypeEnum, nil) 567 } 568 } 569 570 func setSymbolType(err error, name string, symType symbolType) error { 571 if e, ok := err.(*elementNotFoundError); ok { 572 if e.kind == elementKindSymbol && e.name == name && e.symType == symbolTypeUnknown { 573 e.symType = symType 574 } 575 } 576 return err 577 } 578 579 // ResolveEnumValues asks the server to resolve the given fully-qualified enum 580 // name into a map of names to numbers that represents the enum's values. 581 func (cr *Client) ResolveEnumValues(enumName string) (map[string]int32, error) { 582 enumDesc, err := cr.ResolveEnum(enumName) 583 if err != nil { 584 return nil, err 585 } 586 vals := map[string]int32{} 587 for _, valDesc := range enumDesc.GetValues() { 588 vals[valDesc.GetName()] = valDesc.GetNumber() 589 } 590 return vals, nil 591 } 592 593 // ResolveExtension asks the server to resolve the given extension number and 594 // fully-qualified message name into a field descriptor. 595 func (cr *Client) ResolveExtension(extendedType string, extensionNumber int32) (*desc.FieldDescriptor, error) { 596 file, err := cr.FileContainingExtension(extendedType, extensionNumber) 597 if err != nil { 598 return nil, err 599 } 600 d := findExtension(extendedType, extensionNumber, fileDescriptorExtensions{file}) 601 if d == nil { 602 return nil, extensionNotFound(extendedType, extensionNumber, nil) 603 } else { 604 return d, nil 605 } 606 } 607 608 func findExtension(extendedType string, extensionNumber int32, scope extensionScope) *desc.FieldDescriptor { 609 // search extensions in this scope 610 for _, ext := range scope.extensions() { 611 if ext.GetNumber() == extensionNumber && ext.GetOwner().GetFullyQualifiedName() == extendedType { 612 return ext 613 } 614 } 615 616 // if not found, search nested scopes 617 for _, nested := range scope.nestedScopes() { 618 ext := findExtension(extendedType, extensionNumber, nested) 619 if ext != nil { 620 return ext 621 } 622 } 623 624 return nil 625 } 626 627 type extensionScope interface { 628 extensions() []*desc.FieldDescriptor 629 nestedScopes() []extensionScope 630 } 631 632 // fileDescriptorExtensions implements extensionHolder interface on top of 633 // FileDescriptorProto 634 type fileDescriptorExtensions struct { 635 proto *desc.FileDescriptor 636 } 637 638 func (fde fileDescriptorExtensions) extensions() []*desc.FieldDescriptor { 639 return fde.proto.GetExtensions() 640 } 641 642 func (fde fileDescriptorExtensions) nestedScopes() []extensionScope { 643 scopes := make([]extensionScope, len(fde.proto.GetMessageTypes())) 644 for i, m := range fde.proto.GetMessageTypes() { 645 scopes[i] = msgDescriptorExtensions{m} 646 } 647 return scopes 648 } 649 650 // msgDescriptorExtensions implements extensionHolder interface on top of 651 // DescriptorProto 652 type msgDescriptorExtensions struct { 653 proto *desc.MessageDescriptor 654 } 655 656 func (mde msgDescriptorExtensions) extensions() []*desc.FieldDescriptor { 657 return mde.proto.GetNestedExtensions() 658 } 659 660 func (mde msgDescriptorExtensions) nestedScopes() []extensionScope { 661 scopes := make([]extensionScope, len(mde.proto.GetNestedMessageTypes())) 662 for i, m := range mde.proto.GetNestedMessageTypes() { 663 scopes[i] = msgDescriptorExtensions{m} 664 } 665 return scopes 666 }