github.com/bigcommerce/nomad@v0.9.3-bc/nomad/rpc_test.go (about)

     1  package nomad
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/rpc"
     7  	"os"
     8  	"path"
     9  	"testing"
    10  	"time"
    11  
    12  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    13  	"github.com/hashicorp/nomad/helper/pool"
    14  	"github.com/hashicorp/nomad/helper/testlog"
    15  	"github.com/hashicorp/nomad/nomad/mock"
    16  	"github.com/hashicorp/nomad/nomad/structs"
    17  	"github.com/hashicorp/nomad/nomad/structs/config"
    18  	"github.com/hashicorp/nomad/testutil"
    19  	"github.com/hashicorp/raft"
    20  	"github.com/hashicorp/yamux"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  // rpcClient is a test helper method to return a ClientCodec to use to make rpc
    26  // calls to the passed server.
    27  func rpcClient(t *testing.T, s *Server) rpc.ClientCodec {
    28  	addr := s.config.RPCAddr
    29  	conn, err := net.DialTimeout("tcp", addr.String(), time.Second)
    30  	if err != nil {
    31  		t.Fatalf("err: %v", err)
    32  	}
    33  	// Write the Nomad RPC byte to set the mode
    34  	conn.Write([]byte{byte(pool.RpcNomad)})
    35  	return pool.NewClientCodec(conn)
    36  }
    37  
    38  func TestRPC_forwardLeader(t *testing.T) {
    39  	t.Parallel()
    40  	s1 := TestServer(t, nil)
    41  	defer s1.Shutdown()
    42  	s2 := TestServer(t, func(c *Config) {
    43  		c.DevDisableBootstrap = true
    44  	})
    45  	defer s2.Shutdown()
    46  	TestJoin(t, s1, s2)
    47  	testutil.WaitForLeader(t, s1.RPC)
    48  	testutil.WaitForLeader(t, s2.RPC)
    49  
    50  	isLeader, remote := s1.getLeader()
    51  	if !isLeader && remote == nil {
    52  		t.Fatalf("missing leader")
    53  	}
    54  
    55  	if remote != nil {
    56  		var out struct{}
    57  		err := s1.forwardLeader(remote, "Status.Ping", struct{}{}, &out)
    58  		if err != nil {
    59  			t.Fatalf("err: %v", err)
    60  		}
    61  	}
    62  
    63  	isLeader, remote = s2.getLeader()
    64  	if !isLeader && remote == nil {
    65  		t.Fatalf("missing leader")
    66  	}
    67  
    68  	if remote != nil {
    69  		var out struct{}
    70  		err := s2.forwardLeader(remote, "Status.Ping", struct{}{}, &out)
    71  		if err != nil {
    72  			t.Fatalf("err: %v", err)
    73  		}
    74  	}
    75  }
    76  
    77  func TestRPC_forwardRegion(t *testing.T) {
    78  	t.Parallel()
    79  	s1 := TestServer(t, nil)
    80  	defer s1.Shutdown()
    81  	s2 := TestServer(t, func(c *Config) {
    82  		c.Region = "region2"
    83  	})
    84  	defer s2.Shutdown()
    85  	TestJoin(t, s1, s2)
    86  	testutil.WaitForLeader(t, s1.RPC)
    87  	testutil.WaitForLeader(t, s2.RPC)
    88  
    89  	var out struct{}
    90  	err := s1.forwardRegion("region2", "Status.Ping", struct{}{}, &out)
    91  	if err != nil {
    92  		t.Fatalf("err: %v", err)
    93  	}
    94  
    95  	err = s2.forwardRegion("global", "Status.Ping", struct{}{}, &out)
    96  	if err != nil {
    97  		t.Fatalf("err: %v", err)
    98  	}
    99  }
   100  
   101  func TestRPC_PlaintextRPCSucceedsWhenInUpgradeMode(t *testing.T) {
   102  	t.Parallel()
   103  	assert := assert.New(t)
   104  
   105  	const (
   106  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   107  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   108  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   109  	)
   110  	dir := tmpDir(t)
   111  	defer os.RemoveAll(dir)
   112  
   113  	s1 := TestServer(t, func(c *Config) {
   114  		c.DataDir = path.Join(dir, "node1")
   115  		c.TLSConfig = &config.TLSConfig{
   116  			EnableRPC:            true,
   117  			VerifyServerHostname: true,
   118  			CAFile:               cafile,
   119  			CertFile:             foocert,
   120  			KeyFile:              fookey,
   121  			RPCUpgradeMode:       true,
   122  		}
   123  	})
   124  	defer s1.Shutdown()
   125  
   126  	codec := rpcClient(t, s1)
   127  
   128  	// Create the register request
   129  	node := mock.Node()
   130  	req := &structs.NodeRegisterRequest{
   131  		Node:         node,
   132  		WriteRequest: structs.WriteRequest{Region: "global"},
   133  	}
   134  
   135  	var resp structs.GenericResponse
   136  	err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp)
   137  	assert.Nil(err)
   138  
   139  	// Check that heartbeatTimers has the heartbeat ID
   140  	_, ok := s1.heartbeatTimers[node.ID]
   141  	assert.True(ok)
   142  }
   143  
   144  func TestRPC_PlaintextRPCFailsWhenNotInUpgradeMode(t *testing.T) {
   145  	t.Parallel()
   146  	assert := assert.New(t)
   147  
   148  	const (
   149  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   150  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   151  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   152  	)
   153  	dir := tmpDir(t)
   154  	defer os.RemoveAll(dir)
   155  
   156  	s1 := TestServer(t, func(c *Config) {
   157  		c.DataDir = path.Join(dir, "node1")
   158  		c.TLSConfig = &config.TLSConfig{
   159  			EnableRPC:            true,
   160  			VerifyServerHostname: true,
   161  			CAFile:               cafile,
   162  			CertFile:             foocert,
   163  			KeyFile:              fookey,
   164  		}
   165  	})
   166  	defer s1.Shutdown()
   167  
   168  	codec := rpcClient(t, s1)
   169  
   170  	node := mock.Node()
   171  	req := &structs.NodeRegisterRequest{
   172  		Node:         node,
   173  		WriteRequest: structs.WriteRequest{Region: "global"},
   174  	}
   175  
   176  	var resp structs.GenericResponse
   177  	err := msgpackrpc.CallWithCodec(codec, "Node.Register", req, &resp)
   178  	assert.NotNil(err)
   179  }
   180  
   181  func TestRPC_streamingRpcConn_badMethod(t *testing.T) {
   182  	t.Parallel()
   183  	require := require.New(t)
   184  
   185  	s1 := TestServer(t, nil)
   186  	defer s1.Shutdown()
   187  	s2 := TestServer(t, func(c *Config) {
   188  		c.DevDisableBootstrap = true
   189  	})
   190  	defer s2.Shutdown()
   191  	TestJoin(t, s1, s2)
   192  	testutil.WaitForLeader(t, s1.RPC)
   193  	testutil.WaitForLeader(t, s2.RPC)
   194  
   195  	s1.peerLock.RLock()
   196  	ok, parts := isNomadServer(s2.LocalMember())
   197  	require.True(ok)
   198  	server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
   199  	require.NotNil(server)
   200  	s1.peerLock.RUnlock()
   201  
   202  	conn, err := s1.streamingRpc(server, "Bogus")
   203  	require.Nil(conn)
   204  	require.NotNil(err)
   205  	require.Contains(err.Error(), "Bogus")
   206  	require.True(structs.IsErrUnknownMethod(err))
   207  }
   208  
   209  func TestRPC_streamingRpcConn_badMethod_TLS(t *testing.T) {
   210  	t.Parallel()
   211  	require := require.New(t)
   212  	const (
   213  		cafile  = "../helper/tlsutil/testdata/ca.pem"
   214  		foocert = "../helper/tlsutil/testdata/nomad-foo.pem"
   215  		fookey  = "../helper/tlsutil/testdata/nomad-foo-key.pem"
   216  	)
   217  	dir := tmpDir(t)
   218  	defer os.RemoveAll(dir)
   219  	s1 := TestServer(t, func(c *Config) {
   220  		c.Region = "regionFoo"
   221  		c.BootstrapExpect = 2
   222  		c.DevMode = false
   223  		c.DevDisableBootstrap = true
   224  		c.DataDir = path.Join(dir, "node1")
   225  		c.TLSConfig = &config.TLSConfig{
   226  			EnableHTTP:           true,
   227  			EnableRPC:            true,
   228  			VerifyServerHostname: true,
   229  			CAFile:               cafile,
   230  			CertFile:             foocert,
   231  			KeyFile:              fookey,
   232  		}
   233  	})
   234  	defer s1.Shutdown()
   235  
   236  	s2 := TestServer(t, func(c *Config) {
   237  		c.Region = "regionFoo"
   238  		c.BootstrapExpect = 2
   239  		c.DevMode = false
   240  		c.DevDisableBootstrap = true
   241  		c.DataDir = path.Join(dir, "node2")
   242  		c.TLSConfig = &config.TLSConfig{
   243  			EnableHTTP:           true,
   244  			EnableRPC:            true,
   245  			VerifyServerHostname: true,
   246  			CAFile:               cafile,
   247  			CertFile:             foocert,
   248  			KeyFile:              fookey,
   249  		}
   250  	})
   251  	defer s2.Shutdown()
   252  
   253  	TestJoin(t, s1, s2)
   254  	testutil.WaitForLeader(t, s1.RPC)
   255  
   256  	s1.peerLock.RLock()
   257  	ok, parts := isNomadServer(s2.LocalMember())
   258  	require.True(ok)
   259  	server := s1.localPeers[raft.ServerAddress(parts.Addr.String())]
   260  	require.NotNil(server)
   261  	s1.peerLock.RUnlock()
   262  
   263  	conn, err := s1.streamingRpc(server, "Bogus")
   264  	require.Nil(conn)
   265  	require.NotNil(err)
   266  	require.Contains(err.Error(), "Bogus")
   267  	require.True(structs.IsErrUnknownMethod(err))
   268  }
   269  
   270  // COMPAT: Remove in 0.10
   271  // This is a very low level test to assert that the V2 handling works. It is
   272  // making manual RPC calls since no helpers exist at this point since we are
   273  // only implementing support for v2 but not using it yet. In the future we can
   274  // switch the conn pool to establishing v2 connections and we can deprecate this
   275  // test.
   276  func TestRPC_handleMultiplexV2(t *testing.T) {
   277  	t.Parallel()
   278  	require := require.New(t)
   279  	s := TestServer(t, nil)
   280  	defer s.Shutdown()
   281  	testutil.WaitForLeader(t, s.RPC)
   282  
   283  	p1, p2 := net.Pipe()
   284  	defer p1.Close()
   285  	defer p2.Close()
   286  
   287  	// Start the handler
   288  	doneCh := make(chan struct{})
   289  	go func() {
   290  		s.handleConn(context.Background(), p2, &RPCContext{Conn: p2})
   291  		close(doneCh)
   292  	}()
   293  
   294  	// Establish the MultiplexV2 connection
   295  	_, err := p1.Write([]byte{byte(pool.RpcMultiplexV2)})
   296  	require.Nil(err)
   297  
   298  	// Make two streams
   299  	conf := yamux.DefaultConfig()
   300  	conf.LogOutput = nil
   301  	conf.Logger = testlog.Logger(t)
   302  	session, err := yamux.Client(p1, conf)
   303  	require.Nil(err)
   304  
   305  	s1, err := session.Open()
   306  	require.Nil(err)
   307  	defer s1.Close()
   308  
   309  	s2, err := session.Open()
   310  	require.Nil(err)
   311  	defer s2.Close()
   312  
   313  	// Make an RPC
   314  	_, err = s1.Write([]byte{byte(pool.RpcNomad)})
   315  	require.Nil(err)
   316  
   317  	args := &structs.GenericRequest{}
   318  	var l string
   319  	err = msgpackrpc.CallWithCodec(pool.NewClientCodec(s1), "Status.Leader", args, &l)
   320  	require.Nil(err)
   321  	require.NotEmpty(l)
   322  
   323  	// Make a streaming RPC
   324  	err = s.streamingRpcImpl(s2, s.Region(), "Bogus")
   325  	require.NotNil(err)
   326  	require.Contains(err.Error(), "Bogus")
   327  	require.True(structs.IsErrUnknownMethod(err))
   328  
   329  }