github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/web.go (about)

     1  // Copyright 2022 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package larking
     6  
     7  // Support for gRPC-web
     8  // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md
     9  // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
    10  
    11  import (
    12  	"bytes"
    13  	"encoding/base64"
    14  	"encoding/binary"
    15  	"fmt"
    16  	"io"
    17  	"net/http"
    18  	"strings"
    19  
    20  	"google.golang.org/grpc"
    21  )
    22  
    23  const (
    24  	grpcBase    = "application/grpc"
    25  	grpcWeb     = "application/grpc-web"
    26  	grpcWebText = "application/grpc-web-text"
    27  )
    28  
    29  // isWebRequest checks for gRPC Web headers.
    30  func isWebRequest(r *http.Request) (typ string, enc string, ok bool) {
    31  	ct := r.Header.Get("Content-Type")
    32  	if !strings.HasPrefix(ct, "application/grpc-web") || r.Method != http.MethodPost {
    33  		return typ, enc, false
    34  	}
    35  	typ, enc, ok = strings.Cut(ct, "+")
    36  	if !ok {
    37  		enc = "proto"
    38  	}
    39  	ok = typ == grpcWeb || typ == grpcWebText
    40  	return typ, enc, ok
    41  }
    42  
    43  type webWriter struct {
    44  	w   http.ResponseWriter
    45  	typ string // grpcWEB || grpcWebText
    46  	enc string // proto || json || ...
    47  
    48  	wroteHeader bool
    49  	seenHeaders map[string]bool
    50  
    51  	wroteResp bool
    52  	resp      io.Writer
    53  }
    54  
    55  func newWebWriter(w http.ResponseWriter, typ, enc string) *webWriter {
    56  	var resp io.Writer = w
    57  	if typ == grpcWebText {
    58  		resp = base64.NewEncoder(base64.StdEncoding, resp)
    59  
    60  	}
    61  
    62  	return &webWriter{
    63  		w:    w,
    64  		typ:  typ,
    65  		enc:  enc,
    66  		resp: resp,
    67  	}
    68  }
    69  
    70  func (w *webWriter) seeHeaders() {
    71  	hdr := w.Header()
    72  	hdr.Set("Content-Type", w.typ+"+"+w.enc) // override content-type
    73  
    74  	keys := make(map[string]bool, len(hdr))
    75  	for k := range hdr {
    76  		if strings.HasPrefix(k, http.TrailerPrefix) {
    77  			continue
    78  		}
    79  		keys[k] = true
    80  	}
    81  	w.seenHeaders = keys
    82  	w.wroteHeader = true
    83  }
    84  
    85  func (w *webWriter) Write(b []byte) (int, error) {
    86  	if !w.wroteHeader {
    87  		w.seeHeaders()
    88  	}
    89  	return w.resp.Write(b)
    90  }
    91  
    92  func (w *webWriter) Header() http.Header { return w.w.Header() }
    93  
    94  func (w *webWriter) WriteHeader(code int) {
    95  	w.seeHeaders()
    96  	w.w.WriteHeader(code)
    97  }
    98  
    99  func (w *webWriter) Flush() {
   100  	if w.wroteHeader || w.wroteResp {
   101  		if f, ok := w.w.(http.Flusher); ok {
   102  			f.Flush()
   103  		}
   104  	}
   105  }
   106  
   107  func (w *webWriter) writeTrailer() error {
   108  	hdr := w.Header()
   109  
   110  	tr := make(http.Header, len(hdr)-len(w.seenHeaders)+1)
   111  	for key, val := range hdr {
   112  		if w.seenHeaders[key] {
   113  			continue
   114  		}
   115  		key = strings.TrimPrefix(key, http.TrailerPrefix)
   116  		// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md#protocol-differences-vs-grpc-over-http2
   117  		tr[strings.ToLower(key)] = val
   118  	}
   119  
   120  	var buf bytes.Buffer
   121  	if err := tr.Write(&buf); err != nil {
   122  		return err
   123  	}
   124  
   125  	head := []byte{1 << 7, 0, 0, 0, 0} // MSB=1 indicates this is a trailer data frame.
   126  	binary.BigEndian.PutUint32(head[1:5], uint32(buf.Len()))
   127  	if _, err := w.Write(head); err != nil {
   128  		return err
   129  	}
   130  	if _, err := w.Write(buf.Bytes()); err != nil {
   131  		return err
   132  	}
   133  	return nil
   134  }
   135  
   136  func (w *webWriter) flushWithTrailer() {
   137  	// Write trailers only if message has been sent.
   138  	if w.wroteHeader || w.wroteResp {
   139  		if err := w.writeTrailer(); err != nil {
   140  			return // nothing
   141  		}
   142  	}
   143  	w.Flush()
   144  }
   145  
   146  type readCloser struct {
   147  	io.Reader
   148  	io.Closer
   149  }
   150  
   151  func createGRPCWebHandler(gs *grpc.Server) http.HandlerFunc {
   152  	return func(w http.ResponseWriter, r *http.Request) {
   153  		typ, enc, ok := isWebRequest(r)
   154  		if !ok {
   155  			msg := fmt.Sprintf("invalid gRPC-Web content type: %v", r.Header.Get("Content-Type"))
   156  			http.Error(w, msg, http.StatusBadRequest)
   157  			return
   158  		}
   159  		// TODO: Check for websocket request and upgrade.
   160  		if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
   161  			http.Error(w, "unimplemented websocket support", http.StatusInternalServerError)
   162  			return
   163  		}
   164  
   165  		r.ProtoMajor = 2
   166  		r.ProtoMinor = 0
   167  
   168  		hdr := r.Header
   169  		hdr.Del("Content-Length")
   170  		hdr.Set("Content-Type", grpcBase+"+"+enc)
   171  
   172  		if typ == grpcWebText {
   173  			body := base64.NewDecoder(base64.StdEncoding, r.Body)
   174  			r.Body = readCloser{body, r.Body}
   175  		}
   176  
   177  		ww := newWebWriter(w, typ, enc)
   178  		gs.ServeHTTP(ww, r)
   179  		ww.flushWithTrailer()
   180  	}
   181  }