github.com/m3db/m3@v1.5.0/src/msg/producer/writer/consumer_writer_test.go (about) 1 // Copyright (c) 2018 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package writer 22 23 import ( 24 "context" 25 "io" 26 "net" 27 "sync" 28 "testing" 29 "time" 30 31 "github.com/fortytw2/leaktest" 32 "github.com/golang/mock/gomock" 33 "github.com/stretchr/testify/assert" 34 "github.com/stretchr/testify/require" 35 "github.com/uber-go/tally" 36 37 "github.com/m3db/m3/src/msg/generated/proto/msgpb" 38 "github.com/m3db/m3/src/msg/protocol/proto" 39 "github.com/m3db/m3/src/x/pool" 40 "github.com/m3db/m3/src/x/retry" 41 xtest "github.com/m3db/m3/src/x/test" 42 ) 43 44 var ( 45 testMsg = msgpb.Message{ 46 Metadata: msgpb.Metadata{ 47 Shard: 100, 48 Id: 200, 49 }, 50 Value: []byte("foooooooo"), 51 } 52 53 testEncoder = proto.NewEncoder(nil) 54 ) 55 56 func TestNewConsumerWriter(t *testing.T) { 57 defer leaktest.Check(t)() 58 59 lis, err := net.Listen("tcp", "127.0.0.1:0") 60 require.NoError(t, err) 61 defer lis.Close() 62 63 ctrl := xtest.NewController(t) 64 defer ctrl.Finish() 65 66 mockRouter := NewMockackRouter(ctrl) 67 68 opts := testOptions() 69 70 w := newConsumerWriter(lis.Addr().String(), mockRouter, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 71 require.Equal(t, 0, len(w.resetCh)) 72 73 var wg sync.WaitGroup 74 75 wg.Add(1) 76 go func() { 77 testConsumeAndAckOnConnectionListener(t, lis, opts.EncoderOptions(), opts.DecoderOptions()) 78 wg.Done() 79 }() 80 81 require.NoError(t, write(w, &testMsg)) 82 83 wg.Add(1) 84 mockRouter.EXPECT(). 85 Ack(newMetadataFromProto(testMsg.Metadata)). 86 Do(func(interface{}) { wg.Done() }). 87 Return(nil) 88 89 w.Init() 90 wg.Wait() 91 92 w.Close() 93 // Make sure the connection is closed after closing the consumer writer. 94 _, err = w.writeState.conns[0].conn.Read([]byte{}) 95 require.Error(t, err) 96 require.Contains(t, err.Error(), "closed network connection") 97 } 98 99 // TODO: tests for multiple connection writers. 100 101 func TestConsumerWriterSignalResetConnection(t *testing.T) { 102 lis, err := net.Listen("tcp", "127.0.0.1:0") 103 require.NoError(t, err) 104 defer lis.Close() 105 106 w := newConsumerWriter(lis.Addr().String(), nil, testOptions(), testConsumerWriterMetrics()).(*consumerWriterImpl) 107 require.Equal(t, 0, len(w.resetCh)) 108 109 var called int 110 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 111 called++ 112 return uninitializedReadWriter{}, nil 113 } 114 115 w.notifyReset(nil) 116 require.Equal(t, 1, len(w.resetCh)) 117 require.True(t, w.resetTooSoon()) 118 119 now := time.Now() 120 w.nowFn = func() time.Time { return now.Add(1 * time.Hour) } 121 require.Equal(t, 1, len(w.resetCh)) 122 require.False(t, w.resetTooSoon()) 123 require.NoError(t, w.resetWithConnectFn(w.newConnectFn(connectOptions{retry: false}))) 124 require.Equal(t, 1, called) 125 require.Equal(t, 1, len(w.resetCh)) 126 127 // Reset won't do anything as it is too soon since last reset. 128 require.True(t, w.resetTooSoon()) 129 130 w.nowFn = func() time.Time { return now.Add(2 * time.Hour) } 131 require.False(t, w.resetTooSoon()) 132 require.NoError(t, w.resetWithConnectFn(w.newConnectFn(connectOptions{retry: false}))) 133 require.Equal(t, 2, called) 134 } 135 136 func TestConsumerWriterResetConnection(t *testing.T) { 137 w := newConsumerWriter("badAddress", nil, testOptions(), testConsumerWriterMetrics()).(*consumerWriterImpl) 138 require.Equal(t, 1, len(w.resetCh)) 139 err := write(w, &testMsg) 140 require.Error(t, err) 141 require.Equal(t, errInvalidConnection, err) 142 143 var called int 144 conn := new(net.TCPConn) 145 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 146 called++ 147 require.Equal(t, "badAddress", addr) 148 return conn, nil 149 } 150 w.resetWithConnectFn(w.newConnectFn(connectOptions{retry: true})) 151 require.Equal(t, 1, called) 152 } 153 154 func TestConsumerWriterRetryableConnectionBackgroundReset(t *testing.T) { 155 w := newConsumerWriter("badAddress", nil, testOptions(), testConsumerWriterMetrics()).(*consumerWriterImpl) 156 require.Equal(t, 1, len(w.resetCh)) 157 158 var lock sync.Mutex 159 var called int 160 conn := new(net.TCPConn) 161 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 162 lock.Lock() 163 defer lock.Unlock() 164 165 called++ 166 require.Equal(t, "badAddress", addr) 167 return conn, nil 168 } 169 170 now := time.Now() 171 w.nowFn = func() time.Time { return now.Add(1 * time.Hour) } 172 w.Init() 173 for { 174 lock.Lock() 175 c := called 176 lock.Unlock() 177 if c > 0 { 178 break 179 } 180 time.Sleep(100 * time.Millisecond) 181 } 182 w.Close() 183 } 184 185 func TestConsumerWriterWriteErrorTriggerReset(t *testing.T) { 186 defer leaktest.Check(t)() 187 188 opts := testOptions() 189 w := newConsumerWriter("badAddr", nil, opts.SetConnectionOptions( 190 opts.ConnectionOptions().SetWriteBufferSize(1000), 191 ), testConsumerWriterMetrics()).(*consumerWriterImpl) 192 <-w.resetCh 193 require.Equal(t, 0, len(w.resetCh)) 194 err := write(w, &testMsg) 195 require.Error(t, err) 196 require.Equal(t, errInvalidConnection, err) 197 require.Equal(t, 0, len(w.resetCh)) 198 w.writeState.Lock() 199 w.writeState.validConns = true 200 w.writeState.Unlock() 201 202 err = write(w, &testMsg) 203 require.NoError(t, err) 204 for { 205 // The writer will need to wait until buffered size to try the flush 206 // and then realize the connection is broken. 207 err = write(w, &testMsg) 208 if err != nil { 209 break 210 } 211 } 212 require.Error(t, err) 213 require.Equal(t, errInvalidConnection, err) 214 require.Equal(t, 1, len(w.resetCh)) 215 } 216 217 func TestConsumerWriterFlushWriteAfterFlushErrorTriggerReset(t *testing.T) { 218 defer leaktest.Check(t)() 219 220 opts := testOptions() 221 w := newConsumerWriter("badAddr", nil, opts.SetConnectionOptions( 222 opts.ConnectionOptions().SetWriteBufferSize(1000), 223 ), testConsumerWriterMetrics()).(*consumerWriterImpl) 224 <-w.resetCh 225 require.Equal(t, 0, len(w.resetCh)) 226 err := write(w, &testMsg) 227 require.Error(t, err) 228 require.Equal(t, errInvalidConnection, err) 229 require.Equal(t, 0, len(w.resetCh)) 230 w.writeState.Lock() 231 w.writeState.validConns = true 232 w.writeState.Unlock() 233 234 // The write will be buffered in the bufio.Writer, and will 235 // not return err because it has not tried to flush yet. 236 require.NoError(t, write(w, &testMsg)) 237 238 w.writeState.Lock() 239 require.Error(t, w.writeState.conns[0].w.Flush()) 240 w.writeState.Unlock() 241 242 // Flush err will be stored in bufio.Writer, the next time 243 // Write is called, the err will be returned. 244 err = write(w, &testMsg) 245 require.Error(t, err) 246 require.Equal(t, errInvalidConnection, err) 247 require.Equal(t, 1, len(w.resetCh)) 248 } 249 250 func TestConsumerWriterReadErrorTriggerReset(t *testing.T) { 251 defer leaktest.Check(t)() 252 253 opts := testOptions() 254 w := newConsumerWriter("badAddr", nil, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 255 <-w.resetCh 256 w.writeState.Lock() 257 w.writeState.validConns = true 258 w.writeState.Unlock() 259 require.Equal(t, 0, len(w.resetCh)) 260 err := w.readAcks(0) 261 require.Error(t, err) 262 require.Equal(t, errInvalidConnection, err) 263 require.Equal(t, 1, len(w.resetCh)) 264 w.Close() 265 } 266 267 func TestAutoReset(t *testing.T) { 268 defer leaktest.Check(t)() 269 270 ctrl := xtest.NewController(t) 271 defer ctrl.Finish() 272 273 mockRouter := NewMockackRouter(ctrl) 274 275 opts := testOptions() 276 277 w := newConsumerWriter( 278 "badAddress", 279 mockRouter, 280 opts, 281 testConsumerWriterMetrics(), 282 ).(*consumerWriterImpl) 283 require.Equal(t, 1, len(w.resetCh)) 284 require.Error(t, write(w, &testMsg)) 285 286 clientConn, serverConn := net.Pipe() 287 defer clientConn.Close() 288 defer serverConn.Close() 289 290 go func() { 291 testConsumeAndAckOnConnection(t, serverConn, opts.EncoderOptions(), opts.DecoderOptions()) 292 }() 293 294 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 295 return clientConn, nil 296 } 297 298 var wg sync.WaitGroup 299 wg.Add(1) 300 mockRouter.EXPECT(). 301 Ack(newMetadataFromProto(testMsg.Metadata)). 302 Do(func(interface{}) { wg.Done() }). 303 Return(nil) 304 305 w.Init() 306 307 start := time.Now() 308 for time.Since(start) < 15*time.Second { 309 w.writeState.Lock() 310 validConns := w.writeState.validConns 311 w.writeState.Unlock() 312 if validConns { 313 break 314 } 315 time.Sleep(100 * time.Millisecond) 316 } 317 require.NoError(t, write(w, &testMsg)) 318 wg.Wait() 319 320 w.Close() 321 } 322 323 func TestConsumerWriterClose(t *testing.T) { 324 lis, err := net.Listen("tcp", "127.0.0.1:0") 325 require.NoError(t, err) 326 defer lis.Close() 327 328 w := newConsumerWriter(lis.Addr().String(), nil, nil, testConsumerWriterMetrics()).(*consumerWriterImpl) 329 require.Equal(t, 0, len(w.resetCh)) 330 w.Close() 331 // Safe to close again. 332 w.Close() 333 _, ok := <-w.doneCh 334 require.False(t, ok) 335 } 336 337 func TestConsumerWriterCloseWhileDecoding(t *testing.T) { 338 defer leaktest.Check(t)() 339 340 lis, err := net.Listen("tcp", "127.0.0.1:0") 341 require.NoError(t, err) 342 defer lis.Close() 343 344 opts := testOptions() 345 346 w := newConsumerWriter(lis.Addr().String(), nil, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 347 348 var wg sync.WaitGroup 349 wg.Add(1) 350 go func() { 351 wg.Done() 352 require.Error(t, w.writeState.conns[0].decoder.Decode(&testMsg)) 353 }() 354 wg.Wait() 355 time.Sleep(time.Second) 356 w.Close() 357 } 358 359 func TestConsumerWriterResetWhileDecoding(t *testing.T) { 360 defer leaktest.Check(t)() 361 362 lis, err := net.Listen("tcp", "127.0.0.1:0") 363 require.NoError(t, err) 364 defer lis.Close() 365 366 opts := testOptions() 367 368 w := newConsumerWriter(lis.Addr().String(), nil, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 369 370 var wg sync.WaitGroup 371 wg.Add(1) 372 go func() { 373 wg.Done() 374 375 w.writeState.Lock() 376 conn := w.writeState.conns[0] 377 w.writeState.Unlock() 378 379 require.Error(t, conn.decoder.Decode(&testMsg)) 380 }() 381 wg.Wait() 382 time.Sleep(time.Second) 383 w.reset(resetOptions{ 384 connections: []io.ReadWriteCloser{new(net.TCPConn)}, 385 at: w.nowFn(), 386 validConns: true, 387 }) 388 } 389 390 // Interface solely for mocking. 391 //nolint:deadcode,unused 392 type contextDialer interface { 393 DialContext(ctx context.Context, network string, addr string) (net.Conn, error) 394 } 395 396 type keepAlivableConn struct { 397 net.Conn 398 keepAlivable 399 } 400 401 func TestConsumerWriter_connectNoRetry(t *testing.T) { 402 type testDeps struct { 403 Ctrl *gomock.Controller 404 MockDialer *MockcontextDialer 405 Listener net.Listener 406 } 407 408 newTestWriter := func(deps testDeps, opts Options) *consumerWriterImpl { 409 return newConsumerWriter( 410 deps.Listener.Addr().String(), 411 nil, 412 opts, 413 testConsumerWriterMetrics(), 414 ).(*consumerWriterImpl) 415 } 416 417 mustClose := func(t *testing.T, c io.ReadWriteCloser) { 418 require.NoError(t, c.Close()) 419 } 420 421 setup := func(t *testing.T) testDeps { 422 ctrl := gomock.NewController(t) 423 424 lis, err := net.Listen("tcp", "127.0.0.1:0") 425 require.NoError(t, err) 426 t.Cleanup(func() { 427 require.NoError(t, lis.Close()) 428 }) 429 430 return testDeps{ 431 Ctrl: ctrl, 432 Listener: lis, 433 MockDialer: NewMockcontextDialer(ctrl), 434 } 435 } 436 type dialArgs struct { 437 Ctx context.Context 438 Network string 439 Addr string 440 } 441 442 // Other tests in this file cover the case where dialer isn't set explicitly (default). 443 t.Run("uses net.Dialer where dialer is unset", func(t *testing.T) { 444 defer leaktest.Check(t)() 445 tdeps := setup(t) 446 opts := testOptions() 447 w := newTestWriter(tdeps, opts.SetConnectionOptions(opts.ConnectionOptions().SetContextDialer(nil))) 448 conn, err := w.connectNoRetryWithTimeout(tdeps.Listener.Addr().String()) 449 require.NoError(t, err) 450 defer mustClose(t, conn) 451 452 _, err = conn.Write([]byte("test")) 453 require.NoError(t, err) 454 }) 455 t.Run("uses dialer and respects timeout", func(t *testing.T) { 456 defer leaktest.Check(t)() 457 458 tdeps := setup(t) 459 var capturedArgs dialArgs 460 tdeps.MockDialer.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( 461 func(ctx context.Context, network string, addr string) (net.Conn, error) { 462 capturedArgs.Ctx = ctx 463 return (&net.Dialer{}).DialContext(ctx, network, addr) 464 }, 465 ).MinTimes(1) 466 467 const testDialTimeout = 45 * time.Second 468 opts := testOptions() 469 opts = opts.SetConnectionOptions(opts.ConnectionOptions(). 470 SetContextDialer(tdeps.MockDialer.DialContext). 471 SetDialTimeout(testDialTimeout), 472 ) 473 474 start := time.Now() 475 w := newTestWriter(tdeps, opts) 476 conn, err := w.connectNoRetry(tdeps.Listener.Addr().String()) 477 478 require.NoError(t, err) 479 defer mustClose(t, conn) 480 481 deadline, ok := capturedArgs.Ctx.Deadline() 482 require.True(t, ok) 483 // Start is taken *before* we try to connect, so the deadline must = start + <some_time> + testDialTimeout. 484 // Therefore deadline - start >= testDialTimeout. 485 assert.True(t, deadline.Sub(start) >= testDialTimeout) 486 }) 487 488 t.Run("sets KeepAlive where possible", func(t *testing.T) { 489 tdeps := setup(t) 490 // Deep mocking here is solely because it's not easy to get the keep alive off an actual TCP connection 491 // (have to drop down to the syscall layer). 492 const testKeepAlive = 56 * time.Minute 493 mockConn := NewMockkeepAlivable(tdeps.Ctrl) 494 mockConn.EXPECT().SetKeepAlivePeriod(testKeepAlive).Times(2) 495 mockConn.EXPECT().SetKeepAlive(true).Times(2) 496 497 tdeps.MockDialer.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(keepAlivableConn{ 498 keepAlivable: mockConn, 499 }, nil).Times(2) 500 501 opts := testOptions() 502 opts = opts.SetConnectionOptions(opts.ConnectionOptions(). 503 SetKeepAlivePeriod(testKeepAlive). 504 SetContextDialer(tdeps.MockDialer.DialContext), 505 ) 506 w := newTestWriter(tdeps, opts) 507 _, err := w.connectNoRetryWithTimeout("foobar") 508 require.NoError(t, err) 509 }) 510 511 t.Run("handles non TCP connections gracefully", func(t *testing.T) { 512 tdeps := setup(t) 513 tdeps.MockDialer.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( 514 func(ctx context.Context, network string, addr string) (net.Conn, error) { 515 srv, client := net.Pipe() 516 require.NoError(t, srv.Close()) 517 return client, nil 518 }, 519 ).MinTimes(1) 520 521 defer leaktest.Check(t)() 522 523 opts := testOptions() 524 opts = opts.SetConnectionOptions(opts.ConnectionOptions(). 525 SetContextDialer(tdeps.MockDialer.DialContext), 526 ) 527 528 w := newConsumerWriter( 529 "foobar", 530 nil, 531 opts, 532 testConsumerWriterMetrics(), 533 ).(*consumerWriterImpl) 534 conn, err := w.connectNoRetryWithTimeout("foobar") 535 require.NoError(t, err) 536 defer mustClose(t, conn) 537 538 _, isTCPConn := conn.Conn.(*net.TCPConn) 539 assert.False(t, isTCPConn) 540 }) 541 } 542 543 func testOptions() Options { 544 return NewOptions(). 545 SetTopicName("topicName"). 546 SetTopicWatchInitTimeout(100 * time.Millisecond). 547 SetPlacementWatchInitTimeout(100 * time.Millisecond). 548 SetMessagePoolOptions(pool.NewObjectPoolOptions().SetSize(1)). 549 SetMessageQueueNewWritesScanInterval(100 * time.Millisecond). 550 SetMessageQueueFullScanInterval(200 * time.Millisecond). 551 SetMessageRetryNanosFn( 552 NextRetryNanosFn( 553 retry.NewOptions(). 554 SetInitialBackoff(100 * time.Millisecond). 555 SetMaxBackoff(500 * time.Millisecond), 556 ), 557 ). 558 SetAckErrorRetryOptions(retry.NewOptions().SetInitialBackoff(200 * time.Millisecond).SetMaxBackoff(time.Second)). 559 SetConnectionOptions(testConnectionOptions()) 560 } 561 562 func testConnectionOptions() ConnectionOptions { 563 return NewConnectionOptions(). 564 SetNumConnections(1). 565 SetRetryOptions(retry.NewOptions().SetInitialBackoff(200 * time.Millisecond).SetMaxBackoff(time.Second)). 566 SetFlushInterval(100 * time.Millisecond). 567 SetResetDelay(100 * time.Millisecond) 568 } 569 570 func testConsumeAndAckOnConnection( 571 t *testing.T, 572 conn net.Conn, 573 encOpts proto.Options, 574 decOpts proto.Options, 575 ) { 576 serverEncoder := proto.NewEncoder(encOpts) 577 serverDecoder := proto.NewDecoder(conn, decOpts, 10) 578 var msg msgpb.Message 579 assert.NoError(t, serverDecoder.Decode(&msg)) 580 581 err := serverEncoder.Encode(&msgpb.Ack{ 582 Metadata: []msgpb.Metadata{ 583 msg.Metadata, 584 }, 585 }) 586 assert.NoError(t, err) 587 _, err = conn.Write(serverEncoder.Bytes()) 588 assert.NoError(t, err) 589 } 590 591 func testConsumeAndAckOnConnectionListener( 592 t *testing.T, 593 lis net.Listener, 594 encOpts proto.Options, 595 decOpts proto.Options, 596 ) { 597 conn, err := lis.Accept() 598 require.NoError(t, err) 599 defer conn.Close() 600 601 testConsumeAndAckOnConnection(t, conn, encOpts, decOpts) 602 } 603 604 func testConsumerWriterMetrics() consumerWriterMetrics { 605 return newConsumerWriterMetrics(tally.NoopScope) 606 } 607 608 func write(w consumerWriter, m proto.Marshaler) error { 609 err := testEncoder.Encode(m) 610 if err != nil { 611 return err 612 } 613 return w.Write(0, testEncoder.Bytes()) 614 }