github.com/Finschia/ostracon@v1.1.5/privval/signer_listener_endpoint.go (about)

     1  package privval
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"time"
     7  
     8  	privvalproto "github.com/tendermint/tendermint/proto/tendermint/privval"
     9  
    10  	"github.com/Finschia/ostracon/libs/log"
    11  	"github.com/Finschia/ostracon/privval/internal"
    12  	"github.com/Finschia/ostracon/libs/service"
    13  	tmsync "github.com/Finschia/ostracon/libs/sync"
    14  	ocprivvalproto "github.com/Finschia/ostracon/proto/ostracon/privval"
    15  )
    16  
    17  // SignerListenerEndpointOption sets an optional parameter on the SignerListenerEndpoint.
    18  type SignerListenerEndpointOption func(*SignerListenerEndpoint)
    19  
    20  // SignerListenerEndpointTimeoutReadWrite sets the read and write timeout for
    21  // connections from external signing processes.
    22  //
    23  // Default: 5s
    24  func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListenerEndpointOption {
    25  	return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout }
    26  }
    27  
    28  // SignerListenerEndpointAllowAddress sets the address to allow
    29  // connections from the only allowed addresses
    30  func SignerListenerEndpointAllowAddress(protocol string, allowedAddresses []string) SignerListenerEndpointOption {
    31  	return func(sl *SignerListenerEndpoint) {
    32  		if protocol == "tcp" {
    33  			sl.connFilter = internal.NewIpFilter(allowedAddresses, sl.Logger)
    34  			return
    35  		}
    36  		sl.connFilter = internal.NewNullObject()
    37  	}
    38  }
    39  
    40  // SignerListenerEndpoint listens for an external process to dial in and keeps
    41  // the connection alive by dropping and reconnecting.
    42  //
    43  // The process will send pings every ~3s (read/write timeout * 2/3) to keep the
    44  // connection alive.
    45  type SignerListenerEndpoint struct {
    46  	signerEndpoint
    47  
    48  	listener              net.Listener
    49  	connectRequestCh      chan struct{}
    50  	connectionAvailableCh chan net.Conn
    51  
    52  	timeoutAccept time.Duration
    53  	pingTimer     *time.Ticker
    54  	pingInterval  time.Duration
    55  
    56  	instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest
    57  	connFilter  internal.ConnectionFilter
    58  }
    59  
    60  // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
    61  func NewSignerListenerEndpoint(
    62  	logger log.Logger,
    63  	listener net.Listener,
    64  	options ...SignerListenerEndpointOption,
    65  ) *SignerListenerEndpoint {
    66  	sl := &SignerListenerEndpoint{
    67  		listener:      listener,
    68  		timeoutAccept: defaultTimeoutAcceptSeconds * time.Second,
    69  	}
    70  
    71  	sl.BaseService = *service.NewBaseService(logger, "SignerListenerEndpoint", sl)
    72  	sl.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second
    73  
    74  	for _, optionFunc := range options {
    75  		optionFunc(sl)
    76  	}
    77  
    78  	return sl
    79  }
    80  
    81  // OnStart implements service.Service.
    82  func (sl *SignerListenerEndpoint) OnStart() error {
    83  	sl.connectRequestCh = make(chan struct{})
    84  	sl.connectionAvailableCh = make(chan net.Conn)
    85  
    86  	// NOTE: ping timeout must be less than read/write timeout
    87  	sl.pingInterval = time.Duration(sl.signerEndpoint.timeoutReadWrite.Milliseconds()*2/3) * time.Millisecond
    88  	sl.pingTimer = time.NewTicker(sl.pingInterval)
    89  
    90  	go sl.serviceLoop()
    91  	go sl.pingLoop()
    92  
    93  	sl.connectRequestCh <- struct{}{}
    94  
    95  	return nil
    96  }
    97  
    98  // OnStop implements service.Service
    99  func (sl *SignerListenerEndpoint) OnStop() {
   100  	sl.instanceMtx.Lock()
   101  	defer sl.instanceMtx.Unlock()
   102  	_ = sl.Close()
   103  
   104  	// Stop listening
   105  	if sl.listener != nil {
   106  		if err := sl.listener.Close(); err != nil {
   107  			sl.Logger.Error("Closing Listener", "err", err)
   108  			sl.listener = nil
   109  		}
   110  	}
   111  
   112  	sl.pingTimer.Stop()
   113  }
   114  
   115  // WaitForConnection waits maxWait for a connection or returns a timeout error
   116  func (sl *SignerListenerEndpoint) WaitForConnection(maxWait time.Duration) error {
   117  	sl.instanceMtx.Lock()
   118  	defer sl.instanceMtx.Unlock()
   119  	return sl.ensureConnection(maxWait)
   120  }
   121  
   122  // SendRequest ensures there is a connection, sends a request and waits for a response
   123  func (sl *SignerListenerEndpoint) SendRequest(request ocprivvalproto.Message) (*ocprivvalproto.Message, error) {
   124  	sl.instanceMtx.Lock()
   125  	defer sl.instanceMtx.Unlock()
   126  
   127  	err := sl.ensureConnection(sl.timeoutAccept)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	err = sl.WriteMessage(request)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	res, err := sl.ReadMessage()
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	// Reset pingTimer to avoid sending unnecessary pings.
   143  	sl.pingTimer.Reset(sl.pingInterval)
   144  
   145  	return &res, nil
   146  }
   147  
   148  func (sl *SignerListenerEndpoint) ensureConnection(maxWait time.Duration) error {
   149  	if sl.IsConnected() {
   150  		return nil
   151  	}
   152  
   153  	// Is there a connection ready? then use it
   154  	if sl.GetAvailableConnection(sl.connectionAvailableCh) {
   155  		return nil
   156  	}
   157  
   158  	// block until connected or timeout
   159  	sl.Logger.Info("SignerListener: Blocking for connection")
   160  	sl.triggerConnect()
   161  	err := sl.WaitConnection(sl.connectionAvailableCh, maxWait)
   162  	if err != nil {
   163  		return err
   164  	}
   165  
   166  	return nil
   167  }
   168  
   169  func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) {
   170  	if !sl.IsRunning() || sl.listener == nil {
   171  		return nil, fmt.Errorf("endpoint is closing")
   172  	}
   173  
   174  	// wait for a new conn
   175  	sl.Logger.Info("SignerListener: Listening for new connection")
   176  	conn, err := sl.listener.Accept()
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	return conn, nil
   182  }
   183  
   184  func (sl *SignerListenerEndpoint) triggerConnect() {
   185  	select {
   186  	case sl.connectRequestCh <- struct{}{}:
   187  	default:
   188  	}
   189  }
   190  
   191  func (sl *SignerListenerEndpoint) triggerReconnect() {
   192  	sl.DropConnection()
   193  	sl.triggerConnect()
   194  }
   195  
   196  func (sl *SignerListenerEndpoint) serviceLoop() {
   197  	for {
   198  		select {
   199  		case <-sl.connectRequestCh:
   200  			{
   201  				conn, err := sl.acceptNewConnection()
   202  				if err == nil {
   203  					remoteAddr := conn.RemoteAddr()
   204  					if sl.filter(remoteAddr) == nil {
   205  						sl.Logger.Info(fmt.Sprintf("SignerListener: deny a connection request from remote address=%s, expected=%s", remoteAddr, sl.connFilter))
   206  						conn.Close()
   207  						continue
   208  					}
   209  					sl.Logger.Info("SignerListener: Connected")
   210  
   211  					// We have a good connection, wait for someone that needs one otherwise cancellation
   212  					select {
   213  					case sl.connectionAvailableCh <- conn:
   214  					case <-sl.Quit():
   215  						return
   216  					}
   217  				}
   218  
   219  				select {
   220  				case sl.connectRequestCh <- struct{}{}:
   221  				default:
   222  				}
   223  			}
   224  		case <-sl.Quit():
   225  			return
   226  		}
   227  	}
   228  }
   229  
   230  func (sl *SignerListenerEndpoint) filter(addr net.Addr) net.Addr {
   231  	if sl.connFilter == nil {
   232  		return addr
   233  	}
   234  	return sl.connFilter.Filter(addr)
   235  }
   236  
   237  func (sl *SignerListenerEndpoint) pingLoop() {
   238  	for {
   239  		select {
   240  		case <-sl.pingTimer.C:
   241  			{
   242  				_, err := sl.SendRequest(mustWrapMsg(&privvalproto.PingRequest{}))
   243  				if err != nil {
   244  					sl.Logger.Error("SignerListener: Ping timeout")
   245  					sl.triggerReconnect()
   246  				}
   247  			}
   248  		case <-sl.Quit():
   249  			return
   250  		}
   251  	}
   252  }