github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/httputil/httputil_test.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package httputil
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"fmt"
    20  	"io"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"net/url"
    24  	"os"
    25  	"path/filepath"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/pingcap/tidb/pkg/util"
    30  	"github.com/pingcap/tiflow/pkg/security"
    31  	"github.com/stretchr/testify/require"
    32  )
    33  
    34  var httputilServerMsg = "this is httputil test server"
    35  
    36  func TestHttputilNewClient(t *testing.T) {
    37  	t.Parallel()
    38  
    39  	dir, err := os.Getwd()
    40  	require.Nil(t, err)
    41  	certDir := "_certificates"
    42  	serverTLS, err := util.ToTLSConfigWithVerify(
    43  		filepath.Join(dir, certDir, "ca.pem"),
    44  		filepath.Join(dir, certDir, "server.pem"),
    45  		filepath.Join(dir, certDir, "server-key.pem"),
    46  		[]string{},
    47  	)
    48  	require.Nil(t, err)
    49  	server, addr := runServer(http.HandlerFunc(handler), serverTLS)
    50  	defer func() {
    51  		server.Close()
    52  	}()
    53  	credential := &security.Credential{
    54  		CAPath:        filepath.Join(dir, certDir, "ca.pem"),
    55  		CertPath:      filepath.Join(dir, certDir, "client.pem"),
    56  		KeyPath:       filepath.Join(dir, certDir, "client-key.pem"),
    57  		CertAllowedCN: []string{},
    58  	}
    59  	cli, err := NewClient(credential)
    60  	require.Nil(t, err)
    61  	url := fmt.Sprintf("https://%s/", addr)
    62  	resp, err := cli.Get(context.Background(), url)
    63  	require.Nil(t, err)
    64  	defer resp.Body.Close()
    65  	body, err := io.ReadAll(resp.Body)
    66  	require.Nil(t, err)
    67  	require.Equal(t, httputilServerMsg, string(body))
    68  }
    69  
    70  func TestStatusCodeCreated(t *testing.T) {
    71  	t.Parallel()
    72  
    73  	ctx, cancel := context.WithCancel(context.Background())
    74  
    75  	server, addr := runServer(http.HandlerFunc(createHandler), nil)
    76  	defer func() {
    77  		cancel()
    78  		server.Close()
    79  	}()
    80  	cli, err := NewClient(nil)
    81  	require.Nil(t, err)
    82  	url := fmt.Sprintf("http://%s/create", addr)
    83  	respBody, err := cli.DoRequest(ctx, url, http.MethodPost, nil, nil)
    84  	require.NoError(t, err)
    85  	require.Equal(t, []byte(`"{"id": "value"}"`), respBody)
    86  }
    87  
    88  func TestTimeout(t *testing.T) {
    89  	t.Parallel()
    90  
    91  	const timeout = 500 * time.Millisecond
    92  
    93  	server, addr := runServer(sleepHandler(time.Second), nil)
    94  	defer func() {
    95  		server.Close()
    96  	}()
    97  	cli, err := NewClient(nil)
    98  	require.NoError(t, err)
    99  
   100  	cli.SetTimeout(timeout)
   101  	start := time.Now()
   102  	resp, err := cli.Get(context.Background(), fmt.Sprintf("http://%s/", addr))
   103  	if resp != nil && resp.Body != nil {
   104  		require.NoError(t, resp.Body.Close())
   105  	}
   106  	var uErr *url.Error
   107  	require.ErrorAs(t, err, &uErr)
   108  	require.True(t, uErr.Timeout())
   109  	require.GreaterOrEqual(t, time.Since(start), timeout)
   110  }
   111  
   112  func handler(w http.ResponseWriter, req *http.Request) {
   113  	w.Header().Set("Content-Type", "text/plain")
   114  	//nolint:errcheck
   115  	w.Write([]byte(httputilServerMsg))
   116  }
   117  
   118  func createHandler(w http.ResponseWriter, req *http.Request) {
   119  	w.Header().Set("Content-Type", "application/json")
   120  	//nolint:errcheck
   121  	w.WriteHeader(http.StatusCreated)
   122  	w.Write([]byte(`"{"id": "value"}"`))
   123  }
   124  
   125  func sleepHandler(d time.Duration) http.HandlerFunc {
   126  	return func(w http.ResponseWriter, req *http.Request) {
   127  		select {
   128  		case <-time.After(d):
   129  		case <-req.Context().Done():
   130  		}
   131  	}
   132  }
   133  
   134  func runServer(handler http.Handler, tlsCfg *tls.Config) (*httptest.Server, string) {
   135  	server := httptest.NewUnstartedServer(handler)
   136  	addr := server.Listener.Addr().String()
   137  
   138  	if tlsCfg != nil {
   139  		server.TLS = tlsCfg
   140  		server.StartTLS()
   141  	} else {
   142  		server.Start()
   143  	}
   144  	return server, addr
   145  }