github.com/verrazzano/verrazzano@v1.7.0/authproxy/src/cors/cors_test.go (about)

     1  // Copyright (c) 2023, Oracle and/or its affiliates.
     2  // Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl.
     3  
     4  package cors
     5  
     6  import (
     7  	"fmt"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  func TestAddCORSHeaders(t *testing.T) {
    16  	testURL := "https://some-url.example.com"
    17  	optionsReqAllowHeaders := "authorization, content-type"
    18  	optionsReqAllowMethods := "GET, HEAD, POST, PUT, DELETE, OPTIONS, PATCH"
    19  	ingressHostVal := "someorigin.example.com"
    20  	validOrigin := fmt.Sprintf("https://%s", ingressHostVal)
    21  	tests := []struct {
    22  		name              string
    23  		reqMethod         string
    24  		ingressHost       string
    25  		originHeader      []string // use array to test invalid case of multiple origin headers
    26  		expectCORSHeaders bool
    27  		expectOptHeaders  bool
    28  		want              int
    29  		wantErr           bool
    30  	}{
    31  		{"No origin header", http.MethodGet, ingressHostVal, []string{}, false, false, http.StatusOK, false},
    32  		{"Multiple origin headers", http.MethodGet, ingressHostVal, []string{"origin1", "origin2"}, false, false, http.StatusBadRequest, true},
    33  		{"Disallowed origin header GET request", http.MethodGet, ingressHostVal, []string{"https://notallowed"}, false, false, http.StatusOK, false},
    34  		{"Disallowed origin header POST request", http.MethodPost, ingressHostVal, []string{"https://notallowed"}, false, false, http.StatusForbidden, true},
    35  		{"Valid origin header GET request", http.MethodGet, ingressHostVal, []string{validOrigin}, true, false, http.StatusOK, false},
    36  		{"Valid origin header OPTIONS request", http.MethodOptions, ingressHostVal, []string{validOrigin}, true, true, http.StatusOK, false},
    37  	}
    38  	for _, tt := range tests {
    39  		t.Run(tt.name, func(t *testing.T) {
    40  			req, err := http.NewRequest(tt.reqMethod, testURL, nil)
    41  			assert.Nil(t, err)
    42  			for _, org := range tt.originHeader {
    43  				req.Header.Add("Origin", org)
    44  			}
    45  			rw := httptest.NewRecorder()
    46  			got, err := AddCORSHeaders(req, rw, tt.ingressHost)
    47  			if (err != nil) != tt.wantErr {
    48  				t.Errorf("AddCORSHeaders() error = %v, wantErr %v", err, tt.wantErr)
    49  				return
    50  			}
    51  			if got != tt.want {
    52  				t.Errorf("AddCORSHeaders() got = %v, want %v", got, tt.want)
    53  			}
    54  
    55  			expectedAllowCreds := ""
    56  			expectedAllowOrigin := ""
    57  			if tt.expectCORSHeaders {
    58  				expectedAllowCreds = "true"
    59  				expectedAllowOrigin = tt.originHeader[0]
    60  			}
    61  			assert.Equal(t, rw.Header().Get("Access-Control-Allow-Origin"), expectedAllowOrigin)
    62  			assert.Equal(t, rw.Header().Get("Access-Control-Allow-Credentials"), expectedAllowCreds)
    63  
    64  			expectedAllowHeaders := ""
    65  			expectedAllowMethods := ""
    66  			if tt.expectOptHeaders {
    67  				expectedAllowHeaders = optionsReqAllowHeaders
    68  				expectedAllowMethods = optionsReqAllowMethods
    69  			}
    70  			assert.Equal(t, rw.Header().Get("Access-Control-Allow-Headers"), expectedAllowHeaders)
    71  			assert.Equal(t, rw.Header().Get("Access-Control-Allow-Methods"), expectedAllowMethods)
    72  		})
    73  	}
    74  }
    75  
    76  func TestOriginAllowed(t *testing.T) {
    77  	ingressHostVal := "someorigin.example.com"
    78  	oneAllowedOrigin := "https://allowedorigin.example.com"
    79  	defaultAllowedOriginFunc := func() string { return "" }
    80  	oneAllowedOriginFunc := func() string { return oneAllowedOrigin }
    81  	multiAllowedOriginsFunc := func() string { return fmt.Sprintf("https://alsoallowed.example.com,%s", oneAllowedOrigin) }
    82  
    83  	tests := []struct {
    84  		name              string
    85  		ingressHost       string
    86  		origin            string
    87  		allowedOriginFunc func() string
    88  		want              bool
    89  	}{
    90  		{"origin equals ingress host", ingressHostVal, fmt.Sprintf("https://%s", ingressHostVal), defaultAllowedOriginFunc, true},
    91  		{"origin has value 'null'", ingressHostVal, "null", defaultAllowedOriginFunc, false},
    92  		{"origin not equal to ingress host, no allow list", ingressHostVal, "https://otherorigin.example.com", defaultAllowedOriginFunc, false},
    93  		{"origin not equal to ingress host, in allow list with one entry", ingressHostVal, oneAllowedOrigin, oneAllowedOriginFunc, true},
    94  		{"origin not equal to ingress host, in allow list with multiple entries", ingressHostVal, oneAllowedOrigin, multiAllowedOriginsFunc, true},
    95  		{"origin not equal to ingress host, not in allow list", ingressHostVal, "someotheroriginentirely", multiAllowedOriginsFunc, false},
    96  	}
    97  	for _, tt := range tests {
    98  		t.Run(tt.name, func(t *testing.T) {
    99  			allowedOriginsFunc = tt.allowedOriginFunc
   100  			defer func() {
   101  				allowedOriginsFunc = defaultAllowedOriginFunc
   102  			}()
   103  
   104  			if got := originAllowed(tt.origin, tt.ingressHost); got != tt.want {
   105  				t.Errorf("originAllowed() = %v, want %v", got, tt.want)
   106  			}
   107  		})
   108  	}
   109  }