github.com/polygon-io/client-go@v1.16.4/websocket/polygon.go (about)

     1  package polygonws
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"net/url"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/cenkalti/backoff/v4"
    13  	"github.com/gorilla/websocket"
    14  	"github.com/polygon-io/client-go/websocket/models"
    15  	"golang.org/x/exp/maps"
    16  	"golang.org/x/exp/slices"
    17  	"gopkg.in/tomb.v2"
    18  )
    19  
    20  const (
    21  	writeWait      = 5 * time.Second
    22  	pongWait       = 10 * time.Second
    23  	pingPeriod     = (pongWait * 9) / 10
    24  	maxMessageSize = 1000000 // 1MB
    25  )
    26  
    27  // Client defines a client to the Polygon WebSocket API.
    28  type Client struct {
    29  	apiKey string
    30  	feed   Feed
    31  	market Market
    32  	url    string
    33  
    34  	shouldClose bool
    35  	backoff     backoff.BackOff
    36  
    37  	mtx    sync.Mutex
    38  	rwtomb tomb.Tomb
    39  	ptomb  tomb.Tomb
    40  
    41  	conn   *websocket.Conn
    42  	rQueue chan json.RawMessage
    43  	wQueue chan json.RawMessage
    44  	subs   subscriptions
    45  
    46  	rawData              bool
    47  	bypassRawDataRouting bool
    48  	output               chan any
    49  	err                  chan error
    50  
    51  	log Logger
    52  }
    53  
    54  // New creates a client for the Polygon WebSocket API.
    55  func New(config Config) (*Client, error) {
    56  	if err := config.validate(); err != nil {
    57  		return nil, fmt.Errorf("invalid client options: %w", err)
    58  	}
    59  
    60  	c := &Client{
    61  		apiKey:               config.APIKey,
    62  		feed:                 config.Feed,
    63  		market:               config.Market,
    64  		backoff:              backoff.NewExponentialBackOff(),
    65  		rQueue:               make(chan json.RawMessage, 10000),
    66  		wQueue:               make(chan json.RawMessage, 1000),
    67  		subs:                 make(subscriptions),
    68  		rawData:              config.RawData,
    69  		bypassRawDataRouting: config.BypassRawDataRouting,
    70  		output:               make(chan any, 100000),
    71  		err:                  make(chan error),
    72  		log:                  config.Log,
    73  	}
    74  
    75  	uri, err := url.Parse(string(c.feed))
    76  	if err != nil {
    77  		return nil, fmt.Errorf("invalid data feed format: %v", err)
    78  	}
    79  	uri.Path = strings.Join([]string{uri.Path, string(c.market)}, "/")
    80  	c.url = uri.String()
    81  
    82  	if config.MaxRetries != nil {
    83  		c.backoff = backoff.WithMaxRetries(c.backoff, *config.MaxRetries)
    84  	}
    85  
    86  	return c, nil
    87  }
    88  
    89  // Connect dials the WebSocket server and starts the read/write and process threads.
    90  // If any subscription messages are pushed before connecting, it will also send those
    91  // to the server.
    92  func (c *Client) Connect() error {
    93  	c.mtx.Lock()
    94  	defer c.mtx.Unlock()
    95  
    96  	if c.conn != nil {
    97  		return nil
    98  	}
    99  
   100  	notify := func(err error, _ time.Duration) {
   101  		c.log.Errorf(err.Error())
   102  	}
   103  	if err := backoff.RetryNotify(c.connect(false), c.backoff, notify); err != nil {
   104  		return err
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  // Subscribe sends a subscription message for a topic and set of tickers. If no
   111  // tickers are passed, it will subscribe to all tickers for a given topic.
   112  func (c *Client) Subscribe(topic Topic, tickers ...string) error {
   113  	c.mtx.Lock()
   114  	defer c.mtx.Unlock()
   115  
   116  	if !c.market.supports(topic) {
   117  		return fmt.Errorf("topic '%v' not supported for market '%v'", topic.prefix(), c.market)
   118  	}
   119  
   120  	if len(tickers) == 0 || slices.Contains(tickers, "*") {
   121  		tickers = []string{"*"}
   122  	}
   123  
   124  	subscribe, err := getSub(models.Subscribe, topic, tickers...)
   125  	if err != nil {
   126  		return err
   127  	}
   128  
   129  	c.subs.add(topic, tickers...)
   130  	c.wQueue <- subscribe
   131  
   132  	return nil
   133  }
   134  
   135  // Unsubscribe sends a message to unsubscribe from a topic and set of tickers. If no
   136  // tickers are passed, it will unsubscribe from all tickers for a given topic.
   137  func (c *Client) Unsubscribe(topic Topic, tickers ...string) error {
   138  	c.mtx.Lock()
   139  	defer c.mtx.Unlock()
   140  
   141  	if !c.market.supports(topic) {
   142  		return fmt.Errorf("topic '%v' not supported for market '%v'", topic.prefix(), c.market)
   143  	}
   144  
   145  	if len(tickers) == 0 || slices.Contains(tickers, "*") {
   146  		tickers = maps.Keys(c.subs[topic])
   147  	}
   148  
   149  	unsubscribe, err := getSub(models.Unsubscribe, topic, tickers...)
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	c.subs.delete(topic, tickers...)
   155  	c.wQueue <- unsubscribe
   156  
   157  	return nil
   158  }
   159  
   160  // Output returns the output queue.
   161  func (c *Client) Output() <-chan any {
   162  	return c.output
   163  }
   164  
   165  // Error returns an error channel. If the client hits a fatal error (e.g. auth failed),
   166  // it will push an error to this channel and close the connection.
   167  func (c *Client) Error() <-chan error {
   168  	return c.err
   169  }
   170  
   171  // Close attempts to gracefully close the connection to the server.
   172  func (c *Client) Close() {
   173  	c.mtx.Lock()
   174  	defer c.mtx.Unlock()
   175  	c.close(false)
   176  }
   177  
   178  func newConn(uri string) (*websocket.Conn, error) {
   179  	conn, res, err := websocket.DefaultDialer.Dial(uri, nil)
   180  	if err != nil {
   181  		return nil, fmt.Errorf("failed to dial server: %w", err)
   182  	} else if res.StatusCode != 101 {
   183  		return nil, errors.New("server failed to switch protocols")
   184  	}
   185  
   186  	conn.SetReadLimit(maxMessageSize)
   187  	if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
   188  		return nil, fmt.Errorf("failed to set read deadline: %w", err)
   189  	}
   190  	conn.SetPongHandler(func(string) error {
   191  		return conn.SetReadDeadline(time.Now().Add(pongWait))
   192  	})
   193  
   194  	return conn, nil
   195  }
   196  
   197  func (c *Client) connect(reconnect bool) func() error {
   198  	return func() error {
   199  		// dial the server
   200  		conn, err := newConn(c.url)
   201  		if err != nil {
   202  			return err
   203  		}
   204  		c.conn = conn
   205  
   206  		// reset write queue and push auth message
   207  		c.wQueue = make(chan json.RawMessage, 1000)
   208  		auth, err := json.Marshal(models.ControlMessage{
   209  			Action: models.Auth,
   210  			Params: c.apiKey,
   211  		})
   212  		if err != nil {
   213  			return fmt.Errorf("failed to marshal auth message: %w", err)
   214  		}
   215  		c.wQueue <- auth
   216  
   217  		// push subscription messages
   218  		subs := c.subs.get()
   219  		for _, msg := range subs {
   220  			c.wQueue <- msg
   221  		}
   222  
   223  		// start the threads
   224  		c.rwtomb = tomb.Tomb{}
   225  		c.rwtomb.Go(c.read)
   226  		c.rwtomb.Go(c.write)
   227  		if !reconnect {
   228  			c.ptomb = tomb.Tomb{}
   229  			c.ptomb.Go(c.process)
   230  		}
   231  
   232  		return nil
   233  	}
   234  }
   235  
   236  func (c *Client) reconnect() {
   237  	c.mtx.Lock()
   238  	defer c.mtx.Unlock()
   239  
   240  	if c.shouldClose {
   241  		return
   242  	}
   243  
   244  	c.log.Debugf("unexpected disconnect: reconnecting")
   245  	c.close(true)
   246  
   247  	notify := func(err error, _ time.Duration) {
   248  		c.log.Errorf(err.Error())
   249  	}
   250  	err := backoff.RetryNotify(c.connect(true), c.backoff, notify)
   251  	if err != nil {
   252  		err = fmt.Errorf("error reconnecting: %w: closing connection", err)
   253  		c.log.Errorf(err.Error())
   254  		c.close(false)
   255  		c.err <- err
   256  	}
   257  }
   258  
   259  func (c *Client) closeOutput() {
   260  	close(c.output)
   261  	c.log.Debugf("output channel closed")
   262  }
   263  
   264  func (c *Client) close(reconnect bool) {
   265  	if c.conn == nil {
   266  		return
   267  	}
   268  
   269  	c.rwtomb.Kill(nil)
   270  	if err := c.rwtomb.Wait(); err != nil {
   271  		c.log.Errorf("r/w threads closed: %v", err)
   272  	}
   273  
   274  	if !reconnect {
   275  		c.ptomb.Kill(nil)
   276  		if err := c.ptomb.Wait(); err != nil {
   277  			c.log.Errorf("process thread closed: %v", err)
   278  		}
   279  		c.shouldClose = true
   280  		c.closeOutput()
   281  	}
   282  
   283  	if c.conn != nil {
   284  		_ = c.conn.Close()
   285  		c.conn = nil
   286  	}
   287  }
   288  
   289  func (c *Client) read() error {
   290  	defer func() {
   291  		c.log.Debugf("read thread closed")
   292  	}()
   293  
   294  	for {
   295  		select {
   296  		case <-c.rwtomb.Dying():
   297  			return nil
   298  		default:
   299  			_, msg, err := c.conn.ReadMessage()
   300  			if err != nil {
   301  				if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   302  					return nil
   303  				} else if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
   304  					return fmt.Errorf("connection closed unexpectedly: %w", err)
   305  				}
   306  				return fmt.Errorf("failed to read message: %w", err)
   307  			}
   308  			c.rQueue <- msg
   309  		}
   310  	}
   311  }
   312  
   313  func (c *Client) write() error {
   314  	ticker := time.NewTicker(pingPeriod)
   315  	defer func() {
   316  		c.log.Debugf("write thread closed")
   317  		ticker.Stop()
   318  		go c.reconnect()
   319  	}()
   320  
   321  	for {
   322  		select {
   323  		case <-c.rwtomb.Dying():
   324  			if err := c.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(writeWait)); err != nil {
   325  				return fmt.Errorf("failed to gracefully close: %w", err)
   326  			}
   327  			return nil
   328  		case <-ticker.C:
   329  			if err := c.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(writeWait)); err != nil {
   330  				return fmt.Errorf("failed to send ping message: %w", err)
   331  			}
   332  		case msg := <-c.wQueue:
   333  			if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
   334  				return fmt.Errorf("failed to set write deadline: %w", err)
   335  			}
   336  			if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil {
   337  				return fmt.Errorf("failed to send message: %w", err)
   338  			}
   339  		}
   340  	}
   341  }
   342  
   343  func (c *Client) process() (err error) {
   344  	defer func() {
   345  		// this client should close if it hits a fatal error (e.g. auth failed)
   346  		c.log.Debugf("process thread closed")
   347  		if err != nil {
   348  			go c.Close()
   349  			c.err <- err
   350  		}
   351  	}()
   352  
   353  	for {
   354  		select {
   355  		case <-c.ptomb.Dying():
   356  			return nil
   357  		case data := <-c.rQueue:
   358  			if c.rawData && c.bypassRawDataRouting {
   359  				c.output <- data // push raw bytes to output channel
   360  				continue
   361  			}
   362  
   363  			var msgs []json.RawMessage
   364  			if err := json.Unmarshal(data, &msgs); err != nil {
   365  				c.log.Errorf("failed to process raw messages: %v", err)
   366  				continue
   367  			}
   368  			if err := c.route(msgs); err != nil {
   369  				return err
   370  			}
   371  		}
   372  	}
   373  }
   374  
   375  func (c *Client) route(msgs []json.RawMessage) error {
   376  	for _, msg := range msgs {
   377  		var ev models.EventType
   378  		err := json.Unmarshal(msg, &ev)
   379  		if err != nil {
   380  			c.log.Errorf("failed to process message: %v", err)
   381  			continue
   382  		}
   383  
   384  		switch ev.EventType {
   385  		case "status":
   386  			if err := c.handleStatus(msg); err != nil {
   387  				return err
   388  			}
   389  		default:
   390  			c.handleData(ev.EventType, msg)
   391  		}
   392  	}
   393  
   394  	return nil
   395  }
   396  
   397  func (c *Client) handleStatus(msg json.RawMessage) error {
   398  	var cm models.ControlMessage
   399  	if err := json.Unmarshal(msg, &cm); err != nil {
   400  		c.log.Errorf("failed to unmarshal message: %v", err)
   401  		return nil
   402  	}
   403  
   404  	switch cm.Status {
   405  	case "connected":
   406  		c.log.Debugf("connection successful")
   407  	case "auth_success":
   408  		c.log.Debugf("authentication successful")
   409  	case "auth_failed":
   410  		// this is a fatal error so need to close the connection
   411  		return errors.New("authentication failed: closing connection")
   412  	case "success":
   413  		c.log.Debugf("received a successful status message: %v", sanitize(cm.Message))
   414  	case "error":
   415  		c.log.Errorf("received an error status message: %v", sanitize(cm.Message))
   416  	default:
   417  		c.log.Infof("unknown status message '%v': %v", sanitize(cm.Status), sanitize(cm.Message))
   418  	}
   419  
   420  	return nil
   421  }
   422  
   423  func (c *Client) handleData(eventType string, msg json.RawMessage) {
   424  	if c.rawData {
   425  		c.output <- msg // push raw JSON to output channel
   426  		return
   427  	}
   428  
   429  	switch eventType {
   430  	case "A":
   431  		var out models.EquityAgg
   432  		if err := json.Unmarshal(msg, &out); err != nil {
   433  			c.log.Errorf("failed to unmarshal message: %v", err)
   434  			return
   435  		}
   436  		c.output <- out
   437  	case "AM":
   438  		switch c.market {
   439  		case Forex, Crypto:
   440  			if c.feed == LaunchpadFeed {
   441  				var out models.EquityAgg
   442  				if err := json.Unmarshal(msg, &out); err != nil {
   443  					c.log.Errorf("failed to unmarshal message: %v", err)
   444  					return
   445  				}
   446  				c.output <- out
   447  			} else {
   448  				var out models.CurrencyAgg
   449  				if err := json.Unmarshal(msg, &out); err != nil {
   450  					c.log.Errorf("failed to unmarshal message: %v", err)
   451  					return
   452  				}
   453  				c.output <- out
   454  			}
   455  
   456  		default:
   457  			var out models.EquityAgg
   458  			if err := json.Unmarshal(msg, &out); err != nil {
   459  				c.log.Errorf("failed to unmarshal message: %v", err)
   460  				return
   461  			}
   462  			c.output <- out
   463  		}
   464  	case "CA", "CAS":
   465  		var out models.CurrencyAgg
   466  		if err := json.Unmarshal(msg, &out); err != nil {
   467  			c.log.Errorf("failed to unmarshal message: %v", err)
   468  			return
   469  		}
   470  		c.output <- out
   471  	case "XA", "XAS":
   472  		var out models.CurrencyAgg
   473  		if err := json.Unmarshal(msg, &out); err != nil {
   474  			c.log.Errorf("failed to unmarshal message: %v", err)
   475  			return
   476  		}
   477  		c.output <- out
   478  	case "T":
   479  		var out models.EquityTrade
   480  		if err := json.Unmarshal(msg, &out); err != nil {
   481  			c.log.Errorf("failed to unmarshal message: %v", err)
   482  			return
   483  		}
   484  		c.output <- out
   485  	case "XT":
   486  		var out models.CryptoTrade
   487  		if err := json.Unmarshal(msg, &out); err != nil {
   488  			c.log.Errorf("failed to unmarshal message: %v", err)
   489  			return
   490  		}
   491  		c.output <- out
   492  	case "Q":
   493  		var out models.EquityQuote
   494  		if err := json.Unmarshal(msg, &out); err != nil {
   495  			c.log.Errorf("failed to unmarshal message: %v", err)
   496  			return
   497  		}
   498  		c.output <- out
   499  	case "C":
   500  		var out models.ForexQuote
   501  		if err := json.Unmarshal(msg, &out); err != nil {
   502  			c.log.Errorf("failed to unmarshal message: %v", err)
   503  			return
   504  		}
   505  		c.output <- out
   506  	case "XQ":
   507  		var out models.CryptoQuote
   508  		if err := json.Unmarshal(msg, &out); err != nil {
   509  			c.log.Errorf("failed to unmarshal message: %v", err)
   510  			return
   511  		}
   512  		c.output <- out
   513  	case "NOI":
   514  		var out models.Imbalance
   515  		if err := json.Unmarshal(msg, &out); err != nil {
   516  			c.log.Errorf("failed to unmarshal message: %v", err)
   517  			return
   518  		}
   519  		c.output <- out
   520  	case "LULD":
   521  		var out models.LimitUpLimitDown
   522  		if err := json.Unmarshal(msg, &out); err != nil {
   523  			c.log.Errorf("failed to unmarshal message: %v", err)
   524  			return
   525  		}
   526  		c.output <- out
   527  	case "XL2":
   528  		var out models.Level2Book
   529  		if err := json.Unmarshal(msg, &out); err != nil {
   530  			c.log.Errorf("failed to unmarshal message: %v", err)
   531  			return
   532  		}
   533  		c.output <- out
   534  	case "V":
   535  		var out models.IndexValue
   536  		if err := json.Unmarshal(msg, &out); err != nil {
   537  			c.log.Errorf("failed to unmarshal message: %v", err)
   538  			return
   539  		}
   540  		c.output <- out
   541  	case "LV":
   542  		var out models.LaunchpadValue
   543  		if err := json.Unmarshal(msg, &out); err != nil {
   544  			c.log.Errorf("failed to unmarshal message: %v", err)
   545  			return
   546  		}
   547  		c.output <- out
   548  	case "FMV":
   549  		var out models.FairMarketValue
   550  		if err := json.Unmarshal(msg, &out); err != nil {
   551  			c.log.Errorf("failed to unmarshal message: %v", err)
   552  		}
   553  		c.output <- out
   554  
   555  	default:
   556  		c.log.Infof("unknown message type '%s'", sanitize(eventType))
   557  	}
   558  }
   559  
   560  func sanitize(s string) string {
   561  	return strings.Replace(s, "\n", "", -1)
   562  }