github.com/decred/dcrlnd@v0.7.6/lnrpc/websocket_proxy.go (about)

     1  // The code in this file is a heavily modified version of
     2  // https://github.com/tmc/grpc-websocket-proxy/
     3  
     4  package lnrpc
     5  
     6  import (
     7  	"bufio"
     8  	"io"
     9  	"net/http"
    10  	"net/textproto"
    11  	"regexp"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/decred/slog"
    16  	"github.com/gorilla/websocket"
    17  	"golang.org/x/net/context"
    18  )
    19  
    20  const (
    21  	// MethodOverrideParam is the GET query parameter that specifies what
    22  	// HTTP request method should be used for the forwarded REST request.
    23  	// This is necessary because the WebSocket API specifies that a
    24  	// handshake request must always be done through a GET request.
    25  	MethodOverrideParam = "method"
    26  
    27  	// HeaderWebSocketProtocol is the name of the WebSocket protocol
    28  	// exchange header field that we use to transport additional header
    29  	// fields.
    30  	HeaderWebSocketProtocol = "Sec-Websocket-Protocol"
    31  
    32  	// WebSocketProtocolDelimiter is the delimiter we use between the
    33  	// additional header field and its value. We use the plus symbol because
    34  	// the default delimiters aren't allowed in the protocol names.
    35  	WebSocketProtocolDelimiter = "+"
    36  
    37  	// PingContent is the content of the ping message we send out. This is
    38  	// an arbitrary non-empty message that has no deeper meaning but should
    39  	// be sent back by the client in the pong message.
    40  	PingContent = "are you there?"
    41  )
    42  
    43  var (
    44  	// defaultHeadersToForward is a map of all HTTP header fields that are
    45  	// forwarded by default. The keys must be in the canonical MIME header
    46  	// format.
    47  	defaultHeadersToForward = map[string]bool{
    48  		"Origin":                 true,
    49  		"Referer":                true,
    50  		"Grpc-Metadata-Macaroon": true,
    51  	}
    52  
    53  	// defaultProtocolsToAllow are additional header fields that we allow
    54  	// to be transported inside of the Sec-Websocket-Protocol field to be
    55  	// forwarded to the backend.
    56  	defaultProtocolsToAllow = map[string]bool{
    57  		"Grpc-Metadata-Macaroon": true,
    58  	}
    59  
    60  	// DefaultPingInterval is the default number of seconds to wait between
    61  	// sending ping requests.
    62  	DefaultPingInterval = time.Second * 30
    63  
    64  	// DefaultPongWait is the maximum duration we wait for a pong response
    65  	// to a ping we sent before we assume the connection died.
    66  	DefaultPongWait = time.Second * 5
    67  )
    68  
    69  // NewWebSocketProxy attempts to expose the underlying handler as a response-
    70  // streaming WebSocket stream with newline-delimited JSON as the content
    71  // encoding. If pingInterval is a non-zero duration, a ping message will be
    72  // sent out periodically and a pong response message is expected from the
    73  // client. The clientStreamingURIs parameter can hold a list of all patterns
    74  // for URIs that are mapped to client-streaming RPC methods. We need to keep
    75  // track of those to make sure we initialize the request body correctly for the
    76  // underlying grpc-gateway library.
    77  func NewWebSocketProxy(h http.Handler, logger slog.Logger,
    78  	pingInterval, pongWait time.Duration,
    79  	clientStreamingURIs []*regexp.Regexp) http.Handler {
    80  
    81  	p := &WebsocketProxy{
    82  		backend: h,
    83  		logger:  logger,
    84  		upgrader: &websocket.Upgrader{
    85  			ReadBufferSize:  1024,
    86  			WriteBufferSize: 1024,
    87  			CheckOrigin: func(r *http.Request) bool {
    88  				return true
    89  			},
    90  		},
    91  		clientStreamingURIs: clientStreamingURIs,
    92  	}
    93  
    94  	if pingInterval > 0 && pongWait > 0 {
    95  		p.pingInterval = pingInterval
    96  		p.pongWait = pongWait
    97  	}
    98  
    99  	return p
   100  }
   101  
   102  // WebsocketProxy provides websocket transport upgrade to compatible endpoints.
   103  type WebsocketProxy struct {
   104  	backend  http.Handler
   105  	logger   slog.Logger
   106  	upgrader *websocket.Upgrader
   107  
   108  	// clientStreamingURIs holds a list of all patterns for URIs that are
   109  	// mapped to client-streaming RPC methods. We need to keep track of
   110  	// those to make sure we initialize the request body correctly for the
   111  	// underlying grpc-gateway library.
   112  	clientStreamingURIs []*regexp.Regexp
   113  
   114  	pingInterval time.Duration
   115  	pongWait     time.Duration
   116  }
   117  
   118  // pingPongEnabled returns true if a ping interval is set to enable sending and
   119  // expecting regular ping/pong messages.
   120  func (p *WebsocketProxy) pingPongEnabled() bool {
   121  	return p.pingInterval > 0 && p.pongWait > 0
   122  }
   123  
   124  // ServeHTTP handles the incoming HTTP request. If the request is an
   125  // "upgradeable" WebSocket request (identified by header fields), then the
   126  // WS proxy handles the request. Otherwise the request is passed directly to the
   127  // underlying REST proxy.
   128  func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   129  	if !websocket.IsWebSocketUpgrade(r) {
   130  		p.backend.ServeHTTP(w, r)
   131  		return
   132  	}
   133  	p.upgradeToWebSocketProxy(w, r)
   134  }
   135  
   136  // upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads
   137  // one incoming message then streams all responses until either the client or
   138  // server quit the connection.
   139  func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter,
   140  	r *http.Request) {
   141  
   142  	conn, err := p.upgrader.Upgrade(w, r, nil)
   143  	if err != nil {
   144  		p.logger.Errorf("error upgrading websocket:", err)
   145  		return
   146  	}
   147  	defer func() {
   148  		err := conn.Close()
   149  		if err != nil && !IsClosedConnError(err) {
   150  			p.logger.Errorf("WS: error closing upgraded conn: %v",
   151  				err)
   152  		}
   153  	}()
   154  
   155  	ctx, cancelFn := context.WithCancel(r.Context())
   156  	defer cancelFn()
   157  
   158  	requestForwarder := newRequestForwardingReader()
   159  	request, err := http.NewRequestWithContext(
   160  		ctx, r.Method, r.URL.String(), requestForwarder,
   161  	)
   162  	if err != nil {
   163  		p.logger.Errorf("WS: error preparing request:", err)
   164  		return
   165  	}
   166  
   167  	// Allow certain headers to be forwarded, either from source headers
   168  	// or the special Sec-Websocket-Protocol header field.
   169  	forwardHeaders(r.Header, request.Header)
   170  
   171  	// Also allow the target request method to be overwritten, as all
   172  	// WebSocket establishment calls MUST be GET requests.
   173  	if m := r.URL.Query().Get(MethodOverrideParam); m != "" {
   174  		request.Method = m
   175  	}
   176  
   177  	// Is this a call to a client-streaming RPC method?
   178  	clientStreaming := false
   179  	for _, pattern := range p.clientStreamingURIs {
   180  		if pattern.MatchString(r.URL.Path) {
   181  			clientStreaming = true
   182  		}
   183  	}
   184  
   185  	responseForwarder := newResponseForwardingWriter()
   186  	go func() {
   187  		<-ctx.Done()
   188  		responseForwarder.Close()
   189  		requestForwarder.CloseWriter()
   190  	}()
   191  
   192  	go func() {
   193  		defer cancelFn()
   194  		p.backend.ServeHTTP(responseForwarder, request)
   195  	}()
   196  
   197  	// Read loop: Take messages from websocket and write them to the payload
   198  	// channel. This needs to be its own goroutine because for non-client
   199  	// streaming RPCs, the requestForwarder.Write() in the second goroutine
   200  	// will block until the request has fully completed. But for the ping/
   201  	// pong handler to work, we need to have an active call to
   202  	// conn.ReadMessage() going on. So we make sure we have such an active
   203  	// call by starting a second read as soon as the first one has
   204  	// completed.
   205  	payloadChannel := make(chan []byte, 1)
   206  	go func() {
   207  		defer cancelFn()
   208  		defer close(payloadChannel)
   209  
   210  		for {
   211  			select {
   212  			case <-ctx.Done():
   213  				return
   214  			default:
   215  			}
   216  
   217  			_, payload, err := conn.ReadMessage()
   218  			if err != nil {
   219  				if IsClosedConnError(err) {
   220  					p.logger.Tracef("WS: socket "+
   221  						"closed: %v", err)
   222  					return
   223  				}
   224  				p.logger.Errorf("error reading message: %v",
   225  					err)
   226  				return
   227  			}
   228  
   229  			select {
   230  			case payloadChannel <- payload:
   231  			case <-ctx.Done():
   232  				return
   233  			}
   234  		}
   235  	}()
   236  
   237  	// Forward loop: Take messages from the incoming payload channel and
   238  	// write them to the http request.
   239  	go func() {
   240  		defer cancelFn()
   241  		for {
   242  			var payload []byte
   243  			select {
   244  			case <-ctx.Done():
   245  				return
   246  			case newPayload, more := <-payloadChannel:
   247  				if !more {
   248  					p.logger.Infof("WS: incoming payload " +
   249  						"chan closed")
   250  					return
   251  				}
   252  
   253  				payload = newPayload
   254  			}
   255  
   256  			_, err = requestForwarder.Write(payload)
   257  			if err != nil {
   258  				p.logger.Errorf("WS: error writing message "+
   259  					"to upstream http server: %v", err)
   260  				return
   261  			}
   262  			_, _ = requestForwarder.Write([]byte{'\n'})
   263  
   264  			// The grpc-gateway library uses a different request
   265  			// reader depending on whether it is a client streaming
   266  			// RPC or not. For a non-streaming request we need to
   267  			// close with EOF to signal the request was completed.
   268  			if !clientStreaming {
   269  				requestForwarder.CloseWriter()
   270  			}
   271  		}
   272  	}()
   273  
   274  	// Ping write loop: Send a ping message regularly if ping/pong is
   275  	// enabled.
   276  	if p.pingPongEnabled() {
   277  		// We'll send out our first ping in pingInterval. So the initial
   278  		// deadline is that interval plus the time we allow for a
   279  		// response to be sent.
   280  		initialDeadline := time.Now().Add(p.pingInterval + p.pongWait)
   281  		_ = conn.SetReadDeadline(initialDeadline)
   282  
   283  		// Whenever a pong message comes in, we extend the deadline
   284  		// until the next read is expected by the interval plus pong
   285  		// wait time. Since we can never _reach_ any of the deadlines,
   286  		// we also have to advance the deadline for the next expected
   287  		// write to happen, in case the next thing we actually write is
   288  		// the next ping.
   289  		conn.SetPongHandler(func(appData string) error {
   290  			nextDeadline := time.Now().Add(
   291  				p.pingInterval + p.pongWait,
   292  			)
   293  			_ = conn.SetReadDeadline(nextDeadline)
   294  			_ = conn.SetWriteDeadline(nextDeadline)
   295  
   296  			return nil
   297  		})
   298  		go func() {
   299  			ticker := time.NewTicker(p.pingInterval)
   300  			defer ticker.Stop()
   301  
   302  			for {
   303  				select {
   304  				case <-ctx.Done():
   305  					p.logger.Debug("WS: ping loop done")
   306  					return
   307  
   308  				case <-ticker.C:
   309  					// Writing the ping shouldn't take any
   310  					// longer than we'll wait for a response
   311  					// in the first place.
   312  					writeDeadline := time.Now().Add(
   313  						p.pongWait,
   314  					)
   315  					err := conn.WriteControl(
   316  						websocket.PingMessage,
   317  						[]byte(PingContent),
   318  						writeDeadline,
   319  					)
   320  					if err != nil {
   321  						p.logger.Warnf("WS: could not "+
   322  							"send ping message: %v",
   323  							err)
   324  						return
   325  					}
   326  				}
   327  			}
   328  		}()
   329  	}
   330  
   331  	// Write loop: Take messages from the response forwarder and write them
   332  	// to the WebSocket.
   333  	for responseForwarder.Scan() {
   334  		if len(responseForwarder.Bytes()) == 0 {
   335  			p.logger.Errorf("WS: empty scan: %v",
   336  				responseForwarder.Err())
   337  
   338  			continue
   339  		}
   340  
   341  		err = conn.WriteMessage(
   342  			websocket.TextMessage, responseForwarder.Bytes(),
   343  		)
   344  		if err != nil {
   345  			p.logger.Errorf("WS: error writing message: %v", err)
   346  			return
   347  		}
   348  	}
   349  	if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) {
   350  		p.logger.Errorf("WS: scanner err: %v", err)
   351  	}
   352  }
   353  
   354  // forwardHeaders forwards certain allowed header fields from the source request
   355  // to the target request. Because browsers are limited in what header fields
   356  // they can send on the WebSocket setup call, we also allow additional fields to
   357  // be transported in the special Sec-Websocket-Protocol field.
   358  func forwardHeaders(source, target http.Header) {
   359  	// Forward allowed header fields directly.
   360  	for header := range source {
   361  		headerName := textproto.CanonicalMIMEHeaderKey(header)
   362  		forward, ok := defaultHeadersToForward[headerName]
   363  		if ok && forward {
   364  			target.Set(headerName, source.Get(header))
   365  		}
   366  	}
   367  
   368  	// Browser aren't allowed to set custom header fields on WebSocket
   369  	// requests. We need to allow them to submit the macaroon as a WS
   370  	// protocol, which is the only allowed header. Set any "protocols" we
   371  	// declare valid as header fields on the forwarded request.
   372  	protocol := source.Get(HeaderWebSocketProtocol)
   373  	for key := range defaultProtocolsToAllow {
   374  		if strings.HasPrefix(protocol, key) {
   375  			// The format is "<protocol name>+<value>". We know the
   376  			// protocol string starts with the name so we only need
   377  			// to set the value.
   378  			values := strings.Split(
   379  				protocol, WebSocketProtocolDelimiter,
   380  			)
   381  			target.Set(key, values[1])
   382  		}
   383  	}
   384  }
   385  
   386  // newRequestForwardingReader creates a new request forwarding pipe.
   387  func newRequestForwardingReader() *requestForwardingReader {
   388  	r, w := io.Pipe()
   389  	return &requestForwardingReader{
   390  		Reader: r,
   391  		Writer: w,
   392  		pipeR:  r,
   393  		pipeW:  w,
   394  	}
   395  }
   396  
   397  // requestForwardingReader is a wrapper around io.Pipe that embeds both the
   398  // io.Reader and io.Writer interface and can be closed.
   399  type requestForwardingReader struct {
   400  	io.Reader
   401  	io.Writer
   402  
   403  	pipeR *io.PipeReader
   404  	pipeW *io.PipeWriter
   405  }
   406  
   407  // CloseWriter closes the underlying pipe writer.
   408  func (r *requestForwardingReader) CloseWriter() {
   409  	_ = r.pipeW.CloseWithError(io.EOF)
   410  }
   411  
   412  // newResponseForwardingWriter creates a new http.ResponseWriter that intercepts
   413  // what's written to it and presents it through a bufio.Scanner interface.
   414  func newResponseForwardingWriter() *responseForwardingWriter {
   415  	r, w := io.Pipe()
   416  	return &responseForwardingWriter{
   417  		Writer:  w,
   418  		Scanner: bufio.NewScanner(r),
   419  		pipeR:   r,
   420  		pipeW:   w,
   421  		header:  http.Header{},
   422  		closed:  make(chan bool, 1),
   423  	}
   424  }
   425  
   426  // responseForwardingWriter is a type that implements the http.ResponseWriter
   427  // interface but internally forwards what's written to the writer through a pipe
   428  // so it can easily be read again through the bufio.Scanner interface.
   429  type responseForwardingWriter struct {
   430  	io.Writer
   431  	*bufio.Scanner
   432  
   433  	pipeR *io.PipeReader
   434  	pipeW *io.PipeWriter
   435  
   436  	header http.Header
   437  	code   int
   438  	closed chan bool
   439  }
   440  
   441  // Write writes the given bytes to the internal pipe.
   442  //
   443  // NOTE: This is part of the http.ResponseWriter interface.
   444  func (w *responseForwardingWriter) Write(b []byte) (int, error) {
   445  	return w.Writer.Write(b)
   446  }
   447  
   448  // Header returns the HTTP header fields intercepted so far.
   449  //
   450  // NOTE: This is part of the http.ResponseWriter interface.
   451  func (w *responseForwardingWriter) Header() http.Header {
   452  	return w.header
   453  }
   454  
   455  // WriteHeader indicates that the header part of the response is now finished
   456  // and sets the response code.
   457  //
   458  // NOTE: This is part of the http.ResponseWriter interface.
   459  func (w *responseForwardingWriter) WriteHeader(code int) {
   460  	w.code = code
   461  }
   462  
   463  // CloseNotify returns a channel that indicates if a connection was closed.
   464  //
   465  // NOTE: This is part of the http.CloseNotifier interface.
   466  func (w *responseForwardingWriter) CloseNotify() <-chan bool {
   467  	return w.closed
   468  }
   469  
   470  // Flush empties all buffers. We implement this to indicate to our backend that
   471  // we support flushing our content. There is no actual implementation because
   472  // all writes happen immediately, there is no internal buffering.
   473  //
   474  // NOTE: This is part of the http.Flusher interface.
   475  func (w *responseForwardingWriter) Flush() {}
   476  
   477  func (w *responseForwardingWriter) Close() {
   478  	_ = w.pipeR.CloseWithError(io.EOF)
   479  	_ = w.pipeW.CloseWithError(io.EOF)
   480  	w.closed <- true
   481  }
   482  
   483  // IsClosedConnError is a helper function that returns true if the given error
   484  // is an error indicating we are using a closed connection.
   485  func IsClosedConnError(err error) bool {
   486  	if err == nil {
   487  		return false
   488  	}
   489  	if err == http.ErrServerClosed {
   490  		return true
   491  	}
   492  
   493  	str := err.Error()
   494  	if strings.Contains(str, "use of closed network connection") {
   495  		return true
   496  	}
   497  	if strings.Contains(str, "closed pipe") {
   498  		return true
   499  	}
   500  	if strings.Contains(str, "broken pipe") {
   501  		return true
   502  	}
   503  	if strings.Contains(str, "connection reset by peer") {
   504  		return true
   505  	}
   506  	return websocket.IsCloseError(
   507  		err, websocket.CloseNormalClosure, websocket.CloseGoingAway,
   508  	)
   509  }