github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/mux.go (about) 1 // Copyright 2021 Edward McFarlane. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package larking 6 7 import ( 8 "bytes" 9 "compress/gzip" 10 "context" 11 "crypto/sha256" 12 "fmt" 13 "io" 14 "io/ioutil" 15 "math/rand" 16 "net/http" 17 "strings" 18 "sync" 19 "sync/atomic" 20 "time" 21 22 "google.golang.org/genproto/googleapis/api/annotations" 23 "google.golang.org/genproto/googleapis/api/httpbody" 24 "google.golang.org/grpc" 25 "google.golang.org/grpc/codes" 26 "google.golang.org/grpc/encoding" 27 "google.golang.org/grpc/metadata" 28 rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 29 "google.golang.org/grpc/stats" 30 "google.golang.org/grpc/status" 31 "google.golang.org/protobuf/encoding/protojson" 32 "google.golang.org/protobuf/proto" 33 "google.golang.org/protobuf/reflect/protodesc" 34 "google.golang.org/protobuf/reflect/protoreflect" 35 "google.golang.org/protobuf/reflect/protoregistry" 36 "google.golang.org/protobuf/types/descriptorpb" 37 "google.golang.org/protobuf/types/dynamicpb" 38 "nhooyr.io/websocket" 39 ) 40 41 // RO 42 type connList struct { 43 handlers []*handler 44 fdHash []byte 45 } 46 47 type state struct { 48 path *path 49 conns map[*grpc.ClientConn]connList 50 handlers map[string][]*handler 51 } 52 53 func (s *state) clone() *state { 54 if s == nil { 55 return &state{ 56 path: newPath(), 57 conns: make(map[*grpc.ClientConn]connList), 58 handlers: make(map[string][]*handler), 59 } 60 } 61 62 conns := make(map[*grpc.ClientConn]connList) 63 for conn, cl := range s.conns { 64 conns[conn] = cl 65 } 66 67 handlers := make(map[string][]*handler) 68 for method, hds := range s.handlers { 69 handlers[method] = hds 70 } 71 72 return &state{ 73 path: s.path.clone(), 74 conns: conns, 75 handlers: handlers, 76 } 77 } 78 79 type muxOptions struct { 80 maxReceiveMessageSize int 81 maxSendMessageSize int 82 connectionTimeout time.Duration 83 files *protoregistry.Files 84 types protoregistry.MessageTypeResolver 85 unaryInterceptor grpc.UnaryServerInterceptor 86 streamInterceptor grpc.StreamServerInterceptor 87 statsHandler stats.Handler 88 } 89 90 func (o *muxOptions) readAll(r io.Reader) ([]byte, error) { 91 b, err := ioutil.ReadAll(io.LimitReader(r, int64(o.maxReceiveMessageSize)+1)) 92 if err != nil { 93 return nil, err 94 } 95 if len(b) > o.maxReceiveMessageSize { 96 return nil, fmt.Errorf("max receive message size reached") 97 } 98 return b, nil 99 } 100 func (o *muxOptions) writeAll(dst io.Writer, b []byte) error { 101 if len(b) > o.maxSendMessageSize { 102 return fmt.Errorf("max send message size reached") 103 } 104 src := bytes.NewReader(b) 105 _, err := io.Copy(dst, src) 106 return err 107 } 108 109 // unary is a nil-safe interceptor unary call. 110 func (o *muxOptions) unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 111 if ui := o.unaryInterceptor; ui != nil { 112 return ui(ctx, req, info, handler) 113 } 114 return handler(ctx, req) 115 } 116 117 // stream is a nil-safe interceptor stream call. 118 func (o *muxOptions) stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 119 if si := o.streamInterceptor; si != nil { 120 return si(srv, ss, info, handler) 121 } 122 return handler(srv, ss) 123 } 124 125 type MuxOption func(*muxOptions) 126 127 var defaultMuxOptions = muxOptions{ 128 maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, 129 maxSendMessageSize: defaultServerMaxSendMessageSize, 130 connectionTimeout: defaultServerConnectionTimeout, 131 files: protoregistry.GlobalFiles, 132 types: protoregistry.GlobalTypes, 133 } 134 135 func UnaryServerInterceptorOption(interceptor grpc.UnaryServerInterceptor) MuxOption { 136 return func(opts *muxOptions) { opts.unaryInterceptor = interceptor } 137 } 138 139 func StreamServerInterceptorOption(interceptor grpc.StreamServerInterceptor) MuxOption { 140 return func(opts *muxOptions) { opts.streamInterceptor = interceptor } 141 } 142 143 func StatsOption(h stats.Handler) MuxOption { 144 return func(opts *muxOptions) { opts.statsHandler = h } 145 } 146 147 type Mux struct { 148 opts muxOptions 149 //events trace.EventLog TODO 150 mu sync.Mutex // Lock to sync writers 151 state atomic.Value // Value of *state 152 153 // services is a list of registered services 154 services map[*grpc.ServiceDesc]interface{} 155 } 156 157 func NewMux(opts ...MuxOption) (*Mux, error) { 158 // Apply options. 159 var muxOpts = defaultMuxOptions 160 for _, opt := range opts { 161 opt(&muxOpts) 162 } 163 164 return &Mux{ 165 opts: muxOpts, 166 }, nil 167 } 168 169 func (m *Mux) RegisterConn(ctx context.Context, cc *grpc.ClientConn) error { 170 c := rpb.NewServerReflectionClient(cc) 171 172 // TODO: watch the stream. When it is recreated refresh the service 173 // methods and recreate the mux if needed. 174 stream, err := c.ServerReflectionInfo(ctx, grpc.WaitForReady(true)) 175 if err != nil { 176 return err 177 } 178 179 // Load the state for writing. 180 m.mu.Lock() 181 defer m.mu.Unlock() 182 s := m.loadState().clone() 183 184 if err := s.addConnHandler(cc, stream); err != nil { 185 return err 186 } 187 188 m.storeState(s) 189 190 return stream.CloseSend() 191 } 192 193 func (m *Mux) DropConn(ctx context.Context, cc *grpc.ClientConn) bool { 194 // Load the state for writing. 195 m.mu.Lock() 196 defer m.mu.Unlock() 197 s := m.loadState().clone() 198 199 return s.removeHandler(cc) 200 } 201 202 // resolver implements protodesc.Resolver. 203 type resolver struct { 204 files protoregistry.Files 205 stream rpb.ServerReflection_ServerReflectionInfoClient 206 } 207 208 func newResolver(stream rpb.ServerReflection_ServerReflectionInfoClient) (*resolver, error) { 209 r := &resolver{stream: stream} 210 211 if err := r.files.RegisterFile(annotations.File_google_api_annotations_proto); err != nil { 212 return nil, err 213 } 214 if err := r.files.RegisterFile(annotations.File_google_api_http_proto); err != nil { 215 return nil, err 216 } 217 if err := r.files.RegisterFile(httpbody.File_google_api_httpbody_proto); err != nil { 218 return nil, err 219 } 220 return r, nil 221 } 222 223 func (r *resolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) { 224 if fd, err := r.files.FindFileByPath(path); err == nil { 225 return fd, nil // found file 226 } 227 228 if err := r.stream.Send(&rpb.ServerReflectionRequest{ 229 MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ 230 FileByFilename: path, 231 }, 232 }); err != nil { 233 return nil, err 234 } 235 236 fdr, err := r.stream.Recv() 237 if err != nil { 238 return nil, err 239 } 240 fdbs := fdr.GetFileDescriptorResponse().GetFileDescriptorProto() 241 242 var f protoreflect.FileDescriptor 243 for _, fdb := range fdbs { 244 fdp := &descriptorpb.FileDescriptorProto{} 245 if err := proto.Unmarshal(fdb, fdp); err != nil { 246 return nil, err 247 } 248 249 file, err := protodesc.NewFile(fdp, r) 250 if err != nil { 251 return nil, err 252 } 253 // TODO: check duplicate file registry 254 if err := r.files.RegisterFile(file); err != nil { 255 return nil, err 256 } 257 if file.Path() == path { 258 f = file 259 } 260 } 261 if f == nil { 262 return nil, fmt.Errorf("missing file descriptor %s", path) 263 } 264 return f, nil 265 } 266 267 func (r *resolver) FindDescriptorByName(fullname protoreflect.FullName) (protoreflect.Descriptor, error) { 268 return r.files.FindDescriptorByName(fullname) 269 } 270 271 func (s *state) appendHandler( 272 rule *annotations.HttpRule, 273 desc protoreflect.MethodDescriptor, 274 h *handler, 275 ) error { 276 if err := s.path.addRule(rule, desc, h.method); err != nil { 277 return err 278 } 279 s.handlers[h.method] = append(s.handlers[h.method], h) 280 return nil 281 } 282 283 func (s *state) removeHandler(cc *grpc.ClientConn) bool { 284 cl, ok := s.conns[cc] 285 if !ok { 286 return ok 287 } 288 289 // Drop handlers belonging to the client conn. 290 for _, hd := range cl.handlers { 291 name := hd.method 292 293 var hds []*handler 294 for _, mhd := range s.handlers[name] { 295 // Compare if handler belongs to this connection. 296 if mhd != hd { 297 hds = append(hds, mhd) 298 } 299 } 300 if len(hds) == 0 { 301 delete(s.handlers, name) 302 s.path.delRule(name) 303 } else { 304 s.handlers[name] = hds 305 } 306 } 307 // Drop conn on client conn. 308 delete(s.conns, cc) 309 return ok 310 } 311 312 func (s *state) addConnHandler( 313 cc *grpc.ClientConn, 314 stream rpb.ServerReflection_ServerReflectionInfoClient, 315 ) error { 316 // TODO: async fetch and mux creation. 317 318 if err := stream.Send(&rpb.ServerReflectionRequest{ 319 MessageRequest: &rpb.ServerReflectionRequest_ListServices{}, 320 }); err != nil { 321 return err 322 } 323 324 r, err := stream.Recv() 325 if err != nil { 326 return err 327 } 328 // TODO: check r.GetErrorResponse()? 329 330 // File descriptors hash for detecting updates. TODO: sort fds? 331 h := sha256.New() 332 333 fds := make(map[string]*descriptorpb.FileDescriptorProto) 334 for _, svc := range r.GetListServicesResponse().GetService() { 335 if err := stream.Send(&rpb.ServerReflectionRequest{ 336 MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{ 337 FileContainingSymbol: svc.GetName(), 338 }, 339 }); err != nil { 340 return err 341 } 342 343 fdr, err := stream.Recv() 344 if err != nil { 345 return err 346 } 347 348 fdbb := fdr.GetFileDescriptorResponse().GetFileDescriptorProto() 349 350 for _, fdb := range fdbb { 351 fd := &descriptorpb.FileDescriptorProto{} 352 if err := proto.Unmarshal(fdb, fd); err != nil { 353 return err 354 } 355 fds[fd.GetName()] = fd 356 357 if _, err := h.Write(fdb); err != nil { 358 return err 359 } 360 } 361 } 362 363 fdHash := h.Sum(nil) 364 365 // Check if previous connection exists. 366 if cl, ok := s.conns[cc]; ok { 367 if bytes.Equal(cl.fdHash, fdHash) { 368 return nil // nothing to do 369 } 370 371 // Drop and recreate below. 372 s.removeHandler(cc) 373 } 374 375 rslvr, err := newResolver(stream) 376 if err != nil { 377 return err 378 } 379 380 var handlers []*handler 381 for _, fd := range fds { 382 file, err := protodesc.NewFile(fd, rslvr) 383 if err != nil { 384 return err 385 } 386 387 hs, err := s.processFile(cc, file) 388 if err != nil { 389 return err 390 } 391 handlers = append(handlers, hs...) 392 } 393 394 // Update methods list. 395 s.conns[cc] = connList{ 396 handlers: handlers, 397 fdHash: fdHash, 398 } 399 return nil 400 } 401 402 func createConnHandler( 403 cc *grpc.ClientConn, 404 sd protoreflect.ServiceDescriptor, 405 md protoreflect.MethodDescriptor, 406 ) *handler { 407 408 argsDesc := md.Input() 409 replyDesc := md.Output() 410 411 method := fmt.Sprintf("/%s/%s", sd.FullName(), md.Name()) 412 413 isClientStream := md.IsStreamingClient() 414 isServerStream := md.IsStreamingServer() 415 if isClientStream || isServerStream { 416 sd := &grpc.StreamDesc{ 417 ServerStreams: md.IsStreamingServer(), 418 ClientStreams: md.IsStreamingClient(), 419 } 420 info := &grpc.StreamServerInfo{ 421 FullMethod: method, 422 IsClientStream: isClientStream, 423 IsServerStream: isServerStream, 424 } 425 426 fn := func(_ interface{}, stream grpc.ServerStream) error { 427 ctx := stream.Context() 428 429 args := dynamicpb.NewMessage(argsDesc) 430 reply := dynamicpb.NewMessage(replyDesc) 431 432 if err := stream.RecvMsg(args); err != nil { 433 return err 434 } 435 436 if md, ok := metadata.FromIncomingContext(ctx); ok { 437 ctx = metadata.NewOutgoingContext(ctx, md) 438 } 439 440 clientStream, err := cc.NewStream(ctx, sd, method) 441 if err != nil { 442 return err 443 } 444 if err := clientStream.SendMsg(args); err != nil { 445 return err 446 } 447 448 var inErr error 449 var wg sync.WaitGroup 450 if sd.ClientStreams { 451 wg.Add(1) 452 go func() { 453 for { 454 if inErr = stream.RecvMsg(args); inErr != nil { 455 break 456 } 457 458 if inErr = clientStream.SendMsg(args); inErr != nil { 459 break 460 } 461 } 462 wg.Done() 463 }() 464 } 465 var outErr error 466 for { 467 if outErr = clientStream.RecvMsg(reply); outErr != nil { 468 break 469 } 470 471 if outErr = stream.SendMsg(reply); outErr != nil { 472 break 473 } 474 475 if !sd.ServerStreams { 476 break 477 } 478 } 479 480 if isStreamError(outErr) { 481 return outErr 482 } 483 if sd.ClientStreams { 484 wg.Wait() 485 if isStreamError(inErr) { 486 return inErr 487 } 488 } 489 trailer := clientStream.Trailer() 490 stream.SetTrailer(trailer) 491 return nil 492 } 493 494 h := func(opts *muxOptions, stream grpc.ServerStream) error { 495 return opts.stream(nil, stream, info, fn) 496 } 497 498 return &handler{ 499 method: method, 500 descriptor: md, 501 handler: h, 502 } 503 } else { 504 info := &grpc.UnaryServerInfo{ 505 Server: nil, 506 FullMethod: method, 507 } 508 fn := func(ctx context.Context, args interface{}) (interface{}, error) { 509 reply := dynamicpb.NewMessage(replyDesc) 510 511 if md, ok := metadata.FromIncomingContext(ctx); ok { 512 ctx = metadata.NewOutgoingContext(ctx, md) 513 } 514 515 if err := cc.Invoke(ctx, method, args, reply); err != nil { 516 return nil, err 517 } 518 return reply, nil 519 } 520 h := func(opts *muxOptions, stream grpc.ServerStream) error { 521 ctx := stream.Context() 522 args := dynamicpb.NewMessage(argsDesc) 523 524 if err := stream.RecvMsg(args); err != nil { 525 return err 526 } 527 528 reply, err := opts.unary(ctx, args, info, fn) 529 if err != nil { 530 return err 531 } 532 return stream.SendMsg(reply) 533 } 534 535 return &handler{ 536 method: method, 537 descriptor: md, 538 handler: h, 539 } 540 } 541 } 542 543 func (s *state) processFile(cc *grpc.ClientConn, fd protoreflect.FileDescriptor) ([]*handler, error) { 544 var handlers []*handler 545 546 sds := fd.Services() 547 for i := 0; i < sds.Len(); i++ { 548 sd := sds.Get(i) 549 550 mds := sd.Methods() 551 for j := 0; j < mds.Len(); j++ { 552 md := mds.Get(j) 553 554 opts := md.Options() // TODO: nil check fails? 555 556 rule := getExtensionHTTP(opts) 557 if rule == nil { 558 continue 559 } 560 561 hd := createConnHandler(cc, sd, md) 562 563 if err := s.appendHandler(rule, md, hd); err != nil { 564 return nil, err 565 } 566 handlers = append(handlers, hd) 567 } 568 } 569 return handlers, nil 570 } 571 572 func (m *Mux) loadState() *state { 573 s, _ := m.state.Load().(*state) 574 return s 575 } 576 func (m *Mux) storeState(s *state) { m.state.Store(s) } 577 578 func (s *state) pickMethodHandler(name string) (*handler, error) { 579 if s != nil { 580 hds := s.handlers[name] 581 if len(hds) > 0 { 582 hd := hds[rand.Intn(len(hds))] 583 return hd, nil 584 } 585 } 586 return nil, status.Errorf(codes.Unimplemented, "method %s not implemented", name) 587 } 588 589 func (s *state) match(route, verb string) (*method, params, error) { 590 if s == nil { 591 return nil, nil, status.Error(codes.NotFound, "not found") 592 } 593 return s.path.match(route, verb) 594 } 595 596 var ( 597 contentTypeCodec = map[string]encoding.Codec{ 598 "application/protobuf": protoCodec{}, 599 "application/octet-stream": protoCodec{}, 600 "application/json": jsonCodec{}, 601 "": jsonCodec{}, // default 602 } 603 ) 604 605 type streamHTTP struct { 606 ctx context.Context 607 w http.ResponseWriter 608 r *http.Request 609 method *method 610 params params 611 recvN int 612 sendN int 613 614 sentHeader bool 615 header metadata.MD 616 trailer metadata.MD 617 618 opts muxOptions 619 } 620 621 func (s *streamHTTP) SetHeader(md metadata.MD) error { 622 if s.sentHeader { 623 return fmt.Errorf("already sent headers") 624 } 625 s.header = metadata.Join(s.header, md) 626 return nil 627 } 628 func (s *streamHTTP) SendHeader(md metadata.MD) error { 629 if s.sentHeader { 630 return fmt.Errorf("already sent headers") 631 } 632 s.header = metadata.Join(s.header, md) 633 setOutgoingHeader(s.w.Header(), s.header) 634 // don't write the header code, wait for the body. 635 s.sentHeader = true 636 637 if sh := s.opts.statsHandler; sh != nil { 638 sh.HandleRPC(s.ctx, &stats.OutHeader{ 639 Header: s.header.Copy(), 640 Compression: s.r.Header.Get("Accept-Encoding"), 641 }) 642 } 643 return nil 644 } 645 646 func (s *streamHTTP) SetTrailer(md metadata.MD) { 647 s.trailer = metadata.Join(s.trailer, md) 648 } 649 650 func (s *streamHTTP) Context() context.Context { 651 sts := &serverTransportStream{s, s.method.name} 652 return grpc.NewContextWithServerTransportStream(s.ctx, sts) 653 } 654 655 func (s *streamHTTP) SendMsg(m interface{}) error { 656 s.sendN += 1 657 reply := m.(proto.Message) 658 659 accept := s.r.Header.Get("Accept") 660 acceptEncoding := s.r.Header.Get("Accept-Encoding") 661 662 if fRsp, ok := s.w.(http.Flusher); ok { 663 defer fRsp.Flush() 664 } 665 666 setOutgoingHeader(s.w.Header(), s.header, s.trailer) 667 668 var resp io.Writer = s.w 669 switch acceptEncoding { 670 case "gzip": 671 s.w.Header().Set("Content-Encoding", "gzip") 672 gRsp := gzip.NewWriter(s.w) 673 defer gRsp.Close() 674 resp = gRsp 675 } 676 677 cur := reply.ProtoReflect() 678 for _, fd := range s.method.resp { 679 cur = cur.Mutable(fd).Message() 680 } 681 682 msg := cur.Interface() 683 684 if cur.Descriptor().FullName() == "google.api.HttpBody" { 685 fds := cur.Descriptor().Fields() 686 fdContentType := fds.ByName(protoreflect.Name("content_type")) 687 fdData := fds.ByName(protoreflect.Name("data")) 688 pContentType := cur.Get(fdContentType) 689 pData := cur.Get(fdData) 690 691 s.w.Header().Set("Content-Type", pContentType.String()) 692 // TODO different non-message size? 693 if err := s.opts.writeAll(resp, pData.Bytes()); err != nil { 694 return err 695 } 696 return nil 697 } 698 699 if accept == "" { 700 accept = "application/json" 701 } 702 703 codec, ok := contentTypeCodec[accept] 704 if !ok { 705 return fmt.Errorf("unknown accept encoding: %s", accept) 706 } 707 b, err := codec.Marshal(msg) 708 if err != nil { 709 return err 710 } 711 s.w.Header().Set("Content-Type", accept) 712 if err := s.opts.writeAll(resp, b); err != nil { 713 return err 714 } 715 if stats := s.opts.statsHandler; stats != nil { 716 // TODO: raw payload stats. 717 stats.HandleRPC(s.ctx, outPayload(false, m, b, b, time.Now())) 718 } 719 return nil 720 } 721 722 func (s *streamHTTP) decodeRequestArgs(args proto.Message) error { 723 contentType := s.r.Header.Get("Content-Type") 724 contentEncoding := s.r.Header.Get("Content-Encoding") 725 726 var body io.ReadCloser 727 switch contentEncoding { 728 case "gzip": 729 var err error 730 body, err = gzip.NewReader(s.r.Body) 731 if err != nil { 732 return err 733 } 734 735 default: 736 body = s.r.Body 737 } 738 defer body.Close() 739 740 b, err := s.opts.readAll(body) 741 if err != nil { 742 return err 743 } 744 745 cur := args.ProtoReflect() 746 for _, fd := range s.method.body { 747 cur = cur.Mutable(fd).Message() 748 } 749 fullname := cur.Descriptor().FullName() 750 751 msg := cur.Interface() 752 753 if fullname == "google.api.HttpBody" { 754 rfl := msg.ProtoReflect() 755 fds := rfl.Descriptor().Fields() 756 fdContentType := fds.ByName(protoreflect.Name("content_type")) 757 fdData := fds.ByName(protoreflect.Name("data")) 758 rfl.Set(fdContentType, protoreflect.ValueOfString(contentType)) 759 rfl.Set(fdData, protoreflect.ValueOfBytes(b)) 760 // TODO: extensions? 761 return nil 762 } 763 764 if contentType == "" { 765 contentType = "application/json" 766 } 767 768 codec, ok := contentTypeCodec[contentType] 769 if !ok { 770 return fmt.Errorf("unknown content-type encoding: %s", contentType) 771 } 772 if err := codec.Unmarshal(b, msg); err != nil { 773 return err 774 } 775 if stats := s.opts.statsHandler; stats != nil { 776 // TODO: raw payload stats. 777 stats.HandleRPC(s.ctx, inPayload(false, msg, b, b, time.Now())) 778 } 779 return nil 780 } 781 782 func (s *streamHTTP) RecvMsg(m interface{}) error { 783 s.recvN += 1 784 args := m.(proto.Message) 785 786 // TODO: fix the body marshalling 787 if s.method.hasBody { 788 // TODO: handler should decide what to select on? 789 if err := s.decodeRequestArgs(args); err != nil { 790 return err 791 } 792 } 793 if s.recvN == 1 { 794 if err := s.params.set(args); err != nil { 795 return err 796 } 797 } 798 return nil 799 } 800 801 func isWebsocketRequest(r *http.Request) bool { 802 for _, header := range r.Header["Upgrade"] { 803 if header == "websocket" { 804 return true 805 } 806 } 807 return false 808 } 809 810 func encError(w http.ResponseWriter, err error) { 811 s, _ := status.FromError(err) 812 w.Header().Set("Content-Type", "application/json") 813 w.WriteHeader(HTTPStatusCode(s.Code())) 814 815 b, err := protojson.Marshal(s.Proto()) 816 if err != nil { 817 panic(err) // ... 818 } 819 w.Write(b) //nolint 820 } 821 822 func (m *Mux) serveHTTP(w http.ResponseWriter, r *http.Request) error { 823 ctx, mdata := newIncomingContext(r.Context(), r.Header) 824 825 s := m.loadState() 826 isWebsocket := isWebsocketRequest(r) 827 828 verb := r.Method 829 if isWebsocket { 830 verb = kindWebsocket 831 } 832 833 method, params, err := s.match(r.URL.Path, verb) 834 if err != nil { 835 return err 836 } 837 838 queryParams, err := method.parseQueryParams(r.URL.Query()) 839 if err != nil { 840 return err 841 } 842 params = append(params, queryParams...) 843 844 hd, err := s.pickMethodHandler(method.name) 845 if err != nil { 846 return err 847 } 848 849 // Handle stats. 850 beginTime := time.Now() 851 if sh := m.opts.statsHandler; sh != nil { 852 ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{ 853 FullMethodName: hd.method, 854 FailFast: false, // TODO 855 }) 856 857 sh.HandleRPC(ctx, &stats.InHeader{ 858 FullMethod: method.name, 859 RemoteAddr: strAddr(r.RemoteAddr), 860 Compression: r.Header.Get("Content-Encoding"), 861 Header: metadata.MD(mdata).Copy(), 862 }) 863 864 sh.HandleRPC(ctx, &stats.Begin{ 865 Client: false, 866 BeginTime: beginTime, 867 FailFast: false, // TODO 868 IsClientStream: hd.descriptor.IsStreamingClient(), 869 IsServerStream: hd.descriptor.IsStreamingServer(), 870 IsTransparentRetryAttempt: false, // TODO 871 }) 872 } 873 874 if isWebsocket { 875 c, err := websocket.Accept(w, r, &websocket.AcceptOptions{}) 876 if err != nil { 877 return err 878 } 879 880 stream := &streamWS{ 881 ctx: ctx, 882 conn: c, 883 method: method, 884 params: params, 885 } 886 herr := hd.handler(&m.opts, stream) 887 888 if herr != nil { 889 s, _ := status.FromError(herr) 890 // TODO: limit message size. 891 c.Close(WSStatusCode(s.Code()), s.Message()) // TODO 892 } else { 893 c.Close(websocket.StatusNormalClosure, "OK") // TODO 894 } 895 896 // Handle stats. 897 if sh := m.opts.statsHandler; sh != nil { 898 endTime := time.Now() 899 sh.HandleRPC(ctx, &stats.End{ 900 Client: false, 901 BeginTime: beginTime, 902 EndTime: endTime, 903 Error: err, 904 }) 905 } 906 return nil 907 } 908 909 stream := &streamHTTP{ 910 ctx: ctx, 911 w: w, r: r, 912 method: method, 913 params: params, 914 opts: m.opts, 915 } 916 herr := hd.handler(&m.opts, stream) 917 // Handle stats. 918 if sh := m.opts.statsHandler; sh != nil { 919 endTime := time.Now() 920 921 // Try to send Trailers, might not be respected. 922 setOutgoingHeader(w.Header(), stream.trailer) 923 sh.HandleRPC(ctx, &stats.OutTrailer{ 924 Trailer: stream.trailer.Copy(), 925 }) 926 927 sh.HandleRPC(ctx, &stats.End{ 928 Client: false, 929 BeginTime: beginTime, 930 EndTime: endTime, 931 Error: err, 932 }) 933 } 934 return herr 935 } 936 937 func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 938 if !strings.HasPrefix(r.URL.Path, "/") { 939 r.URL.Path = "/" + r.URL.Path 940 } 941 if err := m.serveHTTP(w, r); err != nil { 942 encError(w, err) 943 } 944 }