github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/tcpserver/tcp_server_test.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package tcpserver
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"path"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	grpcTesting "github.com/grpc-ecosystem/go-grpc-middleware/testing"
    27  	grpcTestingProto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
    28  	"github.com/integralist/go-findroot/find"
    29  	"github.com/phayes/freeport"
    30  	"github.com/pingcap/tiflow/pkg/httputil"
    31  	"github.com/pingcap/tiflow/pkg/security"
    32  	"github.com/stretchr/testify/require"
    33  	"google.golang.org/grpc"
    34  )
    35  
    36  func TestTCPServerInsecureHTTP1(t *testing.T) {
    37  	port, err := freeport.GetFreePort()
    38  	require.NoError(t, err)
    39  	addr := fmt.Sprintf("127.0.0.1:%d", port)
    40  
    41  	server, err := NewTCPServer(addr, &security.Credential{})
    42  	require.NoError(t, err)
    43  	defer func() {
    44  		err := server.Close()
    45  		require.NoError(t, err)
    46  	}()
    47  
    48  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
    49  	defer cancel()
    50  
    51  	var wg sync.WaitGroup
    52  
    53  	wg.Add(1)
    54  	go func() {
    55  		defer wg.Done()
    56  		err := server.Run(ctx)
    57  		require.Error(t, err)
    58  		require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error())
    59  	}()
    60  
    61  	wg.Add(1)
    62  	go func() {
    63  		defer wg.Done()
    64  		testWithHTTPWorkload(ctx, t, server, addr, &security.Credential{})
    65  		cancel()
    66  	}()
    67  
    68  	wg.Wait()
    69  }
    70  
    71  func TestTCPServerTLSHTTP1(t *testing.T) {
    72  	port, err := freeport.GetFreePort()
    73  	require.NoError(t, err)
    74  	addr := fmt.Sprintf("127.0.0.1:%d", port)
    75  
    76  	server, err := NewTCPServer(addr, makeCredential4Testing(t))
    77  	require.NoError(t, err)
    78  	require.True(t, server.IsTLSEnabled())
    79  
    80  	defer func() {
    81  		err := server.Close()
    82  		require.NoError(t, err)
    83  	}()
    84  
    85  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
    86  	defer cancel()
    87  
    88  	var wg sync.WaitGroup
    89  
    90  	wg.Add(1)
    91  	go func() {
    92  		defer wg.Done()
    93  		err := server.Run(ctx)
    94  		require.Error(t, err)
    95  		require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error())
    96  	}()
    97  
    98  	wg.Add(1)
    99  	go func() {
   100  		defer wg.Done()
   101  		defer cancel()
   102  		testWithHTTPWorkload(ctx, t, server, addr, makeCredential4Testing(t))
   103  	}()
   104  
   105  	wg.Wait()
   106  }
   107  
   108  func TestTCPServerInsecureGrpc(t *testing.T) {
   109  	port, err := freeport.GetFreePort()
   110  	require.NoError(t, err)
   111  	addr := fmt.Sprintf("127.0.0.1:%d", port)
   112  
   113  	server, err := NewTCPServer(addr, &security.Credential{})
   114  	require.NoError(t, err)
   115  
   116  	defer func() {
   117  		err := server.Close()
   118  		require.NoError(t, err)
   119  	}()
   120  
   121  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   122  	defer cancel()
   123  
   124  	var wg sync.WaitGroup
   125  
   126  	wg.Add(1)
   127  	go func() {
   128  		defer wg.Done()
   129  		err := server.Run(ctx)
   130  		require.Error(t, err)
   131  		require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error())
   132  	}()
   133  
   134  	wg.Add(1)
   135  	go func() {
   136  		defer wg.Done()
   137  		testWithGrpcWorkload(ctx, t, server, addr, &security.Credential{})
   138  		cancel()
   139  	}()
   140  
   141  	wg.Wait()
   142  }
   143  
   144  func TestTCPServerTLSGrpc(t *testing.T) {
   145  	port, err := freeport.GetFreePort()
   146  	require.NoError(t, err)
   147  	addr := fmt.Sprintf("127.0.0.1:%d", port)
   148  
   149  	server, err := NewTCPServer(addr, makeCredential4Testing(t))
   150  	require.NoError(t, err)
   151  	require.True(t, server.IsTLSEnabled())
   152  
   153  	defer func() {
   154  		err := server.Close()
   155  		require.NoError(t, err)
   156  	}()
   157  
   158  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   159  	defer cancel()
   160  
   161  	var wg sync.WaitGroup
   162  
   163  	wg.Add(1)
   164  	go func() {
   165  		defer wg.Done()
   166  		err := server.Run(ctx)
   167  		require.Error(t, err)
   168  		require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error())
   169  	}()
   170  
   171  	wg.Add(1)
   172  	go func() {
   173  		defer wg.Done()
   174  		testWithGrpcWorkload(ctx, t, server, addr, makeCredential4Testing(t))
   175  		cancel()
   176  	}()
   177  
   178  	wg.Wait()
   179  }
   180  
   181  func makeCredential4Testing(t *testing.T) *security.Credential {
   182  	stat, err := find.Repo()
   183  	require.NoError(t, err)
   184  
   185  	tlsPath := fmt.Sprintf("%s/tests/integration_tests/_certificates/", stat.Path)
   186  	return &security.Credential{
   187  		CAPath:        path.Join(tlsPath, "ca.pem"),
   188  		CertPath:      path.Join(tlsPath, "server.pem"),
   189  		KeyPath:       path.Join(tlsPath, "server-key.pem"),
   190  		CertAllowedCN: nil,
   191  	}
   192  }
   193  
   194  func testWithHTTPWorkload(_ context.Context, t *testing.T, server TCPServer, addr string, credentials *security.Credential) {
   195  	httpServer := &http.Server{}
   196  	http.HandleFunc("/", func(writer http.ResponseWriter, _ *http.Request) {
   197  		writer.WriteHeader(200)
   198  		_, err := writer.Write([]byte("ok"))
   199  		require.NoError(t, err)
   200  	})
   201  	defer func() {
   202  		http.DefaultServeMux = http.NewServeMux()
   203  	}()
   204  
   205  	var wg sync.WaitGroup
   206  
   207  	wg.Add(1)
   208  	go func() {
   209  		defer wg.Done()
   210  		err := httpServer.Serve(server.HTTP1Listener())
   211  		if err != nil && err != http.ErrServerClosed {
   212  			require.FailNow(t,
   213  				"unexpected error from http server",
   214  				"%d",
   215  				err.Error())
   216  		}
   217  	}()
   218  
   219  	scheme := "http"
   220  	if credentials.IsTLSEnabled() {
   221  		scheme = "https"
   222  	}
   223  
   224  	cli, err := httputil.NewClient(credentials)
   225  	require.NoError(t, err)
   226  
   227  	uri := fmt.Sprintf("%s://%s/", scheme, addr)
   228  	resp, err := cli.Get(context.Background(), uri)
   229  	require.NoError(t, err)
   230  	defer func() {
   231  		_ = resp.Body.Close()
   232  	}()
   233  	require.Equal(t, 200, resp.StatusCode)
   234  
   235  	body, err := io.ReadAll(resp.Body)
   236  	require.NoError(t, err)
   237  	require.Equal(t, "ok", string(body))
   238  
   239  	err = httpServer.Close()
   240  	require.NoError(t, err)
   241  
   242  	wg.Wait()
   243  }
   244  
   245  func testWithGrpcWorkload(ctx context.Context, t *testing.T, server TCPServer, addr string, credentials *security.Credential) {
   246  	grpcServer := grpc.NewServer()
   247  	service := &grpcTesting.TestPingService{T: t}
   248  	grpcTestingProto.RegisterTestServiceServer(grpcServer, service)
   249  
   250  	var wg sync.WaitGroup
   251  
   252  	wg.Add(1)
   253  	go func() {
   254  		defer wg.Done()
   255  		err := grpcServer.Serve(server.GrpcListener())
   256  		require.NoError(t, err)
   257  	}()
   258  
   259  	var conn *grpc.ClientConn
   260  	if credentials.IsTLSEnabled() {
   261  		tlsOptions, err := credentials.ToGRPCDialOption()
   262  		require.NoError(t, err)
   263  		conn, err = grpc.Dial(addr, tlsOptions)
   264  		require.NoError(t, err)
   265  	} else {
   266  		var err error
   267  		conn, err = grpc.Dial(addr, grpc.WithInsecure())
   268  		require.NoError(t, err)
   269  	}
   270  	defer func() {
   271  		_ = conn.Close()
   272  	}()
   273  
   274  	client := grpcTestingProto.NewTestServiceClient(conn)
   275  
   276  	for i := 0; i < 5; i++ {
   277  		result, err := client.Ping(ctx, &grpcTestingProto.PingRequest{
   278  			Value: fmt.Sprintf("%d", i),
   279  		})
   280  		require.NoError(t, err)
   281  		require.Equal(t, fmt.Sprintf("%d", i), result.Value)
   282  	}
   283  
   284  	wg.Add(1)
   285  	go func() {
   286  		defer wg.Done()
   287  		defer grpcServer.GracefulStop()
   288  
   289  		stream, err := client.PingStream(ctx)
   290  		require.NoError(t, err)
   291  
   292  		for i := 0; i < 10; i++ {
   293  			err := stream.Send(&grpcTestingProto.PingRequest{
   294  				Value: fmt.Sprintf("%d", i),
   295  			})
   296  			require.NoError(t, err)
   297  
   298  			received, err := stream.Recv()
   299  			require.NoError(t, err)
   300  			require.Equal(t, fmt.Sprintf("%d", i), received.Value)
   301  		}
   302  	}()
   303  
   304  	wg.Wait()
   305  }
   306  
   307  func TestTcpServerClose(t *testing.T) {
   308  	port, err := freeport.GetFreePort()
   309  	require.NoError(t, err)
   310  	addr := fmt.Sprintf("127.0.0.1:%d", port)
   311  
   312  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   313  	defer cancel()
   314  
   315  	server, err := NewTCPServer(addr, &security.Credential{})
   316  	require.NoError(t, err)
   317  
   318  	var wg sync.WaitGroup
   319  	wg.Add(1)
   320  	go func() {
   321  		defer wg.Done()
   322  		err := server.Run(ctx)
   323  		require.Error(t, err)
   324  		require.Regexp(t, ".*ErrTCPServerClosed.*", err.Error())
   325  	}()
   326  
   327  	httpServer := &http.Server{}
   328  	http.HandleFunc("/", func(writer http.ResponseWriter, _ *http.Request) {
   329  		writer.WriteHeader(200)
   330  		_, err := writer.Write([]byte("ok"))
   331  		require.NoError(t, err)
   332  	})
   333  	defer func() {
   334  		http.DefaultServeMux = http.NewServeMux()
   335  	}()
   336  
   337  	wg.Add(1)
   338  	go func() {
   339  		defer wg.Done()
   340  		err := httpServer.Serve(server.HTTP1Listener())
   341  		require.Error(t, err)
   342  		require.Regexp(t, ".*mux: server closed.*", err.Error())
   343  	}()
   344  
   345  	cli, err := httputil.NewClient(&security.Credential{})
   346  	require.NoError(t, err)
   347  
   348  	uri := fmt.Sprintf("http://%s/", addr)
   349  	resp, err := cli.Get(context.Background(), uri)
   350  	require.NoError(t, err)
   351  	defer func() {
   352  		_ = resp.Body.Close()
   353  	}()
   354  	require.Equal(t, 200, resp.StatusCode)
   355  
   356  	// Close should be idempotent.
   357  	for i := 0; i < 3; i++ {
   358  		err := server.Close()
   359  		require.NoError(t, err)
   360  	}
   361  
   362  	wg.Wait()
   363  }