decred.org/dcrdex@v1.0.5/server/comms/link.go (about)

     1  // This code is available on the terms of the project LICENSE.md file,
     2  // also available online at https://blueoakcouncil.org/license/1.0.0.
     3  
     4  package comms
     5  
     6  import (
     7  	"encoding/json"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"decred.org/dcrdex/dex/msgjson"
    13  	"decred.org/dcrdex/dex/ws"
    14  )
    15  
    16  const readLimitAuthorized = 262144
    17  
    18  // criticalRoutes are not subject to the rate limiter on websocket connections.
    19  var criticalRoutes = map[string]bool{
    20  	msgjson.ConfigRoute: true,
    21  }
    22  
    23  // Link is an interface for a communication channel with an API client. The
    24  // reference implementation of a Link-satisfying type is the wsLink, which
    25  // passes messages over a websocket connection.
    26  type Link interface {
    27  	// Done returns a channel that is closed when the link goes down.
    28  	Done() <-chan struct{}
    29  	// ID returns a unique ID by which this connection can be identified.
    30  	ID() uint64
    31  	// Addr returns the string-encoded IP address.
    32  	Addr() string
    33  	// Send sends the msgjson.Message to the peer.
    34  	Send(msg *msgjson.Message) error
    35  	// SendRaw sends the raw bytes which is assumed to be a marshalled
    36  	// msgjson.Message to the peer. Can be used to avoid marshalling the
    37  	// same message multiple times.
    38  	SendRaw(b []byte) error
    39  	// SendError sends the msgjson.Error to the peer, with reference to a
    40  	// request message ID.
    41  	SendError(id uint64, rpcErr *msgjson.Error)
    42  	// Request sends the Request-type msgjson.Message to the client and registers
    43  	// a handler for the response.
    44  	Request(msg *msgjson.Message, f func(Link, *msgjson.Message), expireTime time.Duration, expire func()) error
    45  	RequestRaw(msgID uint64, rawMsg []byte, f func(Link, *msgjson.Message), expireTime time.Duration, expire func()) error
    46  	// Banish closes the link and quarantines the client.
    47  	Banish()
    48  	// Disconnect closes the link.
    49  	Disconnect()
    50  	// Authorized should be called from a request handler when the connection
    51  	// becomes authorized. Request handlers must be run synchronous with other
    52  	// reads or it will be a data race with the link's input loop.
    53  	Authorized()
    54  	// SetCustomID
    55  	SetCustomID(string)
    56  	// CustomID
    57  	CustomID() string
    58  }
    59  
    60  // When the DEX sends a request to the client, a responseHandler is created
    61  // to wait for the response.
    62  type responseHandler struct {
    63  	f      func(Link, *msgjson.Message)
    64  	expire *time.Timer
    65  }
    66  
    67  // wsLink is the local, per-connection representation of a DEX client.
    68  type wsLink struct {
    69  	*ws.WSLink
    70  	// The id is the unique identifier assigned to this client.
    71  	id       uint64
    72  	customID atomic.Value
    73  	// For DEX-originating requests, the response handler is mapped to the
    74  	// resquest ID.
    75  	reqMtx       sync.Mutex
    76  	respHandlers map[uint64]*responseHandler
    77  	// Upon closing, the client's IP address will be quarantined by the server if
    78  	// ban = true.
    79  	ban bool
    80  	// dataMeter is a function that will be checked to see if certain data API
    81  	// requests should be denied due to rate limits or if the data API is
    82  	// disabled. This applies to non-critical httpRoutes requests.
    83  	dataMeter func() (int, error)
    84  	// wsLimiter is a route-based rate limiter. This applies to rpcRoutes.
    85  	wsLimiter *routeLimiter
    86  }
    87  
    88  // newWSLink is a constructor for a new wsLink.
    89  func (s *Server) newWSLink(addr string, conn ws.Connection, wsLimiter *routeLimiter, limitData func() (int, error)) *wsLink {
    90  	var c *wsLink
    91  	c = &wsLink{
    92  		WSLink: ws.NewWSLink(addr, conn, pingPeriod, func(msg *msgjson.Message) *msgjson.Error {
    93  			return s.handleMessage(c, msg)
    94  		}, log.SubLogger("WS")),
    95  		respHandlers: make(map[uint64]*responseHandler),
    96  		dataMeter:    limitData,
    97  		wsLimiter:    wsLimiter,
    98  	}
    99  	return c
   100  }
   101  
   102  // Banish sets the ban flag and closes the client.
   103  func (c *wsLink) Banish() {
   104  	c.ban = true
   105  	c.Disconnect()
   106  }
   107  
   108  // ID returns a unique ID by which this connection can be identified.
   109  func (c *wsLink) ID() uint64 {
   110  	return c.id
   111  }
   112  
   113  func (c *wsLink) SetCustomID(id string) {
   114  	c.customID.Store(id)
   115  }
   116  
   117  func (c *wsLink) CustomID() string {
   118  	if s := c.customID.Load(); s != nil {
   119  		return s.(string)
   120  	}
   121  	return ""
   122  }
   123  
   124  // Addr returns the string-encoded IP address.
   125  func (c *wsLink) Addr() string {
   126  	return c.WSLink.Addr()
   127  }
   128  
   129  // Authorized should be called from a request handler when the connection
   130  // becomes authorized. Unless it is run in a request handler synchronous with
   131  // other reads or prior to starting the link, it will be a data race with the
   132  // link's input loop. dex/ws.(*WsLink).inHandler does not run request handlers
   133  // concurrently with reads.
   134  func (c *wsLink) Authorized() {
   135  	c.SetReadLimit(readLimitAuthorized)
   136  }
   137  
   138  // The WSLink.handler for WSLink.inHandler
   139  func (s *Server) handleMessage(c *wsLink, msg *msgjson.Message) *msgjson.Error {
   140  	switch msg.Type {
   141  	case msgjson.Request:
   142  		if msg.ID == 0 {
   143  			return msgjson.NewError(msgjson.RPCParseError, "request id cannot be zero")
   144  		}
   145  		// Look for a registered WebSocket route handler. This excludes the data
   146  		// API routes, which are part of the httpHandler map.
   147  		handler := s.rpcRoutes[msg.Route]
   148  		if handler != nil {
   149  			if !c.wsLimiter.allow(msg.Route) {
   150  				return msgjson.NewError(msgjson.TooManyRequestsError, "too many requests to %s", msg.Route)
   151  			}
   152  			// Handle the request.
   153  			return handler(c, msg)
   154  		}
   155  
   156  		// Look for an HTTP handler.
   157  		httpHandler := s.httpRoutes[msg.Route]
   158  		if httpHandler == nil {
   159  			return msgjson.NewError(msgjson.RPCUnknownRoute, "unknown route")
   160  		}
   161  
   162  		// If it's not a critical route, check the rate limiters.
   163  		if !criticalRoutes[msg.Route] {
   164  			if _, err := c.dataMeter(); err != nil {
   165  				// These errors are actually formatted nicely for sending, since
   166  				// they are used directly in HTTP errors as well.
   167  				return msgjson.NewError(msgjson.TooManyRequestsError, "metered: %v", err)
   168  			}
   169  		}
   170  
   171  		// Prepare the thing and unmarshal.
   172  		var thing any
   173  		switch msg.Route {
   174  		case msgjson.CandlesRoute:
   175  			thing = new(msgjson.CandlesRequest)
   176  		case msgjson.OrderBookRoute:
   177  			thing = new(msgjson.OrderBookSubscription)
   178  		}
   179  		if thing != nil {
   180  			err := msg.Unmarshal(thing)
   181  			if err != nil {
   182  				return msgjson.NewError(msgjson.RPCParseError, "json parse error")
   183  			}
   184  		}
   185  
   186  		// Process request.
   187  		resp, err := httpHandler(thing)
   188  		if err != nil {
   189  			return msgjson.NewError(msgjson.HTTPRouteError, "handler error: %v", err)
   190  		}
   191  
   192  		// Respond.
   193  		msg, err := msgjson.NewResponse(msg.ID, resp, nil)
   194  		if err == nil {
   195  			err = c.Send(msg)
   196  		}
   197  
   198  		if err != nil {
   199  			log.Errorf("Error sending response to %s for requested route %q: %v", c.Addr(), msg.Route, err)
   200  		}
   201  		return nil
   202  
   203  	case msgjson.Notification:
   204  		// Look for a registered WebSocket route handler. This excludes the data
   205  		// API routes, which are part of the httpHandler map.
   206  		handler := s.rpcRoutes[msg.Route]
   207  		if handler != nil {
   208  			if !c.wsLimiter.allow(msg.Route) {
   209  				return msgjson.NewError(msgjson.TooManyRequestsError, "too many requests to %s", msg.Route)
   210  			}
   211  			// Handle the request.
   212  			return handler(c, msg)
   213  		}
   214  	case msgjson.Response:
   215  		// NOTE: In the event of an error, we respond to a response, which makes
   216  		// no sense. A new mechanism is needed with appropriate client handling.
   217  		if msg.ID == 0 {
   218  			return msgjson.NewError(msgjson.RPCParseError, "response id cannot be 0")
   219  		}
   220  		cb := c.respHandler(msg.ID)
   221  		if cb == nil {
   222  			log.Debugf("comms.handleMessage: handler for msg ID %d not found", msg.ID)
   223  			return msgjson.NewError(msgjson.UnknownResponseID,
   224  				"unknown response ID")
   225  		}
   226  		cb.f(c, msg)
   227  		return nil
   228  	}
   229  	return msgjson.NewError(msgjson.UnknownMessageType, "unknown message type")
   230  }
   231  
   232  func (c *wsLink) expire(id uint64) bool {
   233  	c.reqMtx.Lock()
   234  	defer c.reqMtx.Unlock()
   235  	_, removed := c.respHandlers[id]
   236  	delete(c.respHandlers, id)
   237  	return removed
   238  }
   239  
   240  // logReq stores the response handler in the respHandlers map. Requests to the
   241  // client are associated with a response handler.
   242  func (c *wsLink) logReq(id uint64, respHandler func(Link, *msgjson.Message), expireTime time.Duration, expire func()) {
   243  	c.reqMtx.Lock()
   244  	defer c.reqMtx.Unlock()
   245  	doExpire := func() {
   246  		// Delete the response handler, and call the provided expire function if
   247  		// (*wsLink).respHandler has not already retrieved the handler function
   248  		// for execution.
   249  		if c.expire(id) {
   250  			expire()
   251  		}
   252  	}
   253  	c.respHandlers[id] = &responseHandler{
   254  		f:      respHandler,
   255  		expire: time.AfterFunc(expireTime, doExpire),
   256  	}
   257  }
   258  
   259  // Request sends the message to the client and tracks the response handler. If
   260  // the response handler is called, it is guaranteed that the request Message.ID
   261  // is equal to the response Message.ID passed to the handler (see the
   262  // msgjson.Response case in handleMessage).
   263  func (c *wsLink) Request(msg *msgjson.Message, f func(conn Link, msg *msgjson.Message), expireTime time.Duration, expire func()) error {
   264  	rawMsg, err := json.Marshal(msg)
   265  	if err != nil {
   266  		log.Errorf("Failed to marshal message: %v", err)
   267  		return err
   268  	}
   269  
   270  	err = c.RequestRaw(msg.ID, rawMsg, f, expireTime, expire)
   271  	if err != nil {
   272  		log.Debugf("(*wsLink).Request(route '%s') Send error, unregistering msg ID %d handler: %v",
   273  			msg.Route, msg.ID, err)
   274  	}
   275  	return err
   276  }
   277  
   278  func (c *wsLink) RequestRaw(msgID uint64, rawMsg []byte, f func(conn Link, msg *msgjson.Message), expireTime time.Duration, expire func()) error {
   279  	// log.Tracef("Registering '%s' request ID %d (wsLink)", msg.Route, msg.ID)
   280  	c.logReq(msgID, f, expireTime, expire)
   281  	// Send errors are (1) connection is already down or (2) json marshal
   282  	// failure. Any connection write errors just cause the link to quit as the
   283  	// goroutine that actually does the write does not relay any errors back to
   284  	// the caller. The request will eventually expire when no response comes.
   285  	// This is not ideal - we may consider an error callback, or different
   286  	// Send/SendNow/QueueSend functions.
   287  	err := c.SendRaw(rawMsg)
   288  	if err != nil {
   289  		// Neither expire nor the handler should run. Stop the expire timer
   290  		// created by logReq and delete the response handler it added. The
   291  		// caller receives a non-nil error to deal with it.
   292  
   293  		c.respHandler(msgID) // drop the removed responseHandler
   294  	}
   295  	return err
   296  }
   297  
   298  // respHandler extracts the response handler for the provided request ID if it
   299  // exists, else nil. If the handler exists, it will be deleted from the map and
   300  // the expire Timer stopped.
   301  func (c *wsLink) respHandler(id uint64) *responseHandler {
   302  	c.reqMtx.Lock()
   303  	defer c.reqMtx.Unlock()
   304  	cb, ok := c.respHandlers[id]
   305  	if ok {
   306  		// Stop the expiration Timer. If the Timer fired after respHandler was
   307  		// called, but we found the response handler in the map, wsLink.expire
   308  		// is waiting for the reqMtx lock and will return false, thus preventing
   309  		// the registered expire func from executing.
   310  		cb.expire.Stop()
   311  		delete(c.respHandlers, id)
   312  	}
   313  	return cb
   314  }