code.vegaprotocol.io/vega@v0.79.0/datanode/utils/observer.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package utils 17 18 import ( 19 "context" 20 "fmt" 21 "sync" 22 "sync/atomic" 23 "time" 24 25 "code.vegaprotocol.io/vega/datanode/contextutil" 26 "code.vegaprotocol.io/vega/logging" 27 ) 28 29 type subscriber[T any] struct { 30 ch chan []T 31 } 32 33 type Observer[T any] struct { 34 subCount atomic.Int32 35 lastSubID uint64 36 name string 37 log *logging.Logger 38 subscribers map[uint64]subscriber[T] 39 mut sync.RWMutex 40 inChSize int 41 outChSize int 42 } 43 44 func NewObserver[T any](name string, log *logging.Logger, inChSize, outChSize int) Observer[T] { 45 return Observer[T]{ 46 name: name, 47 log: log, 48 subscribers: map[uint64]subscriber[T]{}, 49 inChSize: inChSize, 50 outChSize: outChSize, 51 } 52 } 53 54 func (o *Observer[T]) Subscribe(ctx context.Context, filter func(T) bool) (chan []T, uint64) { 55 o.mut.Lock() 56 defer o.mut.Unlock() 57 58 ch := make(chan []T, o.inChSize) 59 o.lastSubID++ 60 o.subscribers[o.lastSubID] = subscriber[T]{ch} 61 62 ip, _ := contextutil.RemoteIPAddrFromContext(ctx) 63 o.logDebug(ip, o.lastSubID, "new subscription") 64 return ch, o.lastSubID 65 } 66 67 func (o *Observer[T]) Unsubscribe(ctx context.Context, ref uint64) error { 68 o.mut.Lock() 69 defer o.mut.Unlock() 70 71 ip, _ := contextutil.RemoteIPAddrFromContext(ctx) 72 73 if len(o.subscribers) == 0 { 74 o.logDebug(ip, ref, "un-subscribe called but, no subscribers connected") 75 return nil 76 } 77 78 if sub, exists := o.subscribers[ref]; exists { 79 close(sub.ch) 80 delete(o.subscribers, ref) 81 return nil 82 } 83 84 return fmt.Errorf("no subscriber with id: %d", ref) 85 } 86 87 func (o *Observer[T]) GetSubscribersCount() int32 { 88 return o.subCount.Load() 89 } 90 91 func (o *Observer[T]) Notify(values []T) { 92 o.mut.Lock() 93 defer o.mut.Unlock() 94 95 if len(o.subscribers) == 0 { 96 return 97 } 98 99 if len(values) == 0 { 100 return 101 } 102 103 for id, sub := range o.subscribers { 104 select { 105 case sub.ch <- values: 106 o.logDebug("", id, "channel updated successfully") 107 default: 108 o.logWarning("", id, "channel could not be updated, closing") 109 delete(o.subscribers, id) // safe to delete from map while iterating 110 close(sub.ch) 111 } 112 } 113 } 114 115 func (o *Observer[T]) Observe(ctx context.Context, retries int, filter func(T) bool) (<-chan []T, uint64) { 116 out := make(chan []T, o.outChSize) 117 in, ref := o.Subscribe(ctx, filter) 118 ip, _ := contextutil.RemoteIPAddrFromContext(ctx) 119 120 go func() { 121 o.subCount.Add(1) 122 defer o.subCount.Add(-1) 123 124 ctx, cancel := context.WithCancel(ctx) 125 defer cancel() 126 127 for { 128 select { 129 case <-ctx.Done(): 130 o.logDebug(ip, ref, "closed connection") 131 if err := o.Unsubscribe(ctx, ref); err != nil { 132 o.logError(ip, ref, "failure un-subscribing when context.Done()") 133 } 134 close(out) 135 return 136 137 case values, ok := <-in: 138 if !ok { 139 // 'in' channel may get closed because Notify() couldn't write to it 140 close(out) 141 return 142 } 143 144 filtered := make([]T, 0, len(values)) 145 for _, value := range values { 146 if filter(value) { 147 filtered = append(filtered, value) 148 } 149 } 150 if len(filtered) == 0 { 151 continue 152 } 153 retryCount := retries 154 success := false 155 for !success && retryCount >= 0 { 156 select { 157 case out <- filtered: 158 retryCount = retries 159 success = true 160 o.logDebug(ip, ref, "sent successfully") 161 default: 162 retryCount-- 163 if retryCount > 0 { 164 o.logDebug(ip, ref, "not sent, retrying") 165 } 166 time.Sleep(time.Duration(10) * time.Millisecond) 167 } 168 } 169 if !success && retryCount <= 0 { 170 o.logWarning(ip, ref, "hit the retry limit") 171 cancel() 172 } 173 } 174 } 175 }() 176 177 return out, ref 178 } 179 180 func (o *Observer[T]) logDebug(ip string, ref uint64, msg string) { 181 o.log.Debug( 182 fmt.Sprintf("%s subscriber: %s", o.name, msg), 183 logging.Uint64("id", ref), 184 logging.String("ip-address", ip), 185 ) 186 } 187 188 func (o *Observer[T]) logWarning(ip string, ref uint64, msg string) { 189 o.log.Warn( 190 fmt.Sprintf("%s subscriber: %s", o.name, msg), 191 logging.Uint64("id", ref), 192 logging.String("ip-address", ip), 193 ) 194 } 195 196 func (o *Observer[T]) logError(ip string, ref uint64, msg string) { 197 o.log.Error( 198 fmt.Sprintf("%s subscriber: %s", o.name, msg), 199 logging.Uint64("id", ref), 200 logging.String("ip-address", ip), 201 ) 202 }