github.com/verrazzano/verrazzano@v1.7.0/authproxy/src/cors/cors_test.go (about) 1 // Copyright (c) 2023, Oracle and/or its affiliates. 2 // Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl. 3 4 package cors 5 6 import ( 7 "fmt" 8 "net/http" 9 "net/http/httptest" 10 "testing" 11 12 "github.com/stretchr/testify/assert" 13 ) 14 15 func TestAddCORSHeaders(t *testing.T) { 16 testURL := "https://some-url.example.com" 17 optionsReqAllowHeaders := "authorization, content-type" 18 optionsReqAllowMethods := "GET, HEAD, POST, PUT, DELETE, OPTIONS, PATCH" 19 ingressHostVal := "someorigin.example.com" 20 validOrigin := fmt.Sprintf("https://%s", ingressHostVal) 21 tests := []struct { 22 name string 23 reqMethod string 24 ingressHost string 25 originHeader []string // use array to test invalid case of multiple origin headers 26 expectCORSHeaders bool 27 expectOptHeaders bool 28 want int 29 wantErr bool 30 }{ 31 {"No origin header", http.MethodGet, ingressHostVal, []string{}, false, false, http.StatusOK, false}, 32 {"Multiple origin headers", http.MethodGet, ingressHostVal, []string{"origin1", "origin2"}, false, false, http.StatusBadRequest, true}, 33 {"Disallowed origin header GET request", http.MethodGet, ingressHostVal, []string{"https://notallowed"}, false, false, http.StatusOK, false}, 34 {"Disallowed origin header POST request", http.MethodPost, ingressHostVal, []string{"https://notallowed"}, false, false, http.StatusForbidden, true}, 35 {"Valid origin header GET request", http.MethodGet, ingressHostVal, []string{validOrigin}, true, false, http.StatusOK, false}, 36 {"Valid origin header OPTIONS request", http.MethodOptions, ingressHostVal, []string{validOrigin}, true, true, http.StatusOK, false}, 37 } 38 for _, tt := range tests { 39 t.Run(tt.name, func(t *testing.T) { 40 req, err := http.NewRequest(tt.reqMethod, testURL, nil) 41 assert.Nil(t, err) 42 for _, org := range tt.originHeader { 43 req.Header.Add("Origin", org) 44 } 45 rw := httptest.NewRecorder() 46 got, err := AddCORSHeaders(req, rw, tt.ingressHost) 47 if (err != nil) != tt.wantErr { 48 t.Errorf("AddCORSHeaders() error = %v, wantErr %v", err, tt.wantErr) 49 return 50 } 51 if got != tt.want { 52 t.Errorf("AddCORSHeaders() got = %v, want %v", got, tt.want) 53 } 54 55 expectedAllowCreds := "" 56 expectedAllowOrigin := "" 57 if tt.expectCORSHeaders { 58 expectedAllowCreds = "true" 59 expectedAllowOrigin = tt.originHeader[0] 60 } 61 assert.Equal(t, rw.Header().Get("Access-Control-Allow-Origin"), expectedAllowOrigin) 62 assert.Equal(t, rw.Header().Get("Access-Control-Allow-Credentials"), expectedAllowCreds) 63 64 expectedAllowHeaders := "" 65 expectedAllowMethods := "" 66 if tt.expectOptHeaders { 67 expectedAllowHeaders = optionsReqAllowHeaders 68 expectedAllowMethods = optionsReqAllowMethods 69 } 70 assert.Equal(t, rw.Header().Get("Access-Control-Allow-Headers"), expectedAllowHeaders) 71 assert.Equal(t, rw.Header().Get("Access-Control-Allow-Methods"), expectedAllowMethods) 72 }) 73 } 74 } 75 76 func TestOriginAllowed(t *testing.T) { 77 ingressHostVal := "someorigin.example.com" 78 oneAllowedOrigin := "https://allowedorigin.example.com" 79 defaultAllowedOriginFunc := func() string { return "" } 80 oneAllowedOriginFunc := func() string { return oneAllowedOrigin } 81 multiAllowedOriginsFunc := func() string { return fmt.Sprintf("https://alsoallowed.example.com,%s", oneAllowedOrigin) } 82 83 tests := []struct { 84 name string 85 ingressHost string 86 origin string 87 allowedOriginFunc func() string 88 want bool 89 }{ 90 {"origin equals ingress host", ingressHostVal, fmt.Sprintf("https://%s", ingressHostVal), defaultAllowedOriginFunc, true}, 91 {"origin has value 'null'", ingressHostVal, "null", defaultAllowedOriginFunc, false}, 92 {"origin not equal to ingress host, no allow list", ingressHostVal, "https://otherorigin.example.com", defaultAllowedOriginFunc, false}, 93 {"origin not equal to ingress host, in allow list with one entry", ingressHostVal, oneAllowedOrigin, oneAllowedOriginFunc, true}, 94 {"origin not equal to ingress host, in allow list with multiple entries", ingressHostVal, oneAllowedOrigin, multiAllowedOriginsFunc, true}, 95 {"origin not equal to ingress host, not in allow list", ingressHostVal, "someotheroriginentirely", multiAllowedOriginsFunc, false}, 96 } 97 for _, tt := range tests { 98 t.Run(tt.name, func(t *testing.T) { 99 allowedOriginsFunc = tt.allowedOriginFunc 100 defer func() { 101 allowedOriginsFunc = defaultAllowedOriginFunc 102 }() 103 104 if got := originAllowed(tt.origin, tt.ingressHost); got != tt.want { 105 t.Errorf("originAllowed() = %v, want %v", got, tt.want) 106 } 107 }) 108 } 109 }