github.com/lingyao2333/mo-zero@v1.4.1/rest/internal/cors/handlers_test.go (about) 1 package cors 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 ) 10 11 func TestCorsHandlerWithOrigins(t *testing.T) { 12 tests := []struct { 13 name string 14 origins []string 15 reqOrigin string 16 expect string 17 }{ 18 { 19 name: "allow all origins", 20 expect: allOrigins, 21 }, 22 { 23 name: "allow one origin", 24 origins: []string{"http://local"}, 25 reqOrigin: "http://local", 26 expect: "http://local", 27 }, 28 { 29 name: "allow many origins", 30 origins: []string{"http://local", "http://remote"}, 31 reqOrigin: "http://local", 32 expect: "http://local", 33 }, 34 { 35 name: "allow sub origins", 36 origins: []string{"local", "remote"}, 37 reqOrigin: "sub.local", 38 expect: "sub.local", 39 }, 40 { 41 name: "allow all origins", 42 reqOrigin: "http://local", 43 expect: "*", 44 }, 45 { 46 name: "allow many origins with all mark", 47 origins: []string{"http://local", "http://remote", "*"}, 48 reqOrigin: "http://another", 49 expect: "http://another", 50 }, 51 { 52 name: "not allow origin", 53 origins: []string{"http://local", "http://remote"}, 54 reqOrigin: "http://another", 55 }, 56 } 57 58 methods := []string{ 59 http.MethodOptions, 60 http.MethodGet, 61 http.MethodPost, 62 } 63 64 for _, test := range tests { 65 for _, method := range methods { 66 test := test 67 t.Run(test.name+"-handler", func(t *testing.T) { 68 r := httptest.NewRequest(method, "http://localhost", http.NoBody) 69 r.Header.Set(originHeader, test.reqOrigin) 70 w := httptest.NewRecorder() 71 handler := NotAllowedHandler(nil, test.origins...) 72 handler.ServeHTTP(w, r) 73 if method == http.MethodOptions { 74 assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) 75 } else { 76 assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) 77 } 78 assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) 79 }) 80 t.Run(test.name+"-handler-custom", func(t *testing.T) { 81 r := httptest.NewRequest(method, "http://localhost", http.NoBody) 82 r.Header.Set(originHeader, test.reqOrigin) 83 w := httptest.NewRecorder() 84 handler := NotAllowedHandler(func(w http.ResponseWriter) { 85 w.Header().Set("foo", "bar") 86 }, test.origins...) 87 handler.ServeHTTP(w, r) 88 if method == http.MethodOptions { 89 assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) 90 } else { 91 assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) 92 } 93 assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) 94 assert.Equal(t, "bar", w.Header().Get("foo")) 95 }) 96 } 97 } 98 99 for _, test := range tests { 100 for _, method := range methods { 101 test := test 102 t.Run(test.name+"-middleware", func(t *testing.T) { 103 r := httptest.NewRequest(method, "http://localhost", http.NoBody) 104 r.Header.Set(originHeader, test.reqOrigin) 105 w := httptest.NewRecorder() 106 handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) { 107 w.WriteHeader(http.StatusOK) 108 }) 109 handler.ServeHTTP(w, r) 110 if method == http.MethodOptions { 111 assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) 112 } else { 113 assert.Equal(t, http.StatusOK, w.Result().StatusCode) 114 } 115 assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) 116 }) 117 t.Run(test.name+"-middleware-custom", func(t *testing.T) { 118 r := httptest.NewRequest(method, "http://localhost", http.NoBody) 119 r.Header.Set(originHeader, test.reqOrigin) 120 w := httptest.NewRecorder() 121 handler := Middleware(func(header http.Header) { 122 header.Set("foo", "bar") 123 }, test.origins...)(func(w http.ResponseWriter, r *http.Request) { 124 w.WriteHeader(http.StatusOK) 125 }) 126 handler.ServeHTTP(w, r) 127 if method == http.MethodOptions { 128 assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) 129 } else { 130 assert.Equal(t, http.StatusOK, w.Result().StatusCode) 131 } 132 assert.Equal(t, test.expect, w.Header().Get(allowOrigin)) 133 assert.Equal(t, "bar", w.Header().Get("foo")) 134 }) 135 } 136 } 137 }