github.com/airamrguez/cors@v0.0.0-20140622164850-801f09012ccd/cors_test.go (about)

     1  // Copyright 2014 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cors
    16  
    17  import (
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/go-martini/martini"
    25  )
    26  
    27  
    28  type HttpHeaderGuardRecorder struct {
    29  	*httptest.ResponseRecorder
    30  	savedHeaderMap http.Header
    31  }
    32  
    33  
    34  func NewRecorder() *HttpHeaderGuardRecorder {
    35  	return &HttpHeaderGuardRecorder{httptest.NewRecorder(), nil}
    36  }
    37  
    38  
    39  func (gr* HttpHeaderGuardRecorder) WriteHeader(code int) {
    40  	gr.ResponseRecorder.WriteHeader(code)
    41  	gr.savedHeaderMap = gr.ResponseRecorder.Header()
    42  }
    43  
    44  
    45  func (gr* HttpHeaderGuardRecorder) Header() http.Header {
    46  	if gr.savedHeaderMap != nil {
    47  		// headers were written. clone so we don't get updates
    48  		clone := make(http.Header)
    49  		for k, v := range gr.savedHeaderMap {
    50  			clone[k] = v
    51  		}
    52  		return clone
    53  	} else {
    54  		return gr.ResponseRecorder.Header()
    55  	}
    56  }
    57  
    58  
    59  
    60  func Test_AllowAll(t *testing.T) {
    61  	recorder := httptest.NewRecorder()
    62  	m := martini.New()
    63  	m.Use(Allow(&Options{
    64  		AllowAllOrigins: true,
    65  	}))
    66  
    67  	r, _ := http.NewRequest("PUT", "foo", nil)
    68  	m.ServeHTTP(recorder, r)
    69  
    70  	if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
    71  		t.Errorf("Allow-Origin header should be *")
    72  	}
    73  }
    74  
    75  func Test_AllowRegexMatch(t *testing.T) {
    76  	recorder := httptest.NewRecorder()
    77  	m := martini.New()
    78  	m.Use(Allow(&Options{
    79  		AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"},
    80  	}))
    81  
    82  	origin := "https://bar.foo.com"
    83  	r, _ := http.NewRequest("PUT", "foo", nil)
    84  	r.Header.Add("Origin", origin)
    85  	m.ServeHTTP(recorder, r)
    86  
    87  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
    88  	if headerValue != origin {
    89  		t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
    90  	}
    91  }
    92  
    93  func Test_AllowRegexNoMatch(t *testing.T) {
    94  	recorder := httptest.NewRecorder()
    95  	m := martini.New()
    96  	m.Use(Allow(&Options{
    97  		AllowOrigins: []string{"https://*.foo.com"},
    98  	}))
    99  
   100  	origin := "https://ww.foo.com.evil.com"
   101  	r, _ := http.NewRequest("PUT", "foo", nil)
   102  	r.Header.Add("Origin", origin)
   103  	m.ServeHTTP(recorder, r)
   104  
   105  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
   106  	if headerValue != "" {
   107  		t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
   108  	}
   109  }
   110  
   111  func Test_OtherHeaders(t *testing.T) {
   112  	recorder := httptest.NewRecorder()
   113  	m := martini.New()
   114  	m.Use(Allow(&Options{
   115  		AllowAllOrigins:  true,
   116  		AllowCredentials: true,
   117  		AllowMethods:     []string{"PATCH", "GET"},
   118  		AllowHeaders:     []string{"Origin", "X-whatever"},
   119  		ExposeHeaders:    []string{"Content-Length", "Hello"},
   120  		MaxAge:           5 * time.Minute,
   121  	}))
   122  
   123  	r, _ := http.NewRequest("PUT", "foo", nil)
   124  	m.ServeHTTP(recorder, r)
   125  
   126  	credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
   127  	methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
   128  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   129  	exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
   130  	maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
   131  
   132  	if credentialsVal != "true" {
   133  		t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
   134  	}
   135  
   136  	if methodsVal != "PATCH,GET" {
   137  		t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
   138  	}
   139  
   140  	if headersVal != "Origin,X-whatever" {
   141  		t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
   142  	}
   143  
   144  	if exposedHeadersVal != "Content-Length,Hello" {
   145  		t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
   146  	}
   147  
   148  	if maxAgeVal != "300" {
   149  		t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
   150  	}
   151  }
   152  
   153  func Test_DefaultAllowHeaders(t *testing.T) {
   154  	recorder := httptest.NewRecorder()
   155  	m := martini.New()
   156  	m.Use(Allow(&Options{
   157  		AllowAllOrigins: true,
   158  	}))
   159  
   160  	r, _ := http.NewRequest("PUT", "foo", nil)
   161  	m.ServeHTTP(recorder, r)
   162  
   163  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   164  	if headersVal != "Origin,Accept,Content-Type,Authorization" {
   165  		t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal)
   166  	}
   167  }
   168  
   169  func Test_Preflight(t *testing.T) {
   170  	recorder := NewRecorder()
   171  	m := martini.Classic()
   172  	m.Use(Allow(&Options{
   173  		AllowAllOrigins: true,
   174  		AllowMethods:    []string{"PUT", "PATCH"},
   175  		AllowHeaders:    []string{"Origin", "X-whatever", "X-CaseSensitive"},
   176  	}))
   177  
   178  	m.Options("foo", func(res http.ResponseWriter) {
   179  		res.WriteHeader(500)
   180  	})
   181  
   182  	r, _ := http.NewRequest("OPTIONS", "foo", nil)
   183  	r.Header.Add(headerRequestMethod, "PUT")
   184  	r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive")
   185  	m.ServeHTTP(recorder, r)
   186  
   187  	headers := recorder.Header()
   188  	methodsVal := headers.Get(headerAllowMethods)
   189  	headersVal := headers.Get(headerAllowHeaders)
   190  	originVal := headers.Get(headerAllowOrigin)
   191  
   192  	if methodsVal != "PUT,PATCH" {
   193  		t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
   194  	}
   195  
   196  	if !strings.Contains(headersVal, "X-whatever") {
   197  		t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal)
   198  	}
   199  
   200  	if !strings.Contains(headersVal, "x-casesensitive") {
   201  		t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal)
   202  	}
   203  
   204  	if originVal != "*" {
   205  		t.Errorf("Allow-Origin is expected to be *, found %v", originVal)
   206  	}
   207  
   208  	if recorder.Code != http.StatusOK {
   209  		t.Errorf("Status code is expected to be 200, found %d", recorder.Code)
   210  	}
   211  }
   212  
   213  func Benchmark_WithoutCORS(b *testing.B) {
   214  	recorder := httptest.NewRecorder()
   215  	m := martini.New()
   216  
   217  	b.ResetTimer()
   218  	for i := 0; i < 100; i++ {
   219  		r, _ := http.NewRequest("PUT", "foo", nil)
   220  		m.ServeHTTP(recorder, r)
   221  	}
   222  }
   223  
   224  func Benchmark_WithCORS(b *testing.B) {
   225  	recorder := httptest.NewRecorder()
   226  	m := martini.New()
   227  	m.Use(Allow(&Options{
   228  		AllowAllOrigins:  true,
   229  		AllowCredentials: true,
   230  		AllowMethods:     []string{"PATCH", "GET"},
   231  		AllowHeaders:     []string{"Origin", "X-whatever"},
   232  		MaxAge:           5 * time.Minute,
   233  	}))
   234  
   235  	b.ResetTimer()
   236  	for i := 0; i < 100; i++ {
   237  		r, _ := http.NewRequest("PUT", "foo", nil)
   238  		m.ServeHTTP(recorder, r)
   239  	}
   240  }