github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/privval/signer_listener_endpoint.go (about)

     1  package privval
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/ari-anchor/sei-tendermint/libs/log"
    11  	"github.com/ari-anchor/sei-tendermint/libs/service"
    12  	privvalproto "github.com/ari-anchor/sei-tendermint/proto/tendermint/privval"
    13  )
    14  
    15  // SignerListenerEndpointOption sets an optional parameter on the SignerListenerEndpoint.
    16  type SignerListenerEndpointOption func(*SignerListenerEndpoint)
    17  
    18  // SignerListenerEndpointTimeoutReadWrite sets the read and write timeout for
    19  // connections from external signing processes.
    20  //
    21  // Default: 5s
    22  func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListenerEndpointOption {
    23  	return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout }
    24  }
    25  
    26  // SignerListenerEndpoint listens for an external process to dial in and keeps
    27  // the connection alive by dropping and reconnecting.
    28  //
    29  // The process will send pings every ~3s (read/write timeout * 2/3) to keep the
    30  // connection alive.
    31  type SignerListenerEndpoint struct {
    32  	signerEndpoint
    33  
    34  	listener              net.Listener
    35  	connectRequestCh      chan struct{}
    36  	connectionAvailableCh chan net.Conn
    37  
    38  	timeoutAccept time.Duration
    39  	pingTimer     *time.Ticker
    40  	pingInterval  time.Duration
    41  
    42  	instanceMtx sync.Mutex // Ensures instance public methods access, i.e. SendRequest
    43  }
    44  
    45  // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
    46  func NewSignerListenerEndpoint(
    47  	logger log.Logger,
    48  	listener net.Listener,
    49  	options ...SignerListenerEndpointOption,
    50  ) *SignerListenerEndpoint {
    51  	sl := &SignerListenerEndpoint{
    52  		listener:      listener,
    53  		timeoutAccept: defaultTimeoutAcceptSeconds * time.Second,
    54  	}
    55  
    56  	sl.signerEndpoint.logger = logger
    57  	sl.BaseService = *service.NewBaseService(logger, "SignerListenerEndpoint", sl)
    58  	sl.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second
    59  
    60  	for _, optionFunc := range options {
    61  		optionFunc(sl)
    62  	}
    63  
    64  	return sl
    65  }
    66  
    67  // OnStart implements service.Service.
    68  func (sl *SignerListenerEndpoint) OnStart(ctx context.Context) error {
    69  	sl.connectRequestCh = make(chan struct{})
    70  	sl.connectionAvailableCh = make(chan net.Conn)
    71  
    72  	// NOTE: ping timeout must be less than read/write timeout
    73  	sl.pingInterval = time.Duration(sl.signerEndpoint.timeoutReadWrite.Milliseconds()*2/3) * time.Millisecond
    74  	sl.pingTimer = time.NewTicker(sl.pingInterval)
    75  
    76  	go sl.serviceLoop(ctx)
    77  	go sl.pingLoop(ctx)
    78  
    79  	sl.connectRequestCh <- struct{}{}
    80  
    81  	return nil
    82  }
    83  
    84  // OnStop implements service.Service
    85  func (sl *SignerListenerEndpoint) OnStop() {
    86  	sl.instanceMtx.Lock()
    87  	defer sl.instanceMtx.Unlock()
    88  	_ = sl.Close()
    89  
    90  	// Stop listening
    91  	if sl.listener != nil {
    92  		if err := sl.listener.Close(); err != nil {
    93  			sl.logger.Error("Closing Listener", "err", err)
    94  			sl.listener = nil
    95  		}
    96  	}
    97  
    98  	sl.pingTimer.Stop()
    99  }
   100  
   101  // WaitForConnection waits maxWait for a connection or returns a timeout error
   102  func (sl *SignerListenerEndpoint) WaitForConnection(ctx context.Context, maxWait time.Duration) error {
   103  	sl.instanceMtx.Lock()
   104  	defer sl.instanceMtx.Unlock()
   105  	return sl.ensureConnection(ctx, maxWait)
   106  }
   107  
   108  // SendRequest ensures there is a connection, sends a request and waits for a response
   109  func (sl *SignerListenerEndpoint) SendRequest(ctx context.Context, request privvalproto.Message) (*privvalproto.Message, error) {
   110  	sl.instanceMtx.Lock()
   111  	defer sl.instanceMtx.Unlock()
   112  
   113  	err := sl.ensureConnection(ctx, sl.timeoutAccept)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	err = sl.WriteMessage(request)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	res, err := sl.ReadMessage()
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	// Reset pingTimer to avoid sending unnecessary pings.
   129  	sl.pingTimer.Reset(sl.pingInterval)
   130  
   131  	return &res, nil
   132  }
   133  
   134  func (sl *SignerListenerEndpoint) ensureConnection(ctx context.Context, maxWait time.Duration) error {
   135  	if sl.IsConnected() {
   136  		return nil
   137  	}
   138  
   139  	// Is there a connection ready? then use it
   140  	if sl.GetAvailableConnection(sl.connectionAvailableCh) {
   141  		return nil
   142  	}
   143  
   144  	// block until connected or timeout
   145  	sl.logger.Info("SignerListener: Blocking for connection")
   146  	sl.triggerConnect()
   147  	return sl.WaitConnection(ctx, sl.connectionAvailableCh, maxWait)
   148  }
   149  
   150  func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) {
   151  	if !sl.IsRunning() || sl.listener == nil {
   152  		return nil, fmt.Errorf("endpoint is closing")
   153  	}
   154  
   155  	// wait for a new conn
   156  	sl.logger.Info("SignerListener: Listening for new connection")
   157  	conn, err := sl.listener.Accept()
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	return conn, nil
   163  }
   164  
   165  func (sl *SignerListenerEndpoint) triggerConnect() {
   166  	select {
   167  	case sl.connectRequestCh <- struct{}{}:
   168  	default:
   169  	}
   170  }
   171  
   172  func (sl *SignerListenerEndpoint) triggerReconnect() {
   173  	sl.DropConnection()
   174  	sl.triggerConnect()
   175  }
   176  
   177  func (sl *SignerListenerEndpoint) serviceLoop(ctx context.Context) {
   178  	for {
   179  		select {
   180  		case <-sl.connectRequestCh:
   181  			{
   182  				conn, err := sl.acceptNewConnection()
   183  				if err == nil {
   184  					sl.logger.Info("SignerListener: Connected")
   185  
   186  					// We have a good connection, wait for someone that needs one otherwise cancellation
   187  					select {
   188  					case sl.connectionAvailableCh <- conn:
   189  					case <-ctx.Done():
   190  						return
   191  					}
   192  				}
   193  
   194  				select {
   195  				case sl.connectRequestCh <- struct{}{}:
   196  				default:
   197  				}
   198  			}
   199  		case <-ctx.Done():
   200  			return
   201  		}
   202  	}
   203  }
   204  
   205  func (sl *SignerListenerEndpoint) pingLoop(ctx context.Context) {
   206  	for {
   207  		select {
   208  		case <-sl.pingTimer.C:
   209  			{
   210  				_, err := sl.SendRequest(ctx, mustWrapMsg(&privvalproto.PingRequest{}))
   211  				if err != nil {
   212  					sl.logger.Error("SignerListener: Ping timeout")
   213  					sl.triggerReconnect()
   214  				}
   215  			}
   216  		case <-ctx.Done():
   217  			return
   218  		}
   219  	}
   220  }