github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/cors/cors_test.go (about) 1 package cors_test 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 8 zls "github.com/sohaha/zlsgo" 9 "github.com/sohaha/zlsgo/znet" 10 "github.com/sohaha/zlsgo/znet/cors" 11 "github.com/sohaha/zlsgo/zstring" 12 ) 13 14 func TestNewAllowHeaders(t *testing.T) { 15 tt := zls.NewTest(t) 16 17 r := znet.New("TestNewAllowHeaders") 18 r.SetMode(znet.ProdMode) 19 20 addAllowHeader, h := cors.NewAllowHeaders() 21 r.Use(h) 22 r.GET("/TestNewAllowHeaders", func(c *znet.Context) { 23 c.Log.Debug("ok") 24 c.String(200, zstring.Rand(10, "abc")) 25 }) 26 addAllowHeader("AllowTest") 27 w := httptest.NewRecorder() 28 req, _ := http.NewRequest("OPTIONS", "/TestNewAllowHeaders", nil) 29 req.Header.Add("AllowTest", "https://qq.com") 30 req.Header.Add("Origin", "https://qq.com") 31 r.ServeHTTP(w, req) 32 tt.Equal(http.StatusNoContent, w.Code) 33 tt.Equal(0, w.Body.Len()) 34 35 w = httptest.NewRecorder() 36 req, _ = http.NewRequest("GET", "/TestNewAllowHeaders", nil) 37 req.Header.Add("AllowTest", "https://qq.com") 38 req.Header.Add("Origin", "https://qq.com") 39 r.ServeHTTP(w, req) 40 tt.Equal(http.StatusOK, w.Code) 41 tt.Equal(10, w.Body.Len()) 42 } 43 44 func TestDefault(t *testing.T) { 45 tt := zls.NewTest(t) 46 47 r := znet.New("TestDefault") 48 r.SetMode(znet.ProdMode) 49 50 r.Any("/cors", func(c *znet.Context) { 51 c.String(200, zstring.Rand(10, "abc")) 52 }, cors.Default()) 53 w := httptest.NewRecorder() 54 req, _ := http.NewRequest("OPTIONS", "/cors", nil) 55 req.Header.Add("Origin", "https://qq.com") 56 req.Host = "baidu.com" 57 r.ServeHTTP(w, req) 58 tt.Equal(http.StatusNoContent, w.Code) 59 tt.Equal(0, w.Body.Len()) 60 61 r.Any("/cors2", func(c *znet.Context) { 62 c.String(200, zstring.Rand(10, "abc")) 63 }, cors.New(&cors.Config{Domains: []string{"*://?q.com"}})) 64 w = httptest.NewRecorder() 65 req, _ = http.NewRequest("OPTIONS", "/cors2", nil) 66 req.Header.Add("Origin", "https://qq.com") 67 req.Host = "baidu.com" 68 r.ServeHTTP(w, req) 69 tt.Equal(http.StatusNoContent, w.Code) 70 tt.Equal(0, w.Body.Len()) 71 72 r.Any("/cors3", func(c *znet.Context) { 73 c.String(200, zstring.Rand(10, "abc")) 74 }, cors.New(&cors.Config{Domains: []string{"*://?q.com"}})) 75 w = httptest.NewRecorder() 76 req, _ = http.NewRequest("OPTIONS", "/cors3", nil) 77 req.Header.Add("Origin", "https://qa.com") 78 req.Host = "baidu.com" 79 r.ServeHTTP(w, req) 80 tt.Equal(http.StatusForbidden, w.Code) 81 tt.Equal(0, w.Body.Len()) 82 83 r.Any("/cors3", func(c *znet.Context) { 84 c.String(200, zstring.Rand(10, "abc")) 85 }, cors.New(&cors.Config{Domains: []string{"*://?q.com"}, CustomHandler: func(conf *cors.Config, c *znet.Context) { 86 c.Log.Debug(conf.Headers) 87 }})) 88 w = httptest.NewRecorder() 89 req, _ = http.NewRequest("OPTIONS", "/cors3", nil) 90 req.Header.Add("Origin", "https://qa.com") 91 req.Host = "baidu.com" 92 r.ServeHTTP(w, req) 93 tt.Equal(http.StatusForbidden, w.Code) 94 tt.Equal(0, w.Body.Len()) 95 96 }