github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/db/notifications_bus.go (about)

     1  package db
     2  
     3  import (
     4  	"database/sql"
     5  	"sync"
     6  
     7  	"github.com/lib/pq"
     8  )
     9  
    10  //go:generate counterfeiter . Listener
    11  
    12  type Listener interface {
    13  	Close() error
    14  	Listen(channel string) error
    15  	Unlisten(channel string) error
    16  	NotificationChannel() <-chan *pq.Notification
    17  }
    18  
    19  //go:generate counterfeiter . Executor
    20  
    21  type Executor interface {
    22  	Exec(statement string, args ...interface{}) (sql.Result, error)
    23  }
    24  
    25  type NotificationsBus interface {
    26  	Notify(channel string) error
    27  	Listen(channel string) (chan bool, error)
    28  	Unlisten(channel string, notify chan bool) error
    29  	Close() error
    30  }
    31  
    32  type notificationsBus struct {
    33  	sync.Mutex
    34  
    35  	listener Listener
    36  	executor Executor
    37  
    38  	notifications *notificationsMap
    39  }
    40  
    41  func NewNotificationsBus(listener Listener, executor Executor) *notificationsBus {
    42  	bus := &notificationsBus{
    43  		listener:      listener,
    44  		executor:      executor,
    45  		notifications: newNotificationsMap(),
    46  	}
    47  
    48  	go bus.wait()
    49  
    50  	return bus
    51  }
    52  
    53  func (bus *notificationsBus) Close() error {
    54  	return bus.listener.Close()
    55  }
    56  
    57  func (bus *notificationsBus) Notify(channel string) error {
    58  	_, err := bus.executor.Exec("NOTIFY " + channel)
    59  	return err
    60  }
    61  
    62  func (bus *notificationsBus) Listen(channel string) (chan bool, error) {
    63  	bus.Lock()
    64  	defer bus.Unlock()
    65  
    66  	if bus.notifications.empty(channel) {
    67  		err := bus.listener.Listen(channel)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  	}
    72  
    73  	notify := make(chan bool, 1)
    74  	bus.notifications.register(channel, notify)
    75  	return notify, nil
    76  }
    77  
    78  func (bus *notificationsBus) Unlisten(channel string, notify chan bool) error {
    79  	bus.Lock()
    80  	defer bus.Unlock()
    81  
    82  	bus.notifications.unregister(channel, notify)
    83  
    84  	if bus.notifications.empty(channel) {
    85  		return bus.listener.Unlisten(channel)
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  func (bus *notificationsBus) wait() {
    92  	for {
    93  		notification, ok := <-bus.listener.NotificationChannel()
    94  		if !ok {
    95  			break
    96  		}
    97  
    98  		if notification != nil {
    99  			bus.handleNotification(notification)
   100  		} else {
   101  			bus.handleReconnect()
   102  		}
   103  	}
   104  }
   105  
   106  func (bus *notificationsBus) handleNotification(notification *pq.Notification) {
   107  	// alert any relevant listeners of notification being received
   108  	// (nonblocking)
   109  	bus.notifications.eachForChannel(notification.Channel, func(sink chan bool) {
   110  		select {
   111  		case sink <- true:
   112  			// notified of message being received (or queued up)
   113  		default:
   114  			// already had notification queued up; no need to handle it twice
   115  		}
   116  	})
   117  }
   118  
   119  func (bus *notificationsBus) handleReconnect() {
   120  	// alert all listeners of connection break so they can check for things
   121  	// they may have missed
   122  	bus.notifications.each(func(sink chan bool) {
   123  		select {
   124  		case sink <- false:
   125  			// notify that connection was lost, so listener can check for
   126  			// things that may have changed while connection was lost
   127  		default:
   128  			// already had notification queued up; no need to check for
   129  			// anything missed since something will be notified anyway
   130  		}
   131  	})
   132  }
   133  
   134  func newNotificationsMap() *notificationsMap {
   135  	return &notificationsMap{
   136  		notifications: map[string]map[chan bool]struct{}{},
   137  	}
   138  }
   139  
   140  type notificationsMap struct {
   141  	sync.RWMutex
   142  
   143  	notifications map[string]map[chan bool]struct{}
   144  }
   145  
   146  func (m *notificationsMap) empty(channel string) bool {
   147  	m.RLock()
   148  	defer m.RUnlock()
   149  
   150  	return len(m.notifications[channel]) == 0
   151  }
   152  
   153  func (m *notificationsMap) register(channel string, notify chan bool) {
   154  	m.Lock()
   155  	defer m.Unlock()
   156  
   157  	sinks, found := m.notifications[channel]
   158  	if !found {
   159  		sinks = map[chan bool]struct{}{}
   160  		m.notifications[channel] = sinks
   161  	}
   162  
   163  	sinks[notify] = struct{}{}
   164  }
   165  
   166  func (m *notificationsMap) unregister(channel string, notify chan bool) {
   167  	m.Lock()
   168  	defer m.Unlock()
   169  
   170  	delete(m.notifications[channel], notify)
   171  }
   172  
   173  func (m *notificationsMap) each(f func(chan bool)) {
   174  	m.RLock()
   175  	defer m.RUnlock()
   176  
   177  	for _, sinks := range m.notifications {
   178  		for sink := range sinks {
   179  			f(sink)
   180  		}
   181  	}
   182  }
   183  
   184  func (m *notificationsMap) eachForChannel(channel string, f func(chan bool)) {
   185  	m.RLock()
   186  	defer m.RUnlock()
   187  
   188  	for sink := range m.notifications[channel] {
   189  		f(sink)
   190  	}
   191  }