github.com/adityamillind98/nomad@v0.11.8/nomad/rpc_test.go (about)

     1  package nomad
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/rpc"
    10  	"os"
    11  	"path"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/hashicorp/go-msgpack/codec"
    16  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    17  	cstructs "github.com/hashicorp/nomad/client/structs"
    18  	"github.com/hashicorp/nomad/helper/pool"
    19  	"github.com/hashicorp/nomad/helper/testlog"
    20  	"github.com/hashicorp/nomad/helper/tlsutil"
    21  	"github.com/hashicorp/nomad/helper/uuid"
    22  	"github.com/hashicorp/nomad/nomad/mock"
    23  	"github.com/hashicorp/nomad/nomad/structs"
    24  	"github.com/hashicorp/nomad/nomad/structs/config"
    25  	"github.com/hashicorp/nomad/testutil"
    26  	"github.com/hashicorp/raft"
    27  	"github.com/hashicorp/yamux"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  // rpcClient is a test helper method to return a ClientCodec to use to make rpc
    33  // calls to the passed server.
    34  func rpcClient(t *testing.T, s *Server) rpc.ClientCodec {
    35  	addr := s.config.RPCAddr
    36  	conn, err := net.DialTimeout("tcp", addr.String(), time.Second)
    37  	if err != nil {
    38  		t.Fatalf("err: %v", err)
    39  	}
    40  	// Write the Nomad RPC byte to set the mode
    41  	conn.Write([]byte{byte(pool.RpcNomad)})
    42  	return pool.NewClientCodec(conn)
    43  }
    44  
    45  func TestRPC_forwardLeader(t *testing.T) {
    46  	t.Parallel()
    47  
    48  	s1, cleanupS1 := TestServer(t, func(c *Config) {
    49  		c.BootstrapExpect = 2
    50  	})
    51  	defer cleanupS1()
    52  	s2, cleanupS2 := TestServer(t, func(c *Config) {
    53  		c.BootstrapExpect = 2
    54  	})
    55  	defer cleanupS2()
    56  	TestJoin(t, s1, s2)
    57  	testutil.WaitForLeader(t, s1.RPC)
    58  	testutil.WaitForLeader(t, s2.RPC)
    59  
    60  	isLeader, remote := s1.getLeader()
    61  	if !isLeader && remote == nil {
    62  		t.Fatalf("missing leader")
    63  	}
    64  
    65  	if remote != nil {
    66  		var out struct{}
    67  		err := s1.forwardLeader(remote, "Status.Ping", struct{}{}, &out)
    68  		if err != nil {
    69  			t.Fatalf("err: %v", err)
    70  		}
    71  	}
    72  
    73  	isLeader, remote = s2.getLeader()
    74  	if !isLeader && remote == nil {
    75  		t.Fatalf("missing leader")
    76  	}
    77  
    78  	if remote != nil {
    79  		var out struct{}
    80  		err := s2.forwardLeader(remote, "Status.Ping", struct{}{}, &out)
    81  		if err != nil {
    82  			t.Fatalf("err: %v", err)
    83  		}
    84  	}
    85  }
    86  
    87  func TestRPC_WaitForConsistentReads(t *testing.T) {
    88  	t.Parallel()
    89  
    90  	s1, cleanupS2 := TestServer(t, func(c *Config) {
    91  		c.RPCHoldTimeout = 20 * time.Millisecond
    92  	})
    93  	defer cleanupS2()
    94  	testutil.WaitForLeader(t, s1.RPC)
    95  
    96  	isLeader, _ := s1.getLeader()
    97  	require.True(t, isLeader)
    98  	require.True(t, s1.isReadyForConsistentReads())
    99  
   100  	s1.resetConsistentReadReady()
   101  	require.False(t, s1.isReadyForConsistentReads())
   102  
   103  	codec := rpcClient(t, s1)
   104  
   105  	get := &structs.JobListRequest{
   106  		QueryOptions: structs.QueryOptions{
   107  			Region:    "global",
   108  			Namespace: "default",
   109  		},
   110  	}
   111  
   112  	// check timeout while waiting for consistency
   113  	var resp structs.JobListResponse
   114  	err := msgpackrpc.CallWithCodec(codec, "Job.List", get, &resp)
   115  	require.Error(t, err)
   116  	require.Contains(t, err.Error(), structs.ErrNotReadyForConsistentReads.Error())
   117  
   118  	// check we wait and block
   119  	go func() {
   120  		time.Sleep(5 * time.Millisecond)
   121  		s1.setConsistentReadReady()
   122  	}()
   123  
   124  	err = msgpackrpc.CallWithCodec(codec, "Job.List", get, &resp)
   125  	require.NoError(t, err)
   126  
   127  }
   128  
   129  func TestRPC_forwardRegion(t *testing.T) {
   130  	t.Parallel()
   131  
   132  	s1, cleanupS1 := TestServer(t, nil)
   133  	defer cleanupS1()
   134  	s2, cleanupS2 := TestServer(t, func(c *Config) {
   135  		c.Region = "global"
   136  	})
   137  	defer cleanupS2()
   138  	TestJoin(t, s1, s2)
   139  	testutil.WaitForLeader(t, s1.RPC)
   140  	testutil.WaitForLeader(t, s2.RPC)
   141  
   142  	var out struct{}
   143  	err := s1.forwardRegion("global", "Status.Ping", struct{}{}, &out)
   144  	if err != nil {
   145  		t.Fatalf("err: %v", err)
   146  	}
   147  
   148  	err = s2.forwardRegion("global", "Status.Ping", struct{}{}, &out)
   149  	if err != nil {
   150  		t.Fatalf("err: %v", err)
   151  	}
   152  }
   153  
   154  func TestRPC_getServer(t *testing.T) {
   155  	t.Parallel()
   156  
   157  	s1, cleanupS1 := TestServer(t, nil)
   158  	defer cleanupS1()
   159  	s2, cleanupS2 := TestServer(t, func(c *Config) {
   160  		c.Region = "global"
   161  	})
   162  	defer cleanupS2()
   163  	TestJoin(t, s1, s2)
   164  	testutil.WaitForLeader(t, s1.RPC)
   165  	testutil.WaitForLeader(t, s2.RPC)
   166  
   167  	// Lookup by name
   168  	srv, err := s1.getServer("global", s2.serf.LocalMember().Name)
   169  	require.NoError(t, err)
   170  
   171  	require.Equal(t, srv.Name, s2.serf.LocalMember().Name)
   172  
   173  	// Lookup by id
   174  	srv, err = s2.getServer("global", s1.serf.LocalMember().Tags["id"])
   175  	require.NoError(t, err)
   176  
   177  	require.Equal(t, srv.Name, s1.serf.LocalMember().Name)
   178  }
   179  
   180  func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) {
   181  	t.Parallel()
   182  	assert := assert.New(t)
   183  
   184  	const (
   185  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   186  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   187  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   188  	)
   189  	dir := tmpDir(t)
   190  	defer os.RemoveAll(dir)
   191  
   192  	s1, cleanupS1 := TestServer(t, func(c *Config) {
   193  		c.DataDir = path.Join(dir, "node1")
   194  		c.TLSConfig = &config.TLSConfig{
   195  			EnableRPC:            true,
   196  			VerifyServerHostname: true,
   197  			CAFile:               cafile,
   198  			CertFile:             foocert,
   199  			KeyFile:              fookey,
   200  			RPCUpgradeMode:       true,
   201  		}
   202  	})
   203  	defer cleanupS1()
   204  
   205  	codec := rpcClient(t, s1)
   206  
   207  	// Create the register request
   208  	node := mock.Node()
   209  	req := &structs.NodeRegisterRequest{
   210  		Node:         node,
   211  		WriteRequest: structs.WriteRequest{Region: "global"},
   212  	}
   213  
   214  	var resp structs.GenericResponse
   215  	err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp)
   216  	assert.Nil(err)
   217  
   218  	// Check that heartbeatTimers has the heartbeat ID
   219  	_, ok := s1.heartbeatTimers[node.ID]
   220  	assert.True(ok)
   221  }
   222  
   223  func TestRPC_PlaintextRPCFailsWhenNotInUpgradeMode(t *testing.T) {
   224  	t.Parallel()
   225  	assert := assert.New(t)
   226  
   227  	const (
   228  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   229  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   230  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   231  	)
   232  	dir := tmpDir(t)
   233  	defer os.RemoveAll(dir)
   234  
   235  	s1, cleanupS1 := TestServer(t, func(c *Config) {
   236  		c.DataDir = path.Join(dir, "node1")
   237  		c.TLSConfig = &config.TLSConfig{
   238  			EnableRPC:            true,
   239  			VerifyServerHostname: true,
   240  			CAFile:               cafile,
   241  			CertFile:             foocert,
   242  			KeyFile:              fookey,
   243  		}
   244  	})
   245  	defer cleanupS1()
   246  
   247  	codec := rpcClient(t, s1)
   248  
   249  	node := mock.Node()
   250  	req := &structs.NodeRegisterRequest{
   251  		Node:         node,
   252  		WriteRequest: structs.WriteRequest{Region: "global"},
   253  	}
   254  
   255  	var resp structs.GenericResponse
   256  	err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp)
   257  	assert.NotNil(err)
   258  }
   259  
   260  func TestRPC_streamingRpcConn_badMethod(t *testing.T) {
   261  	t.Parallel()
   262  	require := require.New(t)
   263  
   264  	s1, cleanupS1 := TestServer(t, func(c *Config) {
   265  		c.BootstrapExpect = 2
   266  	})
   267  	defer cleanupS1()
   268  	s2, cleanupS2 := TestServer(t, func(c *Config) {
   269  		c.BootstrapExpect = 2
   270  	})
   271  	defer cleanupS2()
   272  	TestJoin(t, s1, s2)
   273  	testutil.WaitForLeader(t, s1.RPC)
   274  	testutil.WaitForLeader(t, s2.RPC)
   275  
   276  	s1.peerLock.RLock()
   277  	ok, parts := isNomadServer(s2.LocalMember())
   278  	require.True(ok)
   279  	server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
   280  	require.NotNil(server)
   281  	s1.peerLock.RUnlock()
   282  
   283  	conn, err := s1.streamingRpc(server, "Bogus")
   284  	require.Nil(conn)
   285  	require.NotNil(err)
   286  	require.Contains(err.Error(), "Bogus")
   287  	require.True(structs.IsErrUnknownMethod(err))
   288  }
   289  
   290  func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
   291  	t.Parallel()
   292  	require := require.New(t)
   293  
   294  	const (
   295  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   296  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   297  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   298  	)
   299  	dir := tmpDir(t)
   300  	defer os.RemoveAll(dir)
   301  	s1, cleanupS1 := TestServer(t, func(c *Config) {
   302  		c.Region = "regionFoo"
   303  		c.BootstrapExpect = 2
   304  		c.DevMode = false
   305  		c.DataDir = path.Join(dir, "node1")
   306  		c.TLSConfig = &config.TLSConfig{
   307  			EnableHTTP:           true,
   308  			EnableRPC:            true,
   309  			VerifyServerHostname: true,
   310  			CAFile:               cafile,
   311  			CertFile:             foocert,
   312  			KeyFile:              fookey,
   313  		}
   314  	})
   315  	defer cleanupS1()
   316  
   317  	s2, cleanupS2 := TestServer(t, func(c *Config) {
   318  		c.Region = "regionFoo"
   319  		c.BootstrapExpect = 2
   320  		c.DevMode = false
   321  		c.DataDir = path.Join(dir, "node2")
   322  		c.TLSConfig = &config.TLSConfig{
   323  			EnableHTTP:           true,
   324  			EnableRPC:            true,
   325  			VerifyServerHostname: true,
   326  			CAFile:               cafile,
   327  			CertFile:             foocert,
   328  			KeyFile:              fookey,
   329  		}
   330  	})
   331  	defer cleanupS2()
   332  
   333  	TestJoin(t, s1, s2)
   334  	testutil.WaitForLeader(t, s1.RPC)
   335  
   336  	s1.peerLock.RLock()
   337  	ok, parts := isNomadServer(s2.LocalMember())
   338  	require.True(ok)
   339  	server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
   340  	require.NotNil(server)
   341  	s1.peerLock.RUnlock()
   342  
   343  	conn, err := s1.streamingRpc(server, "Bogus")
   344  	require.Nil(conn)
   345  	require.NotNil(err)
   346  	require.Contains(err.Error(), "Bogus")
   347  	require.True(structs.IsErrUnknownMethod(err))
   348  }
   349  
   350  func TestRPC_streamingRpcConn_goodMethod_Plaintext(t *testing.T) {
   351  	t.Parallel()
   352  	require := require.New(t)
   353  	dir := tmpDir(t)
   354  	defer os.RemoveAll(dir)
   355  	s1, cleanupS1 := TestServer(t, func(c *Config) {
   356  		c.Region = "regionFoo"
   357  		c.BootstrapExpect = 2
   358  		c.DevMode = false
   359  		c.DataDir = path.Join(dir, "node1")
   360  	})
   361  	defer cleanupS1()
   362  
   363  	s2, cleanupS2 := TestServer(t, func(c *Config) {
   364  		c.Region = "regionFoo"
   365  		c.BootstrapExpect = 2
   366  		c.DevMode = false
   367  		c.DataDir = path.Join(dir, "node2")
   368  	})
   369  	defer cleanupS2()
   370  
   371  	TestJoin(t, s1, s2)
   372  	testutil.WaitForLeader(t, s1.RPC)
   373  
   374  	s1.peerLock.RLock()
   375  	ok, parts := isNomadServer(s2.LocalMember())
   376  	require.True(ok)
   377  	server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
   378  	require.NotNil(server)
   379  	s1.peerLock.RUnlock()
   380  
   381  	conn, err := s1.streamingRpc(server, "FileSystem.Logs")
   382  	require.NotNil(conn)
   383  	require.NoError(err)
   384  
   385  	decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
   386  	encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
   387  
   388  	allocID := uuid.Generate()
   389  	require.NoError(encoder.Encode(cstructs.FsStreamRequest{
   390  		AllocID: allocID,
   391  		QueryOptions: structs.QueryOptions{
   392  			Region: "regionFoo",
   393  		},
   394  	}))
   395  
   396  	var result cstructs.StreamErrWrapper
   397  	require.NoError(decoder.Decode(&result))
   398  	require.Empty(result.Payload)
   399  	require.True(structs.IsErrUnknownAllocation(result.Error))
   400  }
   401  
   402  func TestRPC_streamingRpcConn_goodMethod_TLS(t *testing.T) {
   403  	t.Parallel()
   404  	require := require.New(t)
   405  
   406  	const (
   407  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   408  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   409  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   410  	)
   411  	dir := tmpDir(t)
   412  	defer os.RemoveAll(dir)
   413  	s1, cleanupS1 := TestServer(t, func(c *Config) {
   414  		c.Region = "regionFoo"
   415  		c.BootstrapExpect = 2
   416  		c.DevMode = false
   417  		c.DataDir = path.Join(dir, "node1")
   418  		c.TLSConfig = &config.TLSConfig{
   419  			EnableHTTP:           true,
   420  			EnableRPC:            true,
   421  			VerifyServerHostname: true,
   422  			CAFile:               cafile,
   423  			CertFile:             foocert,
   424  			KeyFile:              fookey,
   425  		}
   426  	})
   427  	defer cleanupS1()
   428  
   429  	s2, cleanupS2 := TestServer(t, func(c *Config) {
   430  		c.Region = "regionFoo"
   431  		c.BootstrapExpect = 2
   432  		c.DevMode = false
   433  		c.DataDir = path.Join(dir, "node2")
   434  		c.TLSConfig = &config.TLSConfig{
   435  			EnableHTTP:           true,
   436  			EnableRPC:            true,
   437  			VerifyServerHostname: true,
   438  			CAFile:               cafile,
   439  			CertFile:             foocert,
   440  			KeyFile:              fookey,
   441  		}
   442  	})
   443  	defer cleanupS2()
   444  
   445  	TestJoin(t, s1, s2)
   446  	testutil.WaitForLeader(t, s1.RPC)
   447  
   448  	s1.peerLock.RLock()
   449  	ok, parts := isNomadServer(s2.LocalMember())
   450  	require.True(ok)
   451  	server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
   452  	require.NotNil(server)
   453  	s1.peerLock.RUnlock()
   454  
   455  	conn, err := s1.streamingRpc(server, "FileSystem.Logs")
   456  	require.NotNil(conn)
   457  	require.NoError(err)
   458  
   459  	decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
   460  	encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
   461  
   462  	allocID := uuid.Generate()
   463  	require.NoError(encoder.Encode(cstructs.FsStreamRequest{
   464  		AllocID: allocID,
   465  		QueryOptions: structs.QueryOptions{
   466  			Region: "regionFoo",
   467  		},
   468  	}))
   469  
   470  	var result cstructs.StreamErrWrapper
   471  	require.NoError(decoder.Decode(&result))
   472  	require.Empty(result.Payload)
   473  	require.True(structs.IsErrUnknownAllocation(result.Error))
   474  }
   475  
   476  // COMPAT: Remove in 0.10
   477  // This is a very low level test to assert that the V2 handling works. It is
   478  // making manual RPC calls since no helpers exist at this point since we are
   479  // only implementing support for v2 but not using it yet. In the future we can
   480  // switch the conn pool to establishing v2 connections and we can deprecate this
   481  // test.
   482  func TestRPC_handleMultiplexV2(t *testing.T) {
   483  	t.Parallel()
   484  	require := require.New(t)
   485  
   486  	s, cleanupS := TestServer(t, nil)
   487  	defer cleanupS()
   488  	testutil.WaitForLeader(t, s.RPC)
   489  
   490  	p1, p2 := net.Pipe()
   491  	defer p1.Close()
   492  	defer p2.Close()
   493  
   494  	// Start the handler
   495  	doneCh := make(chan struct{})
   496  	go func() {
   497  		s.handleConn(context.Background(), p2, &RPCContext{Conn: p2})
   498  		close(doneCh)
   499  	}()
   500  
   501  	// Establish the MultiplexV2 connection
   502  	_, err := p1.Write([]byte{byte(pool.RpcMultiplexV2)})
   503  	require.Nil(err)
   504  
   505  	// Make two streams
   506  	conf := yamux.DefaultConfig()
   507  	conf.LogOutput = nil
   508  	conf.Logger = testlog.Logger(t)
   509  	session, err := yamux.Client(p1, conf)
   510  	require.Nil(err)
   511  
   512  	s1, err := session.Open()
   513  	require.Nil(err)
   514  	defer s1.Close()
   515  
   516  	s2, err := session.Open()
   517  	require.Nil(err)
   518  	defer s2.Close()
   519  
   520  	// Make an RPC
   521  	_, err = s1.Write([]byte{byte(pool.RpcNomad)})
   522  	require.Nil(err)
   523  
   524  	args := &structs.GenericRequest{}
   525  	var l string
   526  	err = msgpackrpc.CallWithCodec(pool.NewClientCodec(s1), "Status.Leader", args, &l)
   527  	require.Nil(err)
   528  	require.NotEmpty(l)
   529  
   530  	// Make a streaming RPC
   531  	_, err = s2.Write([]byte{byte(pool.RpcStreaming)})
   532  	require.Nil(err)
   533  
   534  	_, err = s.streamingRpcImpl(s2, "Bogus")
   535  	require.NotNil(err)
   536  	require.Contains(err.Error(), "Bogus")
   537  	require.True(structs.IsErrUnknownMethod(err))
   538  
   539  }
   540  
   541  // TestRPC_TLS_in_TLS asserts that trying to nest TLS connections fails.
   542  func TestRPC_TLS_in_TLS(t *testing.T) {
   543  	t.Parallel()
   544  
   545  	const (
   546  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   547  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   548  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   549  	)
   550  
   551  	s, cleanup := TestServer(t, func(c *Config) {
   552  		c.TLSConfig = &config.TLSConfig{
   553  			EnableRPC: true,
   554  			CAFile:    cafile,
   555  			CertFile:  foocert,
   556  			KeyFile:   fookey,
   557  		}
   558  	})
   559  	defer func() {
   560  		cleanup()
   561  
   562  		//TODO Avoid panics from logging during shutdown
   563  		time.Sleep(1 * time.Second)
   564  	}()
   565  
   566  	conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
   567  	require.NoError(t, err)
   568  	defer conn.Close()
   569  
   570  	_, err = conn.Write([]byte{byte(pool.RpcTLS)})
   571  	require.NoError(t, err)
   572  
   573  	// Client TLS verification isn't necessary for
   574  	// our assertions
   575  	tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true)
   576  	require.NoError(t, err)
   577  	outTLSConf, err := tlsConf.OutgoingTLSConfig()
   578  	require.NoError(t, err)
   579  	outTLSConf.InsecureSkipVerify = true
   580  
   581  	// Do initial handshake
   582  	tlsConn := tls.Client(conn, outTLSConf)
   583  	require.NoError(t, tlsConn.Handshake())
   584  	conn = tlsConn
   585  
   586  	// Try to create a nested TLS connection
   587  	_, err = conn.Write([]byte{byte(pool.RpcTLS)})
   588  	require.NoError(t, err)
   589  
   590  	// Attempts at nested TLS connections should cause a disconnect
   591  	buf := []byte{0}
   592  	conn.SetReadDeadline(time.Now().Add(1 * time.Second))
   593  	n, err := conn.Read(buf)
   594  	require.Zero(t, n)
   595  	require.Equal(t, io.EOF, err)
   596  }
   597  
   598  // TestRPC_Limits_OK asserts that all valid limits combinations
   599  // (tls/timeout/conns) work.
   600  //
   601  // Invalid limits are tested in command/agent/agent_test.go
   602  func TestRPC_Limits_OK(t *testing.T) {
   603  	t.Parallel()
   604  
   605  	const (
   606  		cafile   = "../helper/tlsutil/testdata/ca.pem"
   607  		foocert  = "../helper/tlsutil/testdata/nomad-foo.pem"
   608  		fookey   = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   609  		maxConns = 10 // limit must be < this for testing
   610  	)
   611  
   612  	cases := []struct {
   613  		tls           bool
   614  		timeout       time.Duration
   615  		limit         int
   616  		assertTimeout bool
   617  		assertLimit   bool
   618  	}{
   619  		{
   620  			tls:           false,
   621  			timeout:       5 * time.Second,
   622  			limit:         0,
   623  			assertTimeout: true,
   624  			assertLimit:   false,
   625  		},
   626  		{
   627  			tls:           true,
   628  			timeout:       5 * time.Second,
   629  			limit:         0,
   630  			assertTimeout: true,
   631  			assertLimit:   false,
   632  		},
   633  		{
   634  			tls:           false,
   635  			timeout:       0,
   636  			limit:         0,
   637  			assertTimeout: false,
   638  			assertLimit:   false,
   639  		},
   640  		{
   641  			tls:           true,
   642  			timeout:       0,
   643  			limit:         0,
   644  			assertTimeout: false,
   645  			assertLimit:   false,
   646  		},
   647  		{
   648  			tls:           false,
   649  			timeout:       0,
   650  			limit:         2,
   651  			assertTimeout: false,
   652  			assertLimit:   true,
   653  		},
   654  		{
   655  			tls:           true,
   656  			timeout:       0,
   657  			limit:         2,
   658  			assertTimeout: false,
   659  			assertLimit:   true,
   660  		},
   661  		{
   662  			tls:           false,
   663  			timeout:       5 * time.Second,
   664  			limit:         2,
   665  			assertTimeout: true,
   666  			assertLimit:   true,
   667  		},
   668  		{
   669  			tls:           true,
   670  			timeout:       5 * time.Second,
   671  			limit:         2,
   672  			assertTimeout: true,
   673  			assertLimit:   true,
   674  		},
   675  	}
   676  
   677  	assertTimeout := func(t *testing.T, s *Server, useTLS bool, timeout time.Duration) {
   678  		// Increase timeout to detect timeouts
   679  		clientTimeout := timeout + time.Second
   680  
   681  		conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second)
   682  		require.NoError(t, err)
   683  		defer conn.Close()
   684  
   685  		buf := []byte{0}
   686  		readDeadline := time.Now().Add(clientTimeout)
   687  		conn.SetReadDeadline(readDeadline)
   688  		n, err := conn.Read(buf)
   689  		require.Zero(t, n)
   690  		if timeout == 0 {
   691  			// Server should *not* have timed out.
   692  			// Now() should always be after the client read deadline, but
   693  			// isn't a sufficient assertion for correctness as slow tests
   694  			// may cause this to be true even if the server timed out.
   695  			now := time.Now()
   696  			require.Truef(t, now.After(readDeadline),
   697  				"Client read deadline (%s) should be in the past (before %s)", readDeadline, now)
   698  
   699  			testutil.RequireDeadlineErr(t, err)
   700  			return
   701  		}
   702  
   703  		// Server *should* have timed out (EOF)
   704  		require.Equal(t, io.EOF, err)
   705  
   706  		// Create a new connection to assert timeout doesn't
   707  		// apply after first byte.
   708  		conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), time.Second)
   709  		require.NoError(t, err)
   710  		defer conn.Close()
   711  
   712  		if useTLS {
   713  			_, err := conn.Write([]byte{byte(pool.RpcTLS)})
   714  			require.NoError(t, err)
   715  
   716  			// Client TLS verification isn't necessary for
   717  			// our assertions
   718  			tlsConf, err := tlsutil.NewTLSConfiguration(s.config.TLSConfig, false, true)
   719  			require.NoError(t, err)
   720  			outTLSConf, err := tlsConf.OutgoingTLSConfig()
   721  			require.NoError(t, err)
   722  			outTLSConf.InsecureSkipVerify = true
   723  
   724  			tlsConn := tls.Client(conn, outTLSConf)
   725  			require.NoError(t, tlsConn.Handshake())
   726  
   727  			conn = tlsConn
   728  		}
   729  
   730  		// Writing the Nomad RPC byte should be sufficient to
   731  		// disable the handshake timeout
   732  		n, err = conn.Write([]byte{byte(pool.RpcNomad)})
   733  		require.NoError(t, err)
   734  		require.Equal(t, 1, n)
   735  
   736  		// Read should timeout due to client timeout, not
   737  		// server's timeout
   738  		readDeadline = time.Now().Add(clientTimeout)
   739  		conn.SetReadDeadline(readDeadline)
   740  		n, err = conn.Read(buf)
   741  		require.Zero(t, n)
   742  		testutil.RequireDeadlineErr(t, err)
   743  	}
   744  
   745  	assertNoLimit := func(t *testing.T, addr string) {
   746  		var err error
   747  
   748  		// Create max connections
   749  		conns := make([]net.Conn, maxConns)
   750  		errCh := make(chan error, maxConns)
   751  		for i := 0; i < maxConns; i++ {
   752  			conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
   753  			require.NoError(t, err)
   754  			defer conns[i].Close()
   755  
   756  			go func(i int) {
   757  				buf := []byte{0}
   758  				readDeadline := time.Now().Add(1 * time.Second)
   759  				conns[i].SetReadDeadline(readDeadline)
   760  				n, err := conns[i].Read(buf)
   761  				if n > 0 {
   762  					errCh <- fmt.Errorf("n > 0: %d", n)
   763  					return
   764  				}
   765  				errCh <- err
   766  			}(i)
   767  		}
   768  
   769  		// Now assert each error is a clientside read deadline error
   770  		deadline := time.After(10 * time.Second)
   771  		for i := 0; i < maxConns; i++ {
   772  			select {
   773  			case <-deadline:
   774  				t.Fatalf("timed out waiting for conn error %d/%d", i+1, maxConns)
   775  			case err := <-errCh:
   776  				testutil.RequireDeadlineErr(t, err)
   777  			}
   778  		}
   779  	}
   780  
   781  	assertLimit := func(t *testing.T, addr string, limit int) {
   782  		var err error
   783  
   784  		// Create limit connections
   785  		conns := make([]net.Conn, limit)
   786  		errCh := make(chan error, limit)
   787  		for i := range conns {
   788  			conns[i], err = net.DialTimeout("tcp", addr, 1*time.Second)
   789  			require.NoError(t, err)
   790  			defer conns[i].Close()
   791  
   792  			go func(i int) {
   793  				buf := []byte{0}
   794  				n, err := conns[i].Read(buf)
   795  				if n > 0 {
   796  					errCh <- fmt.Errorf("n > 0: %d", n)
   797  					return
   798  				}
   799  				errCh <- err
   800  			}(i)
   801  		}
   802  
   803  		// Assert a new connection is dropped
   804  		conn, err := net.DialTimeout("tcp", addr, 1*time.Second)
   805  		require.NoError(t, err)
   806  		defer conn.Close()
   807  
   808  		buf := []byte{0}
   809  		deadline := time.Now().Add(6 * time.Second)
   810  		conn.SetReadDeadline(deadline)
   811  		n, err := conn.Read(buf)
   812  		require.Zero(t, n)
   813  		require.Equal(t, io.EOF, err)
   814  
   815  		// Assert existing connections are ok
   816  	ERRCHECK:
   817  		select {
   818  		case err := <-errCh:
   819  			t.Errorf("unexpected error from idle connection: (%T) %v", err, err)
   820  			goto ERRCHECK
   821  		default:
   822  		}
   823  
   824  		// Cleanup
   825  		for _, conn := range conns {
   826  			conn.Close()
   827  		}
   828  		for i := range conns {
   829  			select {
   830  			case err := <-errCh:
   831  				require.Contains(t, err.Error(), "use of closed network connection")
   832  			case <-time.After(10 * time.Second):
   833  				t.Fatalf("timed out waiting for connection %d/%d to close", i, len(conns))
   834  			}
   835  		}
   836  	}
   837  
   838  	for i := range cases {
   839  		tc := cases[i]
   840  		name := fmt.Sprintf("%d-tls-%t-timeout-%s-limit-%v", i, tc.tls, tc.timeout, tc.limit)
   841  		t.Run(name, func(t *testing.T) {
   842  			t.Parallel()
   843  
   844  			if tc.limit >= maxConns {
   845  				t.Fatalf("test fixture failure: cannot assert limit (%d) >= max (%d)", tc.limit, maxConns)
   846  			}
   847  			if tc.assertTimeout && tc.timeout == 0 {
   848  				t.Fatalf("test fixture failure: cannot assert timeout when no timeout set (0)")
   849  			}
   850  
   851  			s, cleanup := TestServer(t, func(c *Config) {
   852  				if tc.tls {
   853  					c.TLSConfig = &config.TLSConfig{
   854  						EnableRPC: true,
   855  						CAFile:    cafile,
   856  						CertFile:  foocert,
   857  						KeyFile:   fookey,
   858  					}
   859  				}
   860  				c.RPCHandshakeTimeout = tc.timeout
   861  				c.RPCMaxConnsPerClient = tc.limit
   862  			})
   863  			defer func() {
   864  				cleanup()
   865  
   866  				//TODO Avoid panics from logging during shutdown
   867  				time.Sleep(1 * time.Second)
   868  			}()
   869  
   870  			assertTimeout(t, s, tc.tls, tc.timeout)
   871  
   872  			if tc.assertLimit {
   873  				// There's a race between assertTimeout(false) closing
   874  				// its connection and the HTTP server noticing and
   875  				// untracking it. Since there's no way to coordiante
   876  				// when this occurs, sleeping is the only way to avoid
   877  				// asserting limits before the timed out connection is
   878  				// untracked.
   879  				time.Sleep(1 * time.Second)
   880  
   881  				assertLimit(t, s.config.RPCAddr.String(), tc.limit)
   882  			} else {
   883  				assertNoLimit(t, s.config.RPCAddr.String())
   884  			}
   885  		})
   886  	}
   887  }
   888  
   889  // TestRPC_Limits_Streaming asserts that the streaming RPC limit is lower than
   890  // the overall connection limit to prevent DOS via server-routed streaming API
   891  // calls.
   892  func TestRPC_Limits_Streaming(t *testing.T) {
   893  	t.Parallel()
   894  
   895  	s, cleanup := TestServer(t, func(c *Config) {
   896  		limits := config.DefaultLimits()
   897  		c.RPCMaxConnsPerClient = *limits.RPCMaxConnsPerClient
   898  	})
   899  	defer func() {
   900  		cleanup()
   901  
   902  		//TODO Avoid panics from logging during shutdown
   903  		time.Sleep(1 * time.Second)
   904  	}()
   905  
   906  	ctx, cancel := context.WithCancel(context.Background())
   907  	defer cancel()
   908  	errCh := make(chan error, 1)
   909  
   910  	// Create a streaming connection
   911  	dialStreamer := func() net.Conn {
   912  		conn, err := net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second)
   913  		require.NoError(t, err)
   914  
   915  		_, err = conn.Write([]byte{byte(pool.RpcStreaming)})
   916  		require.NoError(t, err)
   917  		return conn
   918  	}
   919  
   920  	// Create up to the limit streaming connections
   921  	streamers := make([]net.Conn, s.config.RPCMaxConnsPerClient-config.LimitsNonStreamingConnsPerClient)
   922  	for i := range streamers {
   923  		streamers[i] = dialStreamer()
   924  
   925  		go func(i int) {
   926  			// Streamer should never die until test exits
   927  			buf := []byte{0}
   928  			_, err := streamers[i].Read(buf)
   929  			if ctx.Err() != nil {
   930  				// Error is expected when test finishes
   931  				return
   932  			}
   933  
   934  			t.Logf("connection %d died with error: (%T) %v", i, err, err)
   935  
   936  			// Send unexpected errors back
   937  			if err != nil {
   938  				select {
   939  				case errCh <- err:
   940  				case <-ctx.Done():
   941  				default:
   942  					// Only send first error
   943  				}
   944  			}
   945  		}(i)
   946  	}
   947  
   948  	defer func() {
   949  		cancel()
   950  		for _, conn := range streamers {
   951  			conn.Close()
   952  		}
   953  	}()
   954  
   955  	// Assert no streamer errors have occurred
   956  	select {
   957  	case err := <-errCh:
   958  		t.Fatalf("unexpected error from blocking streaming RPCs: (%T) %v", err, err)
   959  	case <-time.After(500 * time.Millisecond):
   960  		// Ok! No connections were rejected immediately.
   961  	}
   962  
   963  	// Assert subsequent streaming RPC are rejected
   964  	conn := dialStreamer()
   965  	t.Logf("expect connection to be rejected due to limit")
   966  	buf := []byte{0}
   967  	conn.SetReadDeadline(time.Now().Add(3 * time.Second))
   968  	_, err := conn.Read(buf)
   969  	require.Equalf(t, io.EOF, err, "expected io.EOF but found: (%T) %v", err, err)
   970  
   971  	// Assert no streamer errors have occurred
   972  	select {
   973  	case err := <-errCh:
   974  		t.Fatalf("unexpected error from blocking streaming RPCs: %v", err)
   975  	default:
   976  	}
   977  
   978  	// Subsequent non-streaming RPC should be OK
   979  	conn, err = net.DialTimeout("tcp", s.config.RPCAddr.String(), 1*time.Second)
   980  	require.NoError(t, err)
   981  	_, err = conn.Write([]byte{byte(pool.RpcNomad)})
   982  	require.NoError(t, err)
   983  
   984  	conn.SetReadDeadline(time.Now().Add(1 * time.Second))
   985  	_, err = conn.Read(buf)
   986  	testutil.RequireDeadlineErr(t, err)
   987  
   988  	// Close 1 streamer and assert another is allowed
   989  	t.Logf("expect streaming connection 0 to exit with error")
   990  	streamers[0].Close()
   991  	<-errCh
   992  
   993  	// Assert that new connections are allowed.
   994  	// Due to the distributed nature here, server may not immediately recognize
   995  	// the connection closure, so first attempts may be rejections (i.e. EOF)
   996  	// but the first non-EOF request must be a read-deadline error
   997  	testutil.WaitForResult(func() (bool, error) {
   998  		conn = dialStreamer()
   999  		conn.SetReadDeadline(time.Now().Add(1 * time.Second))
  1000  		_, err = conn.Read(buf)
  1001  		if err == io.EOF {
  1002  			return false, fmt.Errorf("connection was rejected")
  1003  		}
  1004  
  1005  		testutil.RequireDeadlineErr(t, err)
  1006  		return true, nil
  1007  	}, func(err error) {
  1008  		require.NoError(t, err)
  1009  	})
  1010  }