github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/rpc/nodedialer/nodedialer_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package nodedialer
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"math/rand"
    17  	"net"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  
    22  	circuit "github.com/cockroachdb/circuitbreaker"
    23  	"github.com/cockroachdb/cockroach/pkg/clusterversion"
    24  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    25  	"github.com/cockroachdb/cockroach/pkg/rpc"
    26  	"github.com/cockroachdb/cockroach/pkg/settings/cluster"
    27  	"github.com/cockroachdb/cockroach/pkg/testutils"
    28  	"github.com/cockroachdb/cockroach/pkg/util/hlc"
    29  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    30  	"github.com/cockroachdb/cockroach/pkg/util/log"
    31  	"github.com/cockroachdb/cockroach/pkg/util/stop"
    32  	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
    33  	"github.com/cockroachdb/cockroach/pkg/util/tracing"
    34  	"github.com/cockroachdb/cockroach/pkg/util/uuid"
    35  	"github.com/cockroachdb/errors"
    36  	"github.com/stretchr/testify/assert"
    37  	"google.golang.org/grpc"
    38  )
    39  
    40  const staticNodeID = 1
    41  
    42  func TestNodedialerPositive(t *testing.T) {
    43  	defer leaktest.AfterTest(t)()
    44  	stopper, _, _, _, nd := setUpNodedialerTest(t, staticNodeID)
    45  	defer stopper.Stop(context.Background())
    46  	// Ensure that dialing works.
    47  	breaker := nd.GetCircuitBreaker(1, rpc.DefaultClass)
    48  	assert.True(t, breaker.Ready())
    49  	ctx := context.Background()
    50  	_, err := nd.Dial(ctx, staticNodeID, rpc.DefaultClass)
    51  	assert.Nil(t, err, "failed to dial")
    52  	assert.True(t, breaker.Ready())
    53  	assert.Equal(t, breaker.Failures(), int64(0))
    54  }
    55  
    56  func TestDialNoBreaker(t *testing.T) {
    57  	defer leaktest.AfterTest(t)()
    58  
    59  	ctx := context.Background()
    60  
    61  	// Don't use setUpNodedialerTest because we want access to the underlying clock and rpcContext.
    62  	stopper := stop.NewStopper()
    63  	clock := hlc.NewClock(hlc.UnixNano, time.Nanosecond)
    64  	rpcCtx := newTestContext(clock, stopper)
    65  	rpcCtx.NodeID.Set(ctx, staticNodeID)
    66  	_, ln, _ := newTestServer(t, clock, stopper, true /* useHeartbeat */)
    67  	defer stopper.Stop(ctx)
    68  
    69  	// Test that DialNoBreaker is successful normally.
    70  	nd := New(rpcCtx, newSingleNodeResolver(staticNodeID, ln.Addr()))
    71  	testutils.SucceedsSoon(t, func() error {
    72  		return nd.ConnHealth(staticNodeID, rpc.DefaultClass)
    73  	})
    74  	breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass)
    75  	assert.True(t, breaker.Ready())
    76  	_, err := nd.DialNoBreaker(ctx, staticNodeID, rpc.DefaultClass)
    77  	assert.Nil(t, err, "failed to dial")
    78  	assert.True(t, breaker.Ready())
    79  	assert.Equal(t, breaker.Failures(), int64(0))
    80  
    81  	// Test that resolver errors don't trip the breaker.
    82  	boom := fmt.Errorf("boom")
    83  	nd = New(rpcCtx, func(roachpb.NodeID) (net.Addr, error) {
    84  		return nil, boom
    85  	})
    86  	breaker = nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass)
    87  	_, err = nd.DialNoBreaker(ctx, staticNodeID, rpc.DefaultClass)
    88  	assert.Equal(t, errors.Cause(err), boom)
    89  	assert.True(t, breaker.Ready())
    90  	assert.Equal(t, breaker.Failures(), int64(0))
    91  
    92  	// Test that connection errors don't trip the breaker either.
    93  	// To do this, we have to trick grpc into never successfully dialing
    94  	// the server, because if it succeeds once then it doesn't try again
    95  	// to perform a connection. To trick grpc in this way, we have to
    96  	// set up a server without the heartbeat service running. Without
    97  	// getting a heartbeat, the nodedialer will throw an error thinking
    98  	// that it wasn't able to successfully make a connection.
    99  	_, ln, _ = newTestServer(t, clock, stopper, false /* useHeartbeat */)
   100  	nd = New(rpcCtx, newSingleNodeResolver(staticNodeID, ln.Addr()))
   101  	breaker = nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass)
   102  	_, err = nd.DialNoBreaker(ctx, staticNodeID, rpc.DefaultClass)
   103  	assert.NotNil(t, err, "expected dial error")
   104  	assert.True(t, breaker.Ready())
   105  	assert.Equal(t, breaker.Failures(), int64(0))
   106  }
   107  
   108  func TestConcurrentCancellationAndTimeout(t *testing.T) {
   109  	defer leaktest.AfterTest(t)()
   110  	stopper, _, _, _, nd := setUpNodedialerTest(t, staticNodeID)
   111  	defer stopper.Stop(context.Background())
   112  	ctx := context.Background()
   113  	breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass)
   114  	// Test that when a context is canceled during dialing we always return that
   115  	// error but we never trip the breaker.
   116  	const N = 1000
   117  	var wg sync.WaitGroup
   118  	for i := 0; i < N; i++ {
   119  		wg.Add(2)
   120  		// Jiggle when we cancel relative to when we dial to try to hit cases where
   121  		// cancellation happens during the call to GRPCDial.
   122  		iCtx, cancel := context.WithTimeout(ctx, randDuration(time.Millisecond))
   123  		go func() {
   124  			time.Sleep(randDuration(time.Millisecond))
   125  			cancel()
   126  			wg.Done()
   127  		}()
   128  		go func() {
   129  			time.Sleep(randDuration(time.Millisecond))
   130  			_, err := nd.Dial(iCtx, 1, rpc.DefaultClass)
   131  			if err != nil &&
   132  				!errors.IsAny(err, context.Canceled, context.DeadlineExceeded) {
   133  				t.Errorf("got an unexpected error from Dial: %v", err)
   134  			}
   135  			wg.Done()
   136  		}()
   137  	}
   138  	wg.Wait()
   139  	assert.Equal(t, breaker.Failures(), int64(0))
   140  }
   141  
   142  func TestResolverErrorsTrip(t *testing.T) {
   143  	defer leaktest.AfterTest(t)()
   144  	stopper, rpcCtx, _, _, _ := setUpNodedialerTest(t, staticNodeID)
   145  	defer stopper.Stop(context.Background())
   146  	boom := fmt.Errorf("boom")
   147  	nd := New(rpcCtx, func(id roachpb.NodeID) (net.Addr, error) {
   148  		return nil, boom
   149  	})
   150  	_, err := nd.Dial(context.Background(), staticNodeID, rpc.DefaultClass)
   151  	assert.Equal(t, errors.Cause(err), boom)
   152  	breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass)
   153  	assert.False(t, breaker.Ready())
   154  }
   155  
   156  func TestDisconnectsTrip(t *testing.T) {
   157  	defer leaktest.AfterTest(t)()
   158  	stopper, _, ln, hb, nd := setUpNodedialerTest(t, staticNodeID)
   159  	defer stopper.Stop(context.Background())
   160  	ctx := context.Background()
   161  	breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass)
   162  
   163  	// Now close the underlying connection from the server side and set the
   164  	// heartbeat service to return errors. This will eventually lead to the client
   165  	// connection being removed and Dial attempts to return an error.
   166  	// While this is going on there will be many clients attempting to
   167  	// connect. These connecting clients will send interesting errors they observe
   168  	// on the errChan. Once an error from Dial is observed the test re-enables the
   169  	// heartbeat service. The test will confirm that the only errors they record
   170  	// in to the breaker are interesting ones as determined by shouldTrip.
   171  	hb.setErr(fmt.Errorf("boom"))
   172  	underlyingNetConn := ln.popConn()
   173  	assert.Nil(t, underlyingNetConn.Close())
   174  	const N = 1000
   175  	breakerEventChan := make(chan circuit.ListenerEvent, N)
   176  	breaker.AddListener(breakerEventChan)
   177  	errChan := make(chan error, N)
   178  	shouldTrip := func(err error) bool {
   179  		return err != nil &&
   180  			!errors.IsAny(err, context.DeadlineExceeded, context.Canceled, circuit.ErrBreakerOpen)
   181  	}
   182  	var wg sync.WaitGroup
   183  	for i := 0; i < N; i++ {
   184  		wg.Add(2)
   185  		iCtx, cancel := context.WithTimeout(ctx, randDuration(time.Millisecond))
   186  		go func() {
   187  			time.Sleep(randDuration(time.Millisecond))
   188  			cancel()
   189  			wg.Done()
   190  		}()
   191  		go func() {
   192  			time.Sleep(randDuration(time.Millisecond))
   193  			_, err := nd.Dial(iCtx, 1, rpc.DefaultClass)
   194  			if shouldTrip(err) {
   195  				errChan <- err
   196  			}
   197  			wg.Done()
   198  		}()
   199  	}
   200  	go func() { wg.Wait(); close(errChan) }()
   201  	var errorsSeen int
   202  	for range errChan {
   203  		if errorsSeen == 0 {
   204  			hb.setErr(nil)
   205  		}
   206  		errorsSeen++
   207  	}
   208  	breaker.RemoveListener(breakerEventChan)
   209  	close(breakerEventChan)
   210  	var failsSeen int
   211  	for ev := range breakerEventChan {
   212  		if ev.Event == circuit.BreakerFail {
   213  			failsSeen++
   214  		}
   215  	}
   216  	// Ensure that all of the interesting errors were seen by the breaker.
   217  	assert.Equal(t, errorsSeen, failsSeen)
   218  
   219  	// Ensure that the connection becomes healthy soon now that the heartbeat
   220  	// service is not returning errors.
   221  	hb.setErr(nil) // reset in case there were no errors
   222  	testutils.SucceedsSoon(t, func() error {
   223  		return nd.ConnHealth(staticNodeID, rpc.DefaultClass)
   224  	})
   225  }
   226  
   227  func setUpNodedialerTest(
   228  	t *testing.T, nodeID roachpb.NodeID,
   229  ) (
   230  	stopper *stop.Stopper,
   231  	rpcCtx *rpc.Context,
   232  	ln *interceptingListener,
   233  	hb *heartbeatService,
   234  	nd *Dialer,
   235  ) {
   236  	stopper = stop.NewStopper()
   237  	clock := hlc.NewClock(hlc.UnixNano, time.Nanosecond)
   238  	// Create an rpc Context and then
   239  	rpcCtx = newTestContext(clock, stopper)
   240  	rpcCtx.NodeID.Set(context.Background(), nodeID)
   241  	_, ln, hb = newTestServer(t, clock, stopper, true /* useHeartbeat */)
   242  	nd = New(rpcCtx, newSingleNodeResolver(nodeID, ln.Addr()))
   243  	testutils.SucceedsSoon(t, func() error {
   244  		return nd.ConnHealth(nodeID, rpc.DefaultClass)
   245  	})
   246  	return stopper, rpcCtx, ln, hb, nd
   247  }
   248  
   249  // randDuration returns a uniform random duration between 0 and max.
   250  func randDuration(max time.Duration) time.Duration {
   251  	return time.Duration(rand.Intn(int(max)))
   252  }
   253  
   254  func newTestServer(
   255  	t testing.TB, clock *hlc.Clock, stopper *stop.Stopper, useHeartbeat bool,
   256  ) (*grpc.Server, *interceptingListener, *heartbeatService) {
   257  	ctx := context.Background()
   258  	localAddr := "127.0.0.1:0"
   259  	ln, err := net.Listen("tcp", localAddr)
   260  	if err != nil {
   261  		t.Fatalf("failed to listed on %v: %v", localAddr, err)
   262  	}
   263  	il := &interceptingListener{Listener: ln}
   264  	s := grpc.NewServer()
   265  	var hb *heartbeatService
   266  	if useHeartbeat {
   267  		hb = &heartbeatService{
   268  			clock:         clock,
   269  			serverVersion: clusterversion.TestingBinaryVersion,
   270  		}
   271  		rpc.RegisterHeartbeatServer(s, hb)
   272  	}
   273  	if err := stopper.RunAsyncTask(ctx, "localServer", func(ctx context.Context) {
   274  		if err := s.Serve(il); err != nil {
   275  			log.Infof(ctx, "server stopped: %v", err)
   276  		}
   277  	}); err != nil {
   278  		t.Fatalf("failed to run test server: %v", err)
   279  	}
   280  	go func() { <-stopper.ShouldQuiesce(); s.Stop() }()
   281  	return s, il, hb
   282  }
   283  
   284  func newTestContext(clock *hlc.Clock, stopper *stop.Stopper) *rpc.Context {
   285  	cfg := testutils.NewNodeTestBaseContext()
   286  	cfg.Insecure = true
   287  	cfg.RPCHeartbeatInterval = 10 * time.Millisecond
   288  	rctx := rpc.NewContext(
   289  		log.AmbientContext{Tracer: tracing.NewTracer()},
   290  		cfg,
   291  		clock,
   292  		stopper,
   293  		cluster.MakeTestingClusterSettings(),
   294  	)
   295  	// Ensure that tests using this test context and restart/shut down
   296  	// their servers do not inadvertently start talking to servers from
   297  	// unrelated concurrent tests.
   298  	rctx.ClusterID.Set(context.Background(), uuid.MakeV4())
   299  
   300  	return rctx
   301  }
   302  
   303  // interceptingListener wraps a net.Listener and provides access to the
   304  // underlying net.Conn objects which that listener Accepts.
   305  type interceptingListener struct {
   306  	net.Listener
   307  	mu struct {
   308  		syncutil.Mutex
   309  		conns []net.Conn
   310  	}
   311  }
   312  
   313  // newSingleNodeResolver returns a Resolver that resolve a single node id
   314  func newSingleNodeResolver(id roachpb.NodeID, addr net.Addr) AddressResolver {
   315  	return func(toResolve roachpb.NodeID) (net.Addr, error) {
   316  		if id == toResolve {
   317  			return addr, nil
   318  		}
   319  		return nil, fmt.Errorf("unknown node id %d", toResolve)
   320  	}
   321  }
   322  
   323  func (il *interceptingListener) Accept() (c net.Conn, err error) {
   324  	defer func() {
   325  		if err == nil {
   326  			il.mu.Lock()
   327  			il.mu.conns = append(il.mu.conns, c)
   328  			il.mu.Unlock()
   329  		}
   330  	}()
   331  	return il.Listener.Accept()
   332  }
   333  
   334  func (il *interceptingListener) popConn() net.Conn {
   335  	il.mu.Lock()
   336  	defer il.mu.Unlock()
   337  	if len(il.mu.conns) == 0 {
   338  		return nil
   339  	}
   340  	c := il.mu.conns[0]
   341  	il.mu.conns = il.mu.conns[1:]
   342  	return c
   343  }
   344  
   345  type errContainer struct {
   346  	syncutil.RWMutex
   347  	err error
   348  }
   349  
   350  func (ec *errContainer) getErr() error {
   351  	ec.RLock()
   352  	defer ec.RUnlock()
   353  	return ec.err
   354  }
   355  
   356  func (ec *errContainer) setErr(err error) {
   357  	ec.Lock()
   358  	defer ec.Unlock()
   359  	ec.err = err
   360  }
   361  
   362  // heartbeatService is a dummy rpc.HeartbeatService which provides a mechanism
   363  // to inject errors.
   364  type heartbeatService struct {
   365  	errContainer
   366  	clock         *hlc.Clock
   367  	serverVersion roachpb.Version
   368  }
   369  
   370  func (hb *heartbeatService) Ping(
   371  	ctx context.Context, args *rpc.PingRequest,
   372  ) (*rpc.PingResponse, error) {
   373  	if err := hb.getErr(); err != nil {
   374  		return nil, err
   375  	}
   376  	return &rpc.PingResponse{
   377  		Pong:          args.Ping,
   378  		ServerTime:    hb.clock.PhysicalNow(),
   379  		ServerVersion: hb.serverVersion,
   380  	}, nil
   381  }