github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/runtime/handler.go (about)

     1  package runtime
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net/http"
     8  	"net/textproto"
     9  	"strings"
    10  
    11  	"google.golang.org/genproto/googleapis/api/httpbody"
    12  	"google.golang.org/grpc/codes"
    13  	"google.golang.org/grpc/grpclog"
    14  	"google.golang.org/grpc/status"
    15  	"google.golang.org/protobuf/proto"
    16  )
    17  
    18  // ForwardResponseStream forwards the stream from gRPC server to REST client.
    19  func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
    20  	f, ok := w.(http.Flusher)
    21  	if !ok {
    22  		grpclog.Infof("Flush not supported in %T", w)
    23  		http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
    24  		return
    25  	}
    26  
    27  	md, ok := ServerMetadataFromContext(ctx)
    28  	if !ok {
    29  		grpclog.Infof("Failed to extract ServerMetadata from context")
    30  		http.Error(w, "unexpected error", http.StatusInternalServerError)
    31  		return
    32  	}
    33  	handleForwardResponseServerMetadata(w, mux, md)
    34  
    35  	w.Header().Set("Transfer-Encoding", "chunked")
    36  	if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
    37  		HTTPError(ctx, mux, marshaler, w, req, err)
    38  		return
    39  	}
    40  
    41  	var delimiter []byte
    42  	if d, ok := marshaler.(Delimited); ok {
    43  		delimiter = d.Delimiter()
    44  	} else {
    45  		delimiter = []byte("\n")
    46  	}
    47  
    48  	var wroteHeader bool
    49  	for {
    50  		resp, err := recv()
    51  		if errors.Is(err, io.EOF) {
    52  			return
    53  		}
    54  		if err != nil {
    55  			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
    56  			return
    57  		}
    58  		if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
    59  			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
    60  			return
    61  		}
    62  
    63  		if !wroteHeader {
    64  			w.Header().Set("Content-Type", marshaler.ContentType(resp))
    65  		}
    66  
    67  		var buf []byte
    68  		httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
    69  		switch {
    70  		case resp == nil:
    71  			buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
    72  		case isHTTPBody:
    73  			buf = httpBody.GetData()
    74  		default:
    75  			result := map[string]interface{}{"result": resp}
    76  			if rb, ok := resp.(responseBody); ok {
    77  				result["result"] = rb.XXX_ResponseBody()
    78  			}
    79  
    80  			buf, err = marshaler.Marshal(result)
    81  		}
    82  
    83  		if err != nil {
    84  			grpclog.Infof("Failed to marshal response chunk: %v", err)
    85  			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
    86  			return
    87  		}
    88  		if _, err := w.Write(buf); err != nil {
    89  			grpclog.Infof("Failed to send response chunk: %v", err)
    90  			return
    91  		}
    92  		wroteHeader = true
    93  		if _, err := w.Write(delimiter); err != nil {
    94  			grpclog.Infof("Failed to send delimiter chunk: %v", err)
    95  			return
    96  		}
    97  		f.Flush()
    98  	}
    99  }
   100  
   101  func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
   102  	for k, vs := range md.HeaderMD {
   103  		if h, ok := mux.outgoingHeaderMatcher(k); ok {
   104  			for _, v := range vs {
   105  				w.Header().Add(h, v)
   106  			}
   107  		}
   108  	}
   109  }
   110  
   111  func handleForwardResponseTrailerHeader(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
   112  	for k := range md.TrailerMD {
   113  		if h, ok := mux.outgoingTrailerMatcher(k); ok {
   114  			w.Header().Add("Trailer", textproto.CanonicalMIMEHeaderKey(h))
   115  		}
   116  	}
   117  }
   118  
   119  func handleForwardResponseTrailer(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
   120  	for k, vs := range md.TrailerMD {
   121  		if h, ok := mux.outgoingTrailerMatcher(k); ok {
   122  			for _, v := range vs {
   123  				w.Header().Add(h, v)
   124  			}
   125  		}
   126  	}
   127  }
   128  
   129  // responseBody interface contains method for getting field for marshaling to the response body
   130  // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
   131  type responseBody interface {
   132  	XXX_ResponseBody() interface{}
   133  }
   134  
   135  // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
   136  func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
   137  	md, ok := ServerMetadataFromContext(ctx)
   138  	if !ok {
   139  		grpclog.Infof("Failed to extract ServerMetadata from context")
   140  	}
   141  
   142  	handleForwardResponseServerMetadata(w, mux, md)
   143  
   144  	// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
   145  	// Unless the request includes a TE header field indicating "trailers"
   146  	// is acceptable, as described in Section 4.3, a server SHOULD NOT
   147  	// generate trailer fields that it believes are necessary for the user
   148  	// agent to receive.
   149  	doForwardTrailers := requestAcceptsTrailers(req)
   150  
   151  	if doForwardTrailers {
   152  		handleForwardResponseTrailerHeader(w, mux, md)
   153  		w.Header().Set("Transfer-Encoding", "chunked")
   154  	}
   155  
   156  	contentType := marshaler.ContentType(resp)
   157  	w.Header().Set("Content-Type", contentType)
   158  
   159  	if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
   160  		HTTPError(ctx, mux, marshaler, w, req, err)
   161  		return
   162  	}
   163  	var buf []byte
   164  	var err error
   165  	if rb, ok := resp.(responseBody); ok {
   166  		buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
   167  	} else {
   168  		buf, err = marshaler.Marshal(resp)
   169  	}
   170  	if err != nil {
   171  		grpclog.Infof("Marshal error: %v", err)
   172  		HTTPError(ctx, mux, marshaler, w, req, err)
   173  		return
   174  	}
   175  
   176  	if _, err = w.Write(buf); err != nil {
   177  		grpclog.Infof("Failed to write response: %v", err)
   178  	}
   179  
   180  	if doForwardTrailers {
   181  		handleForwardResponseTrailer(w, mux, md)
   182  	}
   183  }
   184  
   185  func requestAcceptsTrailers(req *http.Request) bool {
   186  	te := req.Header.Get("TE")
   187  	return strings.Contains(strings.ToLower(te), "trailers")
   188  }
   189  
   190  func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
   191  	if len(opts) == 0 {
   192  		return nil
   193  	}
   194  	for _, opt := range opts {
   195  		if err := opt(ctx, w, resp); err != nil {
   196  			grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
   197  			return err
   198  		}
   199  	}
   200  	return nil
   201  }
   202  
   203  func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error, delimiter []byte) {
   204  	st := mux.streamErrorHandler(ctx, err)
   205  	msg := errorChunk(st)
   206  	if !wroteHeader {
   207  		w.Header().Set("Content-Type", marshaler.ContentType(msg))
   208  		w.WriteHeader(HTTPStatusFromCode(st.Code()))
   209  	}
   210  	buf, err := marshaler.Marshal(msg)
   211  	if err != nil {
   212  		grpclog.Infof("Failed to marshal an error: %v", err)
   213  		return
   214  	}
   215  	if _, err := w.Write(buf); err != nil {
   216  		grpclog.Infof("Failed to notify error to client: %v", err)
   217  		return
   218  	}
   219  	if _, err := w.Write(delimiter); err != nil {
   220  		grpclog.Infof("Failed to send delimiter chunk: %v", err)
   221  		return
   222  	}
   223  }
   224  
   225  func errorChunk(st *status.Status) map[string]proto.Message {
   226  	return map[string]proto.Message{"error": st.Proto()}
   227  }