github.com/adharshmk96/stk@v1.2.3/pkg/middleware/cors_test.go (about) 1 package middleware_test 2 3 import ( 4 "net/http" 5 "testing" 6 7 "github.com/adharshmk96/stk/gsk" 8 "github.com/adharshmk96/stk/pkg/middleware" 9 "github.com/stretchr/testify/assert" 10 ) 11 12 func TestCORSDefault(t *testing.T) { 13 // Create a new server instance 14 config := &gsk.ServerConfig{ 15 Port: "8888", 16 } 17 s := gsk.New(config) 18 19 s.Use(middleware.CORS()) 20 21 // Register a test route and handler 22 s.Get("/", func(c *gsk.Context) { 23 c.Status(http.StatusOK).JSONResponse("OK") 24 }) 25 26 t.Run("Non-preflight request", func(t *testing.T) { 27 // Run the test request 28 testParams := gsk.TestParams{ 29 Headers: map[string]string{ 30 "Origin": "example.com", 31 }, 32 } 33 rr, _ := s.Test("GET", "/", nil, testParams) 34 35 expectedHeaders := map[string]string{ 36 "Access-Control-Allow-Origin": "example.com", 37 "Access-Control-Allow-Methods": "POST, GET, OPTIONS, PUT, DELETE, PATCH", 38 "Access-Control-Allow-Headers": "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization", 39 } 40 41 // expect http.StatusOK 42 if rr.Code != http.StatusOK { 43 t.Errorf("Expected response code %d, but got %d", http.StatusOK, rr.Code) 44 } 45 46 for header, expectedValue := range expectedHeaders { 47 if value := rr.Header().Get(header); value != expectedValue { 48 t.Errorf("Expected %s header to be %q, but got %q", header, expectedValue, value) 49 } 50 } 51 }) 52 53 } 54 55 func TestCORSAllowedOrigin(t *testing.T) { 56 // Create a new server instance 57 config := &gsk.ServerConfig{ 58 Port: "8888", 59 } 60 61 AllowedOrigins := []string{ 62 "example.com", 63 } 64 s := gsk.New(config) 65 66 s.Use(middleware.CORS(middleware.CORSConfig{ 67 AllowedOrigins: AllowedOrigins, 68 })) 69 70 // Register a test route and handler 71 s.Get("/", func(c *gsk.Context) { 72 c.Status(http.StatusOK).JSONResponse("OK") 73 }) 74 75 t.Run("Non-preflight request from example.com", func(t *testing.T) { 76 77 // Run the test request 78 testParams := gsk.TestParams{ 79 Headers: map[string]string{ 80 "Origin": "example.com", 81 }, 82 } 83 rr, _ := s.Test("GET", "/", nil, testParams) 84 85 expectedHeaders := map[string]string{ 86 "Access-Control-Allow-Origin": "example.com", 87 "Access-Control-Allow-Methods": "POST, GET, OPTIONS, PUT, DELETE, PATCH", 88 "Access-Control-Allow-Headers": "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization", 89 } 90 91 assert.Equal(t, http.StatusOK, rr.Code) 92 93 for header, expectedValue := range expectedHeaders { 94 value := rr.Header().Get(header) 95 assert.Equal(t, expectedValue, value) 96 } 97 }) 98 99 t.Run("Non-preflight request with invalid origin", func(t *testing.T) { 100 101 // Run the test request 102 testParams := gsk.TestParams{ 103 Headers: map[string]string{ 104 "Origin": "invalid.com", 105 }, 106 } 107 rr, _ := s.Test("GET", "/", nil, testParams) 108 109 expectedHeaders := map[string]string{ 110 "Access-Control-Allow-Origin": "", 111 "Access-Control-Allow-Methods": "", 112 "Access-Control-Allow-Headers": "", 113 } 114 115 assert.Equal(t, http.StatusForbidden, rr.Code) 116 117 for header, expectedValue := range expectedHeaders { 118 value := rr.Header().Get(header) 119 assert.Equal(t, expectedValue, value) 120 } 121 }) 122 123 t.Run("Preflight request with example.com", func(t *testing.T) { 124 125 // Run the test request 126 testParams := gsk.TestParams{ 127 Headers: map[string]string{ 128 "Origin": "example.com", 129 "Access-Control-Request-Method": "POST", 130 }, 131 } 132 rr, _ := s.Test("OPTIONS", "/", nil, testParams) 133 134 // NOTE: thie is behaviour from the router package 135 // change this if we are chaning the router 136 expectedHeaders := map[string]string{ 137 "Access-Control-Allow-Origin": "example.com", 138 "Access-Control-Allow-Methods": "POST, GET, OPTIONS, PUT, DELETE, PATCH", 139 "Access-Control-Allow-Headers": "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization", 140 } 141 142 assert.Equal(t, http.StatusNoContent, rr.Code) 143 144 for header, expectedValue := range expectedHeaders { 145 value := rr.Header().Get(header) 146 assert.Equal(t, expectedValue, value) 147 } 148 }) 149 150 t.Run("Preflight request with invalid origin", func(t *testing.T) { 151 152 // Run the test request 153 testParams := gsk.TestParams{ 154 Headers: map[string]string{ 155 "Origin": "invalid.com", 156 "Access-Control-Request-Method": "POST", 157 }, 158 } 159 rr, _ := s.Test("OPTIONS", "/", nil, testParams) 160 161 expectedHeaders := map[string]string{ 162 "Access-Control-Allow-Origin": "", 163 "Access-Control-Allow-Methods": "", 164 "Access-Control-Allow-Headers": "", 165 } 166 167 // TODO - this should be checked later on 168 assert.Equal(t, http.StatusForbidden, rr.Code) 169 170 for header, expectedValue := range expectedHeaders { 171 value := rr.Header().Get(header) 172 assert.Equal(t, expectedValue, value) 173 } 174 }) 175 176 }