github.com/calendreco/cors@v0.0.0-20140529192254-d98a0ee7dc61/cors_test.go (about)

     1  // Copyright 2014 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cors
    16  
    17  import (
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/go-martini/martini"
    25  )
    26  
    27  func Test_AllowAll(t *testing.T) {
    28  	recorder := httptest.NewRecorder()
    29  	m := martini.New()
    30  	m.Use(Allow(&Options{
    31  		AllowAllOrigins: true,
    32  	}))
    33  
    34  	r, _ := http.NewRequest("PUT", "foo", nil)
    35  	m.ServeHTTP(recorder, r)
    36  
    37  	if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
    38  		t.Errorf("Allow-Origin header should be *")
    39  	}
    40  }
    41  
    42  func Test_AllowRegexMatch(t *testing.T) {
    43  	recorder := httptest.NewRecorder()
    44  	m := martini.New()
    45  	m.Use(Allow(&Options{
    46  		AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"},
    47  	}))
    48  
    49  	origin := "https://bar.foo.com"
    50  	r, _ := http.NewRequest("PUT", "foo", nil)
    51  	r.Header.Add("Origin", origin)
    52  	m.ServeHTTP(recorder, r)
    53  
    54  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
    55  	if headerValue != origin {
    56  		t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
    57  	}
    58  }
    59  
    60  func Test_AllowRegexNoMatch(t *testing.T) {
    61  	recorder := httptest.NewRecorder()
    62  	m := martini.New()
    63  	m.Use(Allow(&Options{
    64  		AllowOrigins: []string{"https://*.foo.com"},
    65  	}))
    66  
    67  	origin := "https://ww.foo.com.evil.com"
    68  	r, _ := http.NewRequest("PUT", "foo", nil)
    69  	r.Header.Add("Origin", origin)
    70  	m.ServeHTTP(recorder, r)
    71  
    72  	headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
    73  	if headerValue != "" {
    74  		t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
    75  	}
    76  }
    77  
    78  func Test_OtherHeaders(t *testing.T) {
    79  	recorder := httptest.NewRecorder()
    80  	m := martini.New()
    81  	m.Use(Allow(&Options{
    82  		AllowAllOrigins:  true,
    83  		AllowCredentials: true,
    84  		AllowMethods:     []string{"PATCH", "GET"},
    85  		AllowHeaders:     []string{"Origin", "X-whatever"},
    86  		ExposeHeaders:    []string{"Content-Length", "Hello"},
    87  		MaxAge:           5 * time.Minute,
    88  	}))
    89  
    90  	r, _ := http.NewRequest("PUT", "foo", nil)
    91  	m.ServeHTTP(recorder, r)
    92  
    93  	credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
    94  	methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
    95  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
    96  	exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
    97  	maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
    98  
    99  	if credentialsVal != "true" {
   100  		t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
   101  	}
   102  
   103  	if methodsVal != "PATCH,GET" {
   104  		t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
   105  	}
   106  
   107  	if headersVal != "Origin,X-whatever" {
   108  		t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
   109  	}
   110  
   111  	if exposedHeadersVal != "Content-Length,Hello" {
   112  		t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
   113  	}
   114  
   115  	if maxAgeVal != "300" {
   116  		t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
   117  	}
   118  }
   119  
   120  func Test_DefaultAllowHeaders(t *testing.T) {
   121  	recorder := httptest.NewRecorder()
   122  	m := martini.New()
   123  	m.Use(Allow(&Options{
   124  		AllowAllOrigins: true,
   125  	}))
   126  
   127  	r, _ := http.NewRequest("PUT", "foo", nil)
   128  	m.ServeHTTP(recorder, r)
   129  
   130  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   131  	if headersVal != "Origin,Accept,Content-Type,Authorization" {
   132  		t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal)
   133  	}
   134  }
   135  
   136  func Test_Preflight(t *testing.T) {
   137  	recorder := httptest.NewRecorder()
   138  	m := martini.Classic()
   139  	m.Use(Allow(&Options{
   140  		AllowAllOrigins: true,
   141  		AllowMethods:    []string{"PUT", "PATCH"},
   142  		AllowHeaders:    []string{"Origin", "X-whatever", "X-CaseSensitive"},
   143  	}))
   144  
   145  	m.Options("foo", func(res http.ResponseWriter) {
   146  		res.WriteHeader(500)
   147  	})
   148  
   149  	r, _ := http.NewRequest("OPTIONS", "foo", nil)
   150  	r.Header.Add(headerRequestMethod, "PUT")
   151  	r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive")
   152  	m.ServeHTTP(recorder, r)
   153  
   154  	methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
   155  	headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
   156  	originVal := recorder.HeaderMap.Get(headerAllowOrigin)
   157  
   158  	if methodsVal != "PUT,PATCH" {
   159  		t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
   160  	}
   161  
   162  	if !strings.Contains(headersVal, "X-whatever") {
   163  		t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal)
   164  	}
   165  
   166  	if !strings.Contains(headersVal, "x-casesensitive") {
   167  		t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal)
   168  	}
   169  
   170  	if originVal != "*" {
   171  		t.Errorf("Allow-Origin is expected to be *, found %v", originVal)
   172  	}
   173  
   174  	if recorder.Code != http.StatusOK {
   175  		t.Errorf("Status code is expected to be 200, found %d", recorder.Code)
   176  	}
   177  }
   178  
   179  func Benchmark_WithoutCORS(b *testing.B) {
   180  	recorder := httptest.NewRecorder()
   181  	m := martini.New()
   182  
   183  	b.ResetTimer()
   184  	for i := 0; i < 100; i++ {
   185  		r, _ := http.NewRequest("PUT", "foo", nil)
   186  		m.ServeHTTP(recorder, r)
   187  	}
   188  }
   189  
   190  func Benchmark_WithCORS(b *testing.B) {
   191  	recorder := httptest.NewRecorder()
   192  	m := martini.New()
   193  	m.Use(Allow(&Options{
   194  		AllowAllOrigins:  true,
   195  		AllowCredentials: true,
   196  		AllowMethods:     []string{"PATCH", "GET"},
   197  		AllowHeaders:     []string{"Origin", "X-whatever"},
   198  		MaxAge:           5 * time.Minute,
   199  	}))
   200  
   201  	b.ResetTimer()
   202  	for i := 0; i < 100; i++ {
   203  		r, _ := http.NewRequest("PUT", "foo", nil)
   204  		m.ServeHTTP(recorder, r)
   205  	}
   206  }