github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/balancer/local_dc_test.go (about) 1 package balancer 2 3 import ( 4 "context" 5 "net" 6 "testing" 7 8 "github.com/stretchr/testify/require" 9 10 "github.com/ydb-platform/ydb-go-sdk/v3/balancers" 11 "github.com/ydb-platform/ydb-go-sdk/v3/config" 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/mock" 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" 16 ) 17 18 var localIP = net.IPv4(127, 0, 0, 1) 19 20 type discoveryMock struct { 21 endpoints []endpoint.Endpoint 22 } 23 24 // implement discovery.Client 25 func (d discoveryMock) Close(ctx context.Context) error { 26 return nil 27 } 28 29 func (d discoveryMock) Discover(ctx context.Context) ([]endpoint.Endpoint, error) { 30 return d.endpoints, nil 31 } 32 33 func TestCheckFastestAddress(t *testing.T) { 34 ctx := context.Background() 35 36 t.Run("Ok", func(t *testing.T) { 37 var firstCount int64 38 var secondCount int64 39 40 for i := 0; i < 100; i++ { 41 listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 42 require.NoError(t, err) 43 listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 44 require.NoError(t, err) 45 addr1 := listen1.Addr().String() 46 addr2 := listen2.Addr().String() 47 48 fastest := checkFastestAddress(ctx, []string{addr1, addr2}) 49 require.NotEmpty(t, fastest) 50 51 switch fastest { 52 case addr1: 53 firstCount++ 54 case addr2: 55 secondCount++ 56 default: 57 require.Contains(t, []string{addr1, addr2}, fastest) 58 } 59 60 _ = listen1.Close() 61 _ = listen2.Close() 62 } 63 require.NotEmpty(t, firstCount) 64 require.NotEmpty(t, secondCount) 65 }) 66 t.Run("HasError", func(t *testing.T) { 67 listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 68 require.NoError(t, err) 69 listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 70 require.NoError(t, err) 71 addr1 := listen1.Addr().String() 72 addr2 := listen2.Addr().String() 73 74 _ = listen2.Close() // for can't accept connections 75 76 fastest := checkFastestAddress(ctx, []string{addr1, addr2}) 77 require.Equal(t, addr1, fastest) 78 79 _ = listen1.Close() 80 }) 81 t.Run("AllErrors", func(t *testing.T) { 82 listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 83 require.NoError(t, err) 84 listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 85 require.NoError(t, err) 86 addr1 := listen1.Addr().String() 87 addr2 := listen2.Addr().String() 88 89 _ = listen1.Close() // for can't accept connections 90 _ = listen2.Close() // for can't accept connections 91 92 res := checkFastestAddress(ctx, []string{addr1, addr2}) 93 require.Empty(t, res) 94 }) 95 } 96 97 func TestDetectLocalDC(t *testing.T) { 98 ctx := context.Background() 99 xtest.TestManyTimesWithName(t, "Ok", func(t testing.TB) { 100 listen1, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 101 require.NoError(t, err) 102 defer func() { _ = listen1.Close() }() 103 104 listen2, err := net.ListenTCP("tcp", &net.TCPAddr{IP: localIP}) 105 require.NoError(t, err) 106 listen2Addr := listen2.Addr().String() 107 _ = listen2.Close() // force close, for not accept tcp connections 108 109 dc, err := detectLocalDC(ctx, []endpoint.Endpoint{ 110 &mock.Endpoint{LocationField: "a", AddrField: "grpc://" + listen1.Addr().String()}, 111 &mock.Endpoint{LocationField: "b", AddrField: "grpc://" + listen2Addr}, 112 }) 113 require.NoError(t, err) 114 require.Equal(t, "a", dc) 115 }) 116 t.Run("Empty", func(t *testing.T) { 117 res, err := detectLocalDC(ctx, nil) 118 require.Equal(t, "", res) 119 require.Error(t, err) 120 }) 121 t.Run("OneDC", func(t *testing.T) { 122 res, err := detectLocalDC(ctx, []endpoint.Endpoint{ 123 &mock.Endpoint{LocationField: "a"}, 124 &mock.Endpoint{LocationField: "a"}, 125 }) 126 require.NoError(t, err) 127 require.Equal(t, "a", res) 128 }) 129 } 130 131 func TestLocalDCDiscovery(t *testing.T) { 132 ctx := context.Background() 133 cfg := config.New( 134 config.WithBalancer(balancers.PreferLocalDC(balancers.Default())), 135 ) 136 r := &Balancer{ 137 driverConfig: cfg, 138 config: *cfg.Balancer(), 139 pool: conn.NewPool(context.Background(), cfg), 140 discoveryClient: discoveryMock{endpoints: []endpoint.Endpoint{ 141 &mock.Endpoint{AddrField: "a:123", LocationField: "a"}, 142 &mock.Endpoint{AddrField: "b:234", LocationField: "b"}, 143 &mock.Endpoint{AddrField: "c:456", LocationField: "c"}, 144 }}, 145 localDCDetector: func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) { 146 return "b", nil 147 }, 148 } 149 150 err := r.clusterDiscoveryAttempt(ctx) 151 require.NoError(t, err) 152 153 for i := 0; i < 100; i++ { 154 conn, _ := r.connections().GetConnection(ctx) 155 require.Equal(t, "b:234", conn.Endpoint().Address()) 156 require.Equal(t, "b", conn.Endpoint().Location()) 157 } 158 } 159 160 func TestExtractHostPort(t *testing.T) { 161 table := []struct { 162 name string 163 address string 164 host string 165 port string 166 err bool 167 }{ 168 { 169 "HostPort", 170 "asd:123", 171 "asd", 172 "123", 173 false, 174 }, 175 { 176 "HostPortSchema", 177 "grpc://asd:123", 178 "asd", 179 "123", 180 false, 181 }, 182 { 183 "NoPort", 184 "host", 185 "", 186 "", 187 true, 188 }, 189 { 190 "Empty", 191 "", 192 "", 193 "", 194 true, 195 }, 196 } 197 for _, test := range table { 198 t.Run(test.name, func(t *testing.T) { 199 host, port, err := extractHostPort(test.address) 200 require.Equal(t, test.host, host) 201 require.Equal(t, test.port, port) 202 if test.err { 203 require.Error(t, err) 204 } else { 205 require.NoError(t, err) 206 } 207 }) 208 } 209 } 210 211 func TestGetRandomEndpoints(t *testing.T) { 212 source := []endpoint.Endpoint{ 213 &mock.Endpoint{AddrField: "a"}, 214 &mock.Endpoint{AddrField: "b"}, 215 &mock.Endpoint{AddrField: "c"}, 216 } 217 218 t.Run("ReturnSource", func(t *testing.T) { 219 res := getRandomEndpoints(source, 3) 220 require.Equal(t, source, res) 221 222 res = getRandomEndpoints(source, 4) 223 require.Equal(t, source, res) 224 }) 225 xtest.TestManyTimesWithName(t, "SelectRandom", func(t testing.TB) { 226 res := getRandomEndpoints(source, 2) 227 require.Len(t, res, 2) 228 for _, ep := range res { 229 require.Contains(t, source, ep) 230 } 231 require.NotEqual(t, res[0], res[1]) 232 }) 233 }