github.com/bitfinexcom/bitfinex-api-go@v0.0.0-20210608095005-9e0b26f200fb/v2/websocket/subscriptions.go (about)

     1  package websocket
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/op/go-logging"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  )
    10  
    11  type SubscriptionRequest struct {
    12  	SubID string `json:"subId"`
    13  	Event string `json:"event"`
    14  
    15  	// authenticated
    16  	APIKey      string   `json:"apiKey,omitempty"`
    17  	AuthSig     string   `json:"authSig,omitempty"`
    18  	AuthPayload string   `json:"authPayload,omitempty"`
    19  	AuthNonce   string   `json:"authNonce,omitempty"`
    20  	Filter      []string `json:"filter,omitempty"`
    21  	DMS         int      `json:"dms,omitempty"` // dead man switch
    22  
    23  	// unauthenticated
    24  	Channel   string `json:"channel,omitempty"`
    25  	Symbol    string `json:"symbol,omitempty"`
    26  	Precision string `json:"prec,omitempty"`
    27  	Frequency string `json:"freq,omitempty"`
    28  	Key       string `json:"key,omitempty"`
    29  	Len       string `json:"len,omitempty"`
    30  	Pair      string `json:"pair,omitempty"`
    31  }
    32  
    33  const MaxChannels = 25
    34  
    35  func (s *SubscriptionRequest) String() string {
    36  	if s.Key == "" && s.Channel != "" && s.Symbol != "" {
    37  		return fmt.Sprintf("%s %s", s.Channel, s.Symbol)
    38  	}
    39  	if s.Channel != "" && s.Symbol != "" && s.Precision != "" && s.Frequency != "" {
    40  		return fmt.Sprintf("%s %s %s %s", s.Channel, s.Symbol, s.Precision, s.Frequency)
    41  	}
    42  	if s.Channel != "" && s.Symbol != "" {
    43  		return fmt.Sprintf("%s %s", s.Channel, s.Symbol)
    44  	}
    45  	return ""
    46  }
    47  
    48  type HeartbeatDisconnect struct {
    49  	Subscription *subscription
    50  	Error        error
    51  }
    52  
    53  type UnsubscribeRequest struct {
    54  	Event  string `json:"event"`
    55  	ChanID int64  `json:"chanId"`
    56  }
    57  
    58  type subscription struct {
    59  	ChanID     int64
    60  	SocketId   SocketId
    61  	pending    bool
    62  	Public     bool
    63  
    64  	Request    *SubscriptionRequest
    65  
    66  	hbDeadline time.Time
    67  }
    68  
    69  func isPublic(request *SubscriptionRequest) bool {
    70  	switch request.Channel {
    71  	case ChanBook:
    72  		return true
    73  	case ChanCandles:
    74  		return true
    75  	case ChanTicker:
    76  		return true
    77  	case ChanTrades:
    78  		return true
    79  	case ChanStatus:
    80  		return true
    81  	}
    82  	return false
    83  }
    84  
    85  func newSubscription(socketId SocketId, request *SubscriptionRequest) *subscription {
    86  	return &subscription{
    87  		ChanID:  -1,
    88  		SocketId: socketId,
    89  		Request: request,
    90  		pending: true,
    91  		Public:  isPublic(request),
    92  	}
    93  }
    94  
    95  func (s subscription) SubID() string {
    96  	return s.Request.SubID
    97  }
    98  
    99  func (s subscription) Pending() bool {
   100  	return s.pending
   101  }
   102  
   103  func newSubscriptions(heartbeatTimeout time.Duration, log *logging.Logger) *subscriptions {
   104  	subs := &subscriptions{
   105  		subsBySubID:  make(map[string]*subscription),
   106  		subsByChanID: make(map[int64]*subscription),
   107  		subsBySocketId: make(map[SocketId]SubscriptionSet),
   108  		hbTimeout:    heartbeatTimeout,
   109  		hbShutdown:   make(chan struct{}),
   110  		hbDisconnect: make(chan HeartbeatDisconnect),
   111  		hbSleep:      heartbeatTimeout / time.Duration(4),
   112  		log:          log,
   113  		lock:         &sync.RWMutex{},
   114  	}
   115  	go subs.control()
   116  	return subs
   117  }
   118  
   119  // nolint
   120  type heartbeat struct {
   121  	ChanID int64
   122  	*time.Time
   123  }
   124  
   125  type subscriptions struct {
   126  	lock         *sync.RWMutex
   127  	log          *logging.Logger
   128  
   129  	subsBySocketId map[SocketId]SubscriptionSet // subscripts map indexed by socket id
   130  	subsBySubID  map[string]*subscription // subscription map indexed by subscription ID
   131  	subsByChanID map[int64]*subscription  // subscription map indexed by channel ID
   132  
   133  	hbActive     bool
   134  	hbDisconnect chan HeartbeatDisconnect // disconnect parent due to heartbeat timeout
   135  	hbTimeout    time.Duration
   136  	hbSleep      time.Duration
   137  	hbShutdown   chan struct{}
   138  }
   139  
   140  // SubscriptionSet is a typed version of an array of subscription pointers, intended to meet the sortable interface.
   141  // We need to sort Reset()'s return values for tests with more than 1 subscription (range map order is undefined)
   142  type SubscriptionSet []*subscription
   143  
   144  func (s SubscriptionSet) Len() int {
   145  	return len(s)
   146  }
   147  func (s SubscriptionSet) Less(i, j int) bool {
   148  	return strings.Compare(s[i].SubID(), s[j].SubID()) < 0
   149  }
   150  func (s SubscriptionSet) Swap(i, j int) {
   151  	s[i], s[j] = s[j], s[i]
   152  }
   153  func (s SubscriptionSet) RemoveByChannelId(chanId int64) SubscriptionSet {
   154  	rIndex := -1
   155  	for i, sub := range s {
   156  		if sub.ChanID == chanId {
   157  			rIndex = i
   158  			break
   159  		}
   160  	}
   161  	if rIndex >= 0 {
   162  		return append(s[:rIndex], s[rIndex+1:]...)
   163  	}
   164  	return s
   165  }
   166  
   167  func (s SubscriptionSet) RemoveBySubscriptionId(subID string) SubscriptionSet {
   168  	rIndex := -1
   169  	for i, sub := range s {
   170  		if sub.SubID() == subID {
   171  			rIndex = i
   172  			break
   173  		}
   174  	}
   175  	if rIndex >= 0 {
   176  		return append(s[:rIndex], s[rIndex+1:]...)
   177  	}
   178  	return s
   179  }
   180  
   181  func (s *subscriptions) heartbeat(chanID int64) {
   182  	s.lock.Lock()
   183  	defer s.lock.Unlock()
   184  	if sub, ok := s.subsByChanID[chanID]; ok {
   185  		sub.hbDeadline = time.Now().Add(s.hbTimeout)
   186  	}
   187  }
   188  
   189  func (s *subscriptions) sweep(exp time.Time) {
   190  	s.lock.RLock()
   191  	if !s.hbActive {
   192  		s.lock.RUnlock()
   193  		return
   194  	}
   195  	disconnects := make([]HeartbeatDisconnect, 0)
   196  	for _, sub := range s.subsByChanID {
   197  		if exp.After(sub.hbDeadline) {
   198  			s.hbActive = false
   199  			hbErr := HeartbeatDisconnect{
   200  				Subscription: sub,
   201  				Error: fmt.Errorf("heartbeat disconnect on channel %d expired at %s (%s timeout)", sub.ChanID, sub.hbDeadline, s.hbTimeout),
   202  			}
   203  			disconnects = append(disconnects, hbErr)
   204  		}
   205  	}
   206  	s.lock.RUnlock()
   207  	for _, dis := range disconnects {
   208  		s.hbDisconnect <- dis
   209  	}
   210  }
   211  
   212  func (s *subscriptions) control() {
   213  	for {
   214  		select {
   215  		case <-s.hbShutdown:
   216  			return
   217  		default:
   218  		}
   219  		s.sweep(time.Now())
   220  		time.Sleep(s.hbSleep)
   221  	}
   222  }
   223  
   224  // Close is terminal. Do not call heartbeat after close.
   225  func (s *subscriptions) Close() {
   226  	s.ResetAll()
   227  	close(s.hbShutdown)
   228  }
   229  
   230  
   231  // Reset clears all subscriptions assigned to the given socket ID, and returns
   232  // a slice of the existing subscriptions prior to reset
   233  func (s *subscriptions) ResetSocketSubscriptions(socketId SocketId) []*subscription {
   234  	var retSubs []*subscription
   235  	s.lock.Lock()
   236  	if set, ok := s.subsBySocketId[socketId]; ok {
   237  		for _, sub := range set {
   238  			retSubs = append(retSubs, sub)
   239  			// remove from chanId array
   240  			delete(s.subsByChanID, sub.ChanID)
   241  			// remove from subId array
   242  			delete(s.subsBySubID, sub.SubID())
   243  		}
   244  	}
   245  	s.subsBySocketId[socketId] = make(SubscriptionSet, 0)
   246  	s.lock.Unlock()
   247  	return retSubs
   248  }
   249  
   250  // Removes all tracked subscriptions
   251  func (s *subscriptions) ResetAll() {
   252  	s.lock.Lock()
   253  	s.subsBySubID = make(map[string]*subscription)
   254  	s.subsByChanID = make(map[int64]*subscription)
   255  	s.subsBySocketId = make(map[SocketId]SubscriptionSet)
   256  	s.lock.Unlock()
   257  }
   258  
   259  // ListenDisconnect returns an error channel which receives a message when a heartbeat has expired a channel.
   260  func (s *subscriptions) ListenDisconnect() <-chan HeartbeatDisconnect {
   261  	return s.hbDisconnect
   262  }
   263  
   264  func (s *subscriptions) add(socketId SocketId, sub *SubscriptionRequest) *subscription {
   265  	s.lock.Lock()
   266  	defer s.lock.Unlock()
   267  	subscription := newSubscription(socketId, sub)
   268  	s.subsBySubID[sub.SubID] = subscription
   269  	if _, ok := s.subsBySocketId[socketId]; !ok {
   270  		s.subsBySocketId[socketId] = make(SubscriptionSet, 0)
   271  	}
   272  	s.subsBySocketId[socketId] = append(s.subsBySocketId[socketId], subscription)
   273  	return subscription
   274  }
   275  
   276  func (s *subscriptions) removeByChannelID(chanID int64) error {
   277  	s.lock.Lock()
   278  	defer s.lock.Unlock()
   279  	// remove from socketId map
   280  	sub, ok := s.subsByChanID[chanID]
   281  	if !ok {
   282  		return fmt.Errorf("could not find channel ID %d", chanID)
   283  	}
   284  	delete(s.subsByChanID, chanID)
   285  	delete(s.subsBySubID, sub.SubID())
   286  	// remove from socket map
   287  	if _, ok := s.subsBySocketId[sub.SocketId]; ok {
   288  		s.subsBySocketId[sub.SocketId] = s.subsBySocketId[sub.SocketId].RemoveByChannelId(chanID)
   289  	}
   290  	return nil
   291  }
   292  
   293  // nolint:megacheck
   294  func (s *subscriptions) removeBySubscriptionID(subID string) error {
   295  	s.lock.Lock()
   296  	defer s.lock.Unlock()
   297  	sub, ok := s.subsBySubID[subID]
   298  	if !ok {
   299  		return fmt.Errorf("could not find subscription ID %s to remove", subID)
   300  	}
   301  	// exists, remove both indices
   302  	delete(s.subsBySubID, subID)
   303  	delete(s.subsByChanID, sub.ChanID)
   304  	// remove from socket map
   305  	if _, ok := s.subsBySocketId[sub.SocketId]; ok {
   306  		s.subsBySocketId[sub.SocketId] = s.subsBySocketId[sub.SocketId].RemoveBySubscriptionId(subID)
   307  	}
   308  	return nil
   309  }
   310  
   311  func (s *subscriptions) activate(subID string, chanID int64) error {
   312  	s.lock.Lock()
   313  	defer s.lock.Unlock()
   314  
   315  	if sub, ok := s.subsBySubID[subID]; ok {
   316  		if chanID != 0 {
   317  			s.log.Infof("activated subscription %s %s for channel %d", sub.Request.Channel, sub.Request.Symbol, chanID)
   318  		}
   319  		sub.pending = false
   320  		sub.ChanID = chanID
   321  		sub.hbDeadline = time.Now().Add(s.hbTimeout)
   322  		s.subsByChanID[chanID] = sub
   323  		s.hbActive = true
   324  		return nil
   325  	}
   326  
   327  	return fmt.Errorf("could not find subscription ID %s to activate", subID)
   328  }
   329  
   330  func (s *subscriptions) lookupBySocketChannelID(chanID int64, sId SocketId) (*subscription, error) {
   331  	s.lock.RLock()
   332  	defer s.lock.RUnlock()
   333  	if subs, ok := s.subsBySocketId[sId]; ok {
   334  		for _, s := range subs {
   335  			if s.ChanID == chanID {
   336  				return s, nil
   337  			}
   338  		}
   339  	}
   340  	return nil, fmt.Errorf("could not find subscription for channel ID %d and socket sId %d", chanID, sId)
   341  }
   342  
   343  func (s *subscriptions) lookupBySubscriptionID(subID string) (*subscription, error) {
   344  	s.lock.RLock()
   345  	defer s.lock.RUnlock()
   346  	if sub, ok := s.subsBySubID[subID]; ok {
   347  		return sub, nil
   348  	}
   349  	return nil, fmt.Errorf("could not find subscription ID %s", subID)
   350  }
   351  
   352  func (s *subscriptions) lookupBySocketId(socketId SocketId) (*SubscriptionSet, error) {
   353  	s.lock.RLock()
   354  	defer s.lock.RUnlock()
   355  	if set, ok := s.subsBySocketId[socketId]; ok {
   356  		return &set, nil
   357  	}
   358  	return nil, fmt.Errorf("could not find subscription with socketId %d", socketId)
   359  }