decred.org/dcrdex@v1.0.5/client/comms/wsconn.go (about)

     1  // This code is available on the terms of the project LICENSE.md file,
     2  // also available online at https://blueoakcouncil.org/license/1.0.0.
     3  
     4  package comms
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"net"
    14  	"net/http"
    15  	"net/url"
    16  	"regexp"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"decred.org/dcrdex/dex"
    23  	"decred.org/dcrdex/dex/msgjson"
    24  	"github.com/gorilla/websocket"
    25  )
    26  
    27  const (
    28  	// bufferSize is buffer size for a websocket connection's read channel.
    29  	readBuffSize = 128
    30  
    31  	// The maximum time in seconds to write to a connection.
    32  	writeWait = time.Second * 3
    33  
    34  	// reconnectInterval is the initial and increment between reconnect tries.
    35  	reconnectInterval = 5 * time.Second
    36  
    37  	// maxReconnectInterval is the maximum allowed reconnect interval.
    38  	maxReconnectInterval = time.Minute
    39  
    40  	// DefaultResponseTimeout is the default timeout for responses after a
    41  	// request is successfully sent.
    42  	DefaultResponseTimeout = time.Minute
    43  )
    44  
    45  // ConnectionStatus represents the current status of the websocket connection.
    46  type ConnectionStatus uint32
    47  
    48  const (
    49  	Disconnected ConnectionStatus = iota
    50  	Connected
    51  	InvalidCert
    52  )
    53  
    54  // String gives a human readable string for each connection status.
    55  func (cs ConnectionStatus) String() string {
    56  	switch cs {
    57  	case Disconnected:
    58  		return "disconnected"
    59  	case Connected:
    60  		return "connected"
    61  	case InvalidCert:
    62  		return "invalid certificate"
    63  	default:
    64  		return "unknown status"
    65  	}
    66  }
    67  
    68  // invalidCertRegexp is a regexp that helps check for non-typed x509 errors
    69  // caused by or related to an invalid cert.
    70  var invalidCertRegexp = regexp.MustCompile(".*(unknown authority|not standards compliant|not trusted)")
    71  
    72  // isErrorInvalidCert checks if the provided error is one of the different
    73  // variant of an invalid cert error returned from the x509 package.
    74  func isErrorInvalidCert(err error) bool {
    75  	var invalidCertErr x509.CertificateInvalidError
    76  	var unknownCertAuthErr x509.UnknownAuthorityError
    77  	var hostNameErr x509.HostnameError
    78  	return errors.As(err, &invalidCertErr) || errors.As(err, &hostNameErr) ||
    79  		errors.As(err, &unknownCertAuthErr) || invalidCertRegexp.MatchString(err.Error())
    80  }
    81  
    82  // ErrInvalidCert is the error returned when attempting to use an invalid cert
    83  // to set up a ws connection.
    84  var ErrInvalidCert = fmt.Errorf("invalid certificate")
    85  
    86  // ErrCertRequired is the error returned when a ws connection fails because no
    87  // cert was provided.
    88  var ErrCertRequired = fmt.Errorf("certificate required")
    89  
    90  // WsConn is an interface for a websocket client.
    91  type WsConn interface {
    92  	NextID() uint64
    93  	IsDown() bool
    94  	Send(msg *msgjson.Message) error
    95  	SendRaw(b []byte) error
    96  	Request(msg *msgjson.Message, respHandler func(*msgjson.Message)) error
    97  	RequestRaw(msgID uint64, rawMsg []byte, respHandler func(*msgjson.Message)) error
    98  	RequestWithTimeout(msg *msgjson.Message, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) error
    99  	Connect(ctx context.Context) (*sync.WaitGroup, error)
   100  	MessageSource() <-chan *msgjson.Message
   101  	UpdateURL(string)
   102  }
   103  
   104  // When the DEX sends a request to the client, a responseHandler is created
   105  // to wait for the response.
   106  type responseHandler struct {
   107  	expiration *time.Timer
   108  	f          func(*msgjson.Message)
   109  	abort      func() // only to be run at most once, and not if f ran
   110  }
   111  
   112  // WsCfg is the configuration struct for initializing a WsConn.
   113  type WsCfg struct {
   114  	// URL is the websocket endpoint URL.
   115  	URL string
   116  
   117  	// The maximum time in seconds to wait for a ping from the server. This
   118  	// should be larger than the server's ping interval to allow for network
   119  	// latency.
   120  	PingWait time.Duration
   121  
   122  	// The server's certificate.
   123  	Cert []byte
   124  
   125  	// ReconnectSync runs the needed reconnection synchronization after
   126  	// a reconnect.
   127  	ReconnectSync func()
   128  
   129  	// ConnectEventFunc runs whenever connection status changes.
   130  	//
   131  	// NOTE: Disconnect event notifications may lag behind actual
   132  	// disconnections.
   133  	ConnectEventFunc func(ConnectionStatus)
   134  
   135  	// Logger is the logger for the WsConn.
   136  	Logger dex.Logger
   137  
   138  	// NetDialContext specifies an optional dialer context to use.
   139  	NetDialContext func(context.Context, string, string) (net.Conn, error)
   140  
   141  	// RawHandler overrides the msgjson parsing and forwards all messages to
   142  	// the provided function.
   143  	RawHandler func([]byte)
   144  
   145  	// DisableAutoReconnect disables automatic reconnection.
   146  	DisableAutoReconnect bool
   147  
   148  	ConnectHeaders http.Header
   149  
   150  	// EchoPingData will echo any data from pings as the pong data.
   151  	EchoPingData bool
   152  }
   153  
   154  // wsConn represents a client websocket connection.
   155  type wsConn struct {
   156  	// 64-bit atomic variables first. See
   157  	// https://golang.org/pkg/sync/atomic/#pkg-note-BUG.
   158  	rID    uint64
   159  	cancel context.CancelFunc
   160  	wg     sync.WaitGroup
   161  	log    dex.Logger
   162  	cfg    *WsCfg
   163  	tlsCfg *tls.Config
   164  	readCh chan *msgjson.Message
   165  	urlV   atomic.Value // string
   166  
   167  	wsMtx sync.Mutex
   168  	ws    *websocket.Conn
   169  
   170  	connectionStatus uint32 // atomic
   171  
   172  	reqMtx       sync.RWMutex
   173  	respHandlers map[uint64]*responseHandler
   174  
   175  	reconnectCh chan struct{} // trigger for immediate reconnect
   176  }
   177  
   178  var _ WsConn = (*wsConn)(nil)
   179  
   180  // NewWsConn creates a client websocket connection.
   181  func NewWsConn(cfg *WsCfg) (WsConn, error) {
   182  	if cfg.PingWait < 0 {
   183  		return nil, fmt.Errorf("ping wait cannot be negative")
   184  	}
   185  
   186  	uri, err := url.Parse(cfg.URL)
   187  	if err != nil {
   188  		return nil, fmt.Errorf("error parsing URL: %w", err)
   189  	}
   190  
   191  	rootCAs, _ := x509.SystemCertPool()
   192  	if rootCAs == nil {
   193  		rootCAs = x509.NewCertPool()
   194  	}
   195  
   196  	if len(cfg.Cert) > 0 {
   197  		if ok := rootCAs.AppendCertsFromPEM(cfg.Cert); !ok {
   198  			return nil, ErrInvalidCert
   199  		}
   200  	}
   201  
   202  	tlsConfig := &tls.Config{
   203  		RootCAs:    rootCAs,
   204  		MinVersion: tls.VersionTLS12,
   205  		ServerName: uri.Hostname(),
   206  	}
   207  
   208  	conn := &wsConn{
   209  		cfg:          cfg,
   210  		log:          cfg.Logger,
   211  		tlsCfg:       tlsConfig,
   212  		readCh:       make(chan *msgjson.Message, readBuffSize),
   213  		respHandlers: make(map[uint64]*responseHandler),
   214  		reconnectCh:  make(chan struct{}, 1),
   215  	}
   216  	conn.urlV.Store(cfg.URL)
   217  
   218  	return conn, nil
   219  }
   220  
   221  func (conn *wsConn) UpdateURL(uri string) {
   222  	conn.urlV.Store(uri)
   223  }
   224  
   225  func (conn *wsConn) url() string {
   226  	return conn.urlV.Load().(string)
   227  }
   228  
   229  // IsDown indicates if the connection is known to be down.
   230  func (conn *wsConn) IsDown() bool {
   231  	return atomic.LoadUint32(&conn.connectionStatus) != uint32(Connected)
   232  }
   233  
   234  // setConnectionStatus updates the connection's status and runs the
   235  // ConnectEventFunc in case of a change.
   236  func (conn *wsConn) setConnectionStatus(status ConnectionStatus) {
   237  	oldStatus := atomic.SwapUint32(&conn.connectionStatus, uint32(status))
   238  	statusChange := oldStatus != uint32(status)
   239  	if statusChange && conn.cfg.ConnectEventFunc != nil {
   240  		conn.cfg.ConnectEventFunc(status)
   241  	}
   242  }
   243  
   244  // connect attempts to establish a websocket connection.
   245  func (conn *wsConn) connect(ctx context.Context) error {
   246  	dialer := &websocket.Dialer{
   247  		HandshakeTimeout: DefaultResponseTimeout,
   248  		TLSClientConfig:  conn.tlsCfg,
   249  	}
   250  	if conn.cfg.NetDialContext != nil {
   251  		dialer.NetDialContext = conn.cfg.NetDialContext
   252  	} else {
   253  		dialer.Proxy = http.ProxyFromEnvironment
   254  	}
   255  
   256  	ws, _, err := dialer.DialContext(ctx, conn.url(), conn.cfg.ConnectHeaders)
   257  	if err != nil {
   258  		if isErrorInvalidCert(err) {
   259  			conn.setConnectionStatus(InvalidCert)
   260  			if len(conn.cfg.Cert) == 0 {
   261  				return dex.NewError(ErrCertRequired, err.Error())
   262  			}
   263  			return dex.NewError(ErrInvalidCert, err.Error())
   264  		}
   265  		conn.setConnectionStatus(Disconnected)
   266  		return err
   267  	}
   268  
   269  	// Set the initial read deadline for the first ping. Subsequent read
   270  	// deadlines are set in the ping handler.
   271  	err = ws.SetReadDeadline(time.Now().Add(conn.cfg.PingWait))
   272  	if err != nil {
   273  		conn.log.Errorf("set read deadline failed: %v", err)
   274  		return err
   275  	}
   276  
   277  	echoPing := conn.cfg.EchoPingData
   278  
   279  	ws.SetPingHandler(func(appData string) error {
   280  		now := time.Now()
   281  
   282  		// Set the deadline for the next ping.
   283  		err := ws.SetReadDeadline(now.Add(conn.cfg.PingWait))
   284  		if err != nil {
   285  			conn.log.Errorf("set read deadline failed: %v", err)
   286  			return err
   287  		}
   288  
   289  		var data []byte
   290  		if echoPing {
   291  			data = []byte(appData)
   292  		}
   293  
   294  		// Respond with a pong.
   295  		err = ws.WriteControl(websocket.PongMessage, data, now.Add(writeWait))
   296  		if err != nil {
   297  			// read loop handles reconnect
   298  			conn.log.Errorf("pong write error: %v", err)
   299  			return err
   300  		}
   301  
   302  		return nil
   303  	})
   304  
   305  	conn.wsMtx.Lock()
   306  	// If keepAlive called connect, the wsConn's current websocket.Conn may need
   307  	// to be closed depending on the error that triggered the reconnect.
   308  	if conn.ws != nil {
   309  		conn.close()
   310  	}
   311  	conn.ws = ws
   312  	conn.wsMtx.Unlock()
   313  
   314  	conn.setConnectionStatus(Connected)
   315  	conn.wg.Add(1)
   316  	go func() {
   317  		defer conn.wg.Done()
   318  		if conn.cfg.RawHandler != nil {
   319  			conn.readRaw(ctx)
   320  		} else {
   321  			conn.read(ctx)
   322  		}
   323  	}()
   324  
   325  	return nil
   326  }
   327  
   328  func (conn *wsConn) SetReadLimit(limit int64) {
   329  	conn.wsMtx.Lock()
   330  	ws := conn.ws
   331  	conn.wsMtx.Unlock()
   332  	if ws != nil {
   333  		ws.SetReadLimit(limit)
   334  	}
   335  }
   336  
   337  func (conn *wsConn) handleReadError(err error) {
   338  	reconnect := func() {
   339  		conn.setConnectionStatus(Disconnected)
   340  		if !conn.cfg.DisableAutoReconnect {
   341  			conn.reconnectCh <- struct{}{}
   342  		}
   343  	}
   344  
   345  	var netErr net.Error
   346  	if errors.As(err, &netErr) && netErr.Timeout() {
   347  		conn.log.Errorf("Read timeout on connection to %s.", conn.url())
   348  		reconnect()
   349  		return
   350  	}
   351  	// TODO: Now that wsConn goroutines have contexts that are canceled
   352  	// on shutdown, we do not have to infer the source and severity of
   353  	// the error; just reconnect in ALL other cases, and remove the
   354  	// following legacy checks.
   355  
   356  	// Expected close errors (1000 and 1001) ... but if the server
   357  	// closes we still want to reconnect. (???)
   358  	if websocket.IsCloseError(err, websocket.CloseGoingAway,
   359  		websocket.CloseNormalClosure) ||
   360  		strings.Contains(err.Error(), "websocket: close sent") {
   361  		reconnect()
   362  		return
   363  	}
   364  
   365  	var opErr *net.OpError
   366  	if errors.As(err, &opErr) && opErr.Op == "read" {
   367  		if strings.Contains(opErr.Err.Error(), "use of closed network connection") {
   368  			conn.log.Errorf("read quitting: %v", err)
   369  			reconnect()
   370  			return
   371  		}
   372  	}
   373  
   374  	// Log all other errors and trigger a reconnection.
   375  	conn.log.Errorf("read error (%v), attempting reconnection", err)
   376  	reconnect()
   377  }
   378  
   379  func (conn *wsConn) close() {
   380  	// Attempt to send a close message in case the connection is still live.
   381  	msg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "bye")
   382  	_ = conn.ws.WriteControl(websocket.CloseMessage, msg,
   383  		time.Now().Add(50*time.Millisecond)) // ignore any error
   384  	// Forcibly close the underlying connection.
   385  	conn.ws.Close()
   386  }
   387  
   388  func (conn *wsConn) readRaw(ctx context.Context) {
   389  	for {
   390  		// Lock since conn.ws may be set by connect.
   391  		conn.wsMtx.Lock()
   392  		ws := conn.ws
   393  		conn.wsMtx.Unlock()
   394  
   395  		// Block until a message is received or an error occurs.
   396  		_, msgBytes, err := ws.ReadMessage()
   397  		// Drop the read error on context cancellation.
   398  		if ctx.Err() != nil {
   399  			return
   400  		}
   401  		if err != nil {
   402  			conn.handleReadError(err)
   403  			return
   404  		}
   405  		conn.cfg.RawHandler(msgBytes)
   406  	}
   407  }
   408  
   409  // read fetches and parses incoming messages for processing. This should be
   410  // run as a goroutine. Increment the wg before calling read.
   411  func (conn *wsConn) read(ctx context.Context) {
   412  	for {
   413  		msg := new(msgjson.Message)
   414  
   415  		// Lock since conn.ws may be set by connect.
   416  		conn.wsMtx.Lock()
   417  		ws := conn.ws
   418  		conn.wsMtx.Unlock()
   419  
   420  		// The read itself does not require locking since only this goroutine
   421  		// uses read functions that are not safe for concurrent use.
   422  		err := ws.ReadJSON(msg)
   423  		// Drop the read error on context cancellation.
   424  		if ctx.Err() != nil {
   425  			return
   426  		}
   427  		if err != nil {
   428  			var mErr *json.UnmarshalTypeError
   429  			if errors.As(err, &mErr) {
   430  				// JSON decode errors are not fatal, log and proceed.
   431  				conn.log.Errorf("json decode error: %v", mErr)
   432  				continue
   433  			}
   434  			conn.handleReadError(err)
   435  			return
   436  		}
   437  
   438  		// If the message is a response, find the handler.
   439  		if msg.Type == msgjson.Response {
   440  			handler := conn.respHandler(msg.ID)
   441  			if handler == nil {
   442  				b, _ := json.Marshal(msg)
   443  				conn.log.Errorf("No handler found for response: %v", string(b))
   444  				continue
   445  			}
   446  			// Run handlers in a goroutine so that other messages can be
   447  			// received. Include the handler goroutines in the WaitGroup to
   448  			// allow them to complete if the connection master desires.
   449  			conn.wg.Add(1)
   450  			go func() {
   451  				defer conn.wg.Done()
   452  				handler.f(msg)
   453  			}()
   454  			continue
   455  		}
   456  		conn.readCh <- msg
   457  	}
   458  }
   459  
   460  // keepAlive maintains an active websocket connection by reconnecting when
   461  // the established connection is broken. This should be run as a goroutine.
   462  func (conn *wsConn) keepAlive(ctx context.Context) {
   463  	rcInt := reconnectInterval
   464  	for {
   465  		select {
   466  		case <-conn.reconnectCh:
   467  			// Prioritize context cancellation even if there are reconnect
   468  			// requests.
   469  			if ctx.Err() != nil {
   470  				return
   471  			}
   472  
   473  			conn.log.Infof("Attempting to reconnect to %s...", conn.url())
   474  			err := conn.connect(ctx)
   475  			if err != nil {
   476  				conn.log.Errorf("Reconnect failed. Scheduling reconnect to %s in %.1f seconds.",
   477  					conn.url(), rcInt.Seconds())
   478  				time.AfterFunc(rcInt, func() {
   479  					conn.reconnectCh <- struct{}{}
   480  				})
   481  				// Increment the wait up to PingWait.
   482  				if rcInt < maxReconnectInterval {
   483  					rcInt += reconnectInterval
   484  				}
   485  				continue
   486  			}
   487  
   488  			conn.log.Info("Successfully reconnected.")
   489  			rcInt = reconnectInterval
   490  
   491  			// Synchronize after a reconnection.
   492  			if conn.cfg.ReconnectSync != nil {
   493  				conn.cfg.ReconnectSync()
   494  			}
   495  
   496  		case <-ctx.Done():
   497  			return
   498  		}
   499  	}
   500  }
   501  
   502  // NextID returns the next request id.
   503  func (conn *wsConn) NextID() uint64 {
   504  	return atomic.AddUint64(&conn.rID, 1)
   505  }
   506  
   507  // Connect connects the client. Any error encountered during the initial
   508  // connection will be returned. An auto-(re)connect goroutine will be started,
   509  // even on error. To terminate it, use Stop() or cancel the context.
   510  func (conn *wsConn) Connect(ctx context.Context) (*sync.WaitGroup, error) {
   511  	var ctxInternal context.Context
   512  	ctxInternal, conn.cancel = context.WithCancel(ctx)
   513  
   514  	err := conn.connect(ctxInternal)
   515  	if err != nil {
   516  		// If the certificate is invalid or missing, do not start the reconnect
   517  		// loop, and return an error with no WaitGroup.
   518  		if conn.cfg.DisableAutoReconnect || errors.Is(err, ErrInvalidCert) || errors.Is(err, ErrCertRequired) {
   519  			conn.cancel()
   520  			conn.wg.Wait() // probably a no-op
   521  			close(conn.readCh)
   522  			return nil, err
   523  		}
   524  
   525  		// The read loop would normally trigger keepAlive, but it wasn't started
   526  		// on account of a connect error.
   527  		conn.log.Errorf("Initial connection failed, starting reconnect loop: %v", err)
   528  		time.AfterFunc(5*time.Second, func() {
   529  			conn.reconnectCh <- struct{}{}
   530  		})
   531  	}
   532  
   533  	if !conn.cfg.DisableAutoReconnect {
   534  		conn.wg.Add(1)
   535  		go func() {
   536  			defer conn.wg.Done()
   537  			conn.keepAlive(ctxInternal)
   538  		}()
   539  	}
   540  
   541  	conn.wg.Add(1)
   542  	go func() {
   543  		defer conn.wg.Done()
   544  		<-ctxInternal.Done()
   545  		conn.setConnectionStatus(Disconnected)
   546  		conn.wsMtx.Lock()
   547  		if conn.ws != nil {
   548  			conn.log.Debug("Sending close 1000 (normal) message.")
   549  			conn.close()
   550  		}
   551  		conn.wsMtx.Unlock()
   552  
   553  		// Run the expire funcs so request callers don't hang.
   554  		conn.reqMtx.Lock()
   555  		defer conn.reqMtx.Unlock()
   556  		for id, h := range conn.respHandlers {
   557  			delete(conn.respHandlers, id)
   558  			// Since we are holding reqMtx and deleting the handler, no need to
   559  			// check if expiration fired (see logReq), but good to stop it.
   560  			h.expiration.Stop()
   561  			h.abort()
   562  		}
   563  
   564  		close(conn.readCh) // signal to MessageSource receivers that the wsConn is dead
   565  	}()
   566  
   567  	return &conn.wg, nil
   568  }
   569  
   570  // Stop can be used to close the connection and all of the goroutines started by
   571  // Connect. Alternatively, the context passed to Connect may be canceled.
   572  func (conn *wsConn) Stop() {
   573  	conn.cancel()
   574  }
   575  
   576  // Send pushes outgoing messages over the websocket connection. Sending of the
   577  // message is synchronous, so a nil error guarantees that the message was
   578  // successfully sent. A non-nil error may indicate that the connection is known
   579  // to be down, the message failed to marshall to JSON, or writing to the
   580  // websocket link failed.
   581  func (conn *wsConn) Send(msg *msgjson.Message) error {
   582  	if conn.IsDown() {
   583  		return fmt.Errorf("cannot send on a broken connection")
   584  	}
   585  
   586  	// Marshal the Message first so that we don't send junk to the peer even if
   587  	// it fails to marshal completely, which gorilla/websocket.WriteJSON does.
   588  	b, err := json.Marshal(msg)
   589  	if err != nil {
   590  		conn.log.Errorf("Failed to marshal message: %v", err)
   591  		return err
   592  	}
   593  	return conn.SendRaw(b)
   594  }
   595  
   596  // SendRaw sends a raw byte string over the websocket connection.
   597  func (conn *wsConn) SendRaw(b []byte) error {
   598  	if conn.IsDown() {
   599  		return fmt.Errorf("cannot send on a broken connection")
   600  	}
   601  
   602  	conn.wsMtx.Lock()
   603  	defer conn.wsMtx.Unlock()
   604  	err := conn.ws.SetWriteDeadline(time.Now().Add(writeWait))
   605  	if err != nil {
   606  		conn.log.Errorf("Send: failed to set write deadline: %v", err)
   607  		return err
   608  	}
   609  
   610  	err = conn.ws.WriteMessage(websocket.TextMessage, b)
   611  	if err != nil {
   612  		conn.log.Errorf("Send: WriteMessage error: %v", err)
   613  		return err
   614  	}
   615  	return nil
   616  }
   617  
   618  // Request sends the Request-type msgjson.Message to the server and does not
   619  // wait for a response, but records a callback function to run when a response
   620  // is received. A response must be received within DefaultResponseTimeout of the
   621  // request, after which the response handler expires and any late response will
   622  // be ignored. To handle expiration or to set the timeout duration, use
   623  // RequestWithTimeout. Sending of the request is synchronous, so a nil error
   624  // guarantees that the request message was successfully sent.
   625  func (conn *wsConn) Request(msg *msgjson.Message, f func(*msgjson.Message)) error {
   626  	return conn.RequestWithTimeout(msg, f, DefaultResponseTimeout, func() {})
   627  }
   628  
   629  func (conn *wsConn) RequestRaw(msgID uint64, rawMsg []byte, f func(*msgjson.Message)) error {
   630  	return conn.RequestRawWithTimeout(msgID, rawMsg, f, DefaultResponseTimeout, func() {})
   631  }
   632  
   633  // RequestWithTimeout sends the Request-type message and does not wait for a
   634  // response, but records a callback function to run when a response is received.
   635  // If the server responds within expireTime of the request, the response handler
   636  // is called, otherwise the expire function is called. If the response handler
   637  // is called, it is guaranteed that the response Message.ID is equal to the
   638  // request Message.ID. Sending of the request is synchronous, so a nil error
   639  // guarantees that the request message was successfully sent and that either the
   640  // response handler or expire function will be run; a non-nil error guarantees
   641  // that neither function will run.
   642  //
   643  // For example, to wait on a response or timeout:
   644  //
   645  //	errChan := make(chan error, 1)
   646  //
   647  //	err := conn.RequestWithTimeout(reqMsg, func(msg *msgjson.Message) {
   648  //	    errChan <- msg.UnmarshalResult(responseStructPointer)
   649  //	}, timeout, func() {
   650  //	    errChan <- fmt.Errorf("timed out waiting for '%s' response.", route)
   651  //	})
   652  //	if err != nil {
   653  //	    return err // request error
   654  //	}
   655  //	return <-errChan // timeout or response error
   656  func (conn *wsConn) RequestWithTimeout(msg *msgjson.Message, f func(*msgjson.Message), expireTime time.Duration, expire func()) error {
   657  	if msg.Type != msgjson.Request {
   658  		return fmt.Errorf("Message is not a request: %v", msg.Type)
   659  	}
   660  	rawMsg, err := json.Marshal(msg)
   661  	if err != nil {
   662  		conn.log.Errorf("Failed to marshal message: %v", err)
   663  		return err
   664  	}
   665  	err = conn.RequestRawWithTimeout(msg.ID, rawMsg, f, expireTime, expire)
   666  	if err != nil {
   667  		conn.log.Errorf("(*wsConn).Request(route '%s') Send error (%v), unregistering msg ID %d handler",
   668  			msg.Route, err, msg.ID)
   669  	}
   670  	return err
   671  }
   672  
   673  func (conn *wsConn) RequestRawWithTimeout(msgID uint64, rawMsg []byte, f func(*msgjson.Message), expireTime time.Duration, expire func()) error {
   674  
   675  	// Register the response and expire handlers for this request.
   676  	conn.logReq(msgID, f, expireTime, expire)
   677  	err := conn.SendRaw(rawMsg)
   678  	if err != nil {
   679  		// Neither expire nor the handler should run. Stop the expire timer
   680  		// created by logReq and delete the response handler it added. The
   681  		// caller receives a non-nil error to deal with it.
   682  		conn.respHandler(msgID) // drop the responseHandler logged by logReq that is no longer necessary
   683  	}
   684  	return err
   685  }
   686  
   687  func (conn *wsConn) expire(id uint64) bool {
   688  	conn.reqMtx.Lock()
   689  	defer conn.reqMtx.Unlock()
   690  	_, removed := conn.respHandlers[id]
   691  	delete(conn.respHandlers, id)
   692  	return removed
   693  }
   694  
   695  // logReq stores the response handler in the respHandlers map. Requests to the
   696  // client are associated with a response handler.
   697  func (conn *wsConn) logReq(id uint64, respHandler func(*msgjson.Message), expireTime time.Duration, expire func()) {
   698  	conn.reqMtx.Lock()
   699  	defer conn.reqMtx.Unlock()
   700  	doExpire := func() {
   701  		// Delete the response handler, and call the provided expire function if
   702  		// (*wsLink).respHandler has not already retrieved the handler function
   703  		// for execution.
   704  		if conn.expire(id) {
   705  			expire()
   706  		}
   707  	}
   708  	conn.respHandlers[id] = &responseHandler{
   709  		expiration: time.AfterFunc(expireTime, doExpire),
   710  		f:          respHandler,
   711  		abort:      expire,
   712  	}
   713  }
   714  
   715  // respHandler extracts the response handler for the provided request ID if it
   716  // exists, else nil. If the handler exists, it will be deleted from the map.
   717  func (conn *wsConn) respHandler(id uint64) *responseHandler {
   718  	conn.reqMtx.Lock()
   719  	defer conn.reqMtx.Unlock()
   720  	cb, ok := conn.respHandlers[id]
   721  	if ok {
   722  		cb.expiration.Stop()
   723  		delete(conn.respHandlers, id)
   724  	}
   725  	return cb
   726  }
   727  
   728  // MessageSource returns the connection's read source. The returned chan will
   729  // receive requests and notifications from the server, but not responses, which
   730  // have handlers associated with their request. The same channel is returned on
   731  // each call, so there must only be one receiver. When the connection is
   732  // shutdown, the channel will be closed.
   733  func (conn *wsConn) MessageSource() <-chan *msgjson.Message {
   734  	return conn.readCh
   735  }