github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/nomad/rpc_test.go (about) 1 package nomad 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "net" 9 "net/rpc" 10 "os" 11 "path" 12 "testing" 13 "time" 14 15 "github.com/hashicorp/go-msgpack/codec" 16 msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc" 17 cstructs "github.com/hashicorp/nomad/client/structs" 18 "github.com/hashicorp/nomad/helper/pool" 19 "github.com/hashicorp/nomad/helper/testlog" 20 "github.com/hashicorp/nomad/helper/tlsutil" 21 "github.com/hashicorp/nomad/helper/uuid" 22 "github.com/hashicorp/nomad/nomad/mock" 23 "github.com/hashicorp/nomad/nomad/structs" 24 "github.com/hashicorp/nomad/nomad/structs/config" 25 "github.com/hashicorp/nomad/testutil" 26 "github.com/hashicorp/raft" 27 "github.com/hashicorp/yamux" 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 ) 31 32 // rpcClient is a test helper method to return a ClientCodec to use to make rpc 33 // calls to the passed server. 34 func rpcClient(t *testing.T, s *Server) rpc.ClientCodec { 35 addr := s.config.RPCAddr 36 conn, err := net.DialTimeout("tcp", addr.String(), time.Second) 37 if err != nil { 38 t.Fatalf("err: %v", err) 39 } 40 // Write the Nomad RPC byte to set the mode 41 conn.Write([]byte{byte(pool.RpcNomad)}) 42 return pool.NewClientCodec(conn) 43 } 44 45 func TestRPC_forwardLeader(t *testing.T) { 46 t.Parallel() 47 48 s1, cleanupS1 := TestServer(t, func(c *Config) { 49 c.BootstrapExpect = 2 50 }) 51 defer cleanupS1() 52 s2, cleanupS2 := TestServer(t, func(c *Config) { 53 c.BootstrapExpect = 2 54 }) 55 defer cleanupS2() 56 TestJoin(t, s1, s2) 57 testutil.WaitForLeader(t, s1.RPC) 58 testutil.WaitForLeader(t, s2.RPC) 59 60 isLeader, remote := s1.getLeader() 61 if !isLeader && remote == nil { 62 t.Fatalf("missing leader") 63 } 64 65 if remote != nil { 66 var out struct{} 67 err := s1.forwardLeader(remote, "Status.Ping", struct{}{}, &out) 68 if err != nil { 69 t.Fatalf("err: %v", err) 70 } 71 } 72 73 isLeader, remote = s2.getLeader() 74 if !isLeader && remote == nil { 75 t.Fatalf("missing leader") 76 } 77 78 if remote != nil { 79 var out struct{} 80 err := s2.forwardLeader(remote, "Status.Ping", struct{}{}, &out) 81 if err != nil { 82 t.Fatalf("err: %v", err) 83 } 84 } 85 } 86 87 func TestRPC_WaitForConsistentReads(t *testing.T) { 88 t.Parallel() 89 90 s1, cleanupS2 := TestServer(t, func(c *Config) { 91 c.RPCHoldTimeout = 20 * time.Millisecond 92 }) 93 defer cleanupS2() 94 testutil.WaitForLeader(t, s1.RPC) 95 96 isLeader, _ := s1.getLeader() 97 require.True(t, isLeader) 98 require.True(t, s1.isReadyForConsistentReads()) 99 100 s1.resetConsistentReadReady() 101 require.False(t, s1.isReadyForConsistentReads()) 102 103 codec := rpcClient(t, s1) 104 105 get := &structs.JobListRequest{ 106 QueryOptions: structs.QueryOptions{ 107 Region: "global", 108 Namespace: "default", 109 }, 110 } 111 112 // check timeout while waiting for consistency 113 var resp structs.JobListResponse 114 err := msgpackrpc.CallWithCodec(codec, "Job.List", get, &resp) 115 require.Error(t, err) 116 require.Contains(t, err.Error(), structs.ErrNotReadyForConsistentReads.Error()) 117 118 // check we wait and block 119 go func() { 120 time.Sleep(5 * time.Millisecond) 121 s1.setConsistentReadReady() 122 }() 123 124 err = msgpackrpc.CallWithCodec(codec, "Job.List", get, &resp) 125 require.NoError(t, err) 126 127 } 128 129 func TestRPC_forwardRegion(t *testing.T) { 130 t.Parallel() 131 132 s1, cleanupS1 := TestServer(t, nil) 133 defer cleanupS1() 134 s2, cleanupS2 := TestServer(t, func(c *Config) { 135 c.Region = "global" 136 }) 137 defer cleanupS2() 138 TestJoin(t, s1, s2) 139 testutil.WaitForLeader(t, s1.RPC) 140 testutil.WaitForLeader(t, s2.RPC) 141 142 var out struct{} 143 err := s1.forwardRegion("global", "Status.Ping", struct{}{}, &out) 144 if err != nil { 145 t.Fatalf("err: %v", err) 146 } 147 148 err = s2.forwardRegion("global", "Status.Ping", struct{}{}, &out) 149 if err != nil { 150 t.Fatalf("err: %v", err) 151 } 152 } 153 154 func TestRPC_getServer(t *testing.T) { 155 t.Parallel() 156 157 s1, cleanupS1 := TestServer(t, nil) 158 defer cleanupS1() 159 s2, cleanupS2 := TestServer(t, func(c *Config) { 160 c.Region = "global" 161 }) 162 defer cleanupS2() 163 TestJoin(t, s1, s2) 164 testutil.WaitForLeader(t, s1.RPC) 165 testutil.WaitForLeader(t, s2.RPC) 166 167 // Lookup by name 168 srv, err := s1.getServer("global", s2.serf.LocalMember().Name) 169 require.NoError(t, err) 170 171 require.Equal(t, srv.Name, s2.serf.LocalMember().Name) 172 173 // Lookup by id 174 srv, err = s2.getServer("global", s1.serf.LocalMember().Tags["id"]) 175 require.NoError(t, err) 176 177 require.Equal(t, srv.Name, s1.serf.LocalMember().Name) 178 } 179 180 func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) { 181 t.Parallel() 182 assert := assert.New(t) 183 184 const ( 185 cafile = "../helper/tlsutil/testdata/ca.pem" 186 foocert = "../helper/tlsutil/testdata/nomad-foo.pem" 187 fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" 188 ) 189 dir := tmpDir(t) 190 defer os.RemoveAll(dir) 191 192 s1, cleanupS1 := TestServer(t, func(c *Config) { 193 c.DataDir = path.Join(dir, "node1") 194 c.TLSConfig = &config.TLSConfig{ 195 EnableRPC: true, 196 VerifyServerHostname: true, 197 CAFile: cafile, 198 CertFile: foocert, 199 KeyFile: fookey, 200 RPCUpgradeMode: true, 201 } 202 }) 203 defer cleanupS1() 204 205 codec := rpcClient(t, s1) 206 207 // Create the register request 208 node := mock.Node() 209 req := &structs.NodeRegisterRequest{ 210 Node: node, 211 WriteRequest: structs.WriteRequest{Region: "global"}, 212 } 213 214 var resp structs.GenericResponse 215 err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) 216 assert.Nil(err) 217 218 // Check that heartbeatTimers has the heartbeat ID 219 _, ok := s1.heartbeatTimers[node.ID] 220 assert.True(ok) 221 } 222 223 func TestRPC_PlaintextRPCFailsWhenNotInUpgradeMode(t *testing.T) { 224 t.Parallel() 225 assert := assert.New(t) 226 227 const ( 228 cafile = "../helper/tlsutil/testdata/ca.pem" 229 foocert = "../helper/tlsutil/testdata/nomad-foo.pem" 230 fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" 231 ) 232 dir := tmpDir(t) 233 defer os.RemoveAll(dir) 234 235 s1, cleanupS1 := TestServer(t, func(c *Config) { 236 c.DataDir = path.Join(dir, "node1") 237 c.TLSConfig = &config.TLSConfig{ 238 EnableRPC: true, 239 VerifyServerHostname: true, 240 CAFile: cafile, 241 CertFile: foocert, 242 KeyFile: fookey, 243 } 244 }) 245 defer cleanupS1() 246 247 codec := rpcClient(t, s1) 248 249 node := mock.Node() 250 req := &structs.NodeRegisterRequest{ 251 Node: node, 252 WriteRequest: structs.WriteRequest{Region: "global"}, 253 } 254 255 var resp structs.GenericResponse 256 err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp) 257 assert.NotNil(err) 258 } 259 260 func TestRPC_streamingRpcConn_badMethod(t *testing.T) { 261 t.Parallel() 262 require := require.New(t) 263 264 s1, cleanupS1 := TestServer(t, func(c *Config) { 265 c.BootstrapExpect = 2 266 }) 267 defer cleanupS1() 268 s2, cleanupS2 := TestServer(t, func(c *Config) { 269 c.BootstrapExpect = 2 270 }) 271 defer cleanupS2() 272 TestJoin(t, s1, s2) 273 testutil.WaitForLeader(t, s1.RPC) 274 testutil.WaitForLeader(t, s2.RPC) 275 276 s1.peerLock.RLock() 277 ok, parts := isNomadServer(s2.LocalMember()) 278 require.True(ok) 279 server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] 280 require.NotNil(server) 281 s1.peerLock.RUnlock() 282 283 conn, err := s1.streamingRpc(server, "Bogus") 284 require.Nil(conn) 285 require.NotNil(err) 286 require.Contains(err.Error(), "Bogus") 287 require.True(structs.IsErrUnknownMethod(err)) 288 } 289 290 func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) { 291 t.Parallel() 292 require := require.New(t) 293 294 const ( 295 cafile = "../helper/tlsutil/testdata/ca.pem" 296 foocert = "../helper/tlsutil/testdata/nomad-foo.pem" 297 fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" 298 ) 299 dir := tmpDir(t) 300 defer os.RemoveAll(dir) 301 s1, cleanupS1 := TestServer(t, func(c *Config) { 302 c.Region = "regionFoo" 303 c.BootstrapExpect = 2 304 c.DevMode = false 305 c.DataDir = path.Join(dir, "node1") 306 c.TLSConfig = &config.TLSConfig{ 307 EnableHTTP: true, 308 EnableRPC: true, 309 VerifyServerHostname: true, 310 CAFile: cafile, 311 CertFile: foocert, 312 KeyFile: fookey, 313 } 314 }) 315 defer cleanupS1() 316 317 s2, cleanupS2 := TestServer(t, func(c *Config) { 318 c.Region = "regionFoo" 319 c.BootstrapExpect = 2 320 c.DevMode = false 321 c.DataDir = path.Join(dir, "node2") 322 c.TLSConfig = &config.TLSConfig{ 323 EnableHTTP: true, 324 EnableRPC: true, 325 VerifyServerHostname: true, 326 CAFile: cafile, 327 CertFile: foocert, 328 KeyFile: fookey, 329 } 330 }) 331 defer cleanupS2() 332 333 TestJoin(t, s1, s2) 334 testutil.WaitForLeader(t, s1.RPC) 335 336 s1.peerLock.RLock() 337 ok, parts := isNomadServer(s2.LocalMember()) 338 require.True(ok) 339 server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] 340 require.NotNil(server) 341 s1.peerLock.RUnlock() 342 343 conn, err := s1.streamingRpc(server, "Bogus") 344 require.Nil(conn) 345 require.NotNil(err) 346 require.Contains(err.Error(), "Bogus") 347 require.True(structs.IsErrUnknownMethod(err)) 348 } 349 350 func TestRPC_streamingRpcConn_goodMethod_Plaintext(t *testing.T) { 351 t.Parallel() 352 require := require.New(t) 353 dir := tmpDir(t) 354 defer os.RemoveAll(dir) 355 s1, cleanupS1 := TestServer(t, func(c *Config) { 356 c.Region = "regionFoo" 357 c.BootstrapExpect = 2 358 c.DevMode = false 359 c.DataDir = path.Join(dir, "node1") 360 }) 361 defer cleanupS1() 362 363 s2, cleanupS2 := TestServer(t, func(c *Config) { 364 c.Region = "regionFoo" 365 c.BootstrapExpect = 2 366 c.DevMode = false 367 c.DataDir = path.Join(dir, "node2") 368 }) 369 defer cleanupS2() 370 371 TestJoin(t, s1, s2) 372 testutil.WaitForLeader(t, s1.RPC) 373 374 s1.peerLock.RLock() 375 ok, parts := isNomadServer(s2.LocalMember()) 376 require.True(ok) 377 server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] 378 require.NotNil(server) 379 s1.peerLock.RUnlock() 380 381 conn, err := s1.streamingRpc(server, "FileSystem.Logs") 382 require.NotNil(conn) 383 require.NoError(err) 384 385 decoder := codec.NewDecoder(conn, structs.MsgpackHandle) 386 encoder := codec.NewEncoder(conn, structs.MsgpackHandle) 387 388 allocID := uuid.Generate() 389 require.NoError(encoder.Encode(cstructs.FsStreamRequest{ 390 AllocID: allocID, 391 QueryOptions: structs.QueryOptions{ 392 Region: "regionFoo", 393 }, 394 })) 395 396 var result cstructs.StreamErrWrapper 397 require.NoError(decoder.Decode(&result)) 398 require.Empty(result.Payload) 399 require.True(structs.IsErrUnknownAllocation(result.Error)) 400 } 401 402 func TestRPC_streamingRpcConn_goodMethod_TLS(t *testing.T) { 403 t.Parallel() 404 require := require.New(t) 405 406 const ( 407 cafile = "../helper/tlsutil/testdata/ca.pem" 408 foocert = "../helper/tlsutil/testdata/nomad-foo.pem" 409 fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" 410 ) 411 dir := tmpDir(t) 412 defer os.RemoveAll(dir) 413 s1, cleanupS1 := TestServer(t, func(c *Config) { 414 c.Region = "regionFoo" 415 c.BootstrapExpect = 2 416 c.DevMode = false 417 c.DataDir = path.Join(dir, "node1") 418 c.TLSConfig = &config.TLSConfig{ 419 EnableHTTP: true, 420 EnableRPC: true, 421 VerifyServerHostname: true, 422 CAFile: cafile, 423 CertFile: foocert, 424 KeyFile: fookey, 425 } 426 }) 427 defer cleanupS1() 428 429 s2, cleanupS2 := TestServer(t, func(c *Config) { 430 c.Region = "regionFoo" 431 c.BootstrapExpect = 2 432 c.DevMode = false 433 c.DataDir = path.Join(dir, "node2") 434 c.TLSConfig = &config.TLSConfig{ 435 EnableHTTP: true, 436 EnableRPC: true, 437 VerifyServerHostname: true, 438 CAFile: cafile, 439 CertFile: foocert, 440 KeyFile: fookey, 441 } 442 }) 443 defer cleanupS2() 444 445 TestJoin(t, s1, s2) 446 testutil.WaitForLeader(t, s1.RPC) 447 448 s1.peerLock.RLock() 449 ok, parts := isNomadServer(s2.LocalMember()) 450 require.True(ok) 451 server := s1.localPeers[raft.ServerAddress(parts.Addr.String())] 452 require.NotNil(server) 453 s1.peerLock.RUnlock() 454 455 conn, err := s1.streamingRpc(server, "FileSystem.Logs") 456 require.NotNil(conn) 457 require.NoError(err) 458 459 decoder := codec.NewDecoder(conn, structs.MsgpackHandle) 460 encoder := codec.NewEncoder(conn, structs.MsgpackHandle) 461 462 allocID := uuid.Generate() 463 require.NoError(encoder.Encode(cstructs.FsStreamRequest{ 464 AllocID: allocID, 465 QueryOptions: structs.QueryOptions{ 466 Region: "regionFoo", 467 }, 468 })) 469 470 var result cstructs.StreamErrWrapper 471 require.NoError(decoder.Decode(&result)) 472 require.Empty(result.Payload) 473 require.True(structs.IsErrUnknownAllocation(result.Error)) 474 } 475 476 // COMPAT: Remove in 0.10 477 // This is a very low level test to assert that the V2 handling works. It is 478 // making manual RPC calls since no helpers exist at this point since we are 479 // only implementing support for v2 but not using it yet. In the future we can 480 // switch the conn pool to establishing v2 connections and we can deprecate this 481 // test. 482 func TestRPC_handleMultiplexV2(t *testing.T) { 483 t.Parallel() 484 require := require.New(t) 485 486 s, cleanupS := TestServer(t, nil) 487 defer cleanupS() 488 testutil.WaitForLeader(t, s.RPC) 489 490 p1, p2 := net.Pipe() 491 defer p1.Close() 492 defer p2.Close() 493 494 // Start the handler 495 doneCh := make(chan struct{}) 496 go func() { 497 s.handleConn(context.Background(), p2, &RPCContext{Conn: p2}) 498 close(doneCh) 499 }() 500 501 // Establish the MultiplexV2 connection 502 _, err := p1.Write([]byte{byte(pool.RpcMultiplexV2)}) 503 require.Nil(err) 504 505 // Make two streams 506 conf := yamux.DefaultConfig() 507 conf.LogOutput = nil 508 conf.Logger = testlog.Logger(t) 509 session, err := yamux.Client(p1, conf) 510 require.Nil(err) 511 512 s1, err := session.Open() 513 require.Nil(err) 514 defer s1.Close() 515 516 s2, err := session.Open() 517 require.Nil(err) 518 defer s2.Close() 519 520 // Make an RPC 521 _, err = s1.Write([]byte{byte(pool.RpcNomad)}) 522 require.Nil(err) 523 524 args := &structs.GenericRequest{} 525 var l string 526 err = msgpackrpc.CallWithCodec(pool.NewClientCodec(s1), "Status.Leader", args, &l) 527 require.Nil(err) 528 require.NotEmpty(l) 529 530 // Make a streaming RPC 531 _, err = s2.Write([]byte{byte(pool.RpcStreaming)}) 532 require.Nil(err) 533 534 _, err = s.streamingRpcImpl(s2, "Bogus") 535 require.NotNil(err) 536 require.Contains(err.Error(), "Bogus") 537 require.True(structs.IsErrUnknownMethod(err)) 538 539 } 540 541 // TestRPC_TLS_in_TLS asserts that trying to nest TLS connections fails. 542 func TestRPC_TLS_in_TLS(t *testing.T) { 543 t.Parallel() 544 545 const ( 546 cafile = "../helper/tlsutil/testdata/ca.pem" 547 foocert = "../helper/tlsutil/testdata/nomad-foo.pem" 548 fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" 549 ) 550 551 s, cleanup := TestServer(t, func(c *Config) { 552 c.TLSConfig = &config.TLSConfig{ 553 EnableRPC: true, 554 CAFile: cafile, 555 CertFile: foocert, 556 KeyFile: fookey, 557 } 558 }) 559 defer func() { 560 cleanup() 561 562 //TODO Avoid panics from logging during shutdown 563 time.Sleep(1 * time.Second) 564 }() 565 566 conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) 567 require.NoError(t, err) 568 defer conn.Close() 569 570 _, err = conn.Write([]byte{byte(pool.RpcTLS)}) 571 require.NoError(t, err) 572 573 // Client TLS verification isn't necessary for 574 // our assertions 575 tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true) 576 require.NoError(t, err) 577 outTLSConf, err := tlsConf.OutgoingTLSConfig() 578 require.NoError(t, err) 579 outTLSConf.InsecureSkipVerify = true 580 581 // Do initial handshake 582 tlsConn := tls.Client(conn, outTLSConf) 583 require.NoError(t, tlsConn.Handshake()) 584 conn = tlsConn 585 586 // Try to create a nested TLS connection 587 _, err = conn.Write([]byte{byte(pool.RpcTLS)}) 588 require.NoError(t, err) 589 590 // Attempts at nested TLS connections should cause a disconnect 591 buf := []byte{0} 592 conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 593 n, err := conn.Read(buf) 594 require.Zero(t, n) 595 require.Equal(t, io.EOF, err) 596 } 597 598 // TestRPC_Limits_OK asserts that all valid limits combinations 599 // (tls/timeout/conns) work. 600 // 601 // Invalid limits are tested in command/agent/agent_test.go 602 func TestRPC_Limits_OK(t *testing.T) { 603 t.Parallel() 604 605 const ( 606 cafile = "../helper/tlsutil/testdata/ca.pem" 607 foocert = "../helper/tlsutil/testdata/nomad-foo.pem" 608 fookey = "../helper/tlsutil/testdata/nomad-foo-key.pem" 609 maxConns = 10 // limit must be < this for testing 610 ) 611 612 cases := []struct { 613 tls bool 614 timeout time.Duration 615 limit int 616 assertTimeout bool 617 assertLimit bool 618 }{ 619 { 620 tls: false, 621 timeout: 5 * time.Second, 622 limit: 0, 623 assertTimeout: true, 624 assertLimit: false, 625 }, 626 { 627 tls: true, 628 timeout: 5 * time.Second, 629 limit: 0, 630 assertTimeout: true, 631 assertLimit: false, 632 }, 633 { 634 tls: false, 635 timeout: 0, 636 limit: 0, 637 assertTimeout: false, 638 assertLimit: false, 639 }, 640 { 641 tls: true, 642 timeout: 0, 643 limit: 0, 644 assertTimeout: false, 645 assertLimit: false, 646 }, 647 { 648 tls: false, 649 timeout: 0, 650 limit: 2, 651 assertTimeout: false, 652 assertLimit: true, 653 }, 654 { 655 tls: true, 656 timeout: 0, 657 limit: 2, 658 assertTimeout: false, 659 assertLimit: true, 660 }, 661 { 662 tls: false, 663 timeout: 5 * time.Second, 664 limit: 2, 665 assertTimeout: true, 666 assertLimit: true, 667 }, 668 { 669 tls: true, 670 timeout: 5 * time.Second, 671 limit: 2, 672 assertTimeout: true, 673 assertLimit: true, 674 }, 675 } 676 677 assertTimeout := func(t *testing.T, s *Server, useTLS bool, timeout time.Duration) { 678 // Increase timeout to detect timeouts 679 clientTimeout := timeout + time.Second 680 681 conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second) 682 require.NoError(t, err) 683 defer conn.Close() 684 685 buf := []byte{0} 686 readDeadline := time.Now().Add(clientTimeout) 687 conn.SetReadDeadline(readDeadline) 688 n, err := conn.Read(buf) 689 require.Zero(t, n) 690 if timeout == 0 { 691 // Server should *not* have timed out. 692 // Now() should always be after the client read deadline, but 693 // isn't a sufficient assertion for correctness as slow tests 694 // may cause this to be true even if the server timed out. 695 now := time.Now() 696 require.Truef(t, now.After(readDeadline), 697 "Client read deadline (%s) should be in the past (before %s)", readDeadline, now) 698 699 testutil.RequireDeadlineErr(t, err) 700 return 701 } 702 703 // Server *should* have timed out (EOF) 704 require.Equal(t, io.EOF, err) 705 706 // Create a new connection to assert timeout doesn't 707 // apply after first byte. 708 conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second) 709 require.NoError(t, err) 710 defer conn.Close() 711 712 if useTLS { 713 _, err := conn.Write([]byte{byte(pool.RpcTLS)}) 714 require.NoError(t, err) 715 716 // Client TLS verification isn't necessary for 717 // our assertions 718 tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true) 719 require.NoError(t, err) 720 outTLSConf, err := tlsConf.OutgoingTLSConfig() 721 require.NoError(t, err) 722 outTLSConf.InsecureSkipVerify = true 723 724 tlsConn := tls.Client(conn, outTLSConf) 725 require.NoError(t, tlsConn.Handshake()) 726 727 conn = tlsConn 728 } 729 730 // Writing the Nomad RPC byte should be sufficient to 731 // disable the handshake timeout 732 n, err = conn.Write([]byte{byte(pool.RpcNomad)}) 733 require.NoError(t, err) 734 require.Equal(t, 1, n) 735 736 // Read should timeout due to client timeout, not 737 // server's timeout 738 readDeadline = time.Now().Add(clientTimeout) 739 conn.SetReadDeadline(readDeadline) 740 n, err = conn.Read(buf) 741 require.Zero(t, n) 742 testutil.RequireDeadlineErr(t, err) 743 } 744 745 assertNoLimit := func(t *testing.T, addr string) { 746 var err error 747 748 // Create max connections 749 conns := make([]net.Conn, maxConns) 750 errCh := make(chan error, maxConns) 751 for i := 0; i < maxConns; i++ { 752 conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second) 753 require.NoError(t, err) 754 defer conns[i].Close() 755 756 go func(i int) { 757 buf := []byte{0} 758 readDeadline := time.Now().Add(1 * time.Second) 759 conns[i].SetReadDeadline(readDeadline) 760 n, err := conns[i].Read(buf) 761 if n > 0 { 762 errCh <- fmt.Errorf("n > 0: %d", n) 763 return 764 } 765 errCh <- err 766 }(i) 767 } 768 769 // Now assert each error is a clientside read deadline error 770 deadline := time.After(10 * time.Second) 771 for i := 0; i < maxConns; i++ { 772 select { 773 case <-deadline: 774 t.Fatalf("timed out waiting for conn error %d/%d", i+1, maxConns) 775 case err := <-errCh: 776 testutil.RequireDeadlineErr(t, err) 777 } 778 } 779 } 780 781 assertLimit := func(t *testing.T, addr string, limit int) { 782 var err error 783 784 // Create limit connections 785 conns := make([]net.Conn, limit) 786 errCh := make(chan error, limit) 787 for i := range conns { 788 conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second) 789 require.NoError(t, err) 790 defer conns[i].Close() 791 792 go func(i int) { 793 buf := []byte{0} 794 n, err := conns[i].Read(buf) 795 if n > 0 { 796 errCh <- fmt.Errorf("n > 0: %d", n) 797 return 798 } 799 errCh <- err 800 }(i) 801 } 802 803 // Assert a new connection is dropped 804 conn, err := net.DialTimeout("tcp", addr, 1*time.Second) 805 require.NoError(t, err) 806 defer conn.Close() 807 808 buf := []byte{0} 809 deadline := time.Now().Add(6 * time.Second) 810 conn.SetReadDeadline(deadline) 811 n, err := conn.Read(buf) 812 require.Zero(t, n) 813 require.Equal(t, io.EOF, err) 814 815 // Assert existing connections are ok 816 ERRCHECK: 817 select { 818 case err := <-errCh: 819 t.Errorf("unexpected error from idle connection: (%T) %v", err, err) 820 goto ERRCHECK 821 default: 822 } 823 824 // Cleanup 825 for _, conn := range conns { 826 conn.Close() 827 } 828 for i := range conns { 829 select { 830 case err := <-errCh: 831 require.Contains(t, err.Error(), "use of closed network connection") 832 case <-time.After(10 * time.Second): 833 t.Fatalf("timed out waiting for connection %d/%d to close", i, len(conns)) 834 } 835 } 836 } 837 838 for i := range cases { 839 tc := cases[i] 840 name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit) 841 t.Run(name, func(t *testing.T) { 842 t.Parallel() 843 844 if tc.limit >= maxConns { 845 t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", tc.limit, maxConns) 846 } 847 if tc.assertTimeout && tc.timeout == 0 { 848 t.Fatalf("test fixture failure: cannot assert timeout when no timeout set (0)") 849 } 850 851 s, cleanup := TestServer(t, func(c *Config) { 852 if tc.tls { 853 c.TLSConfig = &config.TLSConfig{ 854 EnableRPC: true, 855 CAFile: cafile, 856 CertFile: foocert, 857 KeyFile: fookey, 858 } 859 } 860 c.RPCHandshakeTimeout = tc.timeout 861 c.RPCMaxConnsPerClient = tc.limit 862 }) 863 defer func() { 864 cleanup() 865 866 //TODO Avoid panics from logging during shutdown 867 time.Sleep(1 * time.Second) 868 }() 869 870 assertTimeout(t, s, tc.tls, tc.timeout) 871 872 if tc.assertLimit { 873 // There's a race between assertTimeout(false) closing 874 // its connection and the HTTP server noticing and 875 // untracking it. Since there's no way to coordiante 876 // when this occurs, sleeping is the only way to avoid 877 // asserting limits before the timed out connection is 878 // untracked. 879 time.Sleep(1 * time.Second) 880 881 assertLimit(t, s.config.RPCAddr.String(), tc.limit) 882 } else { 883 assertNoLimit(t, s.config.RPCAddr.String()) 884 } 885 }) 886 } 887 } 888 889 // TestRPC_Limits_Streaming asserts that the streaming RPC limit is lower than 890 // the overall connection limit to prevent DOS via server-routed streaming API 891 // calls. 892 func TestRPC_Limits_Streaming(t *testing.T) { 893 t.Parallel() 894 895 s, cleanup := TestServer(t, func(c *Config) { 896 limits := config.DefaultLimits() 897 c.RPCMaxConnsPerClient = *limits.RPCMaxConnsPerClient 898 }) 899 defer func() { 900 cleanup() 901 902 //TODO Avoid panics from logging during shutdown 903 time.Sleep(1 * time.Second) 904 }() 905 906 ctx, cancel := context.WithCancel(context.Background()) 907 defer cancel() 908 errCh := make(chan error, 1) 909 910 // Create a streaming connection 911 dialStreamer := func() net.Conn { 912 conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second) 913 require.NoError(t, err) 914 915 _, err = conn.Write([]byte{byte(pool.RpcStreaming)}) 916 require.NoError(t, err) 917 return conn 918 } 919 920 // Create up to the limit streaming connections 921 streamers := make([]net.Conn, s.config.RPCMaxConnsPerClient-config.LimitsNonStreamingConnsPerClient) 922 for i := range streamers { 923 streamers[i] = dialStreamer() 924 925 go func(i int) { 926 // Streamer should never die until test exits 927 buf := []byte{0} 928 _, err := streamers[i].Read(buf) 929 if ctx.Err() != nil { 930 // Error is expected when test finishes 931 return 932 } 933 934 t.Logf("connection %d died with error: (%T) %v", i, err, err) 935 936 // Send unexpected errors back 937 if err != nil { 938 select { 939 case errCh <- err: 940 case <-ctx.Done(): 941 default: 942 // Only send first error 943 } 944 } 945 }(i) 946 } 947 948 defer func() { 949 cancel() 950 for _, conn := range streamers { 951 conn.Close() 952 } 953 }() 954 955 // Assert no streamer errors have occurred 956 select { 957 case err := <-errCh: 958 t.Fatalf("unexpected error from blocking streaming RPCs: (%T) %v", err, err) 959 case <-time.After(500 * time.Millisecond): 960 // Ok! No connections were rejected immediately. 961 } 962 963 // Assert subsequent streaming RPC are rejected 964 conn := dialStreamer() 965 t.Logf("expect connection to be rejected due to limit") 966 buf := []byte{0} 967 conn.SetReadDeadline(time.Now().Add(3 * time.Second)) 968 _, err := conn.Read(buf) 969 require.Equalf(t, io.EOF, err, "expected io.EOF but found: (%T) %v", err, err) 970 971 // Assert no streamer errors have occurred 972 select { 973 case err := <-errCh: 974 t.Fatalf("unexpected error from blocking streaming RPCs: %v", err) 975 default: 976 } 977 978 // Subsequent non-streaming RPC should be OK 979 conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second) 980 require.NoError(t, err) 981 _, err = conn.Write([]byte{byte(pool.RpcNomad)}) 982 require.NoError(t, err) 983 984 conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 985 _, err = conn.Read(buf) 986 testutil.RequireDeadlineErr(t, err) 987 988 // Close 1 streamer and assert another is allowed 989 t.Logf("expect streaming connection 0 to exit with error") 990 streamers[0].Close() 991 <-errCh 992 993 // Assert that new connections are allowed. 994 // Due to the distributed nature here, server may not immediately recognize 995 // the connection closure, so first attempts may be rejections (i.e. EOF) 996 // but the first non-EOF request must be a read-deadline error 997 testutil.WaitForResult(func() (bool, error) { 998 conn = dialStreamer() 999 conn.SetReadDeadline(time.Now().Add(1 * time.Second)) 1000 _, err = conn.Read(buf) 1001 if err == io.EOF { 1002 return false, fmt.Errorf("connection was rejected") 1003 } 1004 1005 testutil.RequireDeadlineErr(t, err) 1006 return true, nil 1007 }, func(err error) { 1008 require.NoError(t, err) 1009 }) 1010 }