github.com/zimmski/negroni-cors@v0.0.0-20210610102725-61d36ce1db64/cors_test.go (about)

     1  // Copyright 2014 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cors
    16  
    17  import (
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/urfave/negroni"
    24  )
    25  
    26  type HttpHeaderGuardRecorder struct {
    27  	*httptest.ResponseRecorder
    28  	savedHeaderMap http.Header
    29  }
    30  
    31  func NewRecorder() *HttpHeaderGuardRecorder {
    32  	return &HttpHeaderGuardRecorder{httptest.NewRecorder(), nil}
    33  }
    34  
    35  func (gr *HttpHeaderGuardRecorder) WriteHeader(code int) {
    36  	gr.ResponseRecorder.WriteHeader(code)
    37  	gr.savedHeaderMap = gr.ResponseRecorder.Header()
    38  }
    39  
    40  func (gr *HttpHeaderGuardRecorder) Header() http.Header {
    41  	if gr.savedHeaderMap != nil {
    42  		// headers were written. clone so we don't get updates
    43  		clone := make(http.Header)
    44  		for k, v := range gr.savedHeaderMap {
    45  			clone[k] = v
    46  		}
    47  		return clone
    48  	} else {
    49  		return gr.ResponseRecorder.Header()
    50  	}
    51  }
    52  
    53  func Test_AllowAll(t *testing.T) {
    54  	recorder := httptest.NewRecorder()
    55  	n := negroni.New()
    56  	n.Use(NewAllow(&Options{
    57  		AllowAllOrigins: true,
    58  	}))
    59  
    60  	r, _ := http.NewRequest("PUT", "foo", nil)
    61  	n.ServeHTTP(recorder, r)
    62  
    63  	if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
    64  		t.Errorf("Allow-Origin header should be *")
    65  	}
    66  }
    67  
    68  func Test_AllowRegexMatch(t *testing.T) {
    69  	recorder := httptest.NewRecorder()
    70  	n := negroni.New()
    71  	n.Use(NewAllow(&Options{
    72  		AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"},
    73  	}))
    74  
    75  	origin := "https://bar.foo.com"
    76  	r, _ := http.NewRequest("PUT", "foo", nil)
    77  	r.Header.Add("Origin", origin)
    78  	n.ServeHTTP(recorder, r)
    79  
    80  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
    81  	if headerValue != origin {
    82  		t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
    83  	}
    84  }
    85  
    86  func Test_AllowRegexNoMatch(t *testing.T) {
    87  	recorder := httptest.NewRecorder()
    88  	n := negroni.New()
    89  	n.Use(NewAllow(&Options{
    90  		AllowOrigins: []string{"https://*.foo.com"},
    91  	}))
    92  
    93  	origin := "https://ww.foo.com.evil.com"
    94  	r, _ := http.NewRequest("PUT", "foo", nil)
    95  	r.Header.Add("Origin", origin)
    96  	n.ServeHTTP(recorder, r)
    97  
    98  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
    99  	if headerValue != "" {
   100  		t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
   101  	}
   102  }
   103  
   104  func Test_OtherHeaders(t *testing.T) {
   105  	recorder := httptest.NewRecorder()
   106  	n := negroni.New()
   107  	n.Use(NewAllow(&Options{
   108  		AllowAllOrigins:  true,
   109  		AllowCredentials: true,
   110  		AllowMethods:     []string{"PATCH", "GET"},
   111  		AllowHeaders:     []string{"Origin", "X-whatever"},
   112  		ExposeHeaders:    []string{"Content-Length", "Hello"},
   113  		MaxAge:           5 * time.Minute,
   114  	}))
   115  
   116  	r, _ := http.NewRequest("PUT", "foo", nil)
   117  	n.ServeHTTP(recorder, r)
   118  
   119  	credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
   120  	methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
   121  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   122  	exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
   123  	maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
   124  
   125  	if credentialsVal != "true" {
   126  		t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
   127  	}
   128  
   129  	if methodsVal != "PATCH,GET" {
   130  		t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
   131  	}
   132  
   133  	if headersVal != "Origin,X-whatever" {
   134  		t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
   135  	}
   136  
   137  	if exposedHeadersVal != "Content-Length,Hello" {
   138  		t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
   139  	}
   140  
   141  	if maxAgeVal != "300" {
   142  		t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
   143  	}
   144  }
   145  
   146  func Test_DefaultAllowHeaders(t *testing.T) {
   147  	recorder := httptest.NewRecorder()
   148  	n := negroni.New()
   149  	n.Use(NewAllow(&Options{
   150  		AllowAllOrigins: true,
   151  	}))
   152  
   153  	r, _ := http.NewRequest("PUT", "foo", nil)
   154  	n.ServeHTTP(recorder, r)
   155  
   156  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   157  	if headersVal != "Origin,Accept,Content-Type,Authorization" {
   158  		t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal)
   159  	}
   160  }
   161  
   162  /*func Test_Preflight(t *testing.T) {
   163  	recorder := NewRecorder()
   164  	n := negroni.Classic()
   165  	n.Use(NewAllow(&Options{
   166  		AllowAllOrigins: true,
   167  		AllowMethods:    []string{"PUT", "PATCH"},
   168  		AllowHeaders:    []string{"Origin", "X-whatever", "X-CaseSensitive"},
   169  	}))
   170  
   171  	n.Options("foo", func(res http.ResponseWriter) {
   172  		res.WriteHeader(500)
   173  	})
   174  
   175  	r, _ := http.NewRequest("OPTIONS", "foo", nil)
   176  	r.Header.Add(headerRequestMethod, "PUT")
   177  	r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive")
   178  	n.ServeHTTP(recorder, r)
   179  
   180  	headers := recorder.Header()
   181  	methodsVal := headers.Get(headerAllowMethods)
   182  	headersVal := headers.Get(headerAllowHeaders)
   183  	originVal := headers.Get(headerAllowOrigin)
   184  
   185  	if methodsVal != "PUT,PATCH" {
   186  		t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
   187  	}
   188  
   189  	if !strings.Contains(headersVal, "X-whatever") {
   190  		t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal)
   191  	}
   192  
   193  	if !strings.Contains(headersVal, "x-casesensitive") {
   194  		t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal)
   195  	}
   196  
   197  	if originVal != "*" {
   198  		t.Errorf("Allow-Origin is expected to be *, found %v", originVal)
   199  	}
   200  
   201  	if recorder.Code != http.StatusOK {
   202  		t.Errorf("Status code is expected to be 200, found %d", recorder.Code)
   203  	}
   204  }*/
   205  
   206  func Benchmark_WithoutCORS(b *testing.B) {
   207  	recorder := httptest.NewRecorder()
   208  	n := negroni.New()
   209  
   210  	b.ResetTimer()
   211  	for i := 0; i < 100; i++ {
   212  		r, _ := http.NewRequest("PUT", "foo", nil)
   213  		n.ServeHTTP(recorder, r)
   214  	}
   215  }
   216  
   217  func Benchmark_WithCORS(b *testing.B) {
   218  	recorder := httptest.NewRecorder()
   219  	n := negroni.New()
   220  	n.Use(NewAllow(&Options{
   221  		AllowAllOrigins:  true,
   222  		AllowCredentials: true,
   223  		AllowMethods:     []string{"PATCH", "GET"},
   224  		AllowHeaders:     []string{"Origin", "X-whatever"},
   225  		MaxAge:           5 * time.Minute,
   226  	}))
   227  
   228  	b.ResetTimer()
   229  	for i := 0; i < 100; i++ {
   230  		r, _ := http.NewRequest("PUT", "foo", nil)
   231  		n.ServeHTTP(recorder, r)
   232  	}
   233  }