code.pfad.fr/gohmekit@v0.2.1/hapip/notification/manager.go (about)

     1  package notification
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/go-kit/log"
    13  )
    14  
    15  // Manager manages notifications.
    16  // m.ConnContext & m.ConnState must be registered on the http.Server respective fields.
    17  type Manager[Key comparable] struct {
    18  	// If a connection is active, events will not be sent (to prevent corruption).
    19  	// EventBuffer indicates the number of events to buffer (per connection),
    20  	// which will be delivered when the connection is back to idle.
    21  	EventBuffer int
    22  
    23  	// CoalesceDuration will prevent too many events from being delivered after another.
    24  	// Standard recommands 1s
    25  	CoalesceDuration time.Duration
    26  	Logger           log.Logger
    27  
    28  	muConn sync.Mutex
    29  	conn   map[net.Conn]*notifier
    30  
    31  	muSubscribers sync.Mutex
    32  	subscribers   map[Key]map[net.Conn]bool
    33  }
    34  
    35  func (m *Manager[Key]) HookConnEvents(server *http.Server) {
    36  	if previous := server.ConnState; previous != nil {
    37  		server.ConnState = func(c net.Conn, cs http.ConnState) {
    38  			previous(c, cs)
    39  			m.ConnState(c, cs)
    40  		}
    41  	} else {
    42  		server.ConnState = m.ConnState
    43  	}
    44  
    45  	if previous := server.ConnContext; previous != nil {
    46  		server.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
    47  			ctx = previous(ctx, c)
    48  			return m.ConnContext(ctx, c)
    49  		}
    50  	} else {
    51  		server.ConnContext = m.ConnContext
    52  	}
    53  }
    54  
    55  type contextKey string
    56  
    57  var contextKeyConn = contextKey("conn")
    58  
    59  func (m *Manager[Key]) ConnContext(ctx context.Context, c net.Conn) context.Context {
    60  	nf := &notifier{
    61  		conn:             c,
    62  		stateCh:          make(chan http.ConnState), // not buffered, to be blocking on ConnState
    63  		eventCh:          make(chan io.WriterTo, m.EventBuffer),
    64  		coalesceDuration: m.CoalesceDuration,
    65  		logger:           m.Logger,
    66  	}
    67  
    68  	m.muConn.Lock()
    69  	if m.conn == nil {
    70  		m.conn = map[net.Conn]*notifier{c: nf}
    71  	} else {
    72  		m.conn[c] = nf
    73  	}
    74  	m.muConn.Unlock()
    75  
    76  	go func() {
    77  		defer m.removeFromSubscribers(c)
    78  		defer m.removeFromConns(c)
    79  		nf.run(c)
    80  	}()
    81  
    82  	return context.WithValue(ctx, contextKeyConn, c)
    83  }
    84  
    85  func getConnValue(ctx context.Context) net.Conn {
    86  	conn, _ := ctx.Value(contextKeyConn).(net.Conn)
    87  	return conn
    88  }
    89  
    90  func (m *Manager[Key]) ConnState(c net.Conn, s http.ConnState) {
    91  	if s == http.StateNew {
    92  		return
    93  	}
    94  	m.muConn.Lock()
    95  	nf := m.conn[c]
    96  	m.muConn.Unlock()
    97  
    98  	nf.stateCh <- s
    99  }
   100  
   101  var errNoConn = errors.New("net.Conn not found in ctx")
   102  
   103  func (m *Manager[Key]) Subscribe(ctx context.Context, id Key, v bool) error {
   104  	c := getConnValue(ctx)
   105  	if c == nil {
   106  		return errNoConn
   107  	}
   108  
   109  	m.muSubscribers.Lock()
   110  	defer m.muSubscribers.Unlock()
   111  
   112  	switch {
   113  	case m.subscribers == nil:
   114  		m.subscribers = map[Key]map[net.Conn]bool{
   115  			id: {c: v},
   116  		}
   117  	case m.subscribers[id] == nil:
   118  		m.subscribers[id] = map[net.Conn]bool{c: v}
   119  	default:
   120  		m.subscribers[id][c] = v
   121  	}
   122  
   123  	return nil
   124  }
   125  func (m *Manager[Key]) IsSubscribed(ctx context.Context, id Key) (bool, error) {
   126  	c := getConnValue(ctx)
   127  	if c == nil {
   128  		return false, errNoConn
   129  	}
   130  
   131  	m.muSubscribers.Lock()
   132  	defer m.muSubscribers.Unlock()
   133  
   134  	if m.subscribers == nil || m.subscribers[id] == nil {
   135  		return false, nil
   136  	}
   137  
   138  	return m.subscribers[id][c], nil
   139  }
   140  
   141  // Coalescer allows event to be merged together to prevent
   142  // sending too many events in a short time.
   143  type Coalescer interface {
   144  	io.WriterTo
   145  
   146  	// Coalescer must return itself on error (not nil!)
   147  	Coalesce(io.WriterTo) (Coalescer, error)
   148  }
   149  
   150  // Publish an event. If the event is a Coalescer it might be throtteled.
   151  func (m *Manager[Key]) Publish(ctx context.Context, id Key, event io.WriterTo) {
   152  	publisher := getConnValue(ctx)
   153  	conns := m.subscribedConns(id, publisher)
   154  	if len(conns) == 0 {
   155  		return
   156  	}
   157  
   158  	notifiers := m.notifiers(conns)
   159  
   160  	for _, nf := range notifiers {
   161  		select {
   162  		case nf.eventCh <- event:
   163  		default:
   164  		}
   165  	}
   166  }
   167  
   168  func (m *Manager[Key]) notifiers(conns []net.Conn) []*notifier {
   169  	notifiers := make([]*notifier, 0, len(conns))
   170  
   171  	m.muConn.Lock()
   172  	defer m.muConn.Unlock()
   173  
   174  	for _, c := range conns {
   175  		nf := m.conn[c]
   176  		if nf == nil {
   177  			continue
   178  		}
   179  		notifiers = append(notifiers, nf)
   180  	}
   181  	return notifiers
   182  }
   183  
   184  func (m *Manager[Key]) subscribedConns(id Key, publisher net.Conn) []net.Conn {
   185  	m.muSubscribers.Lock()
   186  	defer m.muSubscribers.Unlock()
   187  
   188  	if m.subscribers == nil || m.subscribers[id] == nil {
   189  		return nil
   190  	}
   191  
   192  	var conns []net.Conn //nolint:prealloc
   193  	for c, v := range m.subscribers[id] {
   194  		if !v || c == publisher {
   195  			continue
   196  		}
   197  		conns = append(conns, c)
   198  	}
   199  	return conns
   200  }
   201  
   202  type notifier struct {
   203  	conn             net.Conn
   204  	stateCh          chan http.ConnState
   205  	eventCh          chan io.WriterTo
   206  	coalesceDuration time.Duration
   207  	logger           log.Logger
   208  }
   209  
   210  func (n *notifier) run(c net.Conn) {
   211  	state := http.StateNew
   212  	var lastEvent time.Time
   213  	for {
   214  		// wait for the connection to be idle...
   215  		for state != http.StateIdle {
   216  			if state == http.StateClosed || state == http.StateHijacked {
   217  				return
   218  			}
   219  			state = <-n.stateCh
   220  		}
   221  
   222  		select {
   223  		case state = <-n.stateCh:
   224  		case event := <-n.eventCh:
   225  
   226  			nextEventDelay := n.coalesceDuration - time.Since(lastEvent)
   227  			if nextEventDelay > 0 {
   228  				delayExpired := time.After(nextEventDelay)
   229  				acc, canWait := event.(Coalescer)
   230  				for canWait {
   231  					canWait = false
   232  					n.logger.Log("action", "coalesce", "duration", nextEventDelay)
   233  					select {
   234  					case state = <-n.stateCh: // send event right away and hope for the best...
   235  					case <-delayExpired:
   236  					case ev := <-n.eventCh:
   237  						var err error
   238  						acc, err = acc.Coalesce(ev)
   239  						if err != nil {
   240  							n.logger.Log("action", "coalesce", "err", err)
   241  							_, err := ev.WriteTo(c)
   242  							if err != nil {
   243  								n.logger.Log("action", "coalesce-fallback", "err", err)
   244  							}
   245  							delayExpired = time.After(n.coalesceDuration)
   246  						}
   247  						_, canWait = ev.(Coalescer)
   248  					}
   249  					event = acc
   250  				}
   251  			}
   252  
   253  			_, err := event.WriteTo(c)
   254  			lastEvent = time.Now()
   255  			n.logger.Log("action", "write", "err", err)
   256  		}
   257  	}
   258  }
   259  
   260  func (m *Manager[Key]) removeFromSubscribers(c net.Conn) {
   261  	m.muSubscribers.Lock()
   262  	defer m.muSubscribers.Unlock()
   263  
   264  	if m.subscribers == nil {
   265  		return
   266  	}
   267  
   268  	for _, conns := range m.subscribers {
   269  		delete(conns, c)
   270  	}
   271  }
   272  
   273  func (m *Manager[Key]) removeFromConns(c net.Conn) {
   274  	m.muConn.Lock()
   275  	m.Logger.Log("action", "connection removal")
   276  	delete(m.conn, c)
   277  	m.muConn.Unlock()
   278  }