github.com/aakash4dev/cometbft@v0.38.2/privval/signer_client_test.go (about) 1 package privval 2 3 import ( 4 "fmt" 5 "testing" 6 "time" 7 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/require" 10 11 "github.com/aakash4dev/cometbft/crypto" 12 "github.com/aakash4dev/cometbft/crypto/tmhash" 13 cmtrand "github.com/aakash4dev/cometbft/libs/rand" 14 cryptoproto "github.com/aakash4dev/cometbft/proto/tendermint/crypto" 15 privvalproto "github.com/aakash4dev/cometbft/proto/tendermint/privval" 16 cmtproto "github.com/aakash4dev/cometbft/proto/tendermint/types" 17 "github.com/aakash4dev/cometbft/types" 18 ) 19 20 type signerTestCase struct { 21 chainID string 22 mockPV types.PrivValidator 23 signerClient *SignerClient 24 signerServer *SignerServer 25 } 26 27 func getSignerTestCases(t *testing.T) []signerTestCase { 28 testCases := make([]signerTestCase, 0) 29 30 // Get test cases for each possible dialer (DialTCP / DialUnix / etc) 31 for _, dtc := range getDialerTestCases(t) { 32 chainID := cmtrand.Str(12) 33 mockPV := types.NewMockPV() 34 35 // get a pair of signer listener, signer dialer endpoints 36 sl, sd := getMockEndpoints(t, dtc.addr, dtc.dialer) 37 sc, err := NewSignerClient(sl, chainID) 38 require.NoError(t, err) 39 ss := NewSignerServer(sd, chainID, mockPV) 40 41 err = ss.Start() 42 require.NoError(t, err) 43 44 tc := signerTestCase{ 45 chainID: chainID, 46 mockPV: mockPV, 47 signerClient: sc, 48 signerServer: ss, 49 } 50 51 testCases = append(testCases, tc) 52 } 53 54 return testCases 55 } 56 57 func TestSignerClose(t *testing.T) { 58 for _, tc := range getSignerTestCases(t) { 59 err := tc.signerClient.Close() 60 assert.NoError(t, err) 61 62 err = tc.signerServer.Stop() 63 assert.NoError(t, err) 64 } 65 } 66 67 func TestSignerPing(t *testing.T) { 68 for _, tc := range getSignerTestCases(t) { 69 tc := tc 70 t.Cleanup(func() { 71 if err := tc.signerServer.Stop(); err != nil { 72 t.Error(err) 73 } 74 }) 75 t.Cleanup(func() { 76 if err := tc.signerClient.Close(); err != nil { 77 t.Error(err) 78 } 79 }) 80 81 err := tc.signerClient.Ping() 82 assert.NoError(t, err) 83 } 84 } 85 86 func TestSignerGetPubKey(t *testing.T) { 87 for _, tc := range getSignerTestCases(t) { 88 tc := tc 89 t.Cleanup(func() { 90 if err := tc.signerServer.Stop(); err != nil { 91 t.Error(err) 92 } 93 }) 94 t.Cleanup(func() { 95 if err := tc.signerClient.Close(); err != nil { 96 t.Error(err) 97 } 98 }) 99 100 pubKey, err := tc.signerClient.GetPubKey() 101 require.NoError(t, err) 102 expectedPubKey, err := tc.mockPV.GetPubKey() 103 require.NoError(t, err) 104 105 assert.Equal(t, expectedPubKey, pubKey) 106 107 pubKey, err = tc.signerClient.GetPubKey() 108 require.NoError(t, err) 109 expectedpk, err := tc.mockPV.GetPubKey() 110 require.NoError(t, err) 111 expectedAddr := expectedpk.Address() 112 113 assert.Equal(t, expectedAddr, pubKey.Address()) 114 } 115 } 116 117 func TestSignerProposal(t *testing.T) { 118 for _, tc := range getSignerTestCases(t) { 119 ts := time.Now() 120 hash := cmtrand.Bytes(tmhash.Size) 121 have := &types.Proposal{ 122 Type: cmtproto.ProposalType, 123 Height: 1, 124 Round: 2, 125 POLRound: 2, 126 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 127 Timestamp: ts, 128 } 129 want := &types.Proposal{ 130 Type: cmtproto.ProposalType, 131 Height: 1, 132 Round: 2, 133 POLRound: 2, 134 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 135 Timestamp: ts, 136 } 137 138 tc := tc 139 t.Cleanup(func() { 140 if err := tc.signerServer.Stop(); err != nil { 141 t.Error(err) 142 } 143 }) 144 t.Cleanup(func() { 145 if err := tc.signerClient.Close(); err != nil { 146 t.Error(err) 147 } 148 }) 149 150 require.NoError(t, tc.mockPV.SignProposal(tc.chainID, want.ToProto())) 151 require.NoError(t, tc.signerClient.SignProposal(tc.chainID, have.ToProto())) 152 153 assert.Equal(t, want.Signature, have.Signature) 154 } 155 } 156 157 func TestSignerVote(t *testing.T) { 158 for _, tc := range getSignerTestCases(t) { 159 ts := time.Now() 160 hash := cmtrand.Bytes(tmhash.Size) 161 valAddr := cmtrand.Bytes(crypto.AddressSize) 162 want := &types.Vote{ 163 Type: cmtproto.PrecommitType, 164 Height: 1, 165 Round: 2, 166 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 167 Timestamp: ts, 168 ValidatorAddress: valAddr, 169 ValidatorIndex: 1, 170 } 171 172 have := &types.Vote{ 173 Type: cmtproto.PrecommitType, 174 Height: 1, 175 Round: 2, 176 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 177 Timestamp: ts, 178 ValidatorAddress: valAddr, 179 ValidatorIndex: 1, 180 } 181 182 tc := tc 183 t.Cleanup(func() { 184 if err := tc.signerServer.Stop(); err != nil { 185 t.Error(err) 186 } 187 }) 188 t.Cleanup(func() { 189 if err := tc.signerClient.Close(); err != nil { 190 t.Error(err) 191 } 192 }) 193 194 require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto())) 195 require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto())) 196 197 assert.Equal(t, want.Signature, have.Signature) 198 } 199 } 200 201 func TestSignerVoteResetDeadline(t *testing.T) { 202 for _, tc := range getSignerTestCases(t) { 203 ts := time.Now() 204 hash := cmtrand.Bytes(tmhash.Size) 205 valAddr := cmtrand.Bytes(crypto.AddressSize) 206 want := &types.Vote{ 207 Type: cmtproto.PrecommitType, 208 Height: 1, 209 Round: 2, 210 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 211 Timestamp: ts, 212 ValidatorAddress: valAddr, 213 ValidatorIndex: 1, 214 } 215 216 have := &types.Vote{ 217 Type: cmtproto.PrecommitType, 218 Height: 1, 219 Round: 2, 220 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 221 Timestamp: ts, 222 ValidatorAddress: valAddr, 223 ValidatorIndex: 1, 224 } 225 226 tc := tc 227 t.Cleanup(func() { 228 if err := tc.signerServer.Stop(); err != nil { 229 t.Error(err) 230 } 231 }) 232 t.Cleanup(func() { 233 if err := tc.signerClient.Close(); err != nil { 234 t.Error(err) 235 } 236 }) 237 238 time.Sleep(testTimeoutReadWrite2o3) 239 240 require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto())) 241 require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto())) 242 assert.Equal(t, want.Signature, have.Signature) 243 244 // TODO(jleni): Clarify what is actually being tested 245 246 // This would exceed the deadline if it was not extended by the previous message 247 time.Sleep(testTimeoutReadWrite2o3) 248 249 require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto())) 250 require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto())) 251 assert.Equal(t, want.Signature, have.Signature) 252 } 253 } 254 255 func TestSignerVoteKeepAlive(t *testing.T) { 256 for _, tc := range getSignerTestCases(t) { 257 ts := time.Now() 258 hash := cmtrand.Bytes(tmhash.Size) 259 valAddr := cmtrand.Bytes(crypto.AddressSize) 260 want := &types.Vote{ 261 Type: cmtproto.PrecommitType, 262 Height: 1, 263 Round: 2, 264 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 265 Timestamp: ts, 266 ValidatorAddress: valAddr, 267 ValidatorIndex: 1, 268 } 269 270 have := &types.Vote{ 271 Type: cmtproto.PrecommitType, 272 Height: 1, 273 Round: 2, 274 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 275 Timestamp: ts, 276 ValidatorAddress: valAddr, 277 ValidatorIndex: 1, 278 } 279 280 tc := tc 281 t.Cleanup(func() { 282 if err := tc.signerServer.Stop(); err != nil { 283 t.Error(err) 284 } 285 }) 286 t.Cleanup(func() { 287 if err := tc.signerClient.Close(); err != nil { 288 t.Error(err) 289 } 290 }) 291 292 // Check that even if the client does not request a 293 // signature for a long time. The service is still available 294 295 // in this particular case, we use the dialer logger to ensure that 296 // test messages are properly interleaved in the test logs 297 tc.signerServer.Logger.Debug("TEST: Forced Wait -------------------------------------------------") 298 time.Sleep(testTimeoutReadWrite * 3) 299 tc.signerServer.Logger.Debug("TEST: Forced Wait DONE---------------------------------------------") 300 301 require.NoError(t, tc.mockPV.SignVote(tc.chainID, want.ToProto())) 302 require.NoError(t, tc.signerClient.SignVote(tc.chainID, have.ToProto())) 303 304 assert.Equal(t, want.Signature, have.Signature) 305 } 306 } 307 308 func TestSignerSignProposalErrors(t *testing.T) { 309 for _, tc := range getSignerTestCases(t) { 310 // Replace service with a mock that always fails 311 tc.signerServer.privVal = types.NewErroringMockPV() 312 tc.mockPV = types.NewErroringMockPV() 313 314 tc := tc 315 t.Cleanup(func() { 316 if err := tc.signerServer.Stop(); err != nil { 317 t.Error(err) 318 } 319 }) 320 t.Cleanup(func() { 321 if err := tc.signerClient.Close(); err != nil { 322 t.Error(err) 323 } 324 }) 325 326 ts := time.Now() 327 hash := cmtrand.Bytes(tmhash.Size) 328 proposal := &types.Proposal{ 329 Type: cmtproto.ProposalType, 330 Height: 1, 331 Round: 2, 332 POLRound: 2, 333 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 334 Timestamp: ts, 335 Signature: []byte("signature"), 336 } 337 338 err := tc.signerClient.SignProposal(tc.chainID, proposal.ToProto()) 339 require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) 340 341 err = tc.mockPV.SignProposal(tc.chainID, proposal.ToProto()) 342 require.Error(t, err) 343 344 err = tc.signerClient.SignProposal(tc.chainID, proposal.ToProto()) 345 require.Error(t, err) 346 } 347 } 348 349 func TestSignerSignVoteErrors(t *testing.T) { 350 for _, tc := range getSignerTestCases(t) { 351 ts := time.Now() 352 hash := cmtrand.Bytes(tmhash.Size) 353 valAddr := cmtrand.Bytes(crypto.AddressSize) 354 vote := &types.Vote{ 355 Type: cmtproto.PrecommitType, 356 Height: 1, 357 Round: 2, 358 BlockID: types.BlockID{Hash: hash, PartSetHeader: types.PartSetHeader{Hash: hash, Total: 2}}, 359 Timestamp: ts, 360 ValidatorAddress: valAddr, 361 ValidatorIndex: 1, 362 Signature: []byte("signature"), 363 } 364 365 // Replace signer service privval with one that always fails 366 tc.signerServer.privVal = types.NewErroringMockPV() 367 tc.mockPV = types.NewErroringMockPV() 368 369 tc := tc 370 t.Cleanup(func() { 371 if err := tc.signerServer.Stop(); err != nil { 372 t.Error(err) 373 } 374 }) 375 t.Cleanup(func() { 376 if err := tc.signerClient.Close(); err != nil { 377 t.Error(err) 378 } 379 }) 380 381 err := tc.signerClient.SignVote(tc.chainID, vote.ToProto()) 382 require.Equal(t, err.(*RemoteSignerError).Description, types.ErroringMockPVErr.Error()) 383 384 err = tc.mockPV.SignVote(tc.chainID, vote.ToProto()) 385 require.Error(t, err) 386 387 err = tc.signerClient.SignVote(tc.chainID, vote.ToProto()) 388 require.Error(t, err) 389 } 390 } 391 392 func brokenHandler(_ types.PrivValidator, request privvalproto.Message, _ string) (privvalproto.Message, error) { 393 var res privvalproto.Message 394 var err error 395 396 switch r := request.Sum.(type) { 397 // This is broken and will answer most requests with a pubkey response 398 case *privvalproto.Message_PubKeyRequest: 399 res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil}) 400 case *privvalproto.Message_SignVoteRequest: 401 res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil}) 402 case *privvalproto.Message_SignProposalRequest: 403 res = mustWrapMsg(&privvalproto.PubKeyResponse{PubKey: cryptoproto.PublicKey{}, Error: nil}) 404 case *privvalproto.Message_PingRequest: 405 err, res = nil, mustWrapMsg(&privvalproto.PingResponse{}) 406 default: 407 err = fmt.Errorf("unknown msg: %v", r) 408 } 409 410 return res, err 411 } 412 413 func TestSignerUnexpectedResponse(t *testing.T) { 414 for _, tc := range getSignerTestCases(t) { 415 tc.signerServer.privVal = types.NewMockPV() 416 tc.mockPV = types.NewMockPV() 417 418 tc.signerServer.SetRequestHandler(brokenHandler) 419 420 tc := tc 421 t.Cleanup(func() { 422 if err := tc.signerServer.Stop(); err != nil { 423 t.Error(err) 424 } 425 }) 426 t.Cleanup(func() { 427 if err := tc.signerClient.Close(); err != nil { 428 t.Error(err) 429 } 430 }) 431 432 ts := time.Now() 433 want := &types.Vote{Timestamp: ts, Type: cmtproto.PrecommitType} 434 435 e := tc.signerClient.SignVote(tc.chainID, want.ToProto()) 436 assert.EqualError(t, e, "empty response") 437 } 438 }