github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/privval/signer_listener_endpoint.go (about) 1 package privval 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "sync" 8 "time" 9 10 "github.com/ari-anchor/sei-tendermint/libs/log" 11 "github.com/ari-anchor/sei-tendermint/libs/service" 12 privvalproto "github.com/ari-anchor/sei-tendermint/proto/tendermint/privval" 13 ) 14 15 // SignerListenerEndpointOption sets an optional parameter on the SignerListenerEndpoint. 16 type SignerListenerEndpointOption func(*SignerListenerEndpoint) 17 18 // SignerListenerEndpointTimeoutReadWrite sets the read and write timeout for 19 // connections from external signing processes. 20 // 21 // Default: 5s 22 func SignerListenerEndpointTimeoutReadWrite(timeout time.Duration) SignerListenerEndpointOption { 23 return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout } 24 } 25 26 // SignerListenerEndpoint listens for an external process to dial in and keeps 27 // the connection alive by dropping and reconnecting. 28 // 29 // The process will send pings every ~3s (read/write timeout * 2/3) to keep the 30 // connection alive. 31 type SignerListenerEndpoint struct { 32 signerEndpoint 33 34 listener net.Listener 35 connectRequestCh chan struct{} 36 connectionAvailableCh chan net.Conn 37 38 timeoutAccept time.Duration 39 pingTimer *time.Ticker 40 pingInterval time.Duration 41 42 instanceMtx sync.Mutex // Ensures instance public methods access, i.e. SendRequest 43 } 44 45 // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. 46 func NewSignerListenerEndpoint( 47 logger log.Logger, 48 listener net.Listener, 49 options ...SignerListenerEndpointOption, 50 ) *SignerListenerEndpoint { 51 sl := &SignerListenerEndpoint{ 52 listener: listener, 53 timeoutAccept: defaultTimeoutAcceptSeconds * time.Second, 54 } 55 56 sl.signerEndpoint.logger = logger 57 sl.BaseService = *service.NewBaseService(logger, "SignerListenerEndpoint", sl) 58 sl.signerEndpoint.timeoutReadWrite = defaultTimeoutReadWriteSeconds * time.Second 59 60 for _, optionFunc := range options { 61 optionFunc(sl) 62 } 63 64 return sl 65 } 66 67 // OnStart implements service.Service. 68 func (sl *SignerListenerEndpoint) OnStart(ctx context.Context) error { 69 sl.connectRequestCh = make(chan struct{}) 70 sl.connectionAvailableCh = make(chan net.Conn) 71 72 // NOTE: ping timeout must be less than read/write timeout 73 sl.pingInterval = time.Duration(sl.signerEndpoint.timeoutReadWrite.Milliseconds()*2/3) * time.Millisecond 74 sl.pingTimer = time.NewTicker(sl.pingInterval) 75 76 go sl.serviceLoop(ctx) 77 go sl.pingLoop(ctx) 78 79 sl.connectRequestCh <- struct{}{} 80 81 return nil 82 } 83 84 // OnStop implements service.Service 85 func (sl *SignerListenerEndpoint) OnStop() { 86 sl.instanceMtx.Lock() 87 defer sl.instanceMtx.Unlock() 88 _ = sl.Close() 89 90 // Stop listening 91 if sl.listener != nil { 92 if err := sl.listener.Close(); err != nil { 93 sl.logger.Error("Closing Listener", "err", err) 94 sl.listener = nil 95 } 96 } 97 98 sl.pingTimer.Stop() 99 } 100 101 // WaitForConnection waits maxWait for a connection or returns a timeout error 102 func (sl *SignerListenerEndpoint) WaitForConnection(ctx context.Context, maxWait time.Duration) error { 103 sl.instanceMtx.Lock() 104 defer sl.instanceMtx.Unlock() 105 return sl.ensureConnection(ctx, maxWait) 106 } 107 108 // SendRequest ensures there is a connection, sends a request and waits for a response 109 func (sl *SignerListenerEndpoint) SendRequest(ctx context.Context, request privvalproto.Message) (*privvalproto.Message, error) { 110 sl.instanceMtx.Lock() 111 defer sl.instanceMtx.Unlock() 112 113 err := sl.ensureConnection(ctx, sl.timeoutAccept) 114 if err != nil { 115 return nil, err 116 } 117 118 err = sl.WriteMessage(request) 119 if err != nil { 120 return nil, err 121 } 122 123 res, err := sl.ReadMessage() 124 if err != nil { 125 return nil, err 126 } 127 128 // Reset pingTimer to avoid sending unnecessary pings. 129 sl.pingTimer.Reset(sl.pingInterval) 130 131 return &res, nil 132 } 133 134 func (sl *SignerListenerEndpoint) ensureConnection(ctx context.Context, maxWait time.Duration) error { 135 if sl.IsConnected() { 136 return nil 137 } 138 139 // Is there a connection ready? then use it 140 if sl.GetAvailableConnection(sl.connectionAvailableCh) { 141 return nil 142 } 143 144 // block until connected or timeout 145 sl.logger.Info("SignerListener: Blocking for connection") 146 sl.triggerConnect() 147 return sl.WaitConnection(ctx, sl.connectionAvailableCh, maxWait) 148 } 149 150 func (sl *SignerListenerEndpoint) acceptNewConnection() (net.Conn, error) { 151 if !sl.IsRunning() || sl.listener == nil { 152 return nil, fmt.Errorf("endpoint is closing") 153 } 154 155 // wait for a new conn 156 sl.logger.Info("SignerListener: Listening for new connection") 157 conn, err := sl.listener.Accept() 158 if err != nil { 159 return nil, err 160 } 161 162 return conn, nil 163 } 164 165 func (sl *SignerListenerEndpoint) triggerConnect() { 166 select { 167 case sl.connectRequestCh <- struct{}{}: 168 default: 169 } 170 } 171 172 func (sl *SignerListenerEndpoint) triggerReconnect() { 173 sl.DropConnection() 174 sl.triggerConnect() 175 } 176 177 func (sl *SignerListenerEndpoint) serviceLoop(ctx context.Context) { 178 for { 179 select { 180 case <-sl.connectRequestCh: 181 { 182 conn, err := sl.acceptNewConnection() 183 if err == nil { 184 sl.logger.Info("SignerListener: Connected") 185 186 // We have a good connection, wait for someone that needs one otherwise cancellation 187 select { 188 case sl.connectionAvailableCh <- conn: 189 case <-ctx.Done(): 190 return 191 } 192 } 193 194 select { 195 case sl.connectRequestCh <- struct{}{}: 196 default: 197 } 198 } 199 case <-ctx.Done(): 200 return 201 } 202 } 203 } 204 205 func (sl *SignerListenerEndpoint) pingLoop(ctx context.Context) { 206 for { 207 select { 208 case <-sl.pingTimer.C: 209 { 210 _, err := sl.SendRequest(ctx, mustWrapMsg(&privvalproto.PingRequest{})) 211 if err != nil { 212 sl.logger.Error("SignerListener: Ping timeout") 213 sl.triggerReconnect() 214 } 215 } 216 case <-ctx.Done(): 217 return 218 } 219 } 220 }