github.com/astaxie/beego@v1.12.3/plugins/cors/cors_test.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cors
    16  
    17  import (
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/astaxie/beego"
    25  	"github.com/astaxie/beego/context"
    26  )
    27  
    28  // HTTPHeaderGuardRecorder is httptest.ResponseRecorder with own http.Header
    29  type HTTPHeaderGuardRecorder struct {
    30  	*httptest.ResponseRecorder
    31  	savedHeaderMap http.Header
    32  }
    33  
    34  // NewRecorder return HttpHeaderGuardRecorder
    35  func NewRecorder() *HTTPHeaderGuardRecorder {
    36  	return &HTTPHeaderGuardRecorder{httptest.NewRecorder(), nil}
    37  }
    38  
    39  func (gr *HTTPHeaderGuardRecorder) WriteHeader(code int) {
    40  	gr.ResponseRecorder.WriteHeader(code)
    41  	gr.savedHeaderMap = gr.ResponseRecorder.Header()
    42  }
    43  
    44  func (gr *HTTPHeaderGuardRecorder) Header() http.Header {
    45  	if gr.savedHeaderMap != nil {
    46  		// headers were written. clone so we don't get updates
    47  		clone := make(http.Header)
    48  		for k, v := range gr.savedHeaderMap {
    49  			clone[k] = v
    50  		}
    51  		return clone
    52  	}
    53  	return gr.ResponseRecorder.Header()
    54  }
    55  
    56  func Test_AllowAll(t *testing.T) {
    57  	recorder := httptest.NewRecorder()
    58  	handler := beego.NewControllerRegister()
    59  	handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
    60  		AllowAllOrigins: true,
    61  	}))
    62  	handler.Any("/foo", func(ctx *context.Context) {
    63  		ctx.Output.SetStatus(500)
    64  	})
    65  	r, _ := http.NewRequest("PUT", "/foo", nil)
    66  	handler.ServeHTTP(recorder, r)
    67  
    68  	if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
    69  		t.Errorf("Allow-Origin header should be *")
    70  	}
    71  }
    72  
    73  func Test_AllowRegexMatch(t *testing.T) {
    74  	recorder := httptest.NewRecorder()
    75  	handler := beego.NewControllerRegister()
    76  	handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
    77  		AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"},
    78  	}))
    79  	handler.Any("/foo", func(ctx *context.Context) {
    80  		ctx.Output.SetStatus(500)
    81  	})
    82  	origin := "https://bar.foo.com"
    83  	r, _ := http.NewRequest("PUT", "/foo", nil)
    84  	r.Header.Add("Origin", origin)
    85  	handler.ServeHTTP(recorder, r)
    86  
    87  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
    88  	if headerValue != origin {
    89  		t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
    90  	}
    91  }
    92  
    93  func Test_AllowRegexNoMatch(t *testing.T) {
    94  	recorder := httptest.NewRecorder()
    95  	handler := beego.NewControllerRegister()
    96  	handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
    97  		AllowOrigins: []string{"https://*.foo.com"},
    98  	}))
    99  	handler.Any("/foo", func(ctx *context.Context) {
   100  		ctx.Output.SetStatus(500)
   101  	})
   102  	origin := "https://ww.foo.com.evil.com"
   103  	r, _ := http.NewRequest("PUT", "/foo", nil)
   104  	r.Header.Add("Origin", origin)
   105  	handler.ServeHTTP(recorder, r)
   106  
   107  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
   108  	if headerValue != "" {
   109  		t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
   110  	}
   111  }
   112  
   113  func Test_OtherHeaders(t *testing.T) {
   114  	recorder := httptest.NewRecorder()
   115  	handler := beego.NewControllerRegister()
   116  	handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
   117  		AllowAllOrigins:  true,
   118  		AllowCredentials: true,
   119  		AllowMethods:     []string{"PATCH", "GET"},
   120  		AllowHeaders:     []string{"Origin", "X-whatever"},
   121  		ExposeHeaders:    []string{"Content-Length", "Hello"},
   122  		MaxAge:           5 * time.Minute,
   123  	}))
   124  	handler.Any("/foo", func(ctx *context.Context) {
   125  		ctx.Output.SetStatus(500)
   126  	})
   127  	r, _ := http.NewRequest("PUT", "/foo", nil)
   128  	handler.ServeHTTP(recorder, r)
   129  
   130  	credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
   131  	methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
   132  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   133  	exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
   134  	maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
   135  
   136  	if credentialsVal != "true" {
   137  		t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
   138  	}
   139  
   140  	if methodsVal != "PATCH,GET" {
   141  		t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
   142  	}
   143  
   144  	if headersVal != "Origin,X-whatever" {
   145  		t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
   146  	}
   147  
   148  	if exposedHeadersVal != "Content-Length,Hello" {
   149  		t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
   150  	}
   151  
   152  	if maxAgeVal != "300" {
   153  		t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
   154  	}
   155  }
   156  
   157  func Test_DefaultAllowHeaders(t *testing.T) {
   158  	recorder := httptest.NewRecorder()
   159  	handler := beego.NewControllerRegister()
   160  	handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
   161  		AllowAllOrigins: true,
   162  	}))
   163  	handler.Any("/foo", func(ctx *context.Context) {
   164  		ctx.Output.SetStatus(500)
   165  	})
   166  
   167  	r, _ := http.NewRequest("PUT", "/foo", nil)
   168  	handler.ServeHTTP(recorder, r)
   169  
   170  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   171  	if headersVal != "Origin,Accept,Content-Type,Authorization" {
   172  		t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal)
   173  	}
   174  }
   175  
   176  func Test_Preflight(t *testing.T) {
   177  	recorder := NewRecorder()
   178  	handler := beego.NewControllerRegister()
   179  	handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
   180  		AllowAllOrigins: true,
   181  		AllowMethods:    []string{"PUT", "PATCH"},
   182  		AllowHeaders:    []string{"Origin", "X-whatever", "X-CaseSensitive"},
   183  	}))
   184  
   185  	handler.Any("/foo", func(ctx *context.Context) {
   186  		ctx.Output.SetStatus(200)
   187  	})
   188  
   189  	r, _ := http.NewRequest("OPTIONS", "/foo", nil)
   190  	r.Header.Add(headerRequestMethod, "PUT")
   191  	r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive")
   192  	handler.ServeHTTP(recorder, r)
   193  
   194  	headers := recorder.Header()
   195  	methodsVal := headers.Get(headerAllowMethods)
   196  	headersVal := headers.Get(headerAllowHeaders)
   197  	originVal := headers.Get(headerAllowOrigin)
   198  
   199  	if methodsVal != "PUT,PATCH" {
   200  		t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
   201  	}
   202  
   203  	if !strings.Contains(headersVal, "X-whatever") {
   204  		t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal)
   205  	}
   206  
   207  	if !strings.Contains(headersVal, "x-casesensitive") {
   208  		t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal)
   209  	}
   210  
   211  	if originVal != "*" {
   212  		t.Errorf("Allow-Origin is expected to be *, found %v", originVal)
   213  	}
   214  
   215  	if recorder.Code != http.StatusOK {
   216  		t.Errorf("Status code is expected to be 200, found %d", recorder.Code)
   217  	}
   218  }
   219  
   220  func Benchmark_WithoutCORS(b *testing.B) {
   221  	recorder := httptest.NewRecorder()
   222  	handler := beego.NewControllerRegister()
   223  	beego.BConfig.RunMode = beego.PROD
   224  	handler.Any("/foo", func(ctx *context.Context) {
   225  		ctx.Output.SetStatus(500)
   226  	})
   227  	b.ResetTimer()
   228  	r, _ := http.NewRequest("PUT", "/foo", nil)
   229  	for i := 0; i < b.N; i++ {
   230  		handler.ServeHTTP(recorder, r)
   231  	}
   232  }
   233  
   234  func Benchmark_WithCORS(b *testing.B) {
   235  	recorder := httptest.NewRecorder()
   236  	handler := beego.NewControllerRegister()
   237  	beego.BConfig.RunMode = beego.PROD
   238  	handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
   239  		AllowAllOrigins:  true,
   240  		AllowCredentials: true,
   241  		AllowMethods:     []string{"PATCH", "GET"},
   242  		AllowHeaders:     []string{"Origin", "X-whatever"},
   243  		MaxAge:           5 * time.Minute,
   244  	}))
   245  	handler.Any("/foo", func(ctx *context.Context) {
   246  		ctx.Output.SetStatus(500)
   247  	})
   248  	b.ResetTimer()
   249  	r, _ := http.NewRequest("PUT", "/foo", nil)
   250  	for i := 0; i < b.N; i++ {
   251  		handler.ServeHTTP(recorder, r)
   252  	}
   253  }