github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/balancer/local_dc_test.go (about)

     1  package balancer
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/require"
     9  
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/balancers"
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/config"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/mock"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    16  )
    17  
    18  var localIP = net.IPv4(127, 0, 0, 1)
    19  
    20  type discoveryMock struct {
    21  	endpoints []endpoint.Endpoint
    22  }
    23  
    24  // implement discovery.Client
    25  func (d discoveryMock) Close(ctx context.Context) error {
    26  	return nil
    27  }
    28  
    29  func (d discoveryMock) Discover(ctx context.Context) ([]endpoint.Endpoint, error) {
    30  	return d.endpoints, nil
    31  }
    32  
    33  func TestCheckFastestAddress(t *testing.T) {
    34  	ctx := context.Background()
    35  
    36  	t.Run("Ok", func(t *testing.T) {
    37  		var firstCount int64
    38  		var secondCount int64
    39  
    40  		for i := 0; i < 100; i++ {
    41  			listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
    42  			require.NoError(t, err)
    43  			listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
    44  			require.NoError(t, err)
    45  			addr1 := listen1.Addr().String()
    46  			addr2 := listen2.Addr().String()
    47  
    48  			fastest := checkFastestAddress(ctx, []string{addr1, addr2})
    49  			require.NotEmpty(t, fastest)
    50  
    51  			switch fastest {
    52  			case addr1:
    53  				firstCount++
    54  			case addr2:
    55  				secondCount++
    56  			default:
    57  				require.Contains(t, []string{addr1, addr2}, fastest)
    58  			}
    59  
    60  			_ = listen1.Close()
    61  			_ = listen2.Close()
    62  		}
    63  		require.NotEmpty(t, firstCount)
    64  		require.NotEmpty(t, secondCount)
    65  	})
    66  	t.Run("HasError", func(t *testing.T) {
    67  		listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
    68  		require.NoError(t, err)
    69  		listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
    70  		require.NoError(t, err)
    71  		addr1 := listen1.Addr().String()
    72  		addr2 := listen2.Addr().String()
    73  
    74  		_ = listen2.Close() // for can't accept connections
    75  
    76  		fastest := checkFastestAddress(ctx, []string{addr1, addr2})
    77  		require.Equal(t, addr1, fastest)
    78  
    79  		_ = listen1.Close()
    80  	})
    81  	t.Run("AllErrors", func(t *testing.T) {
    82  		listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
    83  		require.NoError(t, err)
    84  		listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
    85  		require.NoError(t, err)
    86  		addr1 := listen1.Addr().String()
    87  		addr2 := listen2.Addr().String()
    88  
    89  		_ = listen1.Close() // for can't accept connections
    90  		_ = listen2.Close() // for can't accept connections
    91  
    92  		res := checkFastestAddress(ctx, []string{addr1, addr2})
    93  		require.Empty(t, res)
    94  	})
    95  }
    96  
    97  func TestDetectLocalDC(t *testing.T) {
    98  	ctx := context.Background()
    99  	xtest.TestManyTimesWithName(t, "Ok", func(t testing.TB) {
   100  		listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
   101  		require.NoError(t, err)
   102  		defer func() { _ = listen1.Close() }()
   103  
   104  		listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP})
   105  		require.NoError(t, err)
   106  		listen2Addr := listen2.Addr().String()
   107  		_ = listen2.Close() // force close, for not accept tcp connections
   108  
   109  		dc, err := detectLocalDC(ctx, []endpoint.Endpoint{
   110  			&mock.Endpoint{LocationField: "a", AddrField: "grpc://" + listen1.Addr().String()},
   111  			&mock.Endpoint{LocationField: "b", AddrField: "grpc://" + listen2Addr},
   112  		})
   113  		require.NoError(t, err)
   114  		require.Equal(t, "a", dc)
   115  	})
   116  	t.Run("Empty", func(t *testing.T) {
   117  		res, err := detectLocalDC(ctx, nil)
   118  		require.Equal(t, "", res)
   119  		require.Error(t, err)
   120  	})
   121  	t.Run("OneDC", func(t *testing.T) {
   122  		res, err := detectLocalDC(ctx, []endpoint.Endpoint{
   123  			&mock.Endpoint{LocationField: "a"},
   124  			&mock.Endpoint{LocationField: "a"},
   125  		})
   126  		require.NoError(t, err)
   127  		require.Equal(t, "a", res)
   128  	})
   129  }
   130  
   131  func TestLocalDCDiscovery(t *testing.T) {
   132  	ctx := context.Background()
   133  	cfg := config.New(
   134  		config.WithBalancer(balancers.PreferLocalDC(balancers.Default())),
   135  	)
   136  	r := &Balancer{
   137  		driverConfig: cfg,
   138  		config:       *cfg.Balancer(),
   139  		pool:         conn.NewPool(context.Background(), cfg),
   140  		discoveryClient: discoveryMock{endpoints: []endpoint.Endpoint{
   141  			&mock.Endpoint{AddrField: "a:123", LocationField: "a"},
   142  			&mock.Endpoint{AddrField: "b:234", LocationField: "b"},
   143  			&mock.Endpoint{AddrField: "c:456", LocationField: "c"},
   144  		}},
   145  		localDCDetector: func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) {
   146  			return "b", nil
   147  		},
   148  	}
   149  
   150  	err := r.clusterDiscoveryAttempt(ctx)
   151  	require.NoError(t, err)
   152  
   153  	for i := 0; i < 100; i++ {
   154  		conn, _ := r.connections().GetConnection(ctx)
   155  		require.Equal(t, "b:234", conn.Endpoint().Address())
   156  		require.Equal(t, "b", conn.Endpoint().Location())
   157  	}
   158  }
   159  
   160  func TestExtractHostPort(t *testing.T) {
   161  	table := []struct {
   162  		name    string
   163  		address string
   164  		host    string
   165  		port    string
   166  		err     bool
   167  	}{
   168  		{
   169  			"HostPort",
   170  			"asd:123",
   171  			"asd",
   172  			"123",
   173  			false,
   174  		},
   175  		{
   176  			"HostPortSchema",
   177  			"grpc://asd:123",
   178  			"asd",
   179  			"123",
   180  			false,
   181  		},
   182  		{
   183  			"NoPort",
   184  			"host",
   185  			"",
   186  			"",
   187  			true,
   188  		},
   189  		{
   190  			"Empty",
   191  			"",
   192  			"",
   193  			"",
   194  			true,
   195  		},
   196  	}
   197  	for _, test := range table {
   198  		t.Run(test.name, func(t *testing.T) {
   199  			host, port, err := extractHostPort(test.address)
   200  			require.Equal(t, test.host, host)
   201  			require.Equal(t, test.port, port)
   202  			if test.err {
   203  				require.Error(t, err)
   204  			} else {
   205  				require.NoError(t, err)
   206  			}
   207  		})
   208  	}
   209  }
   210  
   211  func TestGetRandomEndpoints(t *testing.T) {
   212  	source := []endpoint.Endpoint{
   213  		&mock.Endpoint{AddrField: "a"},
   214  		&mock.Endpoint{AddrField: "b"},
   215  		&mock.Endpoint{AddrField: "c"},
   216  	}
   217  
   218  	t.Run("ReturnSource", func(t *testing.T) {
   219  		res := getRandomEndpoints(source, 3)
   220  		require.Equal(t, source, res)
   221  
   222  		res = getRandomEndpoints(source, 4)
   223  		require.Equal(t, source, res)
   224  	})
   225  	xtest.TestManyTimesWithName(t, "SelectRandom", func(t testing.TB) {
   226  		res := getRandomEndpoints(source, 2)
   227  		require.Len(t, res, 2)
   228  		for _, ep := range res {
   229  			require.Contains(t, source, ep)
   230  		}
   231  		require.NotEqual(t, res[0], res[1])
   232  	})
   233  }