github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/tcpserver/tcp_server_test.go (about) 1 // Copyright 2021 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package tcpserver 15 16 import ( 17 "context" 18 "fmt" 19 "io" 20 "net/http" 21 "path" 22 "sync" 23 "testing" 24 "time" 25 26 grpcTesting "github.com/grpc-ecosystem/go-grpc-middleware/testing" 27 grpcTestingProto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" 28 "github.com/integralist/go-findroot/find" 29 "github.com/phayes/freeport" 30 "github.com/pingcap/tiflow/pkg/httputil" 31 "github.com/pingcap/tiflow/pkg/security" 32 "github.com/stretchr/testify/require" 33 "google.golang.org/grpc" 34 ) 35 36 func TestTCPServerInsecureHTTP1(t *testing.T) { 37 port, err := freeport.GetFreePort() 38 require.NoError(t, err) 39 addr := fmt.Sprintf("127.0.0.1:%d", port) 40 41 server, err := NewTCPServer(addr, &security.Credential{}) 42 require.NoError(t, err) 43 defer func() { 44 err := server.Close() 45 require.NoError(t, err) 46 }() 47 48 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 49 defer cancel() 50 51 var wg sync.WaitGroup 52 53 wg.Add(1) 54 go func() { 55 defer wg.Done() 56 err := server.Run(ctx) 57 require.Error(t, err) 58 require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error()) 59 }() 60 61 wg.Add(1) 62 go func() { 63 defer wg.Done() 64 testWithHTTPWorkload(ctx, t, server, addr, &security.Credential{}) 65 cancel() 66 }() 67 68 wg.Wait() 69 } 70 71 func TestTCPServerTLSHTTP1(t *testing.T) { 72 port, err := freeport.GetFreePort() 73 require.NoError(t, err) 74 addr := fmt.Sprintf("127.0.0.1:%d", port) 75 76 server, err := NewTCPServer(addr, makeCredential4Testing(t)) 77 require.NoError(t, err) 78 require.True(t, server.IsTLSEnabled()) 79 80 defer func() { 81 err := server.Close() 82 require.NoError(t, err) 83 }() 84 85 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 86 defer cancel() 87 88 var wg sync.WaitGroup 89 90 wg.Add(1) 91 go func() { 92 defer wg.Done() 93 err := server.Run(ctx) 94 require.Error(t, err) 95 require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error()) 96 }() 97 98 wg.Add(1) 99 go func() { 100 defer wg.Done() 101 defer cancel() 102 testWithHTTPWorkload(ctx, t, server, addr, makeCredential4Testing(t)) 103 }() 104 105 wg.Wait() 106 } 107 108 func TestTCPServerInsecureGrpc(t *testing.T) { 109 port, err := freeport.GetFreePort() 110 require.NoError(t, err) 111 addr := fmt.Sprintf("127.0.0.1:%d", port) 112 113 server, err := NewTCPServer(addr, &security.Credential{}) 114 require.NoError(t, err) 115 116 defer func() { 117 err := server.Close() 118 require.NoError(t, err) 119 }() 120 121 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 122 defer cancel() 123 124 var wg sync.WaitGroup 125 126 wg.Add(1) 127 go func() { 128 defer wg.Done() 129 err := server.Run(ctx) 130 require.Error(t, err) 131 require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error()) 132 }() 133 134 wg.Add(1) 135 go func() { 136 defer wg.Done() 137 testWithGrpcWorkload(ctx, t, server, addr, &security.Credential{}) 138 cancel() 139 }() 140 141 wg.Wait() 142 } 143 144 func TestTCPServerTLSGrpc(t *testing.T) { 145 port, err := freeport.GetFreePort() 146 require.NoError(t, err) 147 addr := fmt.Sprintf("127.0.0.1:%d", port) 148 149 server, err := NewTCPServer(addr, makeCredential4Testing(t)) 150 require.NoError(t, err) 151 require.True(t, server.IsTLSEnabled()) 152 153 defer func() { 154 err := server.Close() 155 require.NoError(t, err) 156 }() 157 158 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 159 defer cancel() 160 161 var wg sync.WaitGroup 162 163 wg.Add(1) 164 go func() { 165 defer wg.Done() 166 err := server.Run(ctx) 167 require.Error(t, err) 168 require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error()) 169 }() 170 171 wg.Add(1) 172 go func() { 173 defer wg.Done() 174 testWithGrpcWorkload(ctx, t, server, addr, makeCredential4Testing(t)) 175 cancel() 176 }() 177 178 wg.Wait() 179 } 180 181 func makeCredential4Testing(t *testing.T) *security.Credential { 182 stat, err := find.Repo() 183 require.NoError(t, err) 184 185 tlsPath := fmt.Sprintf("%s/tests/integration_tests/_certificates/", stat.Path) 186 return &security.Credential{ 187 CAPath: path.Join(tlsPath, "ca.pem"), 188 CertPath: path.Join(tlsPath, "server.pem"), 189 KeyPath: path.Join(tlsPath, "server-key.pem"), 190 CertAllowedCN: nil, 191 } 192 } 193 194 func testWithHTTPWorkload(_ context.Context, t *testing.T, server TCPServer, addr string, credentials *security.Credential) { 195 httpServer := &http.Server{} 196 http.HandleFunc("/", func(writer http.ResponseWriter, _ *http.Request) { 197 writer.WriteHeader(200) 198 _, err := writer.Write([]byte("ok")) 199 require.NoError(t, err) 200 }) 201 defer func() { 202 http.DefaultServeMux = http.NewServeMux() 203 }() 204 205 var wg sync.WaitGroup 206 207 wg.Add(1) 208 go func() { 209 defer wg.Done() 210 err := httpServer.Serve(server.HTTP1Listener()) 211 if err != nil && err != http.ErrServerClosed { 212 require.FailNow(t, 213 "unexpected error from http server", 214 "%d", 215 err.Error()) 216 } 217 }() 218 219 scheme := "http" 220 if credentials.IsTLSEnabled() { 221 scheme = "https" 222 } 223 224 cli, err := httputil.NewClient(credentials) 225 require.NoError(t, err) 226 227 uri := fmt.Sprintf("%s://%s/", scheme, addr) 228 resp, err := cli.Get(context.Background(), uri) 229 require.NoError(t, err) 230 defer func() { 231 _ = resp.Body.Close() 232 }() 233 require.Equal(t, 200, resp.StatusCode) 234 235 body, err := io.ReadAll(resp.Body) 236 require.NoError(t, err) 237 require.Equal(t, "ok", string(body)) 238 239 err = httpServer.Close() 240 require.NoError(t, err) 241 242 wg.Wait() 243 } 244 245 func testWithGrpcWorkload(ctx context.Context, t *testing.T, server TCPServer, addr string, credentials *security.Credential) { 246 grpcServer := grpc.NewServer() 247 service := &grpcTesting.TestPingService{T: t} 248 grpcTestingProto.RegisterTestServiceServer(grpcServer, service) 249 250 var wg sync.WaitGroup 251 252 wg.Add(1) 253 go func() { 254 defer wg.Done() 255 err := grpcServer.Serve(server.GrpcListener()) 256 require.NoError(t, err) 257 }() 258 259 var conn *grpc.ClientConn 260 if credentials.IsTLSEnabled() { 261 tlsOptions, err := credentials.ToGRPCDialOption() 262 require.NoError(t, err) 263 conn, err = grpc.Dial(addr, tlsOptions) 264 require.NoError(t, err) 265 } else { 266 var err error 267 conn, err = grpc.Dial(addr, grpc.WithInsecure()) 268 require.NoError(t, err) 269 } 270 defer func() { 271 _ = conn.Close() 272 }() 273 274 client := grpcTestingProto.NewTestServiceClient(conn) 275 276 for i := 0; i < 5; i++ { 277 result, err := client.Ping(ctx, &grpcTestingProto.PingRequest{ 278 Value: fmt.Sprintf("%d", i), 279 }) 280 require.NoError(t, err) 281 require.Equal(t, fmt.Sprintf("%d", i), result.Value) 282 } 283 284 wg.Add(1) 285 go func() { 286 defer wg.Done() 287 defer grpcServer.GracefulStop() 288 289 stream, err := client.PingStream(ctx) 290 require.NoError(t, err) 291 292 for i := 0; i < 10; i++ { 293 err := stream.Send(&grpcTestingProto.PingRequest{ 294 Value: fmt.Sprintf("%d", i), 295 }) 296 require.NoError(t, err) 297 298 received, err := stream.Recv() 299 require.NoError(t, err) 300 require.Equal(t, fmt.Sprintf("%d", i), received.Value) 301 } 302 }() 303 304 wg.Wait() 305 } 306 307 func TestTcpServerClose(t *testing.T) { 308 port, err := freeport.GetFreePort() 309 require.NoError(t, err) 310 addr := fmt.Sprintf("127.0.0.1:%d", port) 311 312 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 313 defer cancel() 314 315 server, err := NewTCPServer(addr, &security.Credential{}) 316 require.NoError(t, err) 317 318 var wg sync.WaitGroup 319 wg.Add(1) 320 go func() { 321 defer wg.Done() 322 err := server.Run(ctx) 323 require.Error(t, err) 324 require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error()) 325 }() 326 327 httpServer := &http.Server{} 328 http.HandleFunc("/", func(writer http.ResponseWriter, _ *http.Request) { 329 writer.WriteHeader(200) 330 _, err := writer.Write([]byte("ok")) 331 require.NoError(t, err) 332 }) 333 defer func() { 334 http.DefaultServeMux = http.NewServeMux() 335 }() 336 337 wg.Add(1) 338 go func() { 339 defer wg.Done() 340 err := httpServer.Serve(server.HTTP1Listener()) 341 require.Error(t, err) 342 require.Regexp(t, ".*mux: server closed.*", err.Error()) 343 }() 344 345 cli, err := httputil.NewClient(&security.Credential{}) 346 require.NoError(t, err) 347 348 uri := fmt.Sprintf("http://%s/", addr) 349 resp, err := cli.Get(context.Background(), uri) 350 require.NoError(t, err) 351 defer func() { 352 _ = resp.Body.Close() 353 }() 354 require.Equal(t, 200, resp.StatusCode) 355 356 // Close should be idempotent. 357 for i := 0; i < 3; i++ { 358 err := server.Close() 359 require.NoError(t, err) 360 } 361 362 wg.Wait() 363 }