github.com/m3db/m3@v1.5.0/src/x/net/http/cors/cors_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  //
    21  // Derived from https://github.com/etcd-io/etcd/tree/v3.2.10/pkg/cors under
    22  // http://www.apache.org/licenses/LICENSE-2.0#redistribution .
    23  // See https://github.com/m3db/m3/blob/master/NOTICES.txt for the original copyright.
    24  
    25  package cors
    26  
    27  import (
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"reflect"
    31  	"testing"
    32  )
    33  
    34  func TestCORSInfo(t *testing.T) {
    35  	tests := []struct {
    36  		s     string
    37  		winfo Info
    38  		ws    string
    39  	}{
    40  		{"", Info{}, ""},
    41  		{"http://127.0.0.1", Info{"http://127.0.0.1": true}, "http://127.0.0.1"},
    42  		{"*", Info{"*": true}, "*"},
    43  		// with space around
    44  		{" http://127.0.0.1 ", Info{"http://127.0.0.1": true}, "http://127.0.0.1"},
    45  		// multiple addrs
    46  		{
    47  			"http://127.0.0.1,http://127.0.0.2",
    48  			Info{"http://127.0.0.1": true, "http://127.0.0.2": true},
    49  			"http://127.0.0.1,http://127.0.0.2",
    50  		},
    51  	}
    52  	for i, tt := range tests {
    53  		info := Info{}
    54  		if err := info.Set(tt.s); err != nil {
    55  			t.Errorf("#%d: set error = %v, want nil", i, err)
    56  		}
    57  		if !reflect.DeepEqual(info, tt.winfo) {
    58  			t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo)
    59  		}
    60  		if g := info.String(); g != tt.ws {
    61  			t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws)
    62  		}
    63  	}
    64  }
    65  
    66  func TestCORSInfoOriginAllowed(t *testing.T) {
    67  	tests := []struct {
    68  		set      string
    69  		origin   string
    70  		wallowed bool
    71  	}{
    72  		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true},
    73  		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true},
    74  		{"http://127.0.0.1,http://127.0.0.2", "*", false},
    75  		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false},
    76  		{"*", "*", true},
    77  		{"*", "http://127.0.0.1", true},
    78  	}
    79  	for i, tt := range tests {
    80  		info := Info{}
    81  		if err := info.Set(tt.set); err != nil {
    82  			t.Errorf("#%d: set error = %v, want nil", i, err)
    83  		}
    84  		if g := info.OriginAllowed(tt.origin); g != tt.wallowed {
    85  			t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed)
    86  		}
    87  	}
    88  }
    89  
    90  func TestCORSHandler(t *testing.T) {
    91  	info := &Info{}
    92  	if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil {
    93  		t.Fatalf("unexpected set error: %v", err)
    94  	}
    95  	h := &Handler{
    96  		Handler: http.NotFoundHandler(),
    97  		Info:    info,
    98  	}
    99  
   100  	header := func(origin string) http.Header {
   101  		return http.Header{
   102  			"Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"},
   103  			"Access-Control-Allow-Origin":  []string{origin},
   104  			"Access-Control-Allow-Headers": []string{"accept, content-type, authorization"},
   105  		}
   106  	}
   107  	tests := []struct {
   108  		method  string
   109  		origin  string
   110  		wcode   int
   111  		wheader http.Header
   112  	}{
   113  		{"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")},
   114  		{"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")},
   115  		{"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}},
   116  		{"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")},
   117  	}
   118  	for i, tt := range tests {
   119  		rr := httptest.NewRecorder()
   120  		req := &http.Request{
   121  			Method: tt.method,
   122  			Header: http.Header{"Origin": []string{tt.origin}},
   123  		}
   124  		h.ServeHTTP(rr, req)
   125  		if rr.Code != tt.wcode {
   126  			t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode)
   127  		}
   128  		// it is set by http package, and there is no need to test it
   129  		rr.HeaderMap.Del("Content-Type")
   130  		rr.HeaderMap.Del("X-Content-Type-Options")
   131  		if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) {
   132  			t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader)
   133  		}
   134  	}
   135  }