decred.org/dcrdex@v1.0.5/dex/ws/wslink.go (about)

     1  // This code is available on the terms of the project LICENSE.md file,
     2  // also available online at https://blueoakcouncil.org/license/1.0.0.
     3  
     4  package ws
     5  
     6  import (
     7  	"context"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"net/http"
    13  	"runtime/debug"
    14  	"strings"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	"decred.org/dcrdex/dex"
    20  	"decred.org/dcrdex/dex/msgjson"
    21  	"github.com/gorilla/websocket"
    22  )
    23  
    24  const (
    25  	// outBufferSize is the size of the WSLink's buffered channel for outgoing
    26  	// messages.
    27  	outBufferSize    = 128
    28  	defaultReadLimit = 8192
    29  	writeWait        = 5 * time.Second
    30  	// ErrPeerDisconnected will be returned if Send or Request is called on a
    31  	// disconnected link.
    32  	ErrPeerDisconnected = dex.ErrorKind("peer disconnected")
    33  
    34  	ErrHandshake = dex.ErrorKind("handshake error")
    35  )
    36  
    37  // websocket.Upgrader is the preferred method of upgrading a request to a
    38  // websocket connection.
    39  var upgrader = websocket.Upgrader{}
    40  
    41  // Connection represents a websocket connection to a remote peer. In practice,
    42  // it is satisfied by *websocket.Conn. For testing, a stub can be used.
    43  type Connection interface {
    44  	Close() error
    45  
    46  	SetReadDeadline(t time.Time) error
    47  	ReadMessage() (int, []byte, error)
    48  	SetReadLimit(limit int64)
    49  
    50  	SetWriteDeadline(t time.Time) error
    51  	WriteMessage(int, []byte) error
    52  	WriteControl(messageType int, data []byte, deadline time.Time) error
    53  }
    54  
    55  // WSLink is the local, per-connection representation of a DEX peer (client or
    56  // server) connection.
    57  type WSLink struct {
    58  	// log is the WSLink's logger
    59  	log dex.Logger
    60  	// addr is a string representation of the peer's IP address
    61  	addr string
    62  	// conn is the gorilla websocket.Conn, or a stub for testing.
    63  	conn Connection
    64  	// on is used internally to prevent multiple Close calls on the underlying
    65  	// connections.
    66  	on uint32
    67  	// quit is used to cancel the Context.
    68  	quit context.CancelFunc
    69  	// stopped is closed when quit is called.
    70  	stopped chan struct{}
    71  	// outChan is used to sequence sent messages.
    72  	outChan chan *sendData
    73  	// The WSLink has at least 3 goroutines, one for read, one for write, and
    74  	// one server goroutine to monitor for peer disconnection. The WaitGroup is
    75  	// used to synchronize cleanup on disconnection.
    76  	wg sync.WaitGroup
    77  	// A master message handler.
    78  	handler func(*msgjson.Message) *msgjson.Error
    79  	// pingPeriod is how often to ping the peer.
    80  	pingPeriod time.Duration
    81  
    82  	RawHandler func([]byte)
    83  }
    84  
    85  type sendData struct {
    86  	data []byte
    87  	ret  chan<- error
    88  }
    89  
    90  // NewWSLink is a constructor for a new WSLink.
    91  func NewWSLink(addr string, conn Connection, pingPeriod time.Duration, handler func(*msgjson.Message) *msgjson.Error, logger dex.Logger) *WSLink {
    92  	return &WSLink{
    93  		addr:       addr,
    94  		log:        logger,
    95  		conn:       conn,
    96  		outChan:    make(chan *sendData, outBufferSize),
    97  		pingPeriod: pingPeriod,
    98  		handler:    handler,
    99  	}
   100  }
   101  
   102  // Send sends the passed Message to the websocket peer. The actual writing of
   103  // the message on the peer's link occurs asynchronously. As such, a nil error
   104  // only indicates that the link is believed to be up and the message was
   105  // successfully marshalled.
   106  func (c *WSLink) Send(msg *msgjson.Message) error {
   107  	return c.send(msg, nil)
   108  }
   109  
   110  // SendRaw sends the passed bytes to the websocket peer. The actual writing of
   111  // the message on the peer's link occurs asynchronously. As such, a nil error
   112  // only indicates that the link is believed to be up.
   113  func (c *WSLink) SendRaw(b []byte) error {
   114  	if c.Off() {
   115  		return ErrPeerDisconnected
   116  	}
   117  	return c.sendRaw(b, nil)
   118  }
   119  
   120  // SendNow is like send, but it waits for the message to be written on the
   121  // peer's link, returning any error from the write.
   122  func (c *WSLink) SendNow(msg *msgjson.Message) error {
   123  	writeErrChan := make(chan error, 1)
   124  	if err := c.send(msg, writeErrChan); err != nil {
   125  		return err
   126  	}
   127  	return <-writeErrChan
   128  }
   129  
   130  // sendRaw sends raw bytes to a peer. Whether or not the peer is connected
   131  // should be checked before calling.
   132  func (c *WSLink) sendRaw(b []byte, writeErr chan<- error) error {
   133  	// NOTE: Without the stopped chan or access to the Context we are now
   134  	// racing after the c.Off check in the caller.
   135  	select {
   136  	case c.outChan <- &sendData{b, writeErr}:
   137  	case <-c.stopped:
   138  		return ErrPeerDisconnected
   139  	}
   140  
   141  	return nil
   142  }
   143  
   144  func (c *WSLink) send(msg *msgjson.Message, writeErr chan<- error) error {
   145  	if c.Off() {
   146  		return ErrPeerDisconnected
   147  	}
   148  	b, err := json.Marshal(msg)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	return c.sendRaw(b, writeErr)
   154  }
   155  
   156  // SendError sends the msgjson.Error to the peer in a ResponsePayload.
   157  func (c *WSLink) SendError(id uint64, rpcErr *msgjson.Error) {
   158  	msg, err := msgjson.NewResponse(id, nil, rpcErr)
   159  	if err != nil {
   160  		c.log.Errorf("SendError: failed to create error message %q: %v", rpcErr.Message, err)
   161  	}
   162  	err = c.Send(msg)
   163  	if err != nil {
   164  		c.log.Debugf("SendError: failed to send message to peer %s: %v", c.addr, err)
   165  	}
   166  }
   167  
   168  // Connect begins processing input and output messages. Do not send messages
   169  // until connected.
   170  func (c *WSLink) Connect(ctx context.Context) (*sync.WaitGroup, error) {
   171  	// Set the initial read deadline now that the ping ticker is about to be
   172  	// started. The pong handler will set subsequent read deadlines. 2x ping
   173  	// period is a very generous initial pong wait; the readWait provided to
   174  	// NewConnection could be stored and used here (once) instead.
   175  	if !atomic.CompareAndSwapUint32(&c.on, 0, 1) {
   176  		return nil, fmt.Errorf("attempted to Start a running WSLink")
   177  	}
   178  	linkCtx, quit := context.WithCancel(ctx)
   179  
   180  	// Note that there is a brief window where c.on is true but quit and stopped
   181  	// are not set.
   182  	c.quit = quit
   183  	c.stopped = make(chan struct{}) // control signal to block send
   184  	err := c.conn.SetReadDeadline(time.Now().Add(c.pingPeriod * 2))
   185  	if err != nil {
   186  		return nil, fmt.Errorf("failed to set initial read deadline for %v: %w", c.addr, err)
   187  	}
   188  
   189  	c.log.Tracef("Starting websocket messaging with peer %s", c.addr)
   190  	// Start processing input and output.
   191  	c.wg.Add(3)
   192  	go c.inHandler(linkCtx)
   193  	go c.outHandler(linkCtx)
   194  	go c.pingHandler(linkCtx)
   195  	return &c.wg, nil
   196  }
   197  
   198  func (c *WSLink) stop() {
   199  	// Flip the switch into the off position and cancel the context.
   200  	if !atomic.CompareAndSwapUint32(&c.on, 1, 0) {
   201  		return
   202  	}
   203  	// Signal to senders we are done.
   204  	close(c.stopped)
   205  	// Begin shutdown of goroutines, and ultimately connection closure.
   206  	c.quit()
   207  }
   208  
   209  // Done returns a channel that is closed when the link goes down.
   210  func (c *WSLink) Done() <-chan struct{} {
   211  	// Only call Done after connect.
   212  	return c.stopped
   213  }
   214  
   215  // Disconnect begins shutdown of the WSLink, preventing new messages from
   216  // entering the outgoing queue, and ultimately closing the underlying connection
   217  // when all queued messages have been handled. This shutdown process is complete
   218  // when the WaitGroup returned by Connect is Done.
   219  func (c *WSLink) Disconnect() {
   220  	// Cancel the Context and close the stopped channel if not already done.
   221  	c.stop() // false if already disconnected
   222  	// NOTE: outHandler closes the c.conn on its return.
   223  }
   224  
   225  // handleMessage wraps the configured message handler so that it recovers from
   226  // panics and responds to the peer.
   227  func (c *WSLink) handleMessage(msg *msgjson.Message) {
   228  	defer func() {
   229  		if pv := recover(); pv != nil {
   230  			c.log.Criticalf("Uh-oh! Panic while handling message from %v.\n\n"+
   231  				"Message:\n\n%#v\n\nPanic:\n\n%v\n\nStack:\n\n%v\n\n", c.addr, msg, pv, string(debug.Stack()))
   232  			if msg.Type == msgjson.Request {
   233  				c.SendError(msg.ID, msgjson.NewError(msgjson.RPCInternalError, "internal error"))
   234  			}
   235  		}
   236  	}()
   237  
   238  	rpcErr := c.handler(msg)
   239  	if rpcErr != nil {
   240  		// TODO: figure out how to fix this not making sense when the msg is
   241  		// a response, not a request!
   242  		c.SendError(msg.ID, rpcErr)
   243  	}
   244  }
   245  
   246  // inHandler handles all incoming messages for the websocket connection. It must
   247  // be run as a goroutine.
   248  func (c *WSLink) inHandler(ctx context.Context) {
   249  	// Ensure the connection is closed.
   250  	defer c.wg.Done()
   251  	defer c.stop()
   252  out:
   253  	for {
   254  		// Quit when the context is closed.
   255  		if ctx.Err() != nil {
   256  			break out
   257  		}
   258  		// Block until a message is received or an error occurs.
   259  		_, msgBytes, err := c.conn.ReadMessage()
   260  		if err != nil {
   261  			// Only log the error if it is unexpected (not a disconnect).
   262  			if websocket.IsCloseError(err, websocket.CloseGoingAway,
   263  				websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
   264  				break out // clean Close from client
   265  			}
   266  			var opErr *net.OpError
   267  			if errors.As(err, &opErr) && opErr.Op == "read" &&
   268  				(strings.Contains(opErr.Err.Error(), "use of closed network connection") || // we hung up
   269  					strings.Contains(opErr.Err.Error(), "connection reset by peer")) { // they hung up
   270  				break out
   271  			}
   272  
   273  			c.log.Errorf("Websocket receive error from peer %s: %v (%T)", c.addr, err, err)
   274  			break out
   275  		}
   276  
   277  		if c.RawHandler != nil {
   278  			c.RawHandler(msgBytes)
   279  			continue
   280  		}
   281  
   282  		// Attempt to unmarshal the request. Only requests that successfully decode
   283  		// will be accepted by the server, though failure to decode does not force
   284  		// a disconnect.
   285  		msg := new(msgjson.Message)
   286  		err = json.Unmarshal(msgBytes, msg)
   287  		if err != nil {
   288  			c.SendError(1, msgjson.NewError(msgjson.RPCParseError, "failed to parse message"))
   289  			continue
   290  		}
   291  		if (msg.Type == msgjson.Request || msg.Type == msgjson.Response) && msg.ID == 0 { // also covers msgBytes []byte("null")
   292  			c.SendError(1, msgjson.NewError(msgjson.RPCParseError, "request and response ids cannot be zero"))
   293  			continue
   294  		}
   295  		c.handleMessage(msg)
   296  	}
   297  }
   298  
   299  func (c *WSLink) outHandler(ctx context.Context) {
   300  	// Ensure the connection is closed.
   301  	defer c.wg.Done()
   302  	defer c.conn.Close() // close the Conn
   303  	var writeFailed bool
   304  	defer func() {
   305  		// Unless we are returning because of a write error, try to send a Close
   306  		// control message before closing the connection.
   307  		if writeFailed {
   308  			c.log.Debugf("Connection already dead. Not sending Close control message.")
   309  			return
   310  		}
   311  		_ = c.conn.WriteControl(websocket.CloseMessage,
   312  			websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye"),
   313  			time.Now().Add(time.Second))
   314  	}()
   315  	defer c.stop() // in the event of context cancellation vs Disconnect call
   316  
   317  	// Synchronize access to the output queue and the trigger channel.
   318  	var mtx sync.Mutex
   319  	outQueue := make([]*sendData, 0, 128)
   320  	// buffer length 1 since the writer loop triggers itself.
   321  	trigger := make(chan struct{}, 1)
   322  
   323  	// Relay a write error to senders waiting for one.
   324  	relayError := func(errChan chan<- error, err error) {
   325  		if errChan != nil {
   326  			errChan <- err
   327  		}
   328  	}
   329  
   330  	var writeCount, lostCount int
   331  	write := func(sd *sendData) {
   332  		// If the link is shutting down with previous write errors, skip
   333  		// attempting to send and reply to the sender with an error.
   334  		if writeFailed {
   335  			lostCount++
   336  			relayError(sd.ret, errors.New("connection closed"))
   337  			return
   338  		}
   339  		c.conn.SetWriteDeadline(time.Now().Add(writeWait))
   340  		err := c.conn.WriteMessage(websocket.TextMessage, sd.data)
   341  		if err != nil {
   342  			lostCount++
   343  			relayError(sd.ret, err)
   344  			// The connection is now considered dead: No more Sends should queue
   345  			// messages, goroutines should return gracefully, queued messages
   346  			// will error quickly, and shutdown will not try to send a Close
   347  			// control frame.
   348  			writeFailed = true
   349  			c.stop()
   350  			return
   351  		}
   352  		writeCount++
   353  		if sd.ret != nil {
   354  			close(sd.ret)
   355  		}
   356  	}
   357  
   358  	// On shutdown, process any queued senders before closing the connection, if
   359  	// it is still up.
   360  	defer func() {
   361  		// Send any messages in the outQueue or outChan. First drain the
   362  		// buffered channel of data sent prior to stop, but before it could be
   363  		// put in the outQueue.
   364  	out:
   365  		for {
   366  			select {
   367  			case sd := <-c.outChan:
   368  				outQueue = append(outQueue, sd)
   369  			default:
   370  				break out
   371  			}
   372  		}
   373  		// Attempt sending all queued outgoing messages.
   374  		for _, sd := range outQueue {
   375  			write(sd)
   376  		}
   377  		// NOTE: This also addresses a full trigger channel, but their is no
   378  		// need to drain it, just the outQueue so SendNow never hangs.
   379  
   380  		c.log.Tracef("Sent %d and dropped %d messages to %v before shutdown.",
   381  			writeCount, lostCount, c.addr)
   382  	}()
   383  
   384  	// Top of defer stack: before clean-up, wait for writer goroutine
   385  	var wg sync.WaitGroup
   386  	defer wg.Wait()
   387  
   388  	wg.Add(1)
   389  	go func() {
   390  		defer wg.Done()
   391  
   392  		for {
   393  			select {
   394  			case <-ctx.Done():
   395  				return
   396  			case <-trigger:
   397  				mtx.Lock()
   398  				// pop front
   399  				sd := outQueue[0]
   400  				//outQueue[0] = nil // allow realloc w/o this element
   401  				//outQueue = outQueue[1:] // reduces length *and* capacity, but no copy now
   402  				// Or, to reduce or eliminate reallocs at the expense of frequent copies:
   403  				copy(outQueue, outQueue[1:])
   404  				outQueue[len(outQueue)-1] = nil
   405  				outQueue = outQueue[:len(outQueue)-1]
   406  				if len(outQueue) > 0 {
   407  					trigger <- struct{}{}
   408  				}
   409  				// len(outQueue) may be longer when we get back here, but only
   410  				// this loop reduces it.
   411  				mtx.Unlock()
   412  				write(sd)
   413  			}
   414  		}
   415  	}()
   416  
   417  	for {
   418  		select {
   419  		case <-ctx.Done():
   420  			return
   421  		case sd := <-c.outChan:
   422  			mtx.Lock()
   423  			// push back
   424  			initCap := cap(outQueue)
   425  			outQueue = append(outQueue, sd)
   426  			if newCap := cap(outQueue); newCap > initCap {
   427  				c.log.Infof("Outgoing message queue capacity increased from %d to %d for %v.",
   428  					initCap, newCap, c.addr)
   429  				// The capacity 7168 is a heuristic for when the slice shift on
   430  				// the pop front operation starts to become a performance issue.
   431  				// It is also a reasonable queue size limitation to prevent
   432  				// excessive memory use. If there are thousands of queued
   433  				// messages, something is wrong with the client, or the server
   434  				// is spamming excessively.
   435  				if newCap >= 7168 {
   436  					c.log.Warnf("Stopping client %v with outgoing message queue of length %d, capacity %d",
   437  						c.addr, len(outQueue), newCap)
   438  					c.stop()
   439  				}
   440  			}
   441  			// If we just repopulated an empty queue, trigger the writer,
   442  			// otherwise the writer will trigger itself until the queue is
   443  			// empty.
   444  			if len(outQueue) == 1 {
   445  				trigger <- struct{}{}
   446  			} // else, len>1 and writer will self trigger
   447  			mtx.Unlock()
   448  		}
   449  	}
   450  }
   451  
   452  // pingHandler sends periodic pings to the client.
   453  func (c *WSLink) pingHandler(ctx context.Context) {
   454  	defer c.wg.Done()
   455  	ticker := time.NewTicker(c.pingPeriod)
   456  	defer ticker.Stop()
   457  	ping := []byte{}
   458  out:
   459  	for {
   460  		// Send any messages ready for send until the quit channel is
   461  		// closed.
   462  		select {
   463  		case <-ticker.C:
   464  			err := c.conn.WriteControl(websocket.PingMessage, ping, time.Now().Add(writeWait))
   465  			if err != nil {
   466  				c.stop()
   467  				// Don't really care what the error is, but log it at debug level.
   468  				c.log.Debugf("WriteMessage ping error: %v", err)
   469  				break out
   470  			}
   471  		case <-ctx.Done():
   472  			break out
   473  		}
   474  	}
   475  }
   476  
   477  // Off will return true if the link has disconnected.
   478  func (c *WSLink) Off() bool {
   479  	return atomic.LoadUint32(&c.on) == 0
   480  }
   481  
   482  // Addr returns the string-encoded IP address.
   483  func (c *WSLink) Addr() string {
   484  	return c.addr
   485  }
   486  
   487  // SetReadLimit should only be called before starting the Connection, or in a
   488  // request or response handler that is run synchronously with other text or
   489  // binary frame reads (e.g. ReadMessage).
   490  func (c *WSLink) SetReadLimit(limit int64) {
   491  	c.conn.SetReadLimit(limit)
   492  }
   493  
   494  // NewConnection attempts to to upgrade the http connection to a websocket
   495  // Connection. If the upgrade fails, a reply will be sent with an appropriate
   496  // error code.
   497  func NewConnection(w http.ResponseWriter, r *http.Request, readTimeout time.Duration) (Connection, error) {
   498  	ws, err := upgrader.Upgrade(w, r, nil)
   499  	if err != nil {
   500  		var hsErr websocket.HandshakeError
   501  		if errors.As(err, &hsErr) {
   502  			err = dex.NewError(ErrHandshake, hsErr.Error())
   503  			// gorilla already replies with an error in this case.
   504  		} else {
   505  			// No context to add to the error, so do not bother to wrap it, but
   506  			// no response has been sent by the Upgrader.
   507  
   508  			// Other than websocket.HandshakeError, there are only two possible
   509  			// non-nil error conditions: "client sent data before handshake is
   510  			// complete" and a write error with the "HTTP/1.1 101 Switching
   511  			// Protocols" response. In the first case, this is a client error,
   512  			// so we respond with a StatusBadRequest. In the second case, a
   513  			// failed write almost certainly indicates the connection is down.
   514  			http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
   515  		}
   516  
   517  		return nil, err
   518  	}
   519  
   520  	// Configure the pong handler.
   521  	ws.SetPongHandler(func(string) error {
   522  		return ws.SetReadDeadline(time.Now().Add(readTimeout))
   523  	})
   524  
   525  	// Unauthenticated connections have a small read limit.
   526  	ws.SetReadLimit(defaultReadLimit)
   527  
   528  	// Do not set an initial read deadline until pinging begins.
   529  
   530  	return ws, nil
   531  }