code.vegaprotocol.io/vega@v0.79.0/datanode/utils/observer.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package utils
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"sync"
    22  	"sync/atomic"
    23  	"time"
    24  
    25  	"code.vegaprotocol.io/vega/datanode/contextutil"
    26  	"code.vegaprotocol.io/vega/logging"
    27  )
    28  
    29  type subscriber[T any] struct {
    30  	ch chan []T
    31  }
    32  
    33  type Observer[T any] struct {
    34  	subCount    atomic.Int32
    35  	lastSubID   uint64
    36  	name        string
    37  	log         *logging.Logger
    38  	subscribers map[uint64]subscriber[T]
    39  	mut         sync.RWMutex
    40  	inChSize    int
    41  	outChSize   int
    42  }
    43  
    44  func NewObserver[T any](name string, log *logging.Logger, inChSize, outChSize int) Observer[T] {
    45  	return Observer[T]{
    46  		name:        name,
    47  		log:         log,
    48  		subscribers: map[uint64]subscriber[T]{},
    49  		inChSize:    inChSize,
    50  		outChSize:   outChSize,
    51  	}
    52  }
    53  
    54  func (o *Observer[T]) Subscribe(ctx context.Context, filter func(T) bool) (chan []T, uint64) {
    55  	o.mut.Lock()
    56  	defer o.mut.Unlock()
    57  
    58  	ch := make(chan []T, o.inChSize)
    59  	o.lastSubID++
    60  	o.subscribers[o.lastSubID] = subscriber[T]{ch}
    61  
    62  	ip, _ := contextutil.RemoteIPAddrFromContext(ctx)
    63  	o.logDebug(ip, o.lastSubID, "new subscription")
    64  	return ch, o.lastSubID
    65  }
    66  
    67  func (o *Observer[T]) Unsubscribe(ctx context.Context, ref uint64) error {
    68  	o.mut.Lock()
    69  	defer o.mut.Unlock()
    70  
    71  	ip, _ := contextutil.RemoteIPAddrFromContext(ctx)
    72  
    73  	if len(o.subscribers) == 0 {
    74  		o.logDebug(ip, ref, "un-subscribe called but, no subscribers connected")
    75  		return nil
    76  	}
    77  
    78  	if sub, exists := o.subscribers[ref]; exists {
    79  		close(sub.ch)
    80  		delete(o.subscribers, ref)
    81  		return nil
    82  	}
    83  
    84  	return fmt.Errorf("no subscriber with id: %d", ref)
    85  }
    86  
    87  func (o *Observer[T]) GetSubscribersCount() int32 {
    88  	return o.subCount.Load()
    89  }
    90  
    91  func (o *Observer[T]) Notify(values []T) {
    92  	o.mut.Lock()
    93  	defer o.mut.Unlock()
    94  
    95  	if len(o.subscribers) == 0 {
    96  		return
    97  	}
    98  
    99  	if len(values) == 0 {
   100  		return
   101  	}
   102  
   103  	for id, sub := range o.subscribers {
   104  		select {
   105  		case sub.ch <- values:
   106  			o.logDebug("", id, "channel updated successfully")
   107  		default:
   108  			o.logWarning("", id, "channel could not be updated, closing")
   109  			delete(o.subscribers, id) // safe to delete from map while iterating
   110  			close(sub.ch)
   111  		}
   112  	}
   113  }
   114  
   115  func (o *Observer[T]) Observe(ctx context.Context, retries int, filter func(T) bool) (<-chan []T, uint64) {
   116  	out := make(chan []T, o.outChSize)
   117  	in, ref := o.Subscribe(ctx, filter)
   118  	ip, _ := contextutil.RemoteIPAddrFromContext(ctx)
   119  
   120  	go func() {
   121  		o.subCount.Add(1)
   122  		defer o.subCount.Add(-1)
   123  
   124  		ctx, cancel := context.WithCancel(ctx)
   125  		defer cancel()
   126  
   127  		for {
   128  			select {
   129  			case <-ctx.Done():
   130  				o.logDebug(ip, ref, "closed connection")
   131  				if err := o.Unsubscribe(ctx, ref); err != nil {
   132  					o.logError(ip, ref, "failure un-subscribing when context.Done()")
   133  				}
   134  				close(out)
   135  				return
   136  
   137  			case values, ok := <-in:
   138  				if !ok {
   139  					// 'in' channel may get closed because Notify() couldn't write to it
   140  					close(out)
   141  					return
   142  				}
   143  
   144  				filtered := make([]T, 0, len(values))
   145  				for _, value := range values {
   146  					if filter(value) {
   147  						filtered = append(filtered, value)
   148  					}
   149  				}
   150  				if len(filtered) == 0 {
   151  					continue
   152  				}
   153  				retryCount := retries
   154  				success := false
   155  				for !success && retryCount >= 0 {
   156  					select {
   157  					case out <- filtered:
   158  						retryCount = retries
   159  						success = true
   160  						o.logDebug(ip, ref, "sent successfully")
   161  					default:
   162  						retryCount--
   163  						if retryCount > 0 {
   164  							o.logDebug(ip, ref, "not sent, retrying")
   165  						}
   166  						time.Sleep(time.Duration(10) * time.Millisecond)
   167  					}
   168  				}
   169  				if !success && retryCount <= 0 {
   170  					o.logWarning(ip, ref, "hit the retry limit")
   171  					cancel()
   172  				}
   173  			}
   174  		}
   175  	}()
   176  
   177  	return out, ref
   178  }
   179  
   180  func (o *Observer[T]) logDebug(ip string, ref uint64, msg string) {
   181  	o.log.Debug(
   182  		fmt.Sprintf("%s subscriber: %s", o.name, msg),
   183  		logging.Uint64("id", ref),
   184  		logging.String("ip-address", ip),
   185  	)
   186  }
   187  
   188  func (o *Observer[T]) logWarning(ip string, ref uint64, msg string) {
   189  	o.log.Warn(
   190  		fmt.Sprintf("%s subscriber: %s", o.name, msg),
   191  		logging.Uint64("id", ref),
   192  		logging.String("ip-address", ip),
   193  	)
   194  }
   195  
   196  func (o *Observer[T]) logError(ip string, ref uint64, msg string) {
   197  	o.log.Error(
   198  		fmt.Sprintf("%s subscriber: %s", o.name, msg),
   199  		logging.Uint64("id", ref),
   200  		logging.String("ip-address", ip),
   201  	)
   202  }