code.pfad.fr/gohmekit@v0.2.1/hapip/notification/manager.go (about) 1 package notification 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "net" 8 "net/http" 9 "sync" 10 "time" 11 12 "github.com/go-kit/log" 13 ) 14 15 // Manager manages notifications. 16 // m.ConnContext & m.ConnState must be registered on the http.Server respective fields. 17 type Manager[Key comparable] struct { 18 // If a connection is active, events will not be sent (to prevent corruption). 19 // EventBuffer indicates the number of events to buffer (per connection), 20 // which will be delivered when the connection is back to idle. 21 EventBuffer int 22 23 // CoalesceDuration will prevent too many events from being delivered after another. 24 // Standard recommands 1s 25 CoalesceDuration time.Duration 26 Logger log.Logger 27 28 muConn sync.Mutex 29 conn map[net.Conn]*notifier 30 31 muSubscribers sync.Mutex 32 subscribers map[Key]map[net.Conn]bool 33 } 34 35 func (m *Manager[Key]) HookConnEvents(server *http.Server) { 36 if previous := server.ConnState; previous != nil { 37 server.ConnState = func(c net.Conn, cs http.ConnState) { 38 previous(c, cs) 39 m.ConnState(c, cs) 40 } 41 } else { 42 server.ConnState = m.ConnState 43 } 44 45 if previous := server.ConnContext; previous != nil { 46 server.ConnContext = func(ctx context.Context, c net.Conn) context.Context { 47 ctx = previous(ctx, c) 48 return m.ConnContext(ctx, c) 49 } 50 } else { 51 server.ConnContext = m.ConnContext 52 } 53 } 54 55 type contextKey string 56 57 var contextKeyConn = contextKey("conn") 58 59 func (m *Manager[Key]) ConnContext(ctx context.Context, c net.Conn) context.Context { 60 nf := ¬ifier{ 61 conn: c, 62 stateCh: make(chan http.ConnState), // not buffered, to be blocking on ConnState 63 eventCh: make(chan io.WriterTo, m.EventBuffer), 64 coalesceDuration: m.CoalesceDuration, 65 logger: m.Logger, 66 } 67 68 m.muConn.Lock() 69 if m.conn == nil { 70 m.conn = map[net.Conn]*notifier{c: nf} 71 } else { 72 m.conn[c] = nf 73 } 74 m.muConn.Unlock() 75 76 go func() { 77 defer m.removeFromSubscribers(c) 78 defer m.removeFromConns(c) 79 nf.run(c) 80 }() 81 82 return context.WithValue(ctx, contextKeyConn, c) 83 } 84 85 func getConnValue(ctx context.Context) net.Conn { 86 conn, _ := ctx.Value(contextKeyConn).(net.Conn) 87 return conn 88 } 89 90 func (m *Manager[Key]) ConnState(c net.Conn, s http.ConnState) { 91 if s == http.StateNew { 92 return 93 } 94 m.muConn.Lock() 95 nf := m.conn[c] 96 m.muConn.Unlock() 97 98 nf.stateCh <- s 99 } 100 101 var errNoConn = errors.New("net.Conn not found in ctx") 102 103 func (m *Manager[Key]) Subscribe(ctx context.Context, id Key, v bool) error { 104 c := getConnValue(ctx) 105 if c == nil { 106 return errNoConn 107 } 108 109 m.muSubscribers.Lock() 110 defer m.muSubscribers.Unlock() 111 112 switch { 113 case m.subscribers == nil: 114 m.subscribers = map[Key]map[net.Conn]bool{ 115 id: {c: v}, 116 } 117 case m.subscribers[id] == nil: 118 m.subscribers[id] = map[net.Conn]bool{c: v} 119 default: 120 m.subscribers[id][c] = v 121 } 122 123 return nil 124 } 125 func (m *Manager[Key]) IsSubscribed(ctx context.Context, id Key) (bool, error) { 126 c := getConnValue(ctx) 127 if c == nil { 128 return false, errNoConn 129 } 130 131 m.muSubscribers.Lock() 132 defer m.muSubscribers.Unlock() 133 134 if m.subscribers == nil || m.subscribers[id] == nil { 135 return false, nil 136 } 137 138 return m.subscribers[id][c], nil 139 } 140 141 // Coalescer allows event to be merged together to prevent 142 // sending too many events in a short time. 143 type Coalescer interface { 144 io.WriterTo 145 146 // Coalescer must return itself on error (not nil!) 147 Coalesce(io.WriterTo) (Coalescer, error) 148 } 149 150 // Publish an event. If the event is a Coalescer it might be throtteled. 151 func (m *Manager[Key]) Publish(ctx context.Context, id Key, event io.WriterTo) { 152 publisher := getConnValue(ctx) 153 conns := m.subscribedConns(id, publisher) 154 if len(conns) == 0 { 155 return 156 } 157 158 notifiers := m.notifiers(conns) 159 160 for _, nf := range notifiers { 161 select { 162 case nf.eventCh <- event: 163 default: 164 } 165 } 166 } 167 168 func (m *Manager[Key]) notifiers(conns []net.Conn) []*notifier { 169 notifiers := make([]*notifier, 0, len(conns)) 170 171 m.muConn.Lock() 172 defer m.muConn.Unlock() 173 174 for _, c := range conns { 175 nf := m.conn[c] 176 if nf == nil { 177 continue 178 } 179 notifiers = append(notifiers, nf) 180 } 181 return notifiers 182 } 183 184 func (m *Manager[Key]) subscribedConns(id Key, publisher net.Conn) []net.Conn { 185 m.muSubscribers.Lock() 186 defer m.muSubscribers.Unlock() 187 188 if m.subscribers == nil || m.subscribers[id] == nil { 189 return nil 190 } 191 192 var conns []net.Conn //nolint:prealloc 193 for c, v := range m.subscribers[id] { 194 if !v || c == publisher { 195 continue 196 } 197 conns = append(conns, c) 198 } 199 return conns 200 } 201 202 type notifier struct { 203 conn net.Conn 204 stateCh chan http.ConnState 205 eventCh chan io.WriterTo 206 coalesceDuration time.Duration 207 logger log.Logger 208 } 209 210 func (n *notifier) run(c net.Conn) { 211 state := http.StateNew 212 var lastEvent time.Time 213 for { 214 // wait for the connection to be idle... 215 for state != http.StateIdle { 216 if state == http.StateClosed || state == http.StateHijacked { 217 return 218 } 219 state = <-n.stateCh 220 } 221 222 select { 223 case state = <-n.stateCh: 224 case event := <-n.eventCh: 225 226 nextEventDelay := n.coalesceDuration - time.Since(lastEvent) 227 if nextEventDelay > 0 { 228 delayExpired := time.After(nextEventDelay) 229 acc, canWait := event.(Coalescer) 230 for canWait { 231 canWait = false 232 n.logger.Log("action", "coalesce", "duration", nextEventDelay) 233 select { 234 case state = <-n.stateCh: // send event right away and hope for the best... 235 case <-delayExpired: 236 case ev := <-n.eventCh: 237 var err error 238 acc, err = acc.Coalesce(ev) 239 if err != nil { 240 n.logger.Log("action", "coalesce", "err", err) 241 _, err := ev.WriteTo(c) 242 if err != nil { 243 n.logger.Log("action", "coalesce-fallback", "err", err) 244 } 245 delayExpired = time.After(n.coalesceDuration) 246 } 247 _, canWait = ev.(Coalescer) 248 } 249 event = acc 250 } 251 } 252 253 _, err := event.WriteTo(c) 254 lastEvent = time.Now() 255 n.logger.Log("action", "write", "err", err) 256 } 257 } 258 } 259 260 func (m *Manager[Key]) removeFromSubscribers(c net.Conn) { 261 m.muSubscribers.Lock() 262 defer m.muSubscribers.Unlock() 263 264 if m.subscribers == nil { 265 return 266 } 267 268 for _, conns := range m.subscribers { 269 delete(conns, c) 270 } 271 } 272 273 func (m *Manager[Key]) removeFromConns(c net.Conn) { 274 m.muConn.Lock() 275 m.Logger.Log("action", "connection removal") 276 delete(m.conn, c) 277 m.muConn.Unlock() 278 }