github.com/astaxie/beego@v1.12.3/plugins/cors/cors_test.go (about) 1 // Copyright 2014 beego Author. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package cors 16 17 import ( 18 "net/http" 19 "net/http/httptest" 20 "strings" 21 "testing" 22 "time" 23 24 "github.com/astaxie/beego" 25 "github.com/astaxie/beego/context" 26 ) 27 28 // HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header 29 type HTTPHeaderGuardRecorder struct { 30 *httptest.ResponseRecorder 31 savedHeaderMap http.Header 32 } 33 34 // NewRecorder return HttpHeaderGuardRecorder 35 func NewRecorder() *HTTPHeaderGuardRecorder { 36 return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil} 37 } 38 39 func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) { 40 gr.ResponseRecorder.WriteHeader(code) 41 gr.savedHeaderMap = gr.ResponseRecorder.Header() 42 } 43 44 func (gr *HTTPHeaderGuardRecorder) Header() http.Header { 45 if gr.savedHeaderMap != nil { 46 // headers were written. clone so we don't get updates 47 clone := make(http.Header) 48 for k, v := range gr.savedHeaderMap { 49 clone[k] = v 50 } 51 return clone 52 } 53 return gr.ResponseRecorder.Header() 54 } 55 56 func Test_AllowAll(t *testing.T) { 57 recorder := httptest.NewRecorder() 58 handler := beego.NewControllerRegister() 59 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ 60 AllowAllOrigins: true, 61 })) 62 handler.Any("/foo", func(ctx *context.Context) { 63 ctx.Output.SetStatus(500) 64 }) 65 r, _ := http.NewRequest("PUT", "/foo", nil) 66 handler.ServeHTTP(recorder, r) 67 68 if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { 69 t.Errorf("Allow-Origin header should be *") 70 } 71 } 72 73 func Test_AllowRegexMatch(t *testing.T) { 74 recorder := httptest.NewRecorder() 75 handler := beego.NewControllerRegister() 76 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ 77 AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, 78 })) 79 handler.Any("/foo", func(ctx *context.Context) { 80 ctx.Output.SetStatus(500) 81 }) 82 origin := "https://bar.foo.com" 83 r, _ := http.NewRequest("PUT", "/foo", nil) 84 r.Header.Add("Origin", origin) 85 handler.ServeHTTP(recorder, r) 86 87 headerValue := recorder.HeaderMap.Get(headerAllowOrigin) 88 if headerValue != origin { 89 t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) 90 } 91 } 92 93 func Test_AllowRegexNoMatch(t *testing.T) { 94 recorder := httptest.NewRecorder() 95 handler := beego.NewControllerRegister() 96 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ 97 AllowOrigins: []string{"https://*.foo.com"}, 98 })) 99 handler.Any("/foo", func(ctx *context.Context) { 100 ctx.Output.SetStatus(500) 101 }) 102 origin := "https://ww.foo.com.evil.com" 103 r, _ := http.NewRequest("PUT", "/foo", nil) 104 r.Header.Add("Origin", origin) 105 handler.ServeHTTP(recorder, r) 106 107 headerValue := recorder.HeaderMap.Get(headerAllowOrigin) 108 if headerValue != "" { 109 t.Errorf("Allow-Origin header should not exist, found %v", headerValue) 110 } 111 } 112 113 func Test_OtherHeaders(t *testing.T) { 114 recorder := httptest.NewRecorder() 115 handler := beego.NewControllerRegister() 116 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ 117 AllowAllOrigins: true, 118 AllowCredentials: true, 119 AllowMethods: []string{"PATCH", "GET"}, 120 AllowHeaders: []string{"Origin", "X-whatever"}, 121 ExposeHeaders: []string{"Content-Length", "Hello"}, 122 MaxAge: 5 * time.Minute, 123 })) 124 handler.Any("/foo", func(ctx *context.Context) { 125 ctx.Output.SetStatus(500) 126 }) 127 r, _ := http.NewRequest("PUT", "/foo", nil) 128 handler.ServeHTTP(recorder, r) 129 130 credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) 131 methodsVal := recorder.HeaderMap.Get(headerAllowMethods) 132 headersVal := recorder.HeaderMap.Get(headerAllowHeaders) 133 exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) 134 maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) 135 136 if credentialsVal != "true" { 137 t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) 138 } 139 140 if methodsVal != "PATCH,GET" { 141 t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) 142 } 143 144 if headersVal != "Origin,X-whatever" { 145 t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) 146 } 147 148 if exposedHeadersVal != "Content-Length,Hello" { 149 t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) 150 } 151 152 if maxAgeVal != "300" { 153 t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) 154 } 155 } 156 157 func Test_DefaultAllowHeaders(t *testing.T) { 158 recorder := httptest.NewRecorder() 159 handler := beego.NewControllerRegister() 160 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ 161 AllowAllOrigins: true, 162 })) 163 handler.Any("/foo", func(ctx *context.Context) { 164 ctx.Output.SetStatus(500) 165 }) 166 167 r, _ := http.NewRequest("PUT", "/foo", nil) 168 handler.ServeHTTP(recorder, r) 169 170 headersVal := recorder.HeaderMap.Get(headerAllowHeaders) 171 if headersVal != "Origin,Accept,Content-Type,Authorization" { 172 t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) 173 } 174 } 175 176 func Test_Preflight(t *testing.T) { 177 recorder := NewRecorder() 178 handler := beego.NewControllerRegister() 179 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ 180 AllowAllOrigins: true, 181 AllowMethods: []string{"PUT", "PATCH"}, 182 AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, 183 })) 184 185 handler.Any("/foo", func(ctx *context.Context) { 186 ctx.Output.SetStatus(200) 187 }) 188 189 r, _ := http.NewRequest("OPTIONS", "/foo", nil) 190 r.Header.Add(headerRequestMethod, "PUT") 191 r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") 192 handler.ServeHTTP(recorder, r) 193 194 headers := recorder.Header() 195 methodsVal := headers.Get(headerAllowMethods) 196 headersVal := headers.Get(headerAllowHeaders) 197 originVal := headers.Get(headerAllowOrigin) 198 199 if methodsVal != "PUT,PATCH" { 200 t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) 201 } 202 203 if !strings.Contains(headersVal, "X-whatever") { 204 t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) 205 } 206 207 if !strings.Contains(headersVal, "x-casesensitive") { 208 t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) 209 } 210 211 if originVal != "*" { 212 t.Errorf("Allow-Origin is expected to be *, found %v", originVal) 213 } 214 215 if recorder.Code != http.StatusOK { 216 t.Errorf("Status code is expected to be 200, found %d", recorder.Code) 217 } 218 } 219 220 func Benchmark_WithoutCORS(b *testing.B) { 221 recorder := httptest.NewRecorder() 222 handler := beego.NewControllerRegister() 223 beego.BConfig.RunMode = beego.PROD 224 handler.Any("/foo", func(ctx *context.Context) { 225 ctx.Output.SetStatus(500) 226 }) 227 b.ResetTimer() 228 r, _ := http.NewRequest("PUT", "/foo", nil) 229 for i := 0; i < b.N; i++ { 230 handler.ServeHTTP(recorder, r) 231 } 232 } 233 234 func Benchmark_WithCORS(b *testing.B) { 235 recorder := httptest.NewRecorder() 236 handler := beego.NewControllerRegister() 237 beego.BConfig.RunMode = beego.PROD 238 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ 239 AllowAllOrigins: true, 240 AllowCredentials: true, 241 AllowMethods: []string{"PATCH", "GET"}, 242 AllowHeaders: []string{"Origin", "X-whatever"}, 243 MaxAge: 5 * time.Minute, 244 })) 245 handler.Any("/foo", func(ctx *context.Context) { 246 ctx.Output.SetStatus(500) 247 }) 248 b.ResetTimer() 249 r, _ := http.NewRequest("PUT", "/foo", nil) 250 for i := 0; i < b.N; i++ { 251 handler.ServeHTTP(recorder, r) 252 } 253 }