github.com/GuanceCloud/cliutils@v1.1.21/network/http/gin_test.go (about)

     1  // Unless explicitly stated otherwise all files in this repository are licensed
     2  // under the MIT License.
     3  // This product includes software developed at Guance Cloud (https://www.guance.com/).
     4  // Copyright 2021-present Guance, Inc.
     5  
     6  package http
     7  
     8  import (
     9  	"io"
    10  	"net/http"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/gin-gonic/gin"
    16  	"github.com/stretchr/testify/assert"
    17  )
    18  
    19  func BenchmarkAllMiddlewares(b *testing.B) {
    20  	cases := []struct {
    21  		name string
    22  		ms   []gin.HandlerFunc
    23  	}{
    24  		{
    25  			name: "none",
    26  			ms:   []gin.HandlerFunc{},
    27  		},
    28  		{
    29  			name: "all",
    30  			ms: []gin.HandlerFunc{
    31  				CORSMiddleware, TraceIDMiddleware, RequestLoggerMiddleware,
    32  			},
    33  		},
    34  		{
    35  			name: "cors-0",
    36  			ms: []gin.HandlerFunc{
    37  				CORSMiddlewareV2([]string{}),
    38  			},
    39  		},
    40  		{
    41  			name: "cors-1",
    42  			ms: []gin.HandlerFunc{
    43  				CORSMiddlewareV2([]string{"http://foobar.com"}),
    44  			},
    45  		},
    46  		{
    47  			name: "cors-2",
    48  			ms: []gin.HandlerFunc{
    49  				CORSMiddlewareV2([]string{"www.baidu.com"}),
    50  			},
    51  		},
    52  		{
    53  			name: "trace-id",
    54  			ms: []gin.HandlerFunc{
    55  				TraceIDMiddleware,
    56  			},
    57  		},
    58  		{
    59  			name: "request-logger",
    60  			ms: []gin.HandlerFunc{
    61  				RequestLoggerMiddleware,
    62  			},
    63  		},
    64  	}
    65  
    66  	for _, bc := range cases {
    67  		b.Run(bc.name, func(b *testing.B) {
    68  			r := gin.New()
    69  
    70  			for _, m := range bc.ms {
    71  				r.Use(m)
    72  			}
    73  
    74  			r.Use(gin.LoggerWithConfig(gin.LoggerConfig{
    75  				Formatter: GinLogFormatter,
    76  			}))
    77  
    78  			v1 := r.Group("/v1")
    79  			v1.GET("/get", func(c *gin.Context) { c.Data(400, "application/json", []byte(`{"error": "get-error"}`)) })
    80  
    81  			srv := &http.Server{
    82  				Addr:    `localhost:1234`,
    83  				Handler: r,
    84  			}
    85  
    86  			go func() {
    87  				if err := srv.ListenAndServe(); err != nil {
    88  					b.Log(err)
    89  				}
    90  			}()
    91  
    92  			time.Sleep(time.Second)
    93  
    94  			for i := 0; i < b.N; i++ {
    95  				if !strings.Contains(bc.name, "cors") {
    96  					resp, err := http.Get("http://localhost:1234/v1/get")
    97  					if err != nil {
    98  						b.Logf("get error: %s, ignored", err)
    99  					}
   100  
   101  					if resp.Body != nil {
   102  						io.Copy(io.Discard, resp.Body)
   103  						resp.Body.Close()
   104  					}
   105  				} else {
   106  					req, err := http.NewRequest("GET", "http://localhost:1234/v1/get", nil)
   107  					if err != nil {
   108  						b.Error(err)
   109  					}
   110  					origin := "http://foobar.com"
   111  					req.Header.Set("Origin", origin)
   112  					c := &http.Client{}
   113  					resp, err := c.Do(req)
   114  					if err != nil {
   115  						b.Error(err)
   116  					}
   117  					defer resp.Body.Close()
   118  					got := resp.Header.Get("Access-Control-Allow-Origin")
   119  					if bc.name == "cors-2" {
   120  						origin = ""
   121  					}
   122  					assert.Equal(b, origin, got, "expect %s, got '%s'", origin, got)
   123  				}
   124  			}
   125  			srv.Close()
   126  		})
   127  	}
   128  }
   129  
   130  func TestCORSHeaders_Add(t *testing.T) {
   131  	// Accept, Accept-Encoding, Accept-Language, Authorization, Cache-Control, Content-Language, Content-Length, Content-Type, Origin, X-Csrf-Token, X-Datakit-Uuid, X-Lua, X-Precision, X-Requested-With, X-Rp, X-Token, *
   132  	defaultHeaders := defaultCORSHeader.String()
   133  
   134  	h1 := defaultCORSHeader.Add("content-type  , X-PRECISION")
   135  	assert.Equal(t, defaultHeaders, h1)
   136  
   137  	h2 := defaultCORSHeader.Add("  ")
   138  	assert.Equal(t, defaultHeaders, h2)
   139  
   140  	h3 := defaultCORSHeader.Add("x-Foo ,cache-control , X-BAR")
   141  	assert.Equal(t, "X-Foo, X-Bar, "+defaultHeaders, h3)
   142  
   143  	h4 := defaultCORSHeader.Add(" * ")
   144  	assert.Equal(t, defaultHeaders, h4)
   145  
   146  	h5 := defaultCORSHeader.Add("x-forwarded-for ,x-real-ip , x-client-ip")
   147  	assert.Equal(t, "X-Forwarded-For, X-Real-Ip, X-Client-Ip, "+defaultHeaders, h5)
   148  }
   149  
   150  func TestMiddlewares(t *testing.T) {
   151  	r := gin.New()
   152  
   153  	r.Use(CORSMiddleware)
   154  	r.Use(TraceIDMiddleware)
   155  	r.Use(RequestLoggerMiddleware)
   156  	r.Use(gin.LoggerWithConfig(gin.LoggerConfig{
   157  		Formatter: GinLogFormatter,
   158  	}))
   159  
   160  	t.Setenv("MAX_REQUEST_BODY_LEN", "4")
   161  	Init()
   162  
   163  	v1 := r.Group("/v1")
   164  	v1.GET("/get", func(c *gin.Context) { c.Data(400, "application/json", []byte(`{"error": "get-error"}`)) })
   165  	v1.GET("/get500", func(c *gin.Context) { c.Data(500, "application/json", []byte(`{"error": "get-error"}`)) })
   166  	v1.POST("/post", func(c *gin.Context) { c.Data(400, "application/json", []byte(`{"error": "post-error"}`)) })
   167  	v1.GET("/getok", func(c *gin.Context) { c.Data(200, "application/json", []byte(`{"get": "ok"}`)) })
   168  	v1.POST("/postok", func(c *gin.Context) { c.Data(200, "application/json", []byte(`{"post": "ok"}`)) })
   169  
   170  	srv := &http.Server{
   171  		Addr:    `localhost:1234`,
   172  		Handler: r,
   173  	}
   174  
   175  	go func() {
   176  		if err := srv.ListenAndServe(); err != nil {
   177  			t.Log(err)
   178  		}
   179  	}()
   180  
   181  	defer srv.Close()
   182  
   183  	time.Sleep(time.Second)
   184  
   185  	resp, err := http.Get("http://localhost:1234/v1/get")
   186  	if err != nil {
   187  		t.Logf("get error: %s, ignored", err)
   188  	}
   189  
   190  	if resp.Body != nil {
   191  		io.Copy(io.Discard, resp.Body)
   192  		resp.Body.Close()
   193  	}
   194  
   195  	resp, err = http.Get("http://localhost:1234/v1/get500")
   196  	if err != nil {
   197  		t.Logf("get error: %s, ignored", err)
   198  	}
   199  
   200  	if resp.Body != nil {
   201  		io.Copy(io.Discard, resp.Body)
   202  		resp.Body.Close()
   203  	}
   204  
   205  	resp, err = http.Post("http://localhost:1234/v1/post", "", nil)
   206  	if err != nil {
   207  		t.Logf("get error: %s, ignored", err)
   208  	}
   209  
   210  	if resp.Body != nil {
   211  		io.Copy(io.Discard, resp.Body)
   212  		resp.Body.Close()
   213  	}
   214  
   215  	resp, err = http.Get("http://localhost:1234/v1/getok")
   216  	if err != nil {
   217  		t.Logf("get error: %s, ignored", err)
   218  	}
   219  
   220  	if resp.Body != nil {
   221  		io.Copy(io.Discard, resp.Body)
   222  		resp.Body.Close()
   223  	}
   224  
   225  	resp, err = http.Post("http://localhost:1234/v1/postok", "", nil)
   226  	if err != nil {
   227  		t.Logf("get error: %s, ignored", err)
   228  	}
   229  
   230  	if resp.Body != nil {
   231  		io.Copy(io.Discard, resp.Body)
   232  		resp.Body.Close()
   233  	}
   234  }