github.com/decred/dcrlnd@v0.7.6/lnrpc/websocket_proxy.go (about) 1 // The code in this file is a heavily modified version of 2 // https://github.com/tmc/grpc-websocket-proxy/ 3 4 package lnrpc 5 6 import ( 7 "bufio" 8 "io" 9 "net/http" 10 "net/textproto" 11 "regexp" 12 "strings" 13 "time" 14 15 "github.com/decred/slog" 16 "github.com/gorilla/websocket" 17 "golang.org/x/net/context" 18 ) 19 20 const ( 21 // MethodOverrideParam is the GET query parameter that specifies what 22 // HTTP request method should be used for the forwarded REST request. 23 // This is necessary because the WebSocket API specifies that a 24 // handshake request must always be done through a GET request. 25 MethodOverrideParam = "method" 26 27 // HeaderWebSocketProtocol is the name of the WebSocket protocol 28 // exchange header field that we use to transport additional header 29 // fields. 30 HeaderWebSocketProtocol = "Sec-Websocket-Protocol" 31 32 // WebSocketProtocolDelimiter is the delimiter we use between the 33 // additional header field and its value. We use the plus symbol because 34 // the default delimiters aren't allowed in the protocol names. 35 WebSocketProtocolDelimiter = "+" 36 37 // PingContent is the content of the ping message we send out. This is 38 // an arbitrary non-empty message that has no deeper meaning but should 39 // be sent back by the client in the pong message. 40 PingContent = "are you there?" 41 ) 42 43 var ( 44 // defaultHeadersToForward is a map of all HTTP header fields that are 45 // forwarded by default. The keys must be in the canonical MIME header 46 // format. 47 defaultHeadersToForward = map[string]bool{ 48 "Origin": true, 49 "Referer": true, 50 "Grpc-Metadata-Macaroon": true, 51 } 52 53 // defaultProtocolsToAllow are additional header fields that we allow 54 // to be transported inside of the Sec-Websocket-Protocol field to be 55 // forwarded to the backend. 56 defaultProtocolsToAllow = map[string]bool{ 57 "Grpc-Metadata-Macaroon": true, 58 } 59 60 // DefaultPingInterval is the default number of seconds to wait between 61 // sending ping requests. 62 DefaultPingInterval = time.Second * 30 63 64 // DefaultPongWait is the maximum duration we wait for a pong response 65 // to a ping we sent before we assume the connection died. 66 DefaultPongWait = time.Second * 5 67 ) 68 69 // NewWebSocketProxy attempts to expose the underlying handler as a response- 70 // streaming WebSocket stream with newline-delimited JSON as the content 71 // encoding. If pingInterval is a non-zero duration, a ping message will be 72 // sent out periodically and a pong response message is expected from the 73 // client. The clientStreamingURIs parameter can hold a list of all patterns 74 // for URIs that are mapped to client-streaming RPC methods. We need to keep 75 // track of those to make sure we initialize the request body correctly for the 76 // underlying grpc-gateway library. 77 func NewWebSocketProxy(h http.Handler, logger slog.Logger, 78 pingInterval, pongWait time.Duration, 79 clientStreamingURIs []*regexp.Regexp) http.Handler { 80 81 p := &WebsocketProxy{ 82 backend: h, 83 logger: logger, 84 upgrader: &websocket.Upgrader{ 85 ReadBufferSize: 1024, 86 WriteBufferSize: 1024, 87 CheckOrigin: func(r *http.Request) bool { 88 return true 89 }, 90 }, 91 clientStreamingURIs: clientStreamingURIs, 92 } 93 94 if pingInterval > 0 && pongWait > 0 { 95 p.pingInterval = pingInterval 96 p.pongWait = pongWait 97 } 98 99 return p 100 } 101 102 // WebsocketProxy provides websocket transport upgrade to compatible endpoints. 103 type WebsocketProxy struct { 104 backend http.Handler 105 logger slog.Logger 106 upgrader *websocket.Upgrader 107 108 // clientStreamingURIs holds a list of all patterns for URIs that are 109 // mapped to client-streaming RPC methods. We need to keep track of 110 // those to make sure we initialize the request body correctly for the 111 // underlying grpc-gateway library. 112 clientStreamingURIs []*regexp.Regexp 113 114 pingInterval time.Duration 115 pongWait time.Duration 116 } 117 118 // pingPongEnabled returns true if a ping interval is set to enable sending and 119 // expecting regular ping/pong messages. 120 func (p *WebsocketProxy) pingPongEnabled() bool { 121 return p.pingInterval > 0 && p.pongWait > 0 122 } 123 124 // ServeHTTP handles the incoming HTTP request. If the request is an 125 // "upgradeable" WebSocket request (identified by header fields), then the 126 // WS proxy handles the request. Otherwise the request is passed directly to the 127 // underlying REST proxy. 128 func (p *WebsocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { 129 if !websocket.IsWebSocketUpgrade(r) { 130 p.backend.ServeHTTP(w, r) 131 return 132 } 133 p.upgradeToWebSocketProxy(w, r) 134 } 135 136 // upgradeToWebSocketProxy upgrades the incoming request to a WebSocket, reads 137 // one incoming message then streams all responses until either the client or 138 // server quit the connection. 139 func (p *WebsocketProxy) upgradeToWebSocketProxy(w http.ResponseWriter, 140 r *http.Request) { 141 142 conn, err := p.upgrader.Upgrade(w, r, nil) 143 if err != nil { 144 p.logger.Errorf("error upgrading websocket:", err) 145 return 146 } 147 defer func() { 148 err := conn.Close() 149 if err != nil && !IsClosedConnError(err) { 150 p.logger.Errorf("WS: error closing upgraded conn: %v", 151 err) 152 } 153 }() 154 155 ctx, cancelFn := context.WithCancel(r.Context()) 156 defer cancelFn() 157 158 requestForwarder := newRequestForwardingReader() 159 request, err := http.NewRequestWithContext( 160 ctx, r.Method, r.URL.String(), requestForwarder, 161 ) 162 if err != nil { 163 p.logger.Errorf("WS: error preparing request:", err) 164 return 165 } 166 167 // Allow certain headers to be forwarded, either from source headers 168 // or the special Sec-Websocket-Protocol header field. 169 forwardHeaders(r.Header, request.Header) 170 171 // Also allow the target request method to be overwritten, as all 172 // WebSocket establishment calls MUST be GET requests. 173 if m := r.URL.Query().Get(MethodOverrideParam); m != "" { 174 request.Method = m 175 } 176 177 // Is this a call to a client-streaming RPC method? 178 clientStreaming := false 179 for _, pattern := range p.clientStreamingURIs { 180 if pattern.MatchString(r.URL.Path) { 181 clientStreaming = true 182 } 183 } 184 185 responseForwarder := newResponseForwardingWriter() 186 go func() { 187 <-ctx.Done() 188 responseForwarder.Close() 189 requestForwarder.CloseWriter() 190 }() 191 192 go func() { 193 defer cancelFn() 194 p.backend.ServeHTTP(responseForwarder, request) 195 }() 196 197 // Read loop: Take messages from websocket and write them to the payload 198 // channel. This needs to be its own goroutine because for non-client 199 // streaming RPCs, the requestForwarder.Write() in the second goroutine 200 // will block until the request has fully completed. But for the ping/ 201 // pong handler to work, we need to have an active call to 202 // conn.ReadMessage() going on. So we make sure we have such an active 203 // call by starting a second read as soon as the first one has 204 // completed. 205 payloadChannel := make(chan []byte, 1) 206 go func() { 207 defer cancelFn() 208 defer close(payloadChannel) 209 210 for { 211 select { 212 case <-ctx.Done(): 213 return 214 default: 215 } 216 217 _, payload, err := conn.ReadMessage() 218 if err != nil { 219 if IsClosedConnError(err) { 220 p.logger.Tracef("WS: socket "+ 221 "closed: %v", err) 222 return 223 } 224 p.logger.Errorf("error reading message: %v", 225 err) 226 return 227 } 228 229 select { 230 case payloadChannel <- payload: 231 case <-ctx.Done(): 232 return 233 } 234 } 235 }() 236 237 // Forward loop: Take messages from the incoming payload channel and 238 // write them to the http request. 239 go func() { 240 defer cancelFn() 241 for { 242 var payload []byte 243 select { 244 case <-ctx.Done(): 245 return 246 case newPayload, more := <-payloadChannel: 247 if !more { 248 p.logger.Infof("WS: incoming payload " + 249 "chan closed") 250 return 251 } 252 253 payload = newPayload 254 } 255 256 _, err = requestForwarder.Write(payload) 257 if err != nil { 258 p.logger.Errorf("WS: error writing message "+ 259 "to upstream http server: %v", err) 260 return 261 } 262 _, _ = requestForwarder.Write([]byte{'\n'}) 263 264 // The grpc-gateway library uses a different request 265 // reader depending on whether it is a client streaming 266 // RPC or not. For a non-streaming request we need to 267 // close with EOF to signal the request was completed. 268 if !clientStreaming { 269 requestForwarder.CloseWriter() 270 } 271 } 272 }() 273 274 // Ping write loop: Send a ping message regularly if ping/pong is 275 // enabled. 276 if p.pingPongEnabled() { 277 // We'll send out our first ping in pingInterval. So the initial 278 // deadline is that interval plus the time we allow for a 279 // response to be sent. 280 initialDeadline := time.Now().Add(p.pingInterval + p.pongWait) 281 _ = conn.SetReadDeadline(initialDeadline) 282 283 // Whenever a pong message comes in, we extend the deadline 284 // until the next read is expected by the interval plus pong 285 // wait time. Since we can never _reach_ any of the deadlines, 286 // we also have to advance the deadline for the next expected 287 // write to happen, in case the next thing we actually write is 288 // the next ping. 289 conn.SetPongHandler(func(appData string) error { 290 nextDeadline := time.Now().Add( 291 p.pingInterval + p.pongWait, 292 ) 293 _ = conn.SetReadDeadline(nextDeadline) 294 _ = conn.SetWriteDeadline(nextDeadline) 295 296 return nil 297 }) 298 go func() { 299 ticker := time.NewTicker(p.pingInterval) 300 defer ticker.Stop() 301 302 for { 303 select { 304 case <-ctx.Done(): 305 p.logger.Debug("WS: ping loop done") 306 return 307 308 case <-ticker.C: 309 // Writing the ping shouldn't take any 310 // longer than we'll wait for a response 311 // in the first place. 312 writeDeadline := time.Now().Add( 313 p.pongWait, 314 ) 315 err := conn.WriteControl( 316 websocket.PingMessage, 317 []byte(PingContent), 318 writeDeadline, 319 ) 320 if err != nil { 321 p.logger.Warnf("WS: could not "+ 322 "send ping message: %v", 323 err) 324 return 325 } 326 } 327 } 328 }() 329 } 330 331 // Write loop: Take messages from the response forwarder and write them 332 // to the WebSocket. 333 for responseForwarder.Scan() { 334 if len(responseForwarder.Bytes()) == 0 { 335 p.logger.Errorf("WS: empty scan: %v", 336 responseForwarder.Err()) 337 338 continue 339 } 340 341 err = conn.WriteMessage( 342 websocket.TextMessage, responseForwarder.Bytes(), 343 ) 344 if err != nil { 345 p.logger.Errorf("WS: error writing message: %v", err) 346 return 347 } 348 } 349 if err := responseForwarder.Err(); err != nil && !IsClosedConnError(err) { 350 p.logger.Errorf("WS: scanner err: %v", err) 351 } 352 } 353 354 // forwardHeaders forwards certain allowed header fields from the source request 355 // to the target request. Because browsers are limited in what header fields 356 // they can send on the WebSocket setup call, we also allow additional fields to 357 // be transported in the special Sec-Websocket-Protocol field. 358 func forwardHeaders(source, target http.Header) { 359 // Forward allowed header fields directly. 360 for header := range source { 361 headerName := textproto.CanonicalMIMEHeaderKey(header) 362 forward, ok := defaultHeadersToForward[headerName] 363 if ok && forward { 364 target.Set(headerName, source.Get(header)) 365 } 366 } 367 368 // Browser aren't allowed to set custom header fields on WebSocket 369 // requests. We need to allow them to submit the macaroon as a WS 370 // protocol, which is the only allowed header. Set any "protocols" we 371 // declare valid as header fields on the forwarded request. 372 protocol := source.Get(HeaderWebSocketProtocol) 373 for key := range defaultProtocolsToAllow { 374 if strings.HasPrefix(protocol, key) { 375 // The format is "<protocol name>+<value>". We know the 376 // protocol string starts with the name so we only need 377 // to set the value. 378 values := strings.Split( 379 protocol, WebSocketProtocolDelimiter, 380 ) 381 target.Set(key, values[1]) 382 } 383 } 384 } 385 386 // newRequestForwardingReader creates a new request forwarding pipe. 387 func newRequestForwardingReader() *requestForwardingReader { 388 r, w := io.Pipe() 389 return &requestForwardingReader{ 390 Reader: r, 391 Writer: w, 392 pipeR: r, 393 pipeW: w, 394 } 395 } 396 397 // requestForwardingReader is a wrapper around io.Pipe that embeds both the 398 // io.Reader and io.Writer interface and can be closed. 399 type requestForwardingReader struct { 400 io.Reader 401 io.Writer 402 403 pipeR *io.PipeReader 404 pipeW *io.PipeWriter 405 } 406 407 // CloseWriter closes the underlying pipe writer. 408 func (r *requestForwardingReader) CloseWriter() { 409 _ = r.pipeW.CloseWithError(io.EOF) 410 } 411 412 // newResponseForwardingWriter creates a new http.ResponseWriter that intercepts 413 // what's written to it and presents it through a bufio.Scanner interface. 414 func newResponseForwardingWriter() *responseForwardingWriter { 415 r, w := io.Pipe() 416 return &responseForwardingWriter{ 417 Writer: w, 418 Scanner: bufio.NewScanner(r), 419 pipeR: r, 420 pipeW: w, 421 header: http.Header{}, 422 closed: make(chan bool, 1), 423 } 424 } 425 426 // responseForwardingWriter is a type that implements the http.ResponseWriter 427 // interface but internally forwards what's written to the writer through a pipe 428 // so it can easily be read again through the bufio.Scanner interface. 429 type responseForwardingWriter struct { 430 io.Writer 431 *bufio.Scanner 432 433 pipeR *io.PipeReader 434 pipeW *io.PipeWriter 435 436 header http.Header 437 code int 438 closed chan bool 439 } 440 441 // Write writes the given bytes to the internal pipe. 442 // 443 // NOTE: This is part of the http.ResponseWriter interface. 444 func (w *responseForwardingWriter) Write(b []byte) (int, error) { 445 return w.Writer.Write(b) 446 } 447 448 // Header returns the HTTP header fields intercepted so far. 449 // 450 // NOTE: This is part of the http.ResponseWriter interface. 451 func (w *responseForwardingWriter) Header() http.Header { 452 return w.header 453 } 454 455 // WriteHeader indicates that the header part of the response is now finished 456 // and sets the response code. 457 // 458 // NOTE: This is part of the http.ResponseWriter interface. 459 func (w *responseForwardingWriter) WriteHeader(code int) { 460 w.code = code 461 } 462 463 // CloseNotify returns a channel that indicates if a connection was closed. 464 // 465 // NOTE: This is part of the http.CloseNotifier interface. 466 func (w *responseForwardingWriter) CloseNotify() <-chan bool { 467 return w.closed 468 } 469 470 // Flush empties all buffers. We implement this to indicate to our backend that 471 // we support flushing our content. There is no actual implementation because 472 // all writes happen immediately, there is no internal buffering. 473 // 474 // NOTE: This is part of the http.Flusher interface. 475 func (w *responseForwardingWriter) Flush() {} 476 477 func (w *responseForwardingWriter) Close() { 478 _ = w.pipeR.CloseWithError(io.EOF) 479 _ = w.pipeW.CloseWithError(io.EOF) 480 w.closed <- true 481 } 482 483 // IsClosedConnError is a helper function that returns true if the given error 484 // is an error indicating we are using a closed connection. 485 func IsClosedConnError(err error) bool { 486 if err == nil { 487 return false 488 } 489 if err == http.ErrServerClosed { 490 return true 491 } 492 493 str := err.Error() 494 if strings.Contains(str, "use of closed network connection") { 495 return true 496 } 497 if strings.Contains(str, "closed pipe") { 498 return true 499 } 500 if strings.Contains(str, "broken pipe") { 501 return true 502 } 503 if strings.Contains(str, "connection reset by peer") { 504 return true 505 } 506 return websocket.IsCloseError( 507 err, websocket.CloseNormalClosure, websocket.CloseGoingAway, 508 ) 509 }