github.com/bitfinexcom/bitfinex-api-go@v0.0.0-20210608095005-9e0b26f200fb/v2/websocket/subscriptions.go (about) 1 package websocket 2 3 import ( 4 "fmt" 5 "github.com/op/go-logging" 6 "strings" 7 "sync" 8 "time" 9 ) 10 11 type SubscriptionRequest struct { 12 SubID string `json:"subId"` 13 Event string `json:"event"` 14 15 // authenticated 16 APIKey string `json:"apiKey,omitempty"` 17 AuthSig string `json:"authSig,omitempty"` 18 AuthPayload string `json:"authPayload,omitempty"` 19 AuthNonce string `json:"authNonce,omitempty"` 20 Filter []string `json:"filter,omitempty"` 21 DMS int `json:"dms,omitempty"` // dead man switch 22 23 // unauthenticated 24 Channel string `json:"channel,omitempty"` 25 Symbol string `json:"symbol,omitempty"` 26 Precision string `json:"prec,omitempty"` 27 Frequency string `json:"freq,omitempty"` 28 Key string `json:"key,omitempty"` 29 Len string `json:"len,omitempty"` 30 Pair string `json:"pair,omitempty"` 31 } 32 33 const MaxChannels = 25 34 35 func (s *SubscriptionRequest) String() string { 36 if s.Key == "" && s.Channel != "" && s.Symbol != "" { 37 return fmt.Sprintf("%s %s", s.Channel, s.Symbol) 38 } 39 if s.Channel != "" && s.Symbol != "" && s.Precision != "" && s.Frequency != "" { 40 return fmt.Sprintf("%s %s %s %s", s.Channel, s.Symbol, s.Precision, s.Frequency) 41 } 42 if s.Channel != "" && s.Symbol != "" { 43 return fmt.Sprintf("%s %s", s.Channel, s.Symbol) 44 } 45 return "" 46 } 47 48 type HeartbeatDisconnect struct { 49 Subscription *subscription 50 Error error 51 } 52 53 type UnsubscribeRequest struct { 54 Event string `json:"event"` 55 ChanID int64 `json:"chanId"` 56 } 57 58 type subscription struct { 59 ChanID int64 60 SocketId SocketId 61 pending bool 62 Public bool 63 64 Request *SubscriptionRequest 65 66 hbDeadline time.Time 67 } 68 69 func isPublic(request *SubscriptionRequest) bool { 70 switch request.Channel { 71 case ChanBook: 72 return true 73 case ChanCandles: 74 return true 75 case ChanTicker: 76 return true 77 case ChanTrades: 78 return true 79 case ChanStatus: 80 return true 81 } 82 return false 83 } 84 85 func newSubscription(socketId SocketId, request *SubscriptionRequest) *subscription { 86 return &subscription{ 87 ChanID: -1, 88 SocketId: socketId, 89 Request: request, 90 pending: true, 91 Public: isPublic(request), 92 } 93 } 94 95 func (s subscription) SubID() string { 96 return s.Request.SubID 97 } 98 99 func (s subscription) Pending() bool { 100 return s.pending 101 } 102 103 func newSubscriptions(heartbeatTimeout time.Duration, log *logging.Logger) *subscriptions { 104 subs := &subscriptions{ 105 subsBySubID: make(map[string]*subscription), 106 subsByChanID: make(map[int64]*subscription), 107 subsBySocketId: make(map[SocketId]SubscriptionSet), 108 hbTimeout: heartbeatTimeout, 109 hbShutdown: make(chan struct{}), 110 hbDisconnect: make(chan HeartbeatDisconnect), 111 hbSleep: heartbeatTimeout / time.Duration(4), 112 log: log, 113 lock: &sync.RWMutex{}, 114 } 115 go subs.control() 116 return subs 117 } 118 119 // nolint 120 type heartbeat struct { 121 ChanID int64 122 *time.Time 123 } 124 125 type subscriptions struct { 126 lock *sync.RWMutex 127 log *logging.Logger 128 129 subsBySocketId map[SocketId]SubscriptionSet // subscripts map indexed by socket id 130 subsBySubID map[string]*subscription // subscription map indexed by subscription ID 131 subsByChanID map[int64]*subscription // subscription map indexed by channel ID 132 133 hbActive bool 134 hbDisconnect chan HeartbeatDisconnect // disconnect parent due to heartbeat timeout 135 hbTimeout time.Duration 136 hbSleep time.Duration 137 hbShutdown chan struct{} 138 } 139 140 // SubscriptionSet is a typed version of an array of subscription pointers, intended to meet the sortable interface. 141 // We need to sort Reset()'s return values for tests with more than 1 subscription (range map order is undefined) 142 type SubscriptionSet []*subscription 143 144 func (s SubscriptionSet) Len() int { 145 return len(s) 146 } 147 func (s SubscriptionSet) Less(i, j int) bool { 148 return strings.Compare(s[i].SubID(), s[j].SubID()) < 0 149 } 150 func (s SubscriptionSet) Swap(i, j int) { 151 s[i], s[j] = s[j], s[i] 152 } 153 func (s SubscriptionSet) RemoveByChannelId(chanId int64) SubscriptionSet { 154 rIndex := -1 155 for i, sub := range s { 156 if sub.ChanID == chanId { 157 rIndex = i 158 break 159 } 160 } 161 if rIndex >= 0 { 162 return append(s[:rIndex], s[rIndex+1:]...) 163 } 164 return s 165 } 166 167 func (s SubscriptionSet) RemoveBySubscriptionId(subID string) SubscriptionSet { 168 rIndex := -1 169 for i, sub := range s { 170 if sub.SubID() == subID { 171 rIndex = i 172 break 173 } 174 } 175 if rIndex >= 0 { 176 return append(s[:rIndex], s[rIndex+1:]...) 177 } 178 return s 179 } 180 181 func (s *subscriptions) heartbeat(chanID int64) { 182 s.lock.Lock() 183 defer s.lock.Unlock() 184 if sub, ok := s.subsByChanID[chanID]; ok { 185 sub.hbDeadline = time.Now().Add(s.hbTimeout) 186 } 187 } 188 189 func (s *subscriptions) sweep(exp time.Time) { 190 s.lock.RLock() 191 if !s.hbActive { 192 s.lock.RUnlock() 193 return 194 } 195 disconnects := make([]HeartbeatDisconnect, 0) 196 for _, sub := range s.subsByChanID { 197 if exp.After(sub.hbDeadline) { 198 s.hbActive = false 199 hbErr := HeartbeatDisconnect{ 200 Subscription: sub, 201 Error: fmt.Errorf("heartbeat disconnect on channel %d expired at %s (%s timeout)", sub.ChanID, sub.hbDeadline, s.hbTimeout), 202 } 203 disconnects = append(disconnects, hbErr) 204 } 205 } 206 s.lock.RUnlock() 207 for _, dis := range disconnects { 208 s.hbDisconnect <- dis 209 } 210 } 211 212 func (s *subscriptions) control() { 213 for { 214 select { 215 case <-s.hbShutdown: 216 return 217 default: 218 } 219 s.sweep(time.Now()) 220 time.Sleep(s.hbSleep) 221 } 222 } 223 224 // Close is terminal. Do not call heartbeat after close. 225 func (s *subscriptions) Close() { 226 s.ResetAll() 227 close(s.hbShutdown) 228 } 229 230 231 // Reset clears all subscriptions assigned to the given socket ID, and returns 232 // a slice of the existing subscriptions prior to reset 233 func (s *subscriptions) ResetSocketSubscriptions(socketId SocketId) []*subscription { 234 var retSubs []*subscription 235 s.lock.Lock() 236 if set, ok := s.subsBySocketId[socketId]; ok { 237 for _, sub := range set { 238 retSubs = append(retSubs, sub) 239 // remove from chanId array 240 delete(s.subsByChanID, sub.ChanID) 241 // remove from subId array 242 delete(s.subsBySubID, sub.SubID()) 243 } 244 } 245 s.subsBySocketId[socketId] = make(SubscriptionSet, 0) 246 s.lock.Unlock() 247 return retSubs 248 } 249 250 // Removes all tracked subscriptions 251 func (s *subscriptions) ResetAll() { 252 s.lock.Lock() 253 s.subsBySubID = make(map[string]*subscription) 254 s.subsByChanID = make(map[int64]*subscription) 255 s.subsBySocketId = make(map[SocketId]SubscriptionSet) 256 s.lock.Unlock() 257 } 258 259 // ListenDisconnect returns an error channel which receives a message when a heartbeat has expired a channel. 260 func (s *subscriptions) ListenDisconnect() <-chan HeartbeatDisconnect { 261 return s.hbDisconnect 262 } 263 264 func (s *subscriptions) add(socketId SocketId, sub *SubscriptionRequest) *subscription { 265 s.lock.Lock() 266 defer s.lock.Unlock() 267 subscription := newSubscription(socketId, sub) 268 s.subsBySubID[sub.SubID] = subscription 269 if _, ok := s.subsBySocketId[socketId]; !ok { 270 s.subsBySocketId[socketId] = make(SubscriptionSet, 0) 271 } 272 s.subsBySocketId[socketId] = append(s.subsBySocketId[socketId], subscription) 273 return subscription 274 } 275 276 func (s *subscriptions) removeByChannelID(chanID int64) error { 277 s.lock.Lock() 278 defer s.lock.Unlock() 279 // remove from socketId map 280 sub, ok := s.subsByChanID[chanID] 281 if !ok { 282 return fmt.Errorf("could not find channel ID %d", chanID) 283 } 284 delete(s.subsByChanID, chanID) 285 delete(s.subsBySubID, sub.SubID()) 286 // remove from socket map 287 if _, ok := s.subsBySocketId[sub.SocketId]; ok { 288 s.subsBySocketId[sub.SocketId] = s.subsBySocketId[sub.SocketId].RemoveByChannelId(chanID) 289 } 290 return nil 291 } 292 293 // nolint:megacheck 294 func (s *subscriptions) removeBySubscriptionID(subID string) error { 295 s.lock.Lock() 296 defer s.lock.Unlock() 297 sub, ok := s.subsBySubID[subID] 298 if !ok { 299 return fmt.Errorf("could not find subscription ID %s to remove", subID) 300 } 301 // exists, remove both indices 302 delete(s.subsBySubID, subID) 303 delete(s.subsByChanID, sub.ChanID) 304 // remove from socket map 305 if _, ok := s.subsBySocketId[sub.SocketId]; ok { 306 s.subsBySocketId[sub.SocketId] = s.subsBySocketId[sub.SocketId].RemoveBySubscriptionId(subID) 307 } 308 return nil 309 } 310 311 func (s *subscriptions) activate(subID string, chanID int64) error { 312 s.lock.Lock() 313 defer s.lock.Unlock() 314 315 if sub, ok := s.subsBySubID[subID]; ok { 316 if chanID != 0 { 317 s.log.Infof("activated subscription %s %s for channel %d", sub.Request.Channel, sub.Request.Symbol, chanID) 318 } 319 sub.pending = false 320 sub.ChanID = chanID 321 sub.hbDeadline = time.Now().Add(s.hbTimeout) 322 s.subsByChanID[chanID] = sub 323 s.hbActive = true 324 return nil 325 } 326 327 return fmt.Errorf("could not find subscription ID %s to activate", subID) 328 } 329 330 func (s *subscriptions) lookupBySocketChannelID(chanID int64, sId SocketId) (*subscription, error) { 331 s.lock.RLock() 332 defer s.lock.RUnlock() 333 if subs, ok := s.subsBySocketId[sId]; ok { 334 for _, s := range subs { 335 if s.ChanID == chanID { 336 return s, nil 337 } 338 } 339 } 340 return nil, fmt.Errorf("could not find subscription for channel ID %d and socket sId %d", chanID, sId) 341 } 342 343 func (s *subscriptions) lookupBySubscriptionID(subID string) (*subscription, error) { 344 s.lock.RLock() 345 defer s.lock.RUnlock() 346 if sub, ok := s.subsBySubID[subID]; ok { 347 return sub, nil 348 } 349 return nil, fmt.Errorf("could not find subscription ID %s", subID) 350 } 351 352 func (s *subscriptions) lookupBySocketId(socketId SocketId) (*SubscriptionSet, error) { 353 s.lock.RLock() 354 defer s.lock.RUnlock() 355 if set, ok := s.subsBySocketId[socketId]; ok { 356 return &set, nil 357 } 358 return nil, fmt.Errorf("could not find subscription with socketId %d", socketId) 359 }