github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/msg/producer/writer/consumer_writer_test.go (about) 1 // Copyright (c) 2018 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package writer 22 23 import ( 24 "context" 25 "io" 26 "net" 27 "sync" 28 "testing" 29 "time" 30 31 "github.com/fortytw2/leaktest" 32 "github.com/golang/mock/gomock" 33 "github.com/stretchr/testify/assert" 34 "github.com/stretchr/testify/require" 35 "github.com/uber-go/tally" 36 37 "github.com/m3db/m3/src/msg/generated/proto/msgpb" 38 "github.com/m3db/m3/src/msg/protocol/proto" 39 "github.com/m3db/m3/src/x/retry" 40 xtest "github.com/m3db/m3/src/x/test" 41 ) 42 43 var ( 44 testMsg = msgpb.Message{ 45 Metadata: msgpb.Metadata{ 46 Shard: 100, 47 Id: 200, 48 }, 49 Value: []byte("foooooooo"), 50 } 51 52 testEncoder = proto.NewEncoder(nil) 53 ) 54 55 func TestNewConsumerWriter(t *testing.T) { 56 defer leaktest.Check(t)() 57 58 lis, err := net.Listen("tcp", "127.0.0.1:0") 59 require.NoError(t, err) 60 defer lis.Close() 61 62 ctrl := xtest.NewController(t) 63 defer ctrl.Finish() 64 65 mockRouter := NewMockackRouter(ctrl) 66 67 opts := testOptions() 68 69 w := newConsumerWriter(lis.Addr().String(), mockRouter, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 70 require.Equal(t, 0, len(w.resetCh)) 71 72 var wg sync.WaitGroup 73 74 wg.Add(1) 75 go func() { 76 testConsumeAndAckOnConnectionListener(t, lis, opts.EncoderOptions(), opts.DecoderOptions()) 77 wg.Done() 78 }() 79 80 require.NoError(t, write(w, &testMsg)) 81 82 wg.Add(1) 83 mockRouter.EXPECT(). 84 Ack(newMetadataFromProto(testMsg.Metadata)). 85 Do(func(interface{}) { wg.Done() }). 86 Return(nil) 87 88 w.Init() 89 wg.Wait() 90 91 w.Close() 92 // Make sure the connection is closed after closing the consumer writer. 93 _, err = w.writeState.conns[0].conn.Read([]byte{}) 94 require.Error(t, err) 95 require.Contains(t, err.Error(), "closed network connection") 96 } 97 98 // TODO: tests for multiple connection writers. 99 100 func TestConsumerWriterSignalResetConnection(t *testing.T) { 101 lis, err := net.Listen("tcp", "127.0.0.1:0") 102 require.NoError(t, err) 103 defer lis.Close() 104 105 w := newConsumerWriter(lis.Addr().String(), nil, testOptions(), testConsumerWriterMetrics()).(*consumerWriterImpl) 106 require.Equal(t, 0, len(w.resetCh)) 107 108 var called int 109 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 110 called++ 111 return uninitializedReadWriter{}, nil 112 } 113 114 w.notifyReset(nil) 115 require.Equal(t, 1, len(w.resetCh)) 116 require.True(t, w.resetTooSoon()) 117 118 now := time.Now() 119 w.nowFn = func() time.Time { return now.Add(1 * time.Hour) } 120 require.Equal(t, 1, len(w.resetCh)) 121 require.False(t, w.resetTooSoon()) 122 require.NoError(t, w.resetWithConnectFn(w.newConnectFn(connectOptions{retry: false}))) 123 require.Equal(t, 1, called) 124 require.Equal(t, 1, len(w.resetCh)) 125 126 // Reset won't do anything as it is too soon since last reset. 127 require.True(t, w.resetTooSoon()) 128 129 w.nowFn = func() time.Time { return now.Add(2 * time.Hour) } 130 require.False(t, w.resetTooSoon()) 131 require.NoError(t, w.resetWithConnectFn(w.newConnectFn(connectOptions{retry: false}))) 132 require.Equal(t, 2, called) 133 } 134 135 func TestConsumerWriterResetConnection(t *testing.T) { 136 w := newConsumerWriter("badAddress", nil, testOptions(), testConsumerWriterMetrics()).(*consumerWriterImpl) 137 require.Equal(t, 1, len(w.resetCh)) 138 err := write(w, &testMsg) 139 require.Error(t, err) 140 require.Equal(t, errInvalidConnection, err) 141 142 var called int 143 conn := new(net.TCPConn) 144 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 145 called++ 146 require.Equal(t, "badAddress", addr) 147 return conn, nil 148 } 149 w.resetWithConnectFn(w.newConnectFn(connectOptions{retry: true})) 150 require.Equal(t, 1, called) 151 } 152 153 func TestConsumerWriterRetryableConnectionBackgroundReset(t *testing.T) { 154 w := newConsumerWriter("badAddress", nil, testOptions(), testConsumerWriterMetrics()).(*consumerWriterImpl) 155 require.Equal(t, 1, len(w.resetCh)) 156 157 var lock sync.Mutex 158 var called int 159 conn := new(net.TCPConn) 160 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 161 lock.Lock() 162 defer lock.Unlock() 163 164 called++ 165 require.Equal(t, "badAddress", addr) 166 return conn, nil 167 } 168 169 now := time.Now() 170 w.nowFn = func() time.Time { return now.Add(1 * time.Hour) } 171 w.Init() 172 for { 173 lock.Lock() 174 c := called 175 lock.Unlock() 176 if c > 0 { 177 break 178 } 179 time.Sleep(100 * time.Millisecond) 180 } 181 w.Close() 182 } 183 184 func TestConsumerWriterWriteErrorTriggerReset(t *testing.T) { 185 defer leaktest.Check(t)() 186 187 opts := testOptions() 188 w := newConsumerWriter("badAddr", nil, opts.SetConnectionOptions( 189 opts.ConnectionOptions().SetWriteBufferSize(1000), 190 ), testConsumerWriterMetrics()).(*consumerWriterImpl) 191 <-w.resetCh 192 require.Equal(t, 0, len(w.resetCh)) 193 err := write(w, &testMsg) 194 require.Error(t, err) 195 require.Equal(t, errInvalidConnection, err) 196 require.Equal(t, 0, len(w.resetCh)) 197 w.writeState.Lock() 198 w.writeState.validConns = true 199 w.writeState.Unlock() 200 201 err = write(w, &testMsg) 202 require.NoError(t, err) 203 for { 204 // The writer will need to wait until buffered size to try the flush 205 // and then realize the connection is broken. 206 err = write(w, &testMsg) 207 if err != nil { 208 break 209 } 210 } 211 require.Error(t, err) 212 require.Equal(t, errInvalidConnection, err) 213 require.Equal(t, 1, len(w.resetCh)) 214 } 215 216 func TestConsumerWriterFlushWriteAfterFlushErrorTriggerReset(t *testing.T) { 217 defer leaktest.Check(t)() 218 219 opts := testOptions() 220 w := newConsumerWriter("badAddr", nil, opts.SetConnectionOptions( 221 opts.ConnectionOptions().SetWriteBufferSize(1000), 222 ), testConsumerWriterMetrics()).(*consumerWriterImpl) 223 <-w.resetCh 224 require.Equal(t, 0, len(w.resetCh)) 225 err := write(w, &testMsg) 226 require.Error(t, err) 227 require.Equal(t, errInvalidConnection, err) 228 require.Equal(t, 0, len(w.resetCh)) 229 w.writeState.Lock() 230 w.writeState.validConns = true 231 w.writeState.Unlock() 232 233 // The write will be buffered in the bufio.Writer, and will 234 // not return err because it has not tried to flush yet. 235 require.NoError(t, write(w, &testMsg)) 236 237 w.writeState.Lock() 238 require.Error(t, w.writeState.conns[0].w.Flush()) 239 w.writeState.Unlock() 240 241 // Flush err will be stored in bufio.Writer, the next time 242 // Write is called, the err will be returned. 243 err = write(w, &testMsg) 244 require.Error(t, err) 245 require.Equal(t, errInvalidConnection, err) 246 require.Equal(t, 1, len(w.resetCh)) 247 } 248 249 func TestConsumerWriterReadErrorTriggerReset(t *testing.T) { 250 defer leaktest.Check(t)() 251 252 opts := testOptions() 253 w := newConsumerWriter("badAddr", nil, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 254 <-w.resetCh 255 w.writeState.Lock() 256 w.writeState.validConns = true 257 w.writeState.Unlock() 258 require.Equal(t, 0, len(w.resetCh)) 259 err := w.readAcks(0) 260 require.Error(t, err) 261 require.Equal(t, errInvalidConnection, err) 262 require.Equal(t, 1, len(w.resetCh)) 263 w.Close() 264 } 265 266 func TestAutoReset(t *testing.T) { 267 defer leaktest.Check(t)() 268 269 ctrl := xtest.NewController(t) 270 defer ctrl.Finish() 271 272 mockRouter := NewMockackRouter(ctrl) 273 274 opts := testOptions() 275 276 w := newConsumerWriter( 277 "badAddress", 278 mockRouter, 279 opts, 280 testConsumerWriterMetrics(), 281 ).(*consumerWriterImpl) 282 require.Equal(t, 1, len(w.resetCh)) 283 require.Error(t, write(w, &testMsg)) 284 285 clientConn, serverConn := net.Pipe() 286 defer clientConn.Close() 287 defer serverConn.Close() 288 289 go func() { 290 testConsumeAndAckOnConnection(t, serverConn, opts.EncoderOptions(), opts.DecoderOptions()) 291 }() 292 293 w.connectFn = func(addr string) (io.ReadWriteCloser, error) { 294 return clientConn, nil 295 } 296 297 var wg sync.WaitGroup 298 wg.Add(1) 299 mockRouter.EXPECT(). 300 Ack(newMetadataFromProto(testMsg.Metadata)). 301 Do(func(interface{}) { wg.Done() }). 302 Return(nil) 303 304 w.Init() 305 306 start := time.Now() 307 for time.Since(start) < 15*time.Second { 308 w.writeState.Lock() 309 validConns := w.writeState.validConns 310 w.writeState.Unlock() 311 if validConns { 312 break 313 } 314 time.Sleep(100 * time.Millisecond) 315 } 316 require.NoError(t, write(w, &testMsg)) 317 wg.Wait() 318 319 w.Close() 320 } 321 322 func TestConsumerWriterClose(t *testing.T) { 323 lis, err := net.Listen("tcp", "127.0.0.1:0") 324 require.NoError(t, err) 325 defer lis.Close() 326 327 w := newConsumerWriter(lis.Addr().String(), nil, nil, testConsumerWriterMetrics()).(*consumerWriterImpl) 328 require.Equal(t, 0, len(w.resetCh)) 329 w.Close() 330 // Safe to close again. 331 w.Close() 332 _, ok := <-w.doneCh 333 require.False(t, ok) 334 } 335 336 func TestConsumerWriterCloseWhileDecoding(t *testing.T) { 337 defer leaktest.Check(t)() 338 339 lis, err := net.Listen("tcp", "127.0.0.1:0") 340 require.NoError(t, err) 341 defer lis.Close() 342 343 opts := testOptions() 344 345 w := newConsumerWriter(lis.Addr().String(), nil, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 346 347 var wg sync.WaitGroup 348 wg.Add(1) 349 go func() { 350 wg.Done() 351 require.Error(t, w.writeState.conns[0].decoder.Decode(&testMsg)) 352 }() 353 wg.Wait() 354 time.Sleep(time.Second) 355 w.Close() 356 } 357 358 func TestConsumerWriterResetWhileDecoding(t *testing.T) { 359 defer leaktest.Check(t)() 360 361 lis, err := net.Listen("tcp", "127.0.0.1:0") 362 require.NoError(t, err) 363 defer lis.Close() 364 365 opts := testOptions() 366 367 w := newConsumerWriter(lis.Addr().String(), nil, opts, testConsumerWriterMetrics()).(*consumerWriterImpl) 368 369 var wg sync.WaitGroup 370 wg.Add(1) 371 go func() { 372 wg.Done() 373 374 w.writeState.Lock() 375 conn := w.writeState.conns[0] 376 w.writeState.Unlock() 377 378 require.Error(t, conn.decoder.Decode(&testMsg)) 379 }() 380 wg.Wait() 381 time.Sleep(time.Second) 382 w.reset(resetOptions{ 383 connections: []io.ReadWriteCloser{new(net.TCPConn)}, 384 at: w.nowFn(), 385 validConns: true, 386 }) 387 } 388 389 // Interface solely for mocking. 390 //nolint:deadcode,unused 391 type contextDialer interface { 392 DialContext(ctx context.Context, network string, addr string) (net.Conn, error) 393 } 394 395 type keepAlivableConn struct { 396 net.Conn 397 keepAlivable 398 } 399 400 func TestConsumerWriter_connectNoRetry(t *testing.T) { 401 type testDeps struct { 402 Ctrl *gomock.Controller 403 MockDialer *MockcontextDialer 404 Listener net.Listener 405 } 406 407 newTestWriter := func(deps testDeps, opts Options) *consumerWriterImpl { 408 return newConsumerWriter( 409 deps.Listener.Addr().String(), 410 nil, 411 opts, 412 testConsumerWriterMetrics(), 413 ).(*consumerWriterImpl) 414 } 415 416 mustClose := func(t *testing.T, c io.ReadWriteCloser) { 417 require.NoError(t, c.Close()) 418 } 419 420 setup := func(t *testing.T) testDeps { 421 ctrl := gomock.NewController(t) 422 423 lis, err := net.Listen("tcp", "127.0.0.1:0") 424 require.NoError(t, err) 425 t.Cleanup(func() { 426 require.NoError(t, lis.Close()) 427 }) 428 429 return testDeps{ 430 Ctrl: ctrl, 431 Listener: lis, 432 MockDialer: NewMockcontextDialer(ctrl), 433 } 434 } 435 type dialArgs struct { 436 Ctx context.Context 437 Network string 438 Addr string 439 } 440 441 // Other tests in this file cover the case where dialer isn't set explicitly (default). 442 t.Run("uses net.Dialer where dialer is unset", func(t *testing.T) { 443 defer leaktest.Check(t)() 444 tdeps := setup(t) 445 opts := testOptions() 446 w := newTestWriter(tdeps, opts.SetConnectionOptions(opts.ConnectionOptions().SetContextDialer(nil))) 447 conn, err := w.connectNoRetryWithTimeout(tdeps.Listener.Addr().String()) 448 require.NoError(t, err) 449 defer mustClose(t, conn) 450 451 _, err = conn.Write([]byte("test")) 452 require.NoError(t, err) 453 }) 454 t.Run("uses dialer and respects timeout", func(t *testing.T) { 455 defer leaktest.Check(t)() 456 457 tdeps := setup(t) 458 var capturedArgs dialArgs 459 tdeps.MockDialer.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( 460 func(ctx context.Context, network string, addr string) (net.Conn, error) { 461 capturedArgs.Ctx = ctx 462 return (&net.Dialer{}).DialContext(ctx, network, addr) 463 }, 464 ).MinTimes(1) 465 466 const testDialTimeout = 45 * time.Second 467 opts := testOptions() 468 opts = opts.SetConnectionOptions(opts.ConnectionOptions(). 469 SetContextDialer(tdeps.MockDialer.DialContext). 470 SetDialTimeout(testDialTimeout), 471 ) 472 473 start := time.Now() 474 w := newTestWriter(tdeps, opts) 475 conn, err := w.connectNoRetry(tdeps.Listener.Addr().String()) 476 477 require.NoError(t, err) 478 defer mustClose(t, conn) 479 480 deadline, ok := capturedArgs.Ctx.Deadline() 481 require.True(t, ok) 482 // Start is taken *before* we try to connect, so the deadline must = start + <some_time> + testDialTimeout. 483 // Therefore deadline - start >= testDialTimeout. 484 assert.True(t, deadline.Sub(start) >= testDialTimeout) 485 }) 486 487 t.Run("sets KeepAlive where possible", func(t *testing.T) { 488 tdeps := setup(t) 489 // Deep mocking here is solely because it's not easy to get the keep alive off an actual TCP connection 490 // (have to drop down to the syscall layer). 491 const testKeepAlive = 56 * time.Minute 492 mockConn := NewMockkeepAlivable(tdeps.Ctrl) 493 mockConn.EXPECT().SetKeepAlivePeriod(testKeepAlive).Times(2) 494 mockConn.EXPECT().SetKeepAlive(true).Times(2) 495 496 tdeps.MockDialer.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).Return(keepAlivableConn{ 497 keepAlivable: mockConn, 498 }, nil).Times(2) 499 500 opts := testOptions() 501 opts = opts.SetConnectionOptions(opts.ConnectionOptions(). 502 SetKeepAlivePeriod(testKeepAlive). 503 SetContextDialer(tdeps.MockDialer.DialContext), 504 ) 505 w := newTestWriter(tdeps, opts) 506 _, err := w.connectNoRetryWithTimeout("foobar") 507 require.NoError(t, err) 508 }) 509 510 t.Run("handles non TCP connections gracefully", func(t *testing.T) { 511 tdeps := setup(t) 512 tdeps.MockDialer.EXPECT().DialContext(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( 513 func(ctx context.Context, network string, addr string) (net.Conn, error) { 514 srv, client := net.Pipe() 515 require.NoError(t, srv.Close()) 516 return client, nil 517 }, 518 ).MinTimes(1) 519 520 defer leaktest.Check(t)() 521 522 opts := testOptions() 523 opts = opts.SetConnectionOptions(opts.ConnectionOptions(). 524 SetContextDialer(tdeps.MockDialer.DialContext), 525 ) 526 527 w := newConsumerWriter( 528 "foobar", 529 nil, 530 opts, 531 testConsumerWriterMetrics(), 532 ).(*consumerWriterImpl) 533 conn, err := w.connectNoRetryWithTimeout("foobar") 534 require.NoError(t, err) 535 defer mustClose(t, conn) 536 537 _, isTCPConn := conn.Conn.(*net.TCPConn) 538 assert.False(t, isTCPConn) 539 }) 540 } 541 542 func testOptions() Options { 543 return NewOptions(). 544 SetTopicName("topicName"). 545 SetTopicWatchInitTimeout(100 * time.Millisecond). 546 SetPlacementWatchInitTimeout(100 * time.Millisecond). 547 SetMessageQueueNewWritesScanInterval(100 * time.Millisecond). 548 SetMessageQueueFullScanInterval(200 * time.Millisecond). 549 SetMessageRetryNanosFn( 550 NextRetryNanosFn( 551 retry.NewOptions(). 552 SetInitialBackoff(100 * time.Millisecond). 553 SetMaxBackoff(500 * time.Millisecond), 554 ), 555 ). 556 SetAckErrorRetryOptions(retry.NewOptions().SetInitialBackoff(200 * time.Millisecond).SetMaxBackoff(time.Second)). 557 SetConnectionOptions(testConnectionOptions()) 558 } 559 560 func testConnectionOptions() ConnectionOptions { 561 return NewConnectionOptions(). 562 SetNumConnections(1). 563 SetRetryOptions(retry.NewOptions().SetInitialBackoff(200 * time.Millisecond).SetMaxBackoff(time.Second)). 564 SetFlushInterval(100 * time.Millisecond). 565 SetResetDelay(100 * time.Millisecond) 566 } 567 568 func testConsumeAndAckOnConnection( 569 t *testing.T, 570 conn net.Conn, 571 encOpts proto.Options, 572 decOpts proto.Options, 573 ) { 574 serverEncoder := proto.NewEncoder(encOpts) 575 serverDecoder := proto.NewDecoder(conn, decOpts, 10) 576 var msg msgpb.Message 577 assert.NoError(t, serverDecoder.Decode(&msg)) 578 579 err := serverEncoder.Encode(&msgpb.Ack{ 580 Metadata: []msgpb.Metadata{ 581 msg.Metadata, 582 }, 583 }) 584 assert.NoError(t, err) 585 _, err = conn.Write(serverEncoder.Bytes()) 586 assert.NoError(t, err) 587 } 588 589 func testConsumeAndAckOnConnectionListener( 590 t *testing.T, 591 lis net.Listener, 592 encOpts proto.Options, 593 decOpts proto.Options, 594 ) { 595 conn, err := lis.Accept() 596 require.NoError(t, err) 597 defer conn.Close() 598 599 testConsumeAndAckOnConnection(t, conn, encOpts, decOpts) 600 } 601 602 func testConsumerWriterMetrics() consumerWriterMetrics { 603 return newConsumerWriterMetrics(tally.NoopScope) 604 } 605 606 func write(w consumerWriter, m proto.Marshaler) error { 607 err := testEncoder.Encode(m) 608 if err != nil { 609 return err 610 } 611 return w.Write(0, testEncoder.Bytes()) 612 }