go.etcd.io/etcd@v3.3.27+incompatible/pkg/cors/cors_test.go (about)

     1  // Copyright 2015 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cors
    16  
    17  import (
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"reflect"
    21  	"testing"
    22  )
    23  
    24  func TestCORSInfo(t *testing.T) {
    25  	tests := []struct {
    26  		s     string
    27  		winfo CORSInfo
    28  		ws    string
    29  	}{
    30  		{"", CORSInfo{}, ""},
    31  		{"http://127.0.0.1", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
    32  		{"*", CORSInfo{"*": true}, "*"},
    33  		// with space around
    34  		{" http://127.0.0.1 ", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
    35  		// multiple addrs
    36  		{
    37  			"http://127.0.0.1,http://127.0.0.2",
    38  			CORSInfo{"http://127.0.0.1": true, "http://127.0.0.2": true},
    39  			"http://127.0.0.1,http://127.0.0.2",
    40  		},
    41  	}
    42  	for i, tt := range tests {
    43  		info := CORSInfo{}
    44  		if err := info.Set(tt.s); err != nil {
    45  			t.Errorf("#%d: set error = %v, want nil", i, err)
    46  		}
    47  		if !reflect.DeepEqual(info, tt.winfo) {
    48  			t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo)
    49  		}
    50  		if g := info.String(); g != tt.ws {
    51  			t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws)
    52  		}
    53  	}
    54  }
    55  
    56  func TestCORSInfoOriginAllowed(t *testing.T) {
    57  	tests := []struct {
    58  		set      string
    59  		origin   string
    60  		wallowed bool
    61  	}{
    62  		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true},
    63  		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true},
    64  		{"http://127.0.0.1,http://127.0.0.2", "*", false},
    65  		{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false},
    66  		{"*", "*", true},
    67  		{"*", "http://127.0.0.1", true},
    68  	}
    69  	for i, tt := range tests {
    70  		info := CORSInfo{}
    71  		if err := info.Set(tt.set); err != nil {
    72  			t.Errorf("#%d: set error = %v, want nil", i, err)
    73  		}
    74  		if g := info.OriginAllowed(tt.origin); g != tt.wallowed {
    75  			t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed)
    76  		}
    77  	}
    78  }
    79  
    80  func TestCORSHandler(t *testing.T) {
    81  	info := &CORSInfo{}
    82  	if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil {
    83  		t.Fatalf("unexpected set error: %v", err)
    84  	}
    85  	h := &CORSHandler{
    86  		Handler: http.NotFoundHandler(),
    87  		Info:    info,
    88  	}
    89  
    90  	header := func(origin string) http.Header {
    91  		return http.Header{
    92  			"Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"},
    93  			"Access-Control-Allow-Origin":  []string{origin},
    94  			"Access-Control-Allow-Headers": []string{"accept, content-type, authorization"},
    95  		}
    96  	}
    97  	tests := []struct {
    98  		method  string
    99  		origin  string
   100  		wcode   int
   101  		wheader http.Header
   102  	}{
   103  		{"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")},
   104  		{"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")},
   105  		{"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}},
   106  		{"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")},
   107  	}
   108  	for i, tt := range tests {
   109  		rr := httptest.NewRecorder()
   110  		req := &http.Request{
   111  			Method: tt.method,
   112  			Header: http.Header{"Origin": []string{tt.origin}},
   113  		}
   114  		h.ServeHTTP(rr, req)
   115  		if rr.Code != tt.wcode {
   116  			t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode)
   117  		}
   118  		// it is set by http package, and there is no need to test it
   119  		rr.HeaderMap.Del("Content-Type")
   120  		rr.HeaderMap.Del("X-Content-Type-Options")
   121  		if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) {
   122  			t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader)
   123  		}
   124  	}
   125  }