google.golang.org/grpc@v1.72.2/internal/transport/handler_server.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // This file is the implementation of a gRPC server using HTTP/2 which
    20  // uses the standard Go http2 Server implementation (via the
    21  // http.Handler interface), rather than speaking low-level HTTP/2
    22  // frames itself. It is the implementation of *grpc.Server.ServeHTTP.
    23  
    24  package transport
    25  
    26  import (
    27  	"context"
    28  	"errors"
    29  	"fmt"
    30  	"io"
    31  	"net"
    32  	"net/http"
    33  	"strings"
    34  	"sync"
    35  	"time"
    36  
    37  	"golang.org/x/net/http2"
    38  	"google.golang.org/grpc/codes"
    39  	"google.golang.org/grpc/credentials"
    40  	"google.golang.org/grpc/internal/grpclog"
    41  	"google.golang.org/grpc/internal/grpcutil"
    42  	"google.golang.org/grpc/mem"
    43  	"google.golang.org/grpc/metadata"
    44  	"google.golang.org/grpc/peer"
    45  	"google.golang.org/grpc/stats"
    46  	"google.golang.org/grpc/status"
    47  	"google.golang.org/protobuf/proto"
    48  )
    49  
    50  // NewServerHandlerTransport returns a ServerTransport handling gRPC from
    51  // inside an http.Handler, or writes an HTTP error to w and returns an error.
    52  // It requires that the http Server supports HTTP/2.
    53  func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) {
    54  	if r.Method != http.MethodPost {
    55  		w.Header().Set("Allow", http.MethodPost)
    56  		msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
    57  		http.Error(w, msg, http.StatusMethodNotAllowed)
    58  		return nil, errors.New(msg)
    59  	}
    60  	contentType := r.Header.Get("Content-Type")
    61  	// TODO: do we assume contentType is lowercase? we did before
    62  	contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
    63  	if !validContentType {
    64  		msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType)
    65  		http.Error(w, msg, http.StatusUnsupportedMediaType)
    66  		return nil, errors.New(msg)
    67  	}
    68  	if r.ProtoMajor != 2 {
    69  		msg := "gRPC requires HTTP/2"
    70  		http.Error(w, msg, http.StatusHTTPVersionNotSupported)
    71  		return nil, errors.New(msg)
    72  	}
    73  	if _, ok := w.(http.Flusher); !ok {
    74  		msg := "gRPC requires a ResponseWriter supporting http.Flusher"
    75  		http.Error(w, msg, http.StatusInternalServerError)
    76  		return nil, errors.New(msg)
    77  	}
    78  
    79  	var localAddr net.Addr
    80  	if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
    81  		localAddr, _ = la.(net.Addr)
    82  	}
    83  	var authInfo credentials.AuthInfo
    84  	if r.TLS != nil {
    85  		authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
    86  	}
    87  	p := peer.Peer{
    88  		Addr:      strAddr(r.RemoteAddr),
    89  		LocalAddr: localAddr,
    90  		AuthInfo:  authInfo,
    91  	}
    92  	st := &serverHandlerTransport{
    93  		rw:             w,
    94  		req:            r,
    95  		closedCh:       make(chan struct{}),
    96  		writes:         make(chan func()),
    97  		peer:           p,
    98  		contentType:    contentType,
    99  		contentSubtype: contentSubtype,
   100  		stats:          stats,
   101  		bufferPool:     bufferPool,
   102  	}
   103  	st.logger = prefixLoggerForServerHandlerTransport(st)
   104  
   105  	if v := r.Header.Get("grpc-timeout"); v != "" {
   106  		to, err := decodeTimeout(v)
   107  		if err != nil {
   108  			msg := fmt.Sprintf("malformed grpc-timeout: %v", err)
   109  			http.Error(w, msg, http.StatusBadRequest)
   110  			return nil, status.Error(codes.Internal, msg)
   111  		}
   112  		st.timeoutSet = true
   113  		st.timeout = to
   114  	}
   115  
   116  	metakv := []string{"content-type", contentType}
   117  	if r.Host != "" {
   118  		metakv = append(metakv, ":authority", r.Host)
   119  	}
   120  	for k, vv := range r.Header {
   121  		k = strings.ToLower(k)
   122  		if isReservedHeader(k) && !isWhitelistedHeader(k) {
   123  			continue
   124  		}
   125  		for _, v := range vv {
   126  			v, err := decodeMetadataHeader(k, v)
   127  			if err != nil {
   128  				msg := fmt.Sprintf("malformed binary metadata %q in header %q: %v", v, k, err)
   129  				http.Error(w, msg, http.StatusBadRequest)
   130  				return nil, status.Error(codes.Internal, msg)
   131  			}
   132  			metakv = append(metakv, k, v)
   133  		}
   134  	}
   135  	st.headerMD = metadata.Pairs(metakv...)
   136  
   137  	return st, nil
   138  }
   139  
   140  // serverHandlerTransport is an implementation of ServerTransport
   141  // which replies to exactly one gRPC request (exactly one HTTP request),
   142  // using the net/http.Handler interface. This http.Handler is guaranteed
   143  // at this point to be speaking over HTTP/2, so it's able to speak valid
   144  // gRPC.
   145  type serverHandlerTransport struct {
   146  	rw         http.ResponseWriter
   147  	req        *http.Request
   148  	timeoutSet bool
   149  	timeout    time.Duration
   150  
   151  	headerMD metadata.MD
   152  
   153  	peer peer.Peer
   154  
   155  	closeOnce sync.Once
   156  	closedCh  chan struct{} // closed on Close
   157  
   158  	// writes is a channel of code to run serialized in the
   159  	// ServeHTTP (HandleStreams) goroutine. The channel is closed
   160  	// when WriteStatus is called.
   161  	writes chan func()
   162  
   163  	// block concurrent WriteStatus calls
   164  	// e.g. grpc/(*serverStream).SendMsg/RecvMsg
   165  	writeStatusMu sync.Mutex
   166  
   167  	// we just mirror the request content-type
   168  	contentType string
   169  	// we store both contentType and contentSubtype so we don't keep recreating them
   170  	// TODO make sure this is consistent across handler_server and http2_server
   171  	contentSubtype string
   172  
   173  	stats  []stats.Handler
   174  	logger *grpclog.PrefixLogger
   175  
   176  	bufferPool mem.BufferPool
   177  }
   178  
   179  func (ht *serverHandlerTransport) Close(err error) {
   180  	ht.closeOnce.Do(func() {
   181  		if ht.logger.V(logLevel) {
   182  			ht.logger.Infof("Closing: %v", err)
   183  		}
   184  		close(ht.closedCh)
   185  	})
   186  }
   187  
   188  func (ht *serverHandlerTransport) Peer() *peer.Peer {
   189  	return &peer.Peer{
   190  		Addr:      ht.peer.Addr,
   191  		LocalAddr: ht.peer.LocalAddr,
   192  		AuthInfo:  ht.peer.AuthInfo,
   193  	}
   194  }
   195  
   196  // strAddr is a net.Addr backed by either a TCP "ip:port" string, or
   197  // the empty string if unknown.
   198  type strAddr string
   199  
   200  func (a strAddr) Network() string {
   201  	if a != "" {
   202  		// Per the documentation on net/http.Request.RemoteAddr, if this is
   203  		// set, it's set to the IP:port of the peer (hence, TCP):
   204  		// https://golang.org/pkg/net/http/#Request
   205  		//
   206  		// If we want to support Unix sockets later, we can
   207  		// add our own grpc-specific convention within the
   208  		// grpc codebase to set RemoteAddr to a different
   209  		// format, or probably better: we can attach it to the
   210  		// context and use that from serverHandlerTransport.RemoteAddr.
   211  		return "tcp"
   212  	}
   213  	return ""
   214  }
   215  
   216  func (a strAddr) String() string { return string(a) }
   217  
   218  // do runs fn in the ServeHTTP goroutine.
   219  func (ht *serverHandlerTransport) do(fn func()) error {
   220  	select {
   221  	case <-ht.closedCh:
   222  		return ErrConnClosing
   223  	case ht.writes <- fn:
   224  		return nil
   225  	}
   226  }
   227  
   228  func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status) error {
   229  	ht.writeStatusMu.Lock()
   230  	defer ht.writeStatusMu.Unlock()
   231  
   232  	headersWritten := s.updateHeaderSent()
   233  	err := ht.do(func() {
   234  		if !headersWritten {
   235  			ht.writePendingHeaders(s)
   236  		}
   237  
   238  		// And flush, in case no header or body has been sent yet.
   239  		// This forces a separation of headers and trailers if this is the
   240  		// first call (for example, in end2end tests's TestNoService).
   241  		ht.rw.(http.Flusher).Flush()
   242  
   243  		h := ht.rw.Header()
   244  		h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code()))
   245  		if m := st.Message(); m != "" {
   246  			h.Set("Grpc-Message", encodeGrpcMessage(m))
   247  		}
   248  
   249  		s.hdrMu.Lock()
   250  		defer s.hdrMu.Unlock()
   251  		if p := st.Proto(); p != nil && len(p.Details) > 0 {
   252  			delete(s.trailer, grpcStatusDetailsBinHeader)
   253  			stBytes, err := proto.Marshal(p)
   254  			if err != nil {
   255  				// TODO: return error instead, when callers are able to handle it.
   256  				panic(err)
   257  			}
   258  
   259  			h.Set(grpcStatusDetailsBinHeader, encodeBinHeader(stBytes))
   260  		}
   261  
   262  		if len(s.trailer) > 0 {
   263  			for k, vv := range s.trailer {
   264  				// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
   265  				if isReservedHeader(k) {
   266  					continue
   267  				}
   268  				for _, v := range vv {
   269  					// http2 ResponseWriter mechanism to send undeclared Trailers after
   270  					// the headers have possibly been written.
   271  					h.Add(http2.TrailerPrefix+k, encodeMetadataHeader(k, v))
   272  				}
   273  			}
   274  		}
   275  	})
   276  
   277  	if err == nil { // transport has not been closed
   278  		// Note: The trailer fields are compressed with hpack after this call returns.
   279  		// No WireLength field is set here.
   280  		for _, sh := range ht.stats {
   281  			sh.HandleRPC(s.Context(), &stats.OutTrailer{
   282  				Trailer: s.trailer.Copy(),
   283  			})
   284  		}
   285  	}
   286  	ht.Close(errors.New("finished writing status"))
   287  	return err
   288  }
   289  
   290  // writePendingHeaders sets common and custom headers on the first
   291  // write call (Write, WriteHeader, or WriteStatus)
   292  func (ht *serverHandlerTransport) writePendingHeaders(s *ServerStream) {
   293  	ht.writeCommonHeaders(s)
   294  	ht.writeCustomHeaders(s)
   295  }
   296  
   297  // writeCommonHeaders sets common headers on the first write
   298  // call (Write, WriteHeader, or WriteStatus).
   299  func (ht *serverHandlerTransport) writeCommonHeaders(s *ServerStream) {
   300  	h := ht.rw.Header()
   301  	h["Date"] = nil // suppress Date to make tests happy; TODO: restore
   302  	h.Set("Content-Type", ht.contentType)
   303  
   304  	// Predeclare trailers we'll set later in WriteStatus (after the body).
   305  	// This is a SHOULD in the HTTP RFC, and the way you add (known)
   306  	// Trailers per the net/http.ResponseWriter contract.
   307  	// See https://golang.org/pkg/net/http/#ResponseWriter
   308  	// and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
   309  	h.Add("Trailer", "Grpc-Status")
   310  	h.Add("Trailer", "Grpc-Message")
   311  	h.Add("Trailer", "Grpc-Status-Details-Bin")
   312  
   313  	if s.sendCompress != "" {
   314  		h.Set("Grpc-Encoding", s.sendCompress)
   315  	}
   316  }
   317  
   318  // writeCustomHeaders sets custom headers set on the stream via SetHeader
   319  // on the first write call (Write, WriteHeader, or WriteStatus)
   320  func (ht *serverHandlerTransport) writeCustomHeaders(s *ServerStream) {
   321  	h := ht.rw.Header()
   322  
   323  	s.hdrMu.Lock()
   324  	for k, vv := range s.header {
   325  		if isReservedHeader(k) {
   326  			continue
   327  		}
   328  		for _, v := range vv {
   329  			h.Add(k, encodeMetadataHeader(k, v))
   330  		}
   331  	}
   332  
   333  	s.hdrMu.Unlock()
   334  }
   335  
   336  func (ht *serverHandlerTransport) write(s *ServerStream, hdr []byte, data mem.BufferSlice, _ *WriteOptions) error {
   337  	// Always take a reference because otherwise there is no guarantee the data will
   338  	// be available after this function returns. This is what callers to Write
   339  	// expect.
   340  	data.Ref()
   341  	headersWritten := s.updateHeaderSent()
   342  	err := ht.do(func() {
   343  		defer data.Free()
   344  		if !headersWritten {
   345  			ht.writePendingHeaders(s)
   346  		}
   347  		ht.rw.Write(hdr)
   348  		for _, b := range data {
   349  			_, _ = ht.rw.Write(b.ReadOnlyData())
   350  		}
   351  		ht.rw.(http.Flusher).Flush()
   352  	})
   353  	if err != nil {
   354  		data.Free()
   355  		return err
   356  	}
   357  	return nil
   358  }
   359  
   360  func (ht *serverHandlerTransport) writeHeader(s *ServerStream, md metadata.MD) error {
   361  	if err := s.SetHeader(md); err != nil {
   362  		return err
   363  	}
   364  
   365  	headersWritten := s.updateHeaderSent()
   366  	err := ht.do(func() {
   367  		if !headersWritten {
   368  			ht.writePendingHeaders(s)
   369  		}
   370  
   371  		ht.rw.WriteHeader(200)
   372  		ht.rw.(http.Flusher).Flush()
   373  	})
   374  
   375  	if err == nil {
   376  		for _, sh := range ht.stats {
   377  			// Note: The header fields are compressed with hpack after this call returns.
   378  			// No WireLength field is set here.
   379  			sh.HandleRPC(s.Context(), &stats.OutHeader{
   380  				Header:      md.Copy(),
   381  				Compression: s.sendCompress,
   382  			})
   383  		}
   384  	}
   385  	return err
   386  }
   387  
   388  func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) {
   389  	// With this transport type there will be exactly 1 stream: this HTTP request.
   390  	var cancel context.CancelFunc
   391  	if ht.timeoutSet {
   392  		ctx, cancel = context.WithTimeout(ctx, ht.timeout)
   393  	} else {
   394  		ctx, cancel = context.WithCancel(ctx)
   395  	}
   396  
   397  	// requestOver is closed when the status has been written via WriteStatus.
   398  	requestOver := make(chan struct{})
   399  	go func() {
   400  		select {
   401  		case <-requestOver:
   402  		case <-ht.closedCh:
   403  		case <-ht.req.Context().Done():
   404  		}
   405  		cancel()
   406  		ht.Close(errors.New("request is done processing"))
   407  	}()
   408  
   409  	ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
   410  	req := ht.req
   411  	s := &ServerStream{
   412  		Stream: &Stream{
   413  			id:             0, // irrelevant
   414  			ctx:            ctx,
   415  			requestRead:    func(int) {},
   416  			buf:            newRecvBuffer(),
   417  			method:         req.URL.Path,
   418  			recvCompress:   req.Header.Get("grpc-encoding"),
   419  			contentSubtype: ht.contentSubtype,
   420  		},
   421  		cancel:           cancel,
   422  		st:               ht,
   423  		headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
   424  	}
   425  	s.trReader = &transportReader{
   426  		reader:        &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
   427  		windowHandler: func(int) {},
   428  	}
   429  
   430  	// readerDone is closed when the Body.Read-ing goroutine exits.
   431  	readerDone := make(chan struct{})
   432  	go func() {
   433  		defer close(readerDone)
   434  
   435  		for {
   436  			buf := ht.bufferPool.Get(http2MaxFrameLen)
   437  			n, err := req.Body.Read(*buf)
   438  			if n > 0 {
   439  				*buf = (*buf)[:n]
   440  				s.buf.put(recvMsg{buffer: mem.NewBuffer(buf, ht.bufferPool)})
   441  			} else {
   442  				ht.bufferPool.Put(buf)
   443  			}
   444  			if err != nil {
   445  				s.buf.put(recvMsg{err: mapRecvMsgError(err)})
   446  				return
   447  			}
   448  		}
   449  	}()
   450  
   451  	// startStream is provided by the *grpc.Server's serveStreams.
   452  	// It starts a goroutine serving s and exits immediately.
   453  	// The goroutine that is started is the one that then calls
   454  	// into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
   455  	startStream(s)
   456  
   457  	ht.runStream()
   458  	close(requestOver)
   459  
   460  	// Wait for reading goroutine to finish.
   461  	req.Body.Close()
   462  	<-readerDone
   463  }
   464  
   465  func (ht *serverHandlerTransport) runStream() {
   466  	for {
   467  		select {
   468  		case fn := <-ht.writes:
   469  			fn()
   470  		case <-ht.closedCh:
   471  			return
   472  		}
   473  	}
   474  }
   475  
   476  func (ht *serverHandlerTransport) incrMsgRecv() {}
   477  
   478  func (ht *serverHandlerTransport) Drain(string) {
   479  	panic("Drain() is not implemented")
   480  }
   481  
   482  // mapRecvMsgError returns the non-nil err into the appropriate
   483  // error value as expected by callers of *grpc.parser.recvMsg.
   484  // In particular, in can only be:
   485  //   - io.EOF
   486  //   - io.ErrUnexpectedEOF
   487  //   - of type transport.ConnectionError
   488  //   - an error from the status package
   489  func mapRecvMsgError(err error) error {
   490  	if err == io.EOF || err == io.ErrUnexpectedEOF {
   491  		return err
   492  	}
   493  	if se, ok := err.(http2.StreamError); ok {
   494  		if code, ok := http2ErrConvTab[se.Code]; ok {
   495  			return status.Error(code, se.Error())
   496  		}
   497  	}
   498  	if strings.Contains(err.Error(), "body closed by handler") {
   499  		return status.Error(codes.Canceled, err.Error())
   500  	}
   501  	return connectionErrorf(true, err, "%s", err.Error())
   502  }