github.com/grafviktor/keep-my-secret@v0.9.10-0.20230908165355-19f35cce90e5/internal/api/web/middleware/cors_test.go (about) 1 package middleware 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 8 "github.com/grafviktor/keep-my-secret/internal/config" 9 ) 10 11 type mockHandler struct { 12 called bool 13 } 14 15 func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 16 m.called = true 17 } 18 19 func TestEnableCORS(t *testing.T) { 20 // Create a mock HTTP handler for testing 21 mock := &mockHandler{} 22 23 // Create a mock HTTP request 24 req := httptest.NewRequest("OPTIONS", "http://example.com", nil) 25 26 // Pretend that I'm a kinda browser. Unit test won't set this header for us 27 req.Header.Set("Origin", "http://example.com") 28 29 // Create a response recorder 30 res := httptest.NewRecorder() 31 32 // Create an instance of the middleware with DevMode enabled 33 middlewareInstance := middleware{config: config.AppConfig{DevMode: true}} 34 35 // Call the EnableCORS middleware with the mock handler 36 corsHandler := middlewareInstance.EnableCORS(mock) 37 38 // Process the request using the CORS middleware 39 corsHandler.ServeHTTP(res, req) 40 41 // Verify that the CORS headers are set correctly when DevMode is enabled 42 expectedHeadersDevMode := map[string]string{ 43 "Access-Control-Allow-Origin": "http://example.com", 44 "Access-Control-Allow-Credentials": "true", 45 "Access-Control-Expose-Headers": "*", 46 "Access-Control-Allow-Methods": "GET,POST,PUT,PATCH,DELETE,OPTIONS", 47 "Access-Control-Allow-Headers": "Accept, Content-Type, X-CSRF-Token, Authorization", 48 } 49 50 for header, expectedValue := range expectedHeadersDevMode { 51 actualValue := res.Header().Get(header) 52 if actualValue != expectedValue { 53 t.Errorf("Expected %s header to be '%s', but got '%s'", header, expectedValue, actualValue) 54 } 55 } 56 57 // Reset the recorder for the next test 58 res = httptest.NewRecorder() 59 60 // Create an instance of the middleware with DevMode disabled 61 middlewareInstance.config.DevMode = false 62 63 // Call the EnableCORS middleware with the mock handler 64 corsHandler = middlewareInstance.EnableCORS(mock) 65 66 // Process the request using the CORS middleware 67 corsHandler.ServeHTTP(res, req) 68 69 // Verify that the CORS headers are not set when DevMode is disabled 70 for header := range expectedHeadersDevMode { 71 actualValue := res.Header().Get(header) 72 if actualValue != "" { 73 t.Errorf("Expected %s header to be empty, but got '%s'", header, actualValue) 74 } 75 } 76 77 if !mock.called { 78 t.Error("Expected ServeHTTP be called") 79 } 80 }