github.com/evdatsion/aphelion-dpos-bft@v0.32.1/privval/signer_validator_endpoint_test.go (about) 1 package privval 2 3 import ( 4 "fmt" 5 "net" 6 "testing" 7 "time" 8 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 12 "github.com/evdatsion/aphelion-dpos-bft/crypto/ed25519" 13 cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common" 14 "github.com/evdatsion/aphelion-dpos-bft/libs/log" 15 16 "github.com/evdatsion/aphelion-dpos-bft/types" 17 ) 18 19 var ( 20 testTimeoutAccept = defaultTimeoutAcceptSeconds * time.Second 21 22 testTimeoutReadWrite = 100 * time.Millisecond 23 testTimeoutReadWrite2o3 = 66 * time.Millisecond // 2/3 of the other one 24 25 testTimeoutHeartbeat = 10 * time.Millisecond 26 testTimeoutHeartbeat3o2 = 6 * time.Millisecond // 3/2 of the other one 27 ) 28 29 type socketTestCase struct { 30 addr string 31 dialer SocketDialer 32 } 33 34 func socketTestCases(t *testing.T) []socketTestCase { 35 tcpAddr := fmt.Sprintf("tcp://%s", testFreeTCPAddr(t)) 36 unixFilePath, err := testUnixAddr() 37 require.NoError(t, err) 38 unixAddr := fmt.Sprintf("unix://%s", unixFilePath) 39 return []socketTestCase{ 40 { 41 addr: tcpAddr, 42 dialer: DialTCPFn(tcpAddr, testTimeoutReadWrite, ed25519.GenPrivKey()), 43 }, 44 { 45 addr: unixAddr, 46 dialer: DialUnixFn(unixFilePath), 47 }, 48 } 49 } 50 51 func TestSocketPVAddress(t *testing.T) { 52 for _, tc := range socketTestCases(t) { 53 // Execute the test within a closure to ensure the deferred statements 54 // are called between each for loop iteration, for isolated test cases. 55 func() { 56 var ( 57 chainID = cmn.RandStr(12) 58 validatorEndpoint, serviceEndpoint = testSetupSocketPair(t, chainID, types.NewMockPV(), tc.addr, tc.dialer) 59 ) 60 defer validatorEndpoint.Stop() 61 defer serviceEndpoint.Stop() 62 63 serviceAddr := serviceEndpoint.privVal.GetPubKey().Address() 64 validatorAddr := validatorEndpoint.GetPubKey().Address() 65 66 assert.Equal(t, serviceAddr, validatorAddr) 67 }() 68 } 69 } 70 71 func TestSocketPVPubKey(t *testing.T) { 72 for _, tc := range socketTestCases(t) { 73 func() { 74 var ( 75 chainID = cmn.RandStr(12) 76 validatorEndpoint, serviceEndpoint = testSetupSocketPair( 77 t, 78 chainID, 79 types.NewMockPV(), 80 tc.addr, 81 tc.dialer) 82 ) 83 defer validatorEndpoint.Stop() 84 defer serviceEndpoint.Stop() 85 86 clientKey := validatorEndpoint.GetPubKey() 87 privvalPubKey := serviceEndpoint.privVal.GetPubKey() 88 89 assert.Equal(t, privvalPubKey, clientKey) 90 }() 91 } 92 } 93 94 func TestSocketPVProposal(t *testing.T) { 95 for _, tc := range socketTestCases(t) { 96 func() { 97 var ( 98 chainID = cmn.RandStr(12) 99 validatorEndpoint, serviceEndpoint = testSetupSocketPair( 100 t, 101 chainID, 102 types.NewMockPV(), 103 tc.addr, 104 tc.dialer) 105 106 ts = time.Now() 107 privProposal = &types.Proposal{Timestamp: ts} 108 clientProposal = &types.Proposal{Timestamp: ts} 109 ) 110 defer validatorEndpoint.Stop() 111 defer serviceEndpoint.Stop() 112 113 require.NoError(t, serviceEndpoint.privVal.SignProposal(chainID, privProposal)) 114 require.NoError(t, validatorEndpoint.SignProposal(chainID, clientProposal)) 115 116 assert.Equal(t, privProposal.Signature, clientProposal.Signature) 117 }() 118 } 119 } 120 121 func TestSocketPVVote(t *testing.T) { 122 for _, tc := range socketTestCases(t) { 123 func() { 124 var ( 125 chainID = cmn.RandStr(12) 126 validatorEndpoint, serviceEndpoint = testSetupSocketPair( 127 t, 128 chainID, 129 types.NewMockPV(), 130 tc.addr, 131 tc.dialer) 132 133 ts = time.Now() 134 vType = types.PrecommitType 135 want = &types.Vote{Timestamp: ts, Type: vType} 136 have = &types.Vote{Timestamp: ts, Type: vType} 137 ) 138 defer validatorEndpoint.Stop() 139 defer serviceEndpoint.Stop() 140 141 require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) 142 require.NoError(t, validatorEndpoint.SignVote(chainID, have)) 143 assert.Equal(t, want.Signature, have.Signature) 144 }() 145 } 146 } 147 148 func TestSocketPVVoteResetDeadline(t *testing.T) { 149 for _, tc := range socketTestCases(t) { 150 func() { 151 var ( 152 chainID = cmn.RandStr(12) 153 validatorEndpoint, serviceEndpoint = testSetupSocketPair( 154 t, 155 chainID, 156 types.NewMockPV(), 157 tc.addr, 158 tc.dialer) 159 160 ts = time.Now() 161 vType = types.PrecommitType 162 want = &types.Vote{Timestamp: ts, Type: vType} 163 have = &types.Vote{Timestamp: ts, Type: vType} 164 ) 165 defer validatorEndpoint.Stop() 166 defer serviceEndpoint.Stop() 167 168 time.Sleep(testTimeoutReadWrite2o3) 169 170 require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) 171 require.NoError(t, validatorEndpoint.SignVote(chainID, have)) 172 assert.Equal(t, want.Signature, have.Signature) 173 174 // This would exceed the deadline if it was not extended by the previous message 175 time.Sleep(testTimeoutReadWrite2o3) 176 177 require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) 178 require.NoError(t, validatorEndpoint.SignVote(chainID, have)) 179 assert.Equal(t, want.Signature, have.Signature) 180 }() 181 } 182 } 183 184 func TestSocketPVVoteKeepalive(t *testing.T) { 185 for _, tc := range socketTestCases(t) { 186 func() { 187 var ( 188 chainID = cmn.RandStr(12) 189 validatorEndpoint, serviceEndpoint = testSetupSocketPair( 190 t, 191 chainID, 192 types.NewMockPV(), 193 tc.addr, 194 tc.dialer) 195 196 ts = time.Now() 197 vType = types.PrecommitType 198 want = &types.Vote{Timestamp: ts, Type: vType} 199 have = &types.Vote{Timestamp: ts, Type: vType} 200 ) 201 defer validatorEndpoint.Stop() 202 defer serviceEndpoint.Stop() 203 204 time.Sleep(testTimeoutReadWrite * 2) 205 206 require.NoError(t, serviceEndpoint.privVal.SignVote(chainID, want)) 207 require.NoError(t, validatorEndpoint.SignVote(chainID, have)) 208 assert.Equal(t, want.Signature, have.Signature) 209 }() 210 } 211 } 212 213 func TestSocketPVDeadline(t *testing.T) { 214 for _, tc := range socketTestCases(t) { 215 func() { 216 var ( 217 listenc = make(chan struct{}) 218 thisConnTimeout = 100 * time.Millisecond 219 validatorEndpoint = newSignerValidatorEndpoint(log.TestingLogger(), tc.addr, thisConnTimeout) 220 ) 221 222 go func(sc *SignerValidatorEndpoint) { 223 defer close(listenc) 224 225 // Note: the TCP connection times out at the accept() phase, 226 // whereas the Unix domain sockets connection times out while 227 // attempting to fetch the remote signer's public key. 228 assert.True(t, IsConnTimeout(sc.Start())) 229 230 assert.False(t, sc.IsRunning()) 231 }(validatorEndpoint) 232 233 for { 234 _, err := cmn.Connect(tc.addr) 235 if err == nil { 236 break 237 } 238 } 239 240 <-listenc 241 }() 242 } 243 } 244 245 func TestRemoteSignVoteErrors(t *testing.T) { 246 for _, tc := range socketTestCases(t) { 247 func() { 248 var ( 249 chainID = cmn.RandStr(12) 250 validatorEndpoint, serviceEndpoint = testSetupSocketPair( 251 t, 252 chainID, 253 types.NewErroringMockPV(), 254 tc.addr, 255 tc.dialer) 256 257 ts = time.Now() 258 vType = types.PrecommitType 259 vote = &types.Vote{Timestamp: ts, Type: vType} 260 ) 261 defer validatorEndpoint.Stop() 262 defer serviceEndpoint.Stop() 263 264 err := validatorEndpoint.SignVote("", vote) 265 require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) 266 267 err = serviceEndpoint.privVal.SignVote(chainID, vote) 268 require.Error(t, err) 269 err = validatorEndpoint.SignVote(chainID, vote) 270 require.Error(t, err) 271 }() 272 } 273 } 274 275 func TestRemoteSignProposalErrors(t *testing.T) { 276 for _, tc := range socketTestCases(t) { 277 func() { 278 var ( 279 chainID = cmn.RandStr(12) 280 validatorEndpoint, serviceEndpoint = testSetupSocketPair( 281 t, 282 chainID, 283 types.NewErroringMockPV(), 284 tc.addr, 285 tc.dialer) 286 287 ts = time.Now() 288 proposal = &types.Proposal{Timestamp: ts} 289 ) 290 defer validatorEndpoint.Stop() 291 defer serviceEndpoint.Stop() 292 293 err := validatorEndpoint.SignProposal("", proposal) 294 require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) 295 296 err = serviceEndpoint.privVal.SignProposal(chainID, proposal) 297 require.Error(t, err) 298 299 err = validatorEndpoint.SignProposal(chainID, proposal) 300 require.Error(t, err) 301 }() 302 } 303 } 304 305 func TestErrUnexpectedResponse(t *testing.T) { 306 for _, tc := range socketTestCases(t) { 307 func() { 308 var ( 309 logger = log.TestingLogger() 310 chainID = cmn.RandStr(12) 311 readyCh = make(chan struct{}) 312 errCh = make(chan error, 1) 313 314 serviceEndpoint = NewSignerServiceEndpoint( 315 logger, 316 chainID, 317 types.NewMockPV(), 318 tc.dialer, 319 ) 320 321 validatorEndpoint = newSignerValidatorEndpoint( 322 logger, 323 tc.addr, 324 testTimeoutReadWrite) 325 ) 326 327 testStartEndpoint(t, readyCh, validatorEndpoint) 328 defer validatorEndpoint.Stop() 329 SignerServiceEndpointTimeoutReadWrite(time.Millisecond)(serviceEndpoint) 330 SignerServiceEndpointConnRetries(100)(serviceEndpoint) 331 // we do not want to Start() the remote signer here and instead use the connection to 332 // reply with intentionally wrong replies below: 333 rsConn, err := serviceEndpoint.connect() 334 defer rsConn.Close() 335 require.NoError(t, err) 336 require.NotNil(t, rsConn) 337 // send over public key to get the remote signer running: 338 go testReadWriteResponse(t, &PubKeyResponse{}, rsConn) 339 <-readyCh 340 341 // Proposal: 342 go func(errc chan error) { 343 errc <- validatorEndpoint.SignProposal(chainID, &types.Proposal{}) 344 }(errCh) 345 346 // read request and write wrong response: 347 go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn) 348 err = <-errCh 349 require.Error(t, err) 350 require.Equal(t, err, ErrUnexpectedResponse) 351 352 // Vote: 353 go func(errc chan error) { 354 errc <- validatorEndpoint.SignVote(chainID, &types.Vote{}) 355 }(errCh) 356 // read request and write wrong response: 357 go testReadWriteResponse(t, &SignedProposalResponse{}, rsConn) 358 err = <-errCh 359 require.Error(t, err) 360 require.Equal(t, err, ErrUnexpectedResponse) 361 }() 362 } 363 } 364 365 func TestRetryConnToRemoteSigner(t *testing.T) { 366 for _, tc := range socketTestCases(t) { 367 func() { 368 var ( 369 logger = log.TestingLogger() 370 chainID = cmn.RandStr(12) 371 readyCh = make(chan struct{}) 372 373 serviceEndpoint = NewSignerServiceEndpoint( 374 logger, 375 chainID, 376 types.NewMockPV(), 377 tc.dialer, 378 ) 379 thisConnTimeout = testTimeoutReadWrite 380 validatorEndpoint = newSignerValidatorEndpoint(logger, tc.addr, thisConnTimeout) 381 ) 382 // Ping every: 383 SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) 384 385 SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) 386 SignerServiceEndpointConnRetries(10)(serviceEndpoint) 387 388 testStartEndpoint(t, readyCh, validatorEndpoint) 389 defer validatorEndpoint.Stop() 390 require.NoError(t, serviceEndpoint.Start()) 391 assert.True(t, serviceEndpoint.IsRunning()) 392 393 <-readyCh 394 time.Sleep(testTimeoutHeartbeat * 2) 395 396 serviceEndpoint.Stop() 397 rs2 := NewSignerServiceEndpoint( 398 logger, 399 chainID, 400 types.NewMockPV(), 401 tc.dialer, 402 ) 403 // let some pings pass 404 time.Sleep(testTimeoutHeartbeat3o2) 405 require.NoError(t, rs2.Start()) 406 assert.True(t, rs2.IsRunning()) 407 defer rs2.Stop() 408 409 // give the client some time to re-establish the conn to the remote signer 410 // should see sth like this in the logs: 411 // 412 // E[10016-01-10|17:12:46.128] Ping err="remote signer timed out" 413 // I[10016-01-10|17:16:42.447] Re-created connection to remote signer impl=SocketVal 414 time.Sleep(testTimeoutReadWrite * 2) 415 }() 416 } 417 } 418 419 func newSignerValidatorEndpoint(logger log.Logger, addr string, timeoutReadWrite time.Duration) *SignerValidatorEndpoint { 420 proto, address := cmn.ProtocolAndAddress(addr) 421 422 ln, err := net.Listen(proto, address) 423 logger.Info("Listening at", "proto", proto, "address", address) 424 if err != nil { 425 panic(err) 426 } 427 428 var listener net.Listener 429 430 if proto == "unix" { 431 unixLn := NewUnixListener(ln) 432 UnixListenerTimeoutAccept(testTimeoutAccept)(unixLn) 433 UnixListenerTimeoutReadWrite(timeoutReadWrite)(unixLn) 434 listener = unixLn 435 } else { 436 tcpLn := NewTCPListener(ln, ed25519.GenPrivKey()) 437 TCPListenerTimeoutAccept(testTimeoutAccept)(tcpLn) 438 TCPListenerTimeoutReadWrite(timeoutReadWrite)(tcpLn) 439 listener = tcpLn 440 } 441 442 return NewSignerValidatorEndpoint(logger, listener) 443 } 444 445 func testSetupSocketPair( 446 t *testing.T, 447 chainID string, 448 privValidator types.PrivValidator, 449 addr string, 450 socketDialer SocketDialer, 451 ) (*SignerValidatorEndpoint, *SignerServiceEndpoint) { 452 var ( 453 logger = log.TestingLogger() 454 privVal = privValidator 455 readyc = make(chan struct{}) 456 serviceEndpoint = NewSignerServiceEndpoint( 457 logger, 458 chainID, 459 privVal, 460 socketDialer, 461 ) 462 463 thisConnTimeout = testTimeoutReadWrite 464 validatorEndpoint = newSignerValidatorEndpoint(logger, addr, thisConnTimeout) 465 ) 466 467 SignerValidatorEndpointSetHeartbeat(testTimeoutHeartbeat)(validatorEndpoint) 468 SignerServiceEndpointTimeoutReadWrite(testTimeoutReadWrite)(serviceEndpoint) 469 SignerServiceEndpointConnRetries(1e6)(serviceEndpoint) 470 471 testStartEndpoint(t, readyc, validatorEndpoint) 472 473 require.NoError(t, serviceEndpoint.Start()) 474 assert.True(t, serviceEndpoint.IsRunning()) 475 476 <-readyc 477 478 return validatorEndpoint, serviceEndpoint 479 } 480 481 func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) { 482 _, err := readMsg(rsConn) 483 require.NoError(t, err) 484 485 err = writeMsg(rsConn, resp) 486 require.NoError(t, err) 487 } 488 489 func testStartEndpoint(t *testing.T, readyCh chan struct{}, sc *SignerValidatorEndpoint) { 490 go func(sc *SignerValidatorEndpoint) { 491 require.NoError(t, sc.Start()) 492 assert.True(t, sc.IsRunning()) 493 494 readyCh <- struct{}{} 495 }(sc) 496 } 497 498 // testFreeTCPAddr claims a free port so we don't block on listener being ready. 499 func testFreeTCPAddr(t *testing.T) string { 500 ln, err := net.Listen("tcp", "127.0.0.1:0") 501 require.NoError(t, err) 502 defer ln.Close() 503 504 return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port) 505 }