github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/internal/transport/handler_server.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // This file is the implementation of a gRPC server using HTTP/2 which
    20  // uses the standard Go http2 Server implementation (via the
    21  // http.Handler interface), rather than speaking low-level HTTP/2
    22  // frames itself. It is the implementation of *grpc.Server.ServeHTTP.
    23  
    24  package transport
    25  
    26  import (
    27  	"bytes"
    28  	"context"
    29  	"errors"
    30  	"fmt"
    31  	"io"
    32  	"net"
    33  	"strings"
    34  	"sync"
    35  	"time"
    36  
    37  	http "github.com/hxx258456/ccgo/gmhttp"
    38  
    39  	"github.com/golang/protobuf/proto"
    40  	"github.com/hxx258456/ccgo/grpc/codes"
    41  	"github.com/hxx258456/ccgo/grpc/credentials"
    42  	"github.com/hxx258456/ccgo/grpc/internal/grpcutil"
    43  	"github.com/hxx258456/ccgo/grpc/metadata"
    44  	"github.com/hxx258456/ccgo/grpc/peer"
    45  	"github.com/hxx258456/ccgo/grpc/stats"
    46  	"github.com/hxx258456/ccgo/grpc/status"
    47  	"github.com/hxx258456/ccgo/net/http2"
    48  )
    49  
    50  // NewServerHandlerTransport returns a ServerTransport handling gRPC
    51  // from inside an http.Handler. It requires that the http Server
    52  // supports HTTP/2.
    53  func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
    54  	if r.ProtoMajor != 2 {
    55  		return nil, errors.New("gRPC requires HTTP/2")
    56  	}
    57  	if r.Method != "POST" {
    58  		return nil, errors.New("invalid gRPC request method")
    59  	}
    60  	contentType := r.Header.Get("Content-Type")
    61  	// TODO: do we assume contentType is lowercase? we did before
    62  	contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
    63  	if !validContentType {
    64  		return nil, errors.New("invalid gRPC request content-type")
    65  	}
    66  	if _, ok := w.(http.Flusher); !ok {
    67  		return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
    68  	}
    69  
    70  	st := &serverHandlerTransport{
    71  		rw:             w,
    72  		req:            r,
    73  		closedCh:       make(chan struct{}),
    74  		writes:         make(chan func()),
    75  		contentType:    contentType,
    76  		contentSubtype: contentSubtype,
    77  		stats:          stats,
    78  	}
    79  
    80  	if v := r.Header.Get("grpc-timeout"); v != "" {
    81  		to, err := decodeTimeout(v)
    82  		if err != nil {
    83  			return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err)
    84  		}
    85  		st.timeoutSet = true
    86  		st.timeout = to
    87  	}
    88  
    89  	metakv := []string{"content-type", contentType}
    90  	if r.Host != "" {
    91  		metakv = append(metakv, ":authority", r.Host)
    92  	}
    93  	for k, vv := range r.Header {
    94  		k = strings.ToLower(k)
    95  		if isReservedHeader(k) && !isWhitelistedHeader(k) {
    96  			continue
    97  		}
    98  		for _, v := range vv {
    99  			v, err := decodeMetadataHeader(k, v)
   100  			if err != nil {
   101  				return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err)
   102  			}
   103  			metakv = append(metakv, k, v)
   104  		}
   105  	}
   106  	st.headerMD = metadata.Pairs(metakv...)
   107  
   108  	return st, nil
   109  }
   110  
   111  // serverHandlerTransport is an implementation of ServerTransport
   112  // which replies to exactly one gRPC request (exactly one HTTP request),
   113  // using the net/http.Handler interface. This http.Handler is guaranteed
   114  // at this point to be speaking over HTTP/2, so it's able to speak valid
   115  // gRPC.
   116  type serverHandlerTransport struct {
   117  	rw         http.ResponseWriter
   118  	req        *http.Request
   119  	timeoutSet bool
   120  	timeout    time.Duration
   121  
   122  	headerMD metadata.MD
   123  
   124  	closeOnce sync.Once
   125  	closedCh  chan struct{} // closed on Close
   126  
   127  	// writes is a channel of code to run serialized in the
   128  	// ServeHTTP (HandleStreams) goroutine. The channel is closed
   129  	// when WriteStatus is called.
   130  	writes chan func()
   131  
   132  	// block concurrent WriteStatus calls
   133  	// e.g. grpc/(*serverStream).SendMsg/RecvMsg
   134  	writeStatusMu sync.Mutex
   135  
   136  	// we just mirror the request content-type
   137  	contentType string
   138  	// we store both contentType and contentSubtype so we don't keep recreating them
   139  	// TODO make sure this is consistent across handler_server and http2_server
   140  	contentSubtype string
   141  
   142  	stats stats.Handler
   143  }
   144  
   145  func (ht *serverHandlerTransport) Close() {
   146  	ht.closeOnce.Do(ht.closeCloseChanOnce)
   147  }
   148  
   149  func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) }
   150  
   151  func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
   152  
   153  // strAddr is a net.Addr backed by either a TCP "ip:port" string, or
   154  // the empty string if unknown.
   155  type strAddr string
   156  
   157  func (a strAddr) Network() string {
   158  	if a != "" {
   159  		// Per the documentation on net/http.Request.RemoteAddr, if this is
   160  		// set, it's set to the IP:port of the peer (hence, TCP):
   161  		// https://golang.org/pkg/net/http/#Request
   162  		//
   163  		// If we want to support Unix sockets later, we can
   164  		// add our own grpc-specific convention within the
   165  		// grpc codebase to set RemoteAddr to a different
   166  		// format, or probably better: we can attach it to the
   167  		// context and use that from serverHandlerTransport.RemoteAddr.
   168  		return "tcp"
   169  	}
   170  	return ""
   171  }
   172  
   173  func (a strAddr) String() string { return string(a) }
   174  
   175  // do runs fn in the ServeHTTP goroutine.
   176  func (ht *serverHandlerTransport) do(fn func()) error {
   177  	select {
   178  	case <-ht.closedCh:
   179  		return ErrConnClosing
   180  	case ht.writes <- fn:
   181  		return nil
   182  	}
   183  }
   184  
   185  func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) error {
   186  	ht.writeStatusMu.Lock()
   187  	defer ht.writeStatusMu.Unlock()
   188  
   189  	headersWritten := s.updateHeaderSent()
   190  	err := ht.do(func() {
   191  		if !headersWritten {
   192  			ht.writePendingHeaders(s)
   193  		}
   194  
   195  		// And flush, in case no header or body has been sent yet.
   196  		// This forces a separation of headers and trailers if this is the
   197  		// first call (for example, in end2end tests's TestNoService).
   198  		ht.rw.(http.Flusher).Flush()
   199  
   200  		h := ht.rw.Header()
   201  		h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code()))
   202  		if m := st.Message(); m != "" {
   203  			h.Set("Grpc-Message", encodeGrpcMessage(m))
   204  		}
   205  
   206  		if p := st.Proto(); p != nil && len(p.Details) > 0 {
   207  			stBytes, err := proto.Marshal(p)
   208  			if err != nil {
   209  				// TODO: return error instead, when callers are able to handle it.
   210  				panic(err)
   211  			}
   212  
   213  			h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes))
   214  		}
   215  
   216  		if md := s.Trailer(); len(md) > 0 {
   217  			for k, vv := range md {
   218  				// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
   219  				if isReservedHeader(k) {
   220  					continue
   221  				}
   222  				for _, v := range vv {
   223  					// http2 ResponseWriter mechanism to send undeclared Trailers after
   224  					// the headers have possibly been written.
   225  					h.Add(http2.TrailerPrefix+k, encodeMetadataHeader(k, v))
   226  				}
   227  			}
   228  		}
   229  	})
   230  
   231  	if err == nil { // transport has not been closed
   232  		if ht.stats != nil {
   233  			// Note: The trailer fields are compressed with hpack after this call returns.
   234  			// No WireLength field is set here.
   235  			ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{
   236  				Trailer: s.trailer.Copy(),
   237  			})
   238  		}
   239  	}
   240  	ht.Close()
   241  	return err
   242  }
   243  
   244  // writePendingHeaders sets common and custom headers on the first
   245  // write call (Write, WriteHeader, or WriteStatus)
   246  func (ht *serverHandlerTransport) writePendingHeaders(s *Stream) {
   247  	ht.writeCommonHeaders(s)
   248  	ht.writeCustomHeaders(s)
   249  }
   250  
   251  // writeCommonHeaders sets common headers on the first write
   252  // call (Write, WriteHeader, or WriteStatus).
   253  func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
   254  	h := ht.rw.Header()
   255  	h["Date"] = nil // suppress Date to make tests happy; TODO: restore
   256  	h.Set("Content-Type", ht.contentType)
   257  
   258  	// Predeclare trailers we'll set later in WriteStatus (after the body).
   259  	// This is a SHOULD in the HTTP RFC, and the way you add (known)
   260  	// Trailers per the net/http.ResponseWriter contract.
   261  	// See https://golang.org/pkg/net/http/#ResponseWriter
   262  	// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
   263  	h.Add("Trailer", "Grpc-Status")
   264  	h.Add("Trailer", "Grpc-Message")
   265  	h.Add("Trailer", "Grpc-Status-Details-Bin")
   266  
   267  	if s.sendCompress != "" {
   268  		h.Set("Grpc-Encoding", s.sendCompress)
   269  	}
   270  }
   271  
   272  // writeCustomHeaders sets custom headers set on the stream via SetHeader
   273  // on the first write call (Write, WriteHeader, or WriteStatus).
   274  func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) {
   275  	h := ht.rw.Header()
   276  
   277  	s.hdrMu.Lock()
   278  	for k, vv := range s.header {
   279  		if isReservedHeader(k) {
   280  			continue
   281  		}
   282  		for _, v := range vv {
   283  			h.Add(k, encodeMetadataHeader(k, v))
   284  		}
   285  	}
   286  
   287  	s.hdrMu.Unlock()
   288  }
   289  
   290  func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
   291  	headersWritten := s.updateHeaderSent()
   292  	return ht.do(func() {
   293  		if !headersWritten {
   294  			ht.writePendingHeaders(s)
   295  		}
   296  		ht.rw.Write(hdr)
   297  		ht.rw.Write(data)
   298  		ht.rw.(http.Flusher).Flush()
   299  	})
   300  }
   301  
   302  func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
   303  	if err := s.SetHeader(md); err != nil {
   304  		return err
   305  	}
   306  
   307  	headersWritten := s.updateHeaderSent()
   308  	err := ht.do(func() {
   309  		if !headersWritten {
   310  			ht.writePendingHeaders(s)
   311  		}
   312  
   313  		ht.rw.WriteHeader(200)
   314  		ht.rw.(http.Flusher).Flush()
   315  	})
   316  
   317  	if err == nil {
   318  		if ht.stats != nil {
   319  			// Note: The header fields are compressed with hpack after this call returns.
   320  			// No WireLength field is set here.
   321  			ht.stats.HandleRPC(s.Context(), &stats.OutHeader{
   322  				Header:      md.Copy(),
   323  				Compression: s.sendCompress,
   324  			})
   325  		}
   326  	}
   327  	return err
   328  }
   329  
   330  func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
   331  	// With this transport type there will be exactly 1 stream: this HTTP request.
   332  
   333  	ctx := ht.req.Context()
   334  	var cancel context.CancelFunc
   335  	if ht.timeoutSet {
   336  		ctx, cancel = context.WithTimeout(ctx, ht.timeout)
   337  	} else {
   338  		ctx, cancel = context.WithCancel(ctx)
   339  	}
   340  
   341  	// requestOver is closed when the status has been written via WriteStatus.
   342  	requestOver := make(chan struct{})
   343  	go func() {
   344  		select {
   345  		case <-requestOver:
   346  		case <-ht.closedCh:
   347  		case <-ht.req.Context().Done():
   348  		}
   349  		cancel()
   350  		ht.Close()
   351  	}()
   352  
   353  	req := ht.req
   354  
   355  	s := &Stream{
   356  		id:             0, // irrelevant
   357  		requestRead:    func(int) {},
   358  		cancel:         cancel,
   359  		buf:            newRecvBuffer(),
   360  		st:             ht,
   361  		method:         req.URL.Path,
   362  		recvCompress:   req.Header.Get("grpc-encoding"),
   363  		contentSubtype: ht.contentSubtype,
   364  	}
   365  	pr := &peer.Peer{
   366  		Addr: ht.RemoteAddr(),
   367  	}
   368  	if req.TLS != nil {
   369  		pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
   370  	}
   371  	ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
   372  	s.ctx = peer.NewContext(ctx, pr)
   373  	if ht.stats != nil {
   374  		s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
   375  		inHeader := &stats.InHeader{
   376  			FullMethod:  s.method,
   377  			RemoteAddr:  ht.RemoteAddr(),
   378  			Compression: s.recvCompress,
   379  		}
   380  		ht.stats.HandleRPC(s.ctx, inHeader)
   381  	}
   382  	s.trReader = &transportReader{
   383  		reader:        &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
   384  		windowHandler: func(int) {},
   385  	}
   386  
   387  	// readerDone is closed when the Body.Read-ing goroutine exits.
   388  	readerDone := make(chan struct{})
   389  	go func() {
   390  		defer close(readerDone)
   391  
   392  		// TODO: minimize garbage, optimize recvBuffer code/ownership
   393  		const readSize = 8196
   394  		for buf := make([]byte, readSize); ; {
   395  			n, err := req.Body.Read(buf)
   396  			if n > 0 {
   397  				s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
   398  				buf = buf[n:]
   399  			}
   400  			if err != nil {
   401  				s.buf.put(recvMsg{err: mapRecvMsgError(err)})
   402  				return
   403  			}
   404  			if len(buf) == 0 {
   405  				buf = make([]byte, readSize)
   406  			}
   407  		}
   408  	}()
   409  
   410  	// startStream is provided by the *grpc.Server's serveStreams.
   411  	// It starts a goroutine serving s and exits immediately.
   412  	// The goroutine that is started is the one that then calls
   413  	// into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
   414  	startStream(s)
   415  
   416  	ht.runStream()
   417  	close(requestOver)
   418  
   419  	// Wait for reading goroutine to finish.
   420  	req.Body.Close()
   421  	<-readerDone
   422  }
   423  
   424  func (ht *serverHandlerTransport) runStream() {
   425  	for {
   426  		select {
   427  		case fn := <-ht.writes:
   428  			fn()
   429  		case <-ht.closedCh:
   430  			return
   431  		}
   432  	}
   433  }
   434  
   435  func (ht *serverHandlerTransport) IncrMsgSent() {}
   436  
   437  func (ht *serverHandlerTransport) IncrMsgRecv() {}
   438  
   439  func (ht *serverHandlerTransport) Drain() {
   440  	panic("Drain() is not implemented")
   441  }
   442  
   443  // mapRecvMsgError returns the non-nil err into the appropriate
   444  // error value as expected by callers of *grpc.parser.recvMsg.
   445  // In particular, in can only be:
   446  //   * io.EOF
   447  //   * io.ErrUnexpectedEOF
   448  //   * of type transport.ConnectionError
   449  //   * an error from the status package
   450  func mapRecvMsgError(err error) error {
   451  	if err == io.EOF || err == io.ErrUnexpectedEOF {
   452  		return err
   453  	}
   454  	if se, ok := err.(http2.StreamError); ok {
   455  		if code, ok := http2ErrConvTab[se.Code]; ok {
   456  			return status.Error(code, se.Error())
   457  		}
   458  	}
   459  	if strings.Contains(err.Error(), "body closed by handler") {
   460  		return status.Error(codes.Canceled, err.Error())
   461  	}
   462  	return connectionErrorf(true, err, err.Error())
   463  }