github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/aggregator/client/conn_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package client
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"math"
    28  	"net"
    29  	"sync"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/m3db/m3/src/x/clock"
    34  
    35  	"github.com/golang/mock/gomock"
    36  	"github.com/leanovate/gopter"
    37  	"github.com/leanovate/gopter/gen"
    38  	"github.com/leanovate/gopter/prop"
    39  	"github.com/stretchr/testify/assert"
    40  	"github.com/stretchr/testify/require"
    41  )
    42  
    43  const (
    44  	testFakeServerAddr        = "nonexistent"
    45  	testLocalServerAddr       = "127.0.0.1:0"
    46  	testRandomSeeed           = 831992
    47  	testMinSuccessfulTests    = 1000
    48  	testReconnectThreshold    = 1024
    49  	testMaxReconnectThreshold = 8096
    50  )
    51  
    52  var (
    53  	errTestConnect = errors.New("connect error")
    54  	errTestWrite   = errors.New("write error")
    55  )
    56  
    57  func TestConnectionDontReconnectProperties(t *testing.T) {
    58  	props := testConnectionProperties()
    59  	props.Property(
    60  		`When the number of failures is less than or equal to the threshold and the time since last `+
    61  			`connection is less than the maximum duration writes should:
    62  	  - not attempt to reconnect
    63  	  - increment the number of failures`,
    64  		prop.ForAll(
    65  			func(numFailures int32) (bool, error) {
    66  				conn := newConnection(testFakeServerAddr,
    67  					testConnectionOptions().
    68  						SetMaxReconnectDuration(time.Duration(math.MaxInt64)),
    69  				)
    70  				conn.connectWithLockFn = func() error { return errTestConnect }
    71  				conn.numFailures = int(numFailures)
    72  				conn.threshold = testReconnectThreshold
    73  
    74  				if err := conn.Write(nil); err != errNoActiveConnection {
    75  					return false, fmt.Errorf("unexpected error: %v", err)
    76  				}
    77  
    78  				expected := int(numFailures + 1)
    79  				if conn.numFailures != expected {
    80  					return false, fmt.Errorf(
    81  						"expected the number of failures to be: %v, but found: %v", expected, conn.numFailures,
    82  					)
    83  				}
    84  
    85  				return true, nil
    86  			},
    87  			gen.Int32Range(0, testReconnectThreshold),
    88  		))
    89  
    90  	props.TestingRun(t)
    91  }
    92  
    93  func TestConnectionNumFailuresThresholdReconnectProperty(t *testing.T) {
    94  	props := testConnectionProperties()
    95  	props.Property(
    96  		"When number of failures is greater than the threshold, it is multiplied",
    97  		prop.ForAll(
    98  			func(threshold int32) (bool, error) {
    99  				conn := newConnection(testFakeServerAddr, testConnectionOptions())
   100  				conn.connectWithLockFn = func() error { return errTestConnect }
   101  				conn.threshold = int(threshold)
   102  				conn.multiplier = 2
   103  				conn.numFailures = conn.threshold + 1
   104  				conn.maxThreshold = testMaxReconnectThreshold
   105  
   106  				expectedNewThreshold := conn.threshold * conn.multiplier
   107  				if expectedNewThreshold > conn.maxThreshold {
   108  					expectedNewThreshold = conn.maxThreshold
   109  				}
   110  				if err := conn.Write(nil); !errors.Is(err, errTestConnect) {
   111  					return false, fmt.Errorf("unexpected error: %w", err)
   112  				}
   113  
   114  				require.Equal(t, expectedNewThreshold, conn.threshold)
   115  				return true, nil
   116  			},
   117  			gen.Int32Range(1, testMaxReconnectThreshold),
   118  		))
   119  	props.Property(
   120  		"When the number of failures is greater than the threshold writes should attempt to reconnect",
   121  		prop.ForAll(
   122  			func(threshold int32) (bool, error) {
   123  				conn := newConnection(testFakeServerAddr, testConnectionOptions())
   124  				conn.connectWithLockFn = func() error { return errTestConnect }
   125  				conn.threshold = int(threshold)
   126  				conn.numFailures = conn.threshold + 1
   127  				conn.maxThreshold = 2 * conn.numFailures
   128  
   129  				if err := conn.Write(nil); !errors.Is(err, errTestConnect) {
   130  					return false, fmt.Errorf("unexpected error: %w", err)
   131  				}
   132  				return true, nil
   133  			},
   134  			gen.Int32Range(1, testMaxReconnectThreshold),
   135  		))
   136  	props.Property(
   137  		"When the number of failures is greater than the max threshold writes must not attempt to reconnect",
   138  		prop.ForAll(
   139  			func(threshold int32) (bool, error) {
   140  				conn := newConnection(testFakeServerAddr, testConnectionOptions())
   141  				conn.connectWithLockFn = func() error { return errTestConnect }
   142  				// Exhausted max threshold
   143  				conn.threshold = int(threshold)
   144  				conn.maxThreshold = conn.threshold
   145  				conn.maxDuration = math.MaxInt64
   146  				conn.numFailures = conn.maxThreshold + 1
   147  
   148  				if err := conn.Write(nil); !errors.Is(err, errNoActiveConnection) {
   149  					return false, fmt.Errorf("unexpected error: %w", err)
   150  				}
   151  				return true, nil
   152  			},
   153  			gen.Int32Range(1, testMaxReconnectThreshold),
   154  		))
   155  	props.Property(
   156  		`When the number of failures is greater than the max threshold
   157  		 but time since last connection attempt is greater than the maximum duration
   158  		 then writes should attempt to reconnect`,
   159  		prop.ForAll(
   160  			func(delay int64) (bool, error) {
   161  				conn := newConnection(testFakeServerAddr, testConnectionOptions())
   162  				conn.connectWithLockFn = func() error { return errTestConnect }
   163  				// Exhausted max threshold
   164  				conn.threshold = 1
   165  				conn.maxThreshold = conn.threshold
   166  				conn.numFailures = conn.maxThreshold + 1
   167  
   168  				now := time.Now()
   169  				conn.nowFn = func() time.Time { return now }
   170  				conn.lastConnectAttemptNanos = now.UnixNano() - delay
   171  				conn.maxDuration = time.Duration(delay)
   172  
   173  				if err := conn.Write(nil); !errors.Is(err, errTestConnect) {
   174  					return false, fmt.Errorf("unexpected error: %w", err)
   175  				}
   176  				return true, nil
   177  			},
   178  			gen.Int64Range(1, math.MaxInt64),
   179  		))
   180  
   181  	props.TestingRun(t)
   182  }
   183  
   184  func TestConnectionMaxDurationReconnectProperty(t *testing.T) {
   185  	props := testConnectionProperties()
   186  	props.Property(
   187  		"When the time since last connection is greater than the maximum duration writes should attempt to reconnect",
   188  		prop.ForAll(
   189  			func(delay int64) (bool, error) {
   190  				conn := newConnection(testFakeServerAddr, testConnectionOptions())
   191  				conn.connectWithLockFn = func() error { return errTestConnect }
   192  				now := time.Now()
   193  				conn.nowFn = func() time.Time { return now }
   194  				conn.lastConnectAttemptNanos = now.UnixNano() - delay
   195  				conn.maxDuration = time.Duration(delay)
   196  
   197  				if err := conn.Write(nil); err != errTestConnect {
   198  					return false, fmt.Errorf("unexpected error: %v", err)
   199  				}
   200  				return true, nil
   201  			},
   202  			gen.Int64Range(1, math.MaxInt64),
   203  		))
   204  
   205  	props.TestingRun(t)
   206  }
   207  
   208  func TestConnectionReconnectProperties(t *testing.T) {
   209  	props := testConnectionProperties()
   210  	props.Property(
   211  		`When there is no active connection and a write cannot establish one it should:
   212  		- set number of failures to threshold + 2
   213  	  - update the threshold to be min(threshold*multiplier, maxThreshold)`,
   214  		prop.ForAll(
   215  			func(threshold, multiplier int32) (bool, error) {
   216  				conn := newConnection(testFakeServerAddr, testConnectionOptions())
   217  				conn.connectWithLockFn = func() error { return errTestConnect }
   218  				conn.threshold = int(threshold)
   219  				conn.numFailures = conn.threshold + 1
   220  				conn.multiplier = int(multiplier)
   221  				conn.maxThreshold = testMaxReconnectThreshold
   222  
   223  				if err := conn.Write(nil); err != errTestConnect {
   224  					return false, fmt.Errorf("unexpected error: %v", err)
   225  				}
   226  
   227  				if conn.numFailures != int(threshold+2) {
   228  					return false, fmt.Errorf(
   229  						"expected the number of failures to be %d, but found: %v", threshold+2, conn.numFailures,
   230  					)
   231  				}
   232  
   233  				expected := int(threshold * multiplier)
   234  				if expected > testMaxReconnectThreshold {
   235  					expected = testMaxReconnectThreshold
   236  				}
   237  
   238  				if conn.threshold != expected {
   239  					return false, fmt.Errorf(
   240  						"expected the new threshold to be %v, but found: %v", expected, conn.threshold,
   241  					)
   242  				}
   243  
   244  				return true, nil
   245  			},
   246  			gen.Int32Range(1, testMaxReconnectThreshold),
   247  			gen.Int32Range(1, 16),
   248  		))
   249  
   250  	props.TestingRun(t)
   251  }
   252  
   253  func TestConnectionWriteSucceedsOnSecondAttempt(t *testing.T) {
   254  	conn := newConnection(testFakeServerAddr, testConnectionOptions())
   255  	conn.numFailures = 3
   256  	conn.connectWithLockFn = func() error { return nil }
   257  	var count int
   258  	conn.writeWithLockFn = func([]byte) error {
   259  		count++
   260  		if count == 1 {
   261  			return errTestWrite
   262  		}
   263  		return nil
   264  	}
   265  
   266  	require.NoError(t, conn.Write(nil))
   267  	require.Equal(t, 0, conn.numFailures)
   268  	require.Equal(t, 2, conn.threshold)
   269  }
   270  
   271  func TestConnectionWriteFailsOnSecondAttempt(t *testing.T) {
   272  	conn := newConnection(testFakeServerAddr, testConnectionOptions())
   273  	conn.numFailures = 3
   274  	conn.writeWithLockFn = func([]byte) error { return errTestWrite }
   275  	var count int
   276  	conn.connectWithLockFn = func() error {
   277  		count++
   278  		if count == 1 {
   279  			return nil
   280  		}
   281  		return errTestConnect
   282  	}
   283  
   284  	require.Equal(t, errTestConnect, conn.Write(nil))
   285  	require.Equal(t, 1, conn.numFailures)
   286  	require.Equal(t, 2, conn.threshold)
   287  }
   288  
   289  type keepAlivableConn struct {
   290  	net.Conn
   291  	keepAlivable
   292  }
   293  
   294  func TestConnectWithCustomDialer(t *testing.T) {
   295  	testData := []byte("foobar")
   296  	testConnectionTimeout := 5 * time.Second
   297  
   298  	testWithConn := func(t *testing.T, netConn net.Conn) {
   299  		type args struct {
   300  			Ctx     context.Context
   301  			Network string
   302  			Address string
   303  		}
   304  		var capturedArgs args
   305  		dialer := func(ctx context.Context, network string, address string) (net.Conn, error) {
   306  			capturedArgs = args{
   307  				Ctx:     ctx,
   308  				Network: network,
   309  				Address: address,
   310  			}
   311  			return netConn, nil
   312  		}
   313  		opts := testConnectionOptions().
   314  			SetContextDialer(dialer).
   315  			SetConnectionTimeout(testConnectionTimeout)
   316  		addr := "127.0.0.1:5555"
   317  
   318  		conn := newConnection(addr, opts)
   319  		start := time.Now()
   320  		require.NoError(t, conn.Write(testData))
   321  
   322  		assert.Equal(t, addr, capturedArgs.Address)
   323  		assert.Equal(t, tcpProtocol, capturedArgs.Network)
   324  
   325  		deadline, ok := capturedArgs.Ctx.Deadline()
   326  		require.True(t, ok)
   327  		// Start is taken *before* we try to connect, so the deadline must = start + <some_time> + testDialTimeout.
   328  		// Therefore deadline - start >= testDialTimeout.
   329  		assert.True(t, deadline.Sub(start) >= testConnectionTimeout)
   330  	}
   331  
   332  	t.Run("non keep alivable conn", func(t *testing.T) {
   333  		ctrl := gomock.NewController(t)
   334  		mockConn := NewMockConn(ctrl)
   335  
   336  		mockConn.EXPECT().Write(testData)
   337  		mockConn.EXPECT().SetWriteDeadline(gomock.Any())
   338  		testWithConn(t, mockConn)
   339  	})
   340  
   341  	t.Run("keep alivable conn", func(t *testing.T) {
   342  		ctrl := gomock.NewController(t)
   343  		mockConn := NewMockConn(ctrl)
   344  
   345  		mockConn.EXPECT().Write(testData)
   346  		mockConn.EXPECT().SetWriteDeadline(gomock.Any())
   347  
   348  		mockKeepAlivable := NewMockkeepAlivable(ctrl)
   349  		mockKeepAlivable.EXPECT().SetKeepAlive(true)
   350  
   351  		testWithConn(t, keepAlivableConn{
   352  			Conn:         mockConn,
   353  			keepAlivable: mockKeepAlivable,
   354  		})
   355  	})
   356  }
   357  
   358  func TestConnectWriteToServer(t *testing.T) {
   359  	data := []byte("foobar")
   360  
   361  	// Start tcp server.
   362  	var wg sync.WaitGroup
   363  	wg.Add(1)
   364  
   365  	l, err := net.Listen(tcpProtocol, testLocalServerAddr)
   366  	require.NoError(t, err)
   367  	serverAddr := l.Addr().String()
   368  
   369  	go func() {
   370  		defer wg.Done()
   371  
   372  		// Ignore the first testing connection.
   373  		conn, err := l.Accept()
   374  		require.NoError(t, err)
   375  		require.NoError(t, conn.Close())
   376  
   377  		// Read from the second connection.
   378  		conn, err = l.Accept()
   379  		require.NoError(t, err)
   380  		buf := make([]byte, 1024)
   381  		n, err := conn.Read(buf)
   382  		require.NoError(t, err)
   383  		require.Equal(t, data, buf[:n])
   384  		conn.Close() // nolint: errcheck
   385  	}()
   386  
   387  	// Wait until the server starts up.
   388  	testConn, err := net.DialTimeout(tcpProtocol, serverAddr, time.Minute)
   389  	require.NoError(t, err)
   390  	require.NoError(t, testConn.Close())
   391  
   392  	// Create a new connection and assert we can write successfully.
   393  	opts := testConnectionOptions().SetInitReconnectThreshold(0)
   394  	conn := newConnection(serverAddr, opts)
   395  	require.NoError(t, conn.Write(data))
   396  	require.Equal(t, 0, conn.numFailures)
   397  	require.NotNil(t, conn.conn)
   398  
   399  	// Stop the server.
   400  	l.Close() // nolint: errcheck
   401  	wg.Wait()
   402  
   403  	// Close the connection
   404  	conn.Close()
   405  	require.Nil(t, conn.conn)
   406  }
   407  
   408  func testConnectionOptions() ConnectionOptions {
   409  	return NewConnectionOptions().
   410  		SetClockOptions(clock.NewOptions()).
   411  		SetConnectionKeepAlive(true).
   412  		SetConnectionTimeout(100 * time.Millisecond).
   413  		SetInitReconnectThreshold(2).
   414  		SetMaxReconnectThreshold(6).
   415  		SetReconnectThresholdMultiplier(2).
   416  		SetWriteTimeout(100 * time.Millisecond)
   417  }
   418  
   419  func testConnectionProperties() *gopter.Properties {
   420  	params := gopter.DefaultTestParameters()
   421  	params.Rng.Seed(testRandomSeeed)
   422  	params.MinSuccessfulTests = testMinSuccessfulTests
   423  	return gopter.NewProperties(params)
   424  }