github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/cors/cors_test.go (about)

     1  package cors_test
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  
     8  	zls "github.com/sohaha/zlsgo"
     9  	"github.com/sohaha/zlsgo/znet"
    10  	"github.com/sohaha/zlsgo/znet/cors"
    11  	"github.com/sohaha/zlsgo/zstring"
    12  )
    13  
    14  func TestNewAllowHeaders(t *testing.T) {
    15  	tt := zls.NewTest(t)
    16  
    17  	r := znet.New("TestNewAllowHeaders")
    18  	r.SetMode(znet.ProdMode)
    19  
    20  	addAllowHeader, h := cors.NewAllowHeaders()
    21  	r.Use(h)
    22  	r.GET("/TestNewAllowHeaders", func(c *znet.Context) {
    23  		c.Log.Debug("ok")
    24  		c.String(200, zstring.Rand(10, "abc"))
    25  	})
    26  	addAllowHeader("AllowTest")
    27  	w := httptest.NewRecorder()
    28  	req, _ := http.NewRequest("OPTIONS", "/TestNewAllowHeaders", nil)
    29  	req.Header.Add("AllowTest", "https://qq.com")
    30  	req.Header.Add("Origin", "https://qq.com")
    31  	r.ServeHTTP(w, req)
    32  	tt.Equal(http.StatusNoContent, w.Code)
    33  	tt.Equal(0, w.Body.Len())
    34  
    35  	w = httptest.NewRecorder()
    36  	req, _ = http.NewRequest("GET", "/TestNewAllowHeaders", nil)
    37  	req.Header.Add("AllowTest", "https://qq.com")
    38  	req.Header.Add("Origin", "https://qq.com")
    39  	r.ServeHTTP(w, req)
    40  	tt.Equal(http.StatusOK, w.Code)
    41  	tt.Equal(10, w.Body.Len())
    42  }
    43  
    44  func TestDefault(t *testing.T) {
    45  	tt := zls.NewTest(t)
    46  
    47  	r := znet.New("TestDefault")
    48  	r.SetMode(znet.ProdMode)
    49  
    50  	r.Any("/cors", func(c *znet.Context) {
    51  		c.String(200, zstring.Rand(10, "abc"))
    52  	}, cors.Default())
    53  	w := httptest.NewRecorder()
    54  	req, _ := http.NewRequest("OPTIONS", "/cors", nil)
    55  	req.Header.Add("Origin", "https://qq.com")
    56  	req.Host = "baidu.com"
    57  	r.ServeHTTP(w, req)
    58  	tt.Equal(http.StatusNoContent, w.Code)
    59  	tt.Equal(0, w.Body.Len())
    60  
    61  	r.Any("/cors2", func(c *znet.Context) {
    62  		c.String(200, zstring.Rand(10, "abc"))
    63  	}, cors.New(&cors.Config{Domains: []string{"*://?q.com"}}))
    64  	w = httptest.NewRecorder()
    65  	req, _ = http.NewRequest("OPTIONS", "/cors2", nil)
    66  	req.Header.Add("Origin", "https://qq.com")
    67  	req.Host = "baidu.com"
    68  	r.ServeHTTP(w, req)
    69  	tt.Equal(http.StatusNoContent, w.Code)
    70  	tt.Equal(0, w.Body.Len())
    71  
    72  	r.Any("/cors3", func(c *znet.Context) {
    73  		c.String(200, zstring.Rand(10, "abc"))
    74  	}, cors.New(&cors.Config{Domains: []string{"*://?q.com"}}))
    75  	w = httptest.NewRecorder()
    76  	req, _ = http.NewRequest("OPTIONS", "/cors3", nil)
    77  	req.Header.Add("Origin", "https://qa.com")
    78  	req.Host = "baidu.com"
    79  	r.ServeHTTP(w, req)
    80  	tt.Equal(http.StatusForbidden, w.Code)
    81  	tt.Equal(0, w.Body.Len())
    82  
    83  	r.Any("/cors3", func(c *znet.Context) {
    84  		c.String(200, zstring.Rand(10, "abc"))
    85  	}, cors.New(&cors.Config{Domains: []string{"*://?q.com"}, CustomHandler: func(conf *cors.Config, c *znet.Context) {
    86  		c.Log.Debug(conf.Headers)
    87  	}}))
    88  	w = httptest.NewRecorder()
    89  	req, _ = http.NewRequest("OPTIONS", "/cors3", nil)
    90  	req.Header.Add("Origin", "https://qa.com")
    91  	req.Host = "baidu.com"
    92  	r.ServeHTTP(w, req)
    93  	tt.Equal(http.StatusForbidden, w.Code)
    94  	tt.Equal(0, w.Body.Len())
    95  
    96  }