github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/rpc/nodedialer/nodedialer_test.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package nodedialer 12 13 import ( 14 "context" 15 "fmt" 16 "math/rand" 17 "net" 18 "sync" 19 "testing" 20 "time" 21 22 circuit "github.com/cockroachdb/circuitbreaker" 23 "github.com/cockroachdb/cockroach/pkg/clusterversion" 24 "github.com/cockroachdb/cockroach/pkg/roachpb" 25 "github.com/cockroachdb/cockroach/pkg/rpc" 26 "github.com/cockroachdb/cockroach/pkg/settings/cluster" 27 "github.com/cockroachdb/cockroach/pkg/testutils" 28 "github.com/cockroachdb/cockroach/pkg/util/hlc" 29 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 30 "github.com/cockroachdb/cockroach/pkg/util/log" 31 "github.com/cockroachdb/cockroach/pkg/util/stop" 32 "github.com/cockroachdb/cockroach/pkg/util/syncutil" 33 "github.com/cockroachdb/cockroach/pkg/util/tracing" 34 "github.com/cockroachdb/cockroach/pkg/util/uuid" 35 "github.com/cockroachdb/errors" 36 "github.com/stretchr/testify/assert" 37 "google.golang.org/grpc" 38 ) 39 40 const staticNodeID = 1 41 42 func TestNodedialerPositive(t *testing.T) { 43 defer leaktest.AfterTest(t)() 44 stopper, _, _, _, nd := setUpNodedialerTest(t, staticNodeID) 45 defer stopper.Stop(context.Background()) 46 // Ensure that dialing works. 47 breaker := nd.GetCircuitBreaker(1, rpc.DefaultClass) 48 assert.True(t, breaker.Ready()) 49 ctx := context.Background() 50 _, err := nd.Dial(ctx, staticNodeID, rpc.DefaultClass) 51 assert.Nil(t, err, "failed to dial") 52 assert.True(t, breaker.Ready()) 53 assert.Equal(t, breaker.Failures(), int64(0)) 54 } 55 56 func TestDialNoBreaker(t *testing.T) { 57 defer leaktest.AfterTest(t)() 58 59 ctx := context.Background() 60 61 // Don't use setUpNodedialerTest because we want access to the underlying clock and rpcContext. 62 stopper := stop.NewStopper() 63 clock := hlc.NewClock(hlc.UnixNano, time.Nanosecond) 64 rpcCtx := newTestContext(clock, stopper) 65 rpcCtx.NodeID.Set(ctx, staticNodeID) 66 _, ln, _ := newTestServer(t, clock, stopper, true /* useHeartbeat */) 67 defer stopper.Stop(ctx) 68 69 // Test that DialNoBreaker is successful normally. 70 nd := New(rpcCtx, newSingleNodeResolver(staticNodeID, ln.Addr())) 71 testutils.SucceedsSoon(t, func() error { 72 return nd.ConnHealth(staticNodeID, rpc.DefaultClass) 73 }) 74 breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass) 75 assert.True(t, breaker.Ready()) 76 _, err := nd.DialNoBreaker(ctx, staticNodeID, rpc.DefaultClass) 77 assert.Nil(t, err, "failed to dial") 78 assert.True(t, breaker.Ready()) 79 assert.Equal(t, breaker.Failures(), int64(0)) 80 81 // Test that resolver errors don't trip the breaker. 82 boom := fmt.Errorf("boom") 83 nd = New(rpcCtx, func(roachpb.NodeID) (net.Addr, error) { 84 return nil, boom 85 }) 86 breaker = nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass) 87 _, err = nd.DialNoBreaker(ctx, staticNodeID, rpc.DefaultClass) 88 assert.Equal(t, errors.Cause(err), boom) 89 assert.True(t, breaker.Ready()) 90 assert.Equal(t, breaker.Failures(), int64(0)) 91 92 // Test that connection errors don't trip the breaker either. 93 // To do this, we have to trick grpc into never successfully dialing 94 // the server, because if it succeeds once then it doesn't try again 95 // to perform a connection. To trick grpc in this way, we have to 96 // set up a server without the heartbeat service running. Without 97 // getting a heartbeat, the nodedialer will throw an error thinking 98 // that it wasn't able to successfully make a connection. 99 _, ln, _ = newTestServer(t, clock, stopper, false /* useHeartbeat */) 100 nd = New(rpcCtx, newSingleNodeResolver(staticNodeID, ln.Addr())) 101 breaker = nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass) 102 _, err = nd.DialNoBreaker(ctx, staticNodeID, rpc.DefaultClass) 103 assert.NotNil(t, err, "expected dial error") 104 assert.True(t, breaker.Ready()) 105 assert.Equal(t, breaker.Failures(), int64(0)) 106 } 107 108 func TestConcurrentCancellationAndTimeout(t *testing.T) { 109 defer leaktest.AfterTest(t)() 110 stopper, _, _, _, nd := setUpNodedialerTest(t, staticNodeID) 111 defer stopper.Stop(context.Background()) 112 ctx := context.Background() 113 breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass) 114 // Test that when a context is canceled during dialing we always return that 115 // error but we never trip the breaker. 116 const N = 1000 117 var wg sync.WaitGroup 118 for i := 0; i < N; i++ { 119 wg.Add(2) 120 // Jiggle when we cancel relative to when we dial to try to hit cases where 121 // cancellation happens during the call to GRPCDial. 122 iCtx, cancel := context.WithTimeout(ctx, randDuration(time.Millisecond)) 123 go func() { 124 time.Sleep(randDuration(time.Millisecond)) 125 cancel() 126 wg.Done() 127 }() 128 go func() { 129 time.Sleep(randDuration(time.Millisecond)) 130 _, err := nd.Dial(iCtx, 1, rpc.DefaultClass) 131 if err != nil && 132 !errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { 133 t.Errorf("got an unexpected error from Dial: %v", err) 134 } 135 wg.Done() 136 }() 137 } 138 wg.Wait() 139 assert.Equal(t, breaker.Failures(), int64(0)) 140 } 141 142 func TestResolverErrorsTrip(t *testing.T) { 143 defer leaktest.AfterTest(t)() 144 stopper, rpcCtx, _, _, _ := setUpNodedialerTest(t, staticNodeID) 145 defer stopper.Stop(context.Background()) 146 boom := fmt.Errorf("boom") 147 nd := New(rpcCtx, func(id roachpb.NodeID) (net.Addr, error) { 148 return nil, boom 149 }) 150 _, err := nd.Dial(context.Background(), staticNodeID, rpc.DefaultClass) 151 assert.Equal(t, errors.Cause(err), boom) 152 breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass) 153 assert.False(t, breaker.Ready()) 154 } 155 156 func TestDisconnectsTrip(t *testing.T) { 157 defer leaktest.AfterTest(t)() 158 stopper, _, ln, hb, nd := setUpNodedialerTest(t, staticNodeID) 159 defer stopper.Stop(context.Background()) 160 ctx := context.Background() 161 breaker := nd.GetCircuitBreaker(staticNodeID, rpc.DefaultClass) 162 163 // Now close the underlying connection from the server side and set the 164 // heartbeat service to return errors. This will eventually lead to the client 165 // connection being removed and Dial attempts to return an error. 166 // While this is going on there will be many clients attempting to 167 // connect. These connecting clients will send interesting errors they observe 168 // on the errChan. Once an error from Dial is observed the test re-enables the 169 // heartbeat service. The test will confirm that the only errors they record 170 // in to the breaker are interesting ones as determined by shouldTrip. 171 hb.setErr(fmt.Errorf("boom")) 172 underlyingNetConn := ln.popConn() 173 assert.Nil(t, underlyingNetConn.Close()) 174 const N = 1000 175 breakerEventChan := make(chan circuit.ListenerEvent, N) 176 breaker.AddListener(breakerEventChan) 177 errChan := make(chan error, N) 178 shouldTrip := func(err error) bool { 179 return err != nil && 180 !errors.IsAny(err, context.DeadlineExceeded, context.Canceled, circuit.ErrBreakerOpen) 181 } 182 var wg sync.WaitGroup 183 for i := 0; i < N; i++ { 184 wg.Add(2) 185 iCtx, cancel := context.WithTimeout(ctx, randDuration(time.Millisecond)) 186 go func() { 187 time.Sleep(randDuration(time.Millisecond)) 188 cancel() 189 wg.Done() 190 }() 191 go func() { 192 time.Sleep(randDuration(time.Millisecond)) 193 _, err := nd.Dial(iCtx, 1, rpc.DefaultClass) 194 if shouldTrip(err) { 195 errChan <- err 196 } 197 wg.Done() 198 }() 199 } 200 go func() { wg.Wait(); close(errChan) }() 201 var errorsSeen int 202 for range errChan { 203 if errorsSeen == 0 { 204 hb.setErr(nil) 205 } 206 errorsSeen++ 207 } 208 breaker.RemoveListener(breakerEventChan) 209 close(breakerEventChan) 210 var failsSeen int 211 for ev := range breakerEventChan { 212 if ev.Event == circuit.BreakerFail { 213 failsSeen++ 214 } 215 } 216 // Ensure that all of the interesting errors were seen by the breaker. 217 assert.Equal(t, errorsSeen, failsSeen) 218 219 // Ensure that the connection becomes healthy soon now that the heartbeat 220 // service is not returning errors. 221 hb.setErr(nil) // reset in case there were no errors 222 testutils.SucceedsSoon(t, func() error { 223 return nd.ConnHealth(staticNodeID, rpc.DefaultClass) 224 }) 225 } 226 227 func setUpNodedialerTest( 228 t *testing.T, nodeID roachpb.NodeID, 229 ) ( 230 stopper *stop.Stopper, 231 rpcCtx *rpc.Context, 232 ln *interceptingListener, 233 hb *heartbeatService, 234 nd *Dialer, 235 ) { 236 stopper = stop.NewStopper() 237 clock := hlc.NewClock(hlc.UnixNano, time.Nanosecond) 238 // Create an rpc Context and then 239 rpcCtx = newTestContext(clock, stopper) 240 rpcCtx.NodeID.Set(context.Background(), nodeID) 241 _, ln, hb = newTestServer(t, clock, stopper, true /* useHeartbeat */) 242 nd = New(rpcCtx, newSingleNodeResolver(nodeID, ln.Addr())) 243 testutils.SucceedsSoon(t, func() error { 244 return nd.ConnHealth(nodeID, rpc.DefaultClass) 245 }) 246 return stopper, rpcCtx, ln, hb, nd 247 } 248 249 // randDuration returns a uniform random duration between 0 and max. 250 func randDuration(max time.Duration) time.Duration { 251 return time.Duration(rand.Intn(int(max))) 252 } 253 254 func newTestServer( 255 t testing.TB, clock *hlc.Clock, stopper *stop.Stopper, useHeartbeat bool, 256 ) (*grpc.Server, *interceptingListener, *heartbeatService) { 257 ctx := context.Background() 258 localAddr := "127.0.0.1:0" 259 ln, err := net.Listen("tcp", localAddr) 260 if err != nil { 261 t.Fatalf("failed to listed on %v: %v", localAddr, err) 262 } 263 il := &interceptingListener{Listener: ln} 264 s := grpc.NewServer() 265 var hb *heartbeatService 266 if useHeartbeat { 267 hb = &heartbeatService{ 268 clock: clock, 269 serverVersion: clusterversion.TestingBinaryVersion, 270 } 271 rpc.RegisterHeartbeatServer(s, hb) 272 } 273 if err := stopper.RunAsyncTask(ctx, "localServer", func(ctx context.Context) { 274 if err := s.Serve(il); err != nil { 275 log.Infof(ctx, "server stopped: %v", err) 276 } 277 }); err != nil { 278 t.Fatalf("failed to run test server: %v", err) 279 } 280 go func() { <-stopper.ShouldQuiesce(); s.Stop() }() 281 return s, il, hb 282 } 283 284 func newTestContext(clock *hlc.Clock, stopper *stop.Stopper) *rpc.Context { 285 cfg := testutils.NewNodeTestBaseContext() 286 cfg.Insecure = true 287 cfg.RPCHeartbeatInterval = 10 * time.Millisecond 288 rctx := rpc.NewContext( 289 log.AmbientContext{Tracer: tracing.NewTracer()}, 290 cfg, 291 clock, 292 stopper, 293 cluster.MakeTestingClusterSettings(), 294 ) 295 // Ensure that tests using this test context and restart/shut down 296 // their servers do not inadvertently start talking to servers from 297 // unrelated concurrent tests. 298 rctx.ClusterID.Set(context.Background(), uuid.MakeV4()) 299 300 return rctx 301 } 302 303 // interceptingListener wraps a net.Listener and provides access to the 304 // underlying net.Conn objects which that listener Accepts. 305 type interceptingListener struct { 306 net.Listener 307 mu struct { 308 syncutil.Mutex 309 conns []net.Conn 310 } 311 } 312 313 // newSingleNodeResolver returns a Resolver that resolve a single node id 314 func newSingleNodeResolver(id roachpb.NodeID, addr net.Addr) AddressResolver { 315 return func(toResolve roachpb.NodeID) (net.Addr, error) { 316 if id == toResolve { 317 return addr, nil 318 } 319 return nil, fmt.Errorf("unknown node id %d", toResolve) 320 } 321 } 322 323 func (il *interceptingListener) Accept() (c net.Conn, err error) { 324 defer func() { 325 if err == nil { 326 il.mu.Lock() 327 il.mu.conns = append(il.mu.conns, c) 328 il.mu.Unlock() 329 } 330 }() 331 return il.Listener.Accept() 332 } 333 334 func (il *interceptingListener) popConn() net.Conn { 335 il.mu.Lock() 336 defer il.mu.Unlock() 337 if len(il.mu.conns) == 0 { 338 return nil 339 } 340 c := il.mu.conns[0] 341 il.mu.conns = il.mu.conns[1:] 342 return c 343 } 344 345 type errContainer struct { 346 syncutil.RWMutex 347 err error 348 } 349 350 func (ec *errContainer) getErr() error { 351 ec.RLock() 352 defer ec.RUnlock() 353 return ec.err 354 } 355 356 func (ec *errContainer) setErr(err error) { 357 ec.Lock() 358 defer ec.Unlock() 359 ec.err = err 360 } 361 362 // heartbeatService is a dummy rpc.HeartbeatService which provides a mechanism 363 // to inject errors. 364 type heartbeatService struct { 365 errContainer 366 clock *hlc.Clock 367 serverVersion roachpb.Version 368 } 369 370 func (hb *heartbeatService) Ping( 371 ctx context.Context, args *rpc.PingRequest, 372 ) (*rpc.PingResponse, error) { 373 if err := hb.getErr(); err != nil { 374 return nil, err 375 } 376 return &rpc.PingResponse{ 377 Pong: args.Ping, 378 ServerTime: hb.clock.PhysicalNow(), 379 ServerVersion: hb.serverVersion, 380 }, nil 381 }