github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/mux.go (about)

     1  // Copyright 2021 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package larking
     6  
     7  import (
     8  	"bytes"
     9  	"compress/gzip"
    10  	"context"
    11  	"crypto/sha256"
    12  	"fmt"
    13  	"io"
    14  	"io/ioutil"
    15  	"math/rand"
    16  	"net/http"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"google.golang.org/genproto/googleapis/api/annotations"
    23  	"google.golang.org/genproto/googleapis/api/httpbody"
    24  	"google.golang.org/grpc"
    25  	"google.golang.org/grpc/codes"
    26  	"google.golang.org/grpc/encoding"
    27  	"google.golang.org/grpc/metadata"
    28  	rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
    29  	"google.golang.org/grpc/stats"
    30  	"google.golang.org/grpc/status"
    31  	"google.golang.org/protobuf/encoding/protojson"
    32  	"google.golang.org/protobuf/proto"
    33  	"google.golang.org/protobuf/reflect/protodesc"
    34  	"google.golang.org/protobuf/reflect/protoreflect"
    35  	"google.golang.org/protobuf/reflect/protoregistry"
    36  	"google.golang.org/protobuf/types/descriptorpb"
    37  	"google.golang.org/protobuf/types/dynamicpb"
    38  	"nhooyr.io/websocket"
    39  )
    40  
    41  // RO
    42  type connList struct {
    43  	handlers []*handler
    44  	fdHash   []byte
    45  }
    46  
    47  type state struct {
    48  	path     *path
    49  	conns    map[*grpc.ClientConn]connList
    50  	handlers map[string][]*handler
    51  }
    52  
    53  func (s *state) clone() *state {
    54  	if s == nil {
    55  		return &state{
    56  			path:     newPath(),
    57  			conns:    make(map[*grpc.ClientConn]connList),
    58  			handlers: make(map[string][]*handler),
    59  		}
    60  	}
    61  
    62  	conns := make(map[*grpc.ClientConn]connList)
    63  	for conn, cl := range s.conns {
    64  		conns[conn] = cl
    65  	}
    66  
    67  	handlers := make(map[string][]*handler)
    68  	for method, hds := range s.handlers {
    69  		handlers[method] = hds
    70  	}
    71  
    72  	return &state{
    73  		path:     s.path.clone(),
    74  		conns:    conns,
    75  		handlers: handlers,
    76  	}
    77  }
    78  
    79  type muxOptions struct {
    80  	maxReceiveMessageSize int
    81  	maxSendMessageSize    int
    82  	connectionTimeout     time.Duration
    83  	files                 *protoregistry.Files
    84  	types                 protoregistry.MessageTypeResolver
    85  	unaryInterceptor      grpc.UnaryServerInterceptor
    86  	streamInterceptor     grpc.StreamServerInterceptor
    87  	statsHandler          stats.Handler
    88  }
    89  
    90  func (o *muxOptions) readAll(r io.Reader) ([]byte, error) {
    91  	b, err := ioutil.ReadAll(io.LimitReader(r, int64(o.maxReceiveMessageSize)+1))
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	if len(b) > o.maxReceiveMessageSize {
    96  		return nil, fmt.Errorf("max receive message size reached")
    97  	}
    98  	return b, nil
    99  }
   100  func (o *muxOptions) writeAll(dst io.Writer, b []byte) error {
   101  	if len(b) > o.maxSendMessageSize {
   102  		return fmt.Errorf("max send message size reached")
   103  	}
   104  	src := bytes.NewReader(b)
   105  	_, err := io.Copy(dst, src)
   106  	return err
   107  }
   108  
   109  // unary is a nil-safe interceptor unary call.
   110  func (o *muxOptions) unary(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
   111  	if ui := o.unaryInterceptor; ui != nil {
   112  		return ui(ctx, req, info, handler)
   113  	}
   114  	return handler(ctx, req)
   115  }
   116  
   117  // stream is a nil-safe interceptor stream call.
   118  func (o *muxOptions) stream(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   119  	if si := o.streamInterceptor; si != nil {
   120  		return si(srv, ss, info, handler)
   121  	}
   122  	return handler(srv, ss)
   123  }
   124  
   125  type MuxOption func(*muxOptions)
   126  
   127  var defaultMuxOptions = muxOptions{
   128  	maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
   129  	maxSendMessageSize:    defaultServerMaxSendMessageSize,
   130  	connectionTimeout:     defaultServerConnectionTimeout,
   131  	files:                 protoregistry.GlobalFiles,
   132  	types:                 protoregistry.GlobalTypes,
   133  }
   134  
   135  func UnaryServerInterceptorOption(interceptor grpc.UnaryServerInterceptor) MuxOption {
   136  	return func(opts *muxOptions) { opts.unaryInterceptor = interceptor }
   137  }
   138  
   139  func StreamServerInterceptorOption(interceptor grpc.StreamServerInterceptor) MuxOption {
   140  	return func(opts *muxOptions) { opts.streamInterceptor = interceptor }
   141  }
   142  
   143  func StatsOption(h stats.Handler) MuxOption {
   144  	return func(opts *muxOptions) { opts.statsHandler = h }
   145  }
   146  
   147  type Mux struct {
   148  	opts muxOptions
   149  	//events trace.EventLog TODO
   150  	mu    sync.Mutex   // Lock to sync writers
   151  	state atomic.Value // Value of *state
   152  
   153  	// services is a list of registered services
   154  	services map[*grpc.ServiceDesc]interface{}
   155  }
   156  
   157  func NewMux(opts ...MuxOption) (*Mux, error) {
   158  	// Apply options.
   159  	var muxOpts = defaultMuxOptions
   160  	for _, opt := range opts {
   161  		opt(&muxOpts)
   162  	}
   163  
   164  	return &Mux{
   165  		opts: muxOpts,
   166  	}, nil
   167  }
   168  
   169  func (m *Mux) RegisterConn(ctx context.Context, cc *grpc.ClientConn) error {
   170  	c := rpb.NewServerReflectionClient(cc)
   171  
   172  	// TODO: watch the stream. When it is recreated refresh the service
   173  	// methods and recreate the mux if needed.
   174  	stream, err := c.ServerReflectionInfo(ctx, grpc.WaitForReady(true))
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	// Load the state for writing.
   180  	m.mu.Lock()
   181  	defer m.mu.Unlock()
   182  	s := m.loadState().clone()
   183  
   184  	if err := s.addConnHandler(cc, stream); err != nil {
   185  		return err
   186  	}
   187  
   188  	m.storeState(s)
   189  
   190  	return stream.CloseSend()
   191  }
   192  
   193  func (m *Mux) DropConn(ctx context.Context, cc *grpc.ClientConn) bool {
   194  	// Load the state for writing.
   195  	m.mu.Lock()
   196  	defer m.mu.Unlock()
   197  	s := m.loadState().clone()
   198  
   199  	return s.removeHandler(cc)
   200  }
   201  
   202  // resolver implements protodesc.Resolver.
   203  type resolver struct {
   204  	files  protoregistry.Files
   205  	stream rpb.ServerReflection_ServerReflectionInfoClient
   206  }
   207  
   208  func newResolver(stream rpb.ServerReflection_ServerReflectionInfoClient) (*resolver, error) {
   209  	r := &resolver{stream: stream}
   210  
   211  	if err := r.files.RegisterFile(annotations.File_google_api_annotations_proto); err != nil {
   212  		return nil, err
   213  	}
   214  	if err := r.files.RegisterFile(annotations.File_google_api_http_proto); err != nil {
   215  		return nil, err
   216  	}
   217  	if err := r.files.RegisterFile(httpbody.File_google_api_httpbody_proto); err != nil {
   218  		return nil, err
   219  	}
   220  	return r, nil
   221  }
   222  
   223  func (r *resolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
   224  	if fd, err := r.files.FindFileByPath(path); err == nil {
   225  		return fd, nil // found file
   226  	}
   227  
   228  	if err := r.stream.Send(&rpb.ServerReflectionRequest{
   229  		MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
   230  			FileByFilename: path,
   231  		},
   232  	}); err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	fdr, err := r.stream.Recv()
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  	fdbs := fdr.GetFileDescriptorResponse().GetFileDescriptorProto()
   241  
   242  	var f protoreflect.FileDescriptor
   243  	for _, fdb := range fdbs {
   244  		fdp := &descriptorpb.FileDescriptorProto{}
   245  		if err := proto.Unmarshal(fdb, fdp); err != nil {
   246  			return nil, err
   247  		}
   248  
   249  		file, err := protodesc.NewFile(fdp, r)
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  		// TODO: check duplicate file registry
   254  		if err := r.files.RegisterFile(file); err != nil {
   255  			return nil, err
   256  		}
   257  		if file.Path() == path {
   258  			f = file
   259  		}
   260  	}
   261  	if f == nil {
   262  		return nil, fmt.Errorf("missing file descriptor %s", path)
   263  	}
   264  	return f, nil
   265  }
   266  
   267  func (r *resolver) FindDescriptorByName(fullname protoreflect.FullName) (protoreflect.Descriptor, error) {
   268  	return r.files.FindDescriptorByName(fullname)
   269  }
   270  
   271  func (s *state) appendHandler(
   272  	rule *annotations.HttpRule,
   273  	desc protoreflect.MethodDescriptor,
   274  	h *handler,
   275  ) error {
   276  	if err := s.path.addRule(rule, desc, h.method); err != nil {
   277  		return err
   278  	}
   279  	s.handlers[h.method] = append(s.handlers[h.method], h)
   280  	return nil
   281  }
   282  
   283  func (s *state) removeHandler(cc *grpc.ClientConn) bool {
   284  	cl, ok := s.conns[cc]
   285  	if !ok {
   286  		return ok
   287  	}
   288  
   289  	// Drop handlers belonging to the client conn.
   290  	for _, hd := range cl.handlers {
   291  		name := hd.method
   292  
   293  		var hds []*handler
   294  		for _, mhd := range s.handlers[name] {
   295  			// Compare if handler belongs to this connection.
   296  			if mhd != hd {
   297  				hds = append(hds, mhd)
   298  			}
   299  		}
   300  		if len(hds) == 0 {
   301  			delete(s.handlers, name)
   302  			s.path.delRule(name)
   303  		} else {
   304  			s.handlers[name] = hds
   305  		}
   306  	}
   307  	// Drop conn on client conn.
   308  	delete(s.conns, cc)
   309  	return ok
   310  }
   311  
   312  func (s *state) addConnHandler(
   313  	cc *grpc.ClientConn,
   314  	stream rpb.ServerReflection_ServerReflectionInfoClient,
   315  ) error {
   316  	// TODO: async fetch and mux creation.
   317  
   318  	if err := stream.Send(&rpb.ServerReflectionRequest{
   319  		MessageRequest: &rpb.ServerReflectionRequest_ListServices{},
   320  	}); err != nil {
   321  		return err
   322  	}
   323  
   324  	r, err := stream.Recv()
   325  	if err != nil {
   326  		return err
   327  	}
   328  	// TODO: check r.GetErrorResponse()?
   329  
   330  	// File descriptors hash for detecting updates. TODO: sort fds?
   331  	h := sha256.New()
   332  
   333  	fds := make(map[string]*descriptorpb.FileDescriptorProto)
   334  	for _, svc := range r.GetListServicesResponse().GetService() {
   335  		if err := stream.Send(&rpb.ServerReflectionRequest{
   336  			MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{
   337  				FileContainingSymbol: svc.GetName(),
   338  			},
   339  		}); err != nil {
   340  			return err
   341  		}
   342  
   343  		fdr, err := stream.Recv()
   344  		if err != nil {
   345  			return err
   346  		}
   347  
   348  		fdbb := fdr.GetFileDescriptorResponse().GetFileDescriptorProto()
   349  
   350  		for _, fdb := range fdbb {
   351  			fd := &descriptorpb.FileDescriptorProto{}
   352  			if err := proto.Unmarshal(fdb, fd); err != nil {
   353  				return err
   354  			}
   355  			fds[fd.GetName()] = fd
   356  
   357  			if _, err := h.Write(fdb); err != nil {
   358  				return err
   359  			}
   360  		}
   361  	}
   362  
   363  	fdHash := h.Sum(nil)
   364  
   365  	// Check if previous connection exists.
   366  	if cl, ok := s.conns[cc]; ok {
   367  		if bytes.Equal(cl.fdHash, fdHash) {
   368  			return nil // nothing to do
   369  		}
   370  
   371  		// Drop and recreate below.
   372  		s.removeHandler(cc)
   373  	}
   374  
   375  	rslvr, err := newResolver(stream)
   376  	if err != nil {
   377  		return err
   378  	}
   379  
   380  	var handlers []*handler
   381  	for _, fd := range fds {
   382  		file, err := protodesc.NewFile(fd, rslvr)
   383  		if err != nil {
   384  			return err
   385  		}
   386  
   387  		hs, err := s.processFile(cc, file)
   388  		if err != nil {
   389  			return err
   390  		}
   391  		handlers = append(handlers, hs...)
   392  	}
   393  
   394  	// Update methods list.
   395  	s.conns[cc] = connList{
   396  		handlers: handlers,
   397  		fdHash:   fdHash,
   398  	}
   399  	return nil
   400  }
   401  
   402  func createConnHandler(
   403  	cc *grpc.ClientConn,
   404  	sd protoreflect.ServiceDescriptor,
   405  	md protoreflect.MethodDescriptor,
   406  ) *handler {
   407  
   408  	argsDesc := md.Input()
   409  	replyDesc := md.Output()
   410  
   411  	method := fmt.Sprintf("/%s/%s", sd.FullName(), md.Name())
   412  
   413  	isClientStream := md.IsStreamingClient()
   414  	isServerStream := md.IsStreamingServer()
   415  	if isClientStream || isServerStream {
   416  		sd := &grpc.StreamDesc{
   417  			ServerStreams: md.IsStreamingServer(),
   418  			ClientStreams: md.IsStreamingClient(),
   419  		}
   420  		info := &grpc.StreamServerInfo{
   421  			FullMethod:     method,
   422  			IsClientStream: isClientStream,
   423  			IsServerStream: isServerStream,
   424  		}
   425  
   426  		fn := func(_ interface{}, stream grpc.ServerStream) error {
   427  			ctx := stream.Context()
   428  
   429  			args := dynamicpb.NewMessage(argsDesc)
   430  			reply := dynamicpb.NewMessage(replyDesc)
   431  
   432  			if err := stream.RecvMsg(args); err != nil {
   433  				return err
   434  			}
   435  
   436  			if md, ok := metadata.FromIncomingContext(ctx); ok {
   437  				ctx = metadata.NewOutgoingContext(ctx, md)
   438  			}
   439  
   440  			clientStream, err := cc.NewStream(ctx, sd, method)
   441  			if err != nil {
   442  				return err
   443  			}
   444  			if err := clientStream.SendMsg(args); err != nil {
   445  				return err
   446  			}
   447  
   448  			var inErr error
   449  			var wg sync.WaitGroup
   450  			if sd.ClientStreams {
   451  				wg.Add(1)
   452  				go func() {
   453  					for {
   454  						if inErr = stream.RecvMsg(args); inErr != nil {
   455  							break
   456  						}
   457  
   458  						if inErr = clientStream.SendMsg(args); inErr != nil {
   459  							break
   460  						}
   461  					}
   462  					wg.Done()
   463  				}()
   464  			}
   465  			var outErr error
   466  			for {
   467  				if outErr = clientStream.RecvMsg(reply); outErr != nil {
   468  					break
   469  				}
   470  
   471  				if outErr = stream.SendMsg(reply); outErr != nil {
   472  					break
   473  				}
   474  
   475  				if !sd.ServerStreams {
   476  					break
   477  				}
   478  			}
   479  
   480  			if isStreamError(outErr) {
   481  				return outErr
   482  			}
   483  			if sd.ClientStreams {
   484  				wg.Wait()
   485  				if isStreamError(inErr) {
   486  					return inErr
   487  				}
   488  			}
   489  			trailer := clientStream.Trailer()
   490  			stream.SetTrailer(trailer)
   491  			return nil
   492  		}
   493  
   494  		h := func(opts *muxOptions, stream grpc.ServerStream) error {
   495  			return opts.stream(nil, stream, info, fn)
   496  		}
   497  
   498  		return &handler{
   499  			method:     method,
   500  			descriptor: md,
   501  			handler:    h,
   502  		}
   503  	} else {
   504  		info := &grpc.UnaryServerInfo{
   505  			Server:     nil,
   506  			FullMethod: method,
   507  		}
   508  		fn := func(ctx context.Context, args interface{}) (interface{}, error) {
   509  			reply := dynamicpb.NewMessage(replyDesc)
   510  
   511  			if md, ok := metadata.FromIncomingContext(ctx); ok {
   512  				ctx = metadata.NewOutgoingContext(ctx, md)
   513  			}
   514  
   515  			if err := cc.Invoke(ctx, method, args, reply); err != nil {
   516  				return nil, err
   517  			}
   518  			return reply, nil
   519  		}
   520  		h := func(opts *muxOptions, stream grpc.ServerStream) error {
   521  			ctx := stream.Context()
   522  			args := dynamicpb.NewMessage(argsDesc)
   523  
   524  			if err := stream.RecvMsg(args); err != nil {
   525  				return err
   526  			}
   527  
   528  			reply, err := opts.unary(ctx, args, info, fn)
   529  			if err != nil {
   530  				return err
   531  			}
   532  			return stream.SendMsg(reply)
   533  		}
   534  
   535  		return &handler{
   536  			method:     method,
   537  			descriptor: md,
   538  			handler:    h,
   539  		}
   540  	}
   541  }
   542  
   543  func (s *state) processFile(cc *grpc.ClientConn, fd protoreflect.FileDescriptor) ([]*handler, error) {
   544  	var handlers []*handler
   545  
   546  	sds := fd.Services()
   547  	for i := 0; i < sds.Len(); i++ {
   548  		sd := sds.Get(i)
   549  
   550  		mds := sd.Methods()
   551  		for j := 0; j < mds.Len(); j++ {
   552  			md := mds.Get(j)
   553  
   554  			opts := md.Options() // TODO: nil check fails?
   555  
   556  			rule := getExtensionHTTP(opts)
   557  			if rule == nil {
   558  				continue
   559  			}
   560  
   561  			hd := createConnHandler(cc, sd, md)
   562  
   563  			if err := s.appendHandler(rule, md, hd); err != nil {
   564  				return nil, err
   565  			}
   566  			handlers = append(handlers, hd)
   567  		}
   568  	}
   569  	return handlers, nil
   570  }
   571  
   572  func (m *Mux) loadState() *state {
   573  	s, _ := m.state.Load().(*state)
   574  	return s
   575  }
   576  func (m *Mux) storeState(s *state) { m.state.Store(s) }
   577  
   578  func (s *state) pickMethodHandler(name string) (*handler, error) {
   579  	if s != nil {
   580  		hds := s.handlers[name]
   581  		if len(hds) > 0 {
   582  			hd := hds[rand.Intn(len(hds))]
   583  			return hd, nil
   584  		}
   585  	}
   586  	return nil, status.Errorf(codes.Unimplemented, "method %s not implemented", name)
   587  }
   588  
   589  func (s *state) match(route, verb string) (*method, params, error) {
   590  	if s == nil {
   591  		return nil, nil, status.Error(codes.NotFound, "not found")
   592  	}
   593  	return s.path.match(route, verb)
   594  }
   595  
   596  var (
   597  	contentTypeCodec = map[string]encoding.Codec{
   598  		"application/protobuf":     protoCodec{},
   599  		"application/octet-stream": protoCodec{},
   600  		"application/json":         jsonCodec{},
   601  		"":                         jsonCodec{}, // default
   602  	}
   603  )
   604  
   605  type streamHTTP struct {
   606  	ctx    context.Context
   607  	w      http.ResponseWriter
   608  	r      *http.Request
   609  	method *method
   610  	params params
   611  	recvN  int
   612  	sendN  int
   613  
   614  	sentHeader bool
   615  	header     metadata.MD
   616  	trailer    metadata.MD
   617  
   618  	opts muxOptions
   619  }
   620  
   621  func (s *streamHTTP) SetHeader(md metadata.MD) error {
   622  	if s.sentHeader {
   623  		return fmt.Errorf("already sent headers")
   624  	}
   625  	s.header = metadata.Join(s.header, md)
   626  	return nil
   627  }
   628  func (s *streamHTTP) SendHeader(md metadata.MD) error {
   629  	if s.sentHeader {
   630  		return fmt.Errorf("already sent headers")
   631  	}
   632  	s.header = metadata.Join(s.header, md)
   633  	setOutgoingHeader(s.w.Header(), s.header)
   634  	// don't write the header code, wait for the body.
   635  	s.sentHeader = true
   636  
   637  	if sh := s.opts.statsHandler; sh != nil {
   638  		sh.HandleRPC(s.ctx, &stats.OutHeader{
   639  			Header:      s.header.Copy(),
   640  			Compression: s.r.Header.Get("Accept-Encoding"),
   641  		})
   642  	}
   643  	return nil
   644  }
   645  
   646  func (s *streamHTTP) SetTrailer(md metadata.MD) {
   647  	s.trailer = metadata.Join(s.trailer, md)
   648  }
   649  
   650  func (s *streamHTTP) Context() context.Context {
   651  	sts := &serverTransportStream{s, s.method.name}
   652  	return grpc.NewContextWithServerTransportStream(s.ctx, sts)
   653  }
   654  
   655  func (s *streamHTTP) SendMsg(m interface{}) error {
   656  	s.sendN += 1
   657  	reply := m.(proto.Message)
   658  
   659  	accept := s.r.Header.Get("Accept")
   660  	acceptEncoding := s.r.Header.Get("Accept-Encoding")
   661  
   662  	if fRsp, ok := s.w.(http.Flusher); ok {
   663  		defer fRsp.Flush()
   664  	}
   665  
   666  	setOutgoingHeader(s.w.Header(), s.header, s.trailer)
   667  
   668  	var resp io.Writer = s.w
   669  	switch acceptEncoding {
   670  	case "gzip":
   671  		s.w.Header().Set("Content-Encoding", "gzip")
   672  		gRsp := gzip.NewWriter(s.w)
   673  		defer gRsp.Close()
   674  		resp = gRsp
   675  	}
   676  
   677  	cur := reply.ProtoReflect()
   678  	for _, fd := range s.method.resp {
   679  		cur = cur.Mutable(fd).Message()
   680  	}
   681  
   682  	msg := cur.Interface()
   683  
   684  	if cur.Descriptor().FullName() == "google.api.HttpBody" {
   685  		fds := cur.Descriptor().Fields()
   686  		fdContentType := fds.ByName(protoreflect.Name("content_type"))
   687  		fdData := fds.ByName(protoreflect.Name("data"))
   688  		pContentType := cur.Get(fdContentType)
   689  		pData := cur.Get(fdData)
   690  
   691  		s.w.Header().Set("Content-Type", pContentType.String())
   692  		// TODO different non-message size?
   693  		if err := s.opts.writeAll(resp, pData.Bytes()); err != nil {
   694  			return err
   695  		}
   696  		return nil
   697  	}
   698  
   699  	if accept == "" {
   700  		accept = "application/json"
   701  	}
   702  
   703  	codec, ok := contentTypeCodec[accept]
   704  	if !ok {
   705  		return fmt.Errorf("unknown accept encoding: %s", accept)
   706  	}
   707  	b, err := codec.Marshal(msg)
   708  	if err != nil {
   709  		return err
   710  	}
   711  	s.w.Header().Set("Content-Type", accept)
   712  	if err := s.opts.writeAll(resp, b); err != nil {
   713  		return err
   714  	}
   715  	if stats := s.opts.statsHandler; stats != nil {
   716  		// TODO: raw payload stats.
   717  		stats.HandleRPC(s.ctx, outPayload(false, m, b, b, time.Now()))
   718  	}
   719  	return nil
   720  }
   721  
   722  func (s *streamHTTP) decodeRequestArgs(args proto.Message) error {
   723  	contentType := s.r.Header.Get("Content-Type")
   724  	contentEncoding := s.r.Header.Get("Content-Encoding")
   725  
   726  	var body io.ReadCloser
   727  	switch contentEncoding {
   728  	case "gzip":
   729  		var err error
   730  		body, err = gzip.NewReader(s.r.Body)
   731  		if err != nil {
   732  			return err
   733  		}
   734  
   735  	default:
   736  		body = s.r.Body
   737  	}
   738  	defer body.Close()
   739  
   740  	b, err := s.opts.readAll(body)
   741  	if err != nil {
   742  		return err
   743  	}
   744  
   745  	cur := args.ProtoReflect()
   746  	for _, fd := range s.method.body {
   747  		cur = cur.Mutable(fd).Message()
   748  	}
   749  	fullname := cur.Descriptor().FullName()
   750  
   751  	msg := cur.Interface()
   752  
   753  	if fullname == "google.api.HttpBody" {
   754  		rfl := msg.ProtoReflect()
   755  		fds := rfl.Descriptor().Fields()
   756  		fdContentType := fds.ByName(protoreflect.Name("content_type"))
   757  		fdData := fds.ByName(protoreflect.Name("data"))
   758  		rfl.Set(fdContentType, protoreflect.ValueOfString(contentType))
   759  		rfl.Set(fdData, protoreflect.ValueOfBytes(b))
   760  		// TODO: extensions?
   761  		return nil
   762  	}
   763  
   764  	if contentType == "" {
   765  		contentType = "application/json"
   766  	}
   767  
   768  	codec, ok := contentTypeCodec[contentType]
   769  	if !ok {
   770  		return fmt.Errorf("unknown content-type encoding: %s", contentType)
   771  	}
   772  	if err := codec.Unmarshal(b, msg); err != nil {
   773  		return err
   774  	}
   775  	if stats := s.opts.statsHandler; stats != nil {
   776  		// TODO: raw payload stats.
   777  		stats.HandleRPC(s.ctx, inPayload(false, msg, b, b, time.Now()))
   778  	}
   779  	return nil
   780  }
   781  
   782  func (s *streamHTTP) RecvMsg(m interface{}) error {
   783  	s.recvN += 1
   784  	args := m.(proto.Message)
   785  
   786  	// TODO: fix the body marshalling
   787  	if s.method.hasBody {
   788  		// TODO: handler should decide what to select on?
   789  		if err := s.decodeRequestArgs(args); err != nil {
   790  			return err
   791  		}
   792  	}
   793  	if s.recvN == 1 {
   794  		if err := s.params.set(args); err != nil {
   795  			return err
   796  		}
   797  	}
   798  	return nil
   799  }
   800  
   801  func isWebsocketRequest(r *http.Request) bool {
   802  	for _, header := range r.Header["Upgrade"] {
   803  		if header == "websocket" {
   804  			return true
   805  		}
   806  	}
   807  	return false
   808  }
   809  
   810  func encError(w http.ResponseWriter, err error) {
   811  	s, _ := status.FromError(err)
   812  	w.Header().Set("Content-Type", "application/json")
   813  	w.WriteHeader(HTTPStatusCode(s.Code()))
   814  
   815  	b, err := protojson.Marshal(s.Proto())
   816  	if err != nil {
   817  		panic(err) // ...
   818  	}
   819  	w.Write(b) //nolint
   820  }
   821  
   822  func (m *Mux) serveHTTP(w http.ResponseWriter, r *http.Request) error {
   823  	ctx, mdata := newIncomingContext(r.Context(), r.Header)
   824  
   825  	s := m.loadState()
   826  	isWebsocket := isWebsocketRequest(r)
   827  
   828  	verb := r.Method
   829  	if isWebsocket {
   830  		verb = kindWebsocket
   831  	}
   832  
   833  	method, params, err := s.match(r.URL.Path, verb)
   834  	if err != nil {
   835  		return err
   836  	}
   837  
   838  	queryParams, err := method.parseQueryParams(r.URL.Query())
   839  	if err != nil {
   840  		return err
   841  	}
   842  	params = append(params, queryParams...)
   843  
   844  	hd, err := s.pickMethodHandler(method.name)
   845  	if err != nil {
   846  		return err
   847  	}
   848  
   849  	// Handle stats.
   850  	beginTime := time.Now()
   851  	if sh := m.opts.statsHandler; sh != nil {
   852  		ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{
   853  			FullMethodName: hd.method,
   854  			FailFast:       false, // TODO
   855  		})
   856  
   857  		sh.HandleRPC(ctx, &stats.InHeader{
   858  			FullMethod:  method.name,
   859  			RemoteAddr:  strAddr(r.RemoteAddr),
   860  			Compression: r.Header.Get("Content-Encoding"),
   861  			Header:      metadata.MD(mdata).Copy(),
   862  		})
   863  
   864  		sh.HandleRPC(ctx, &stats.Begin{
   865  			Client:                    false,
   866  			BeginTime:                 beginTime,
   867  			FailFast:                  false, // TODO
   868  			IsClientStream:            hd.descriptor.IsStreamingClient(),
   869  			IsServerStream:            hd.descriptor.IsStreamingServer(),
   870  			IsTransparentRetryAttempt: false, // TODO
   871  		})
   872  	}
   873  
   874  	if isWebsocket {
   875  		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{})
   876  		if err != nil {
   877  			return err
   878  		}
   879  
   880  		stream := &streamWS{
   881  			ctx:    ctx,
   882  			conn:   c,
   883  			method: method,
   884  			params: params,
   885  		}
   886  		herr := hd.handler(&m.opts, stream)
   887  
   888  		if herr != nil {
   889  			s, _ := status.FromError(herr)
   890  			// TODO: limit message size.
   891  			c.Close(WSStatusCode(s.Code()), s.Message()) // TODO
   892  		} else {
   893  			c.Close(websocket.StatusNormalClosure, "OK") // TODO
   894  		}
   895  
   896  		// Handle stats.
   897  		if sh := m.opts.statsHandler; sh != nil {
   898  			endTime := time.Now()
   899  			sh.HandleRPC(ctx, &stats.End{
   900  				Client:    false,
   901  				BeginTime: beginTime,
   902  				EndTime:   endTime,
   903  				Error:     err,
   904  			})
   905  		}
   906  		return nil
   907  	}
   908  
   909  	stream := &streamHTTP{
   910  		ctx: ctx,
   911  		w:   w, r: r,
   912  		method: method,
   913  		params: params,
   914  		opts:   m.opts,
   915  	}
   916  	herr := hd.handler(&m.opts, stream)
   917  	// Handle stats.
   918  	if sh := m.opts.statsHandler; sh != nil {
   919  		endTime := time.Now()
   920  
   921  		// Try to send Trailers, might not be respected.
   922  		setOutgoingHeader(w.Header(), stream.trailer)
   923  		sh.HandleRPC(ctx, &stats.OutTrailer{
   924  			Trailer: stream.trailer.Copy(),
   925  		})
   926  
   927  		sh.HandleRPC(ctx, &stats.End{
   928  			Client:    false,
   929  			BeginTime: beginTime,
   930  			EndTime:   endTime,
   931  			Error:     err,
   932  		})
   933  	}
   934  	return herr
   935  }
   936  
   937  func (m *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   938  	if !strings.HasPrefix(r.URL.Path, "/") {
   939  		r.URL.Path = "/" + r.URL.Path
   940  	}
   941  	if err := m.serveHTTP(w, r); err != nil {
   942  		encError(w, err)
   943  	}
   944  }