github.com/grafviktor/keep-my-secret@v0.9.10-0.20230908165355-19f35cce90e5/internal/api/web/middleware/cors_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  
     8  	"github.com/grafviktor/keep-my-secret/internal/config"
     9  )
    10  
    11  type mockHandler struct {
    12  	called bool
    13  }
    14  
    15  func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    16  	m.called = true
    17  }
    18  
    19  func TestEnableCORS(t *testing.T) {
    20  	// Create a mock HTTP handler for testing
    21  	mock := &mockHandler{}
    22  
    23  	// Create a mock HTTP request
    24  	req := httptest.NewRequest("OPTIONS", "http://example.com", nil)
    25  
    26  	// Pretend that I'm a kinda browser. Unit test won't set this header for us
    27  	req.Header.Set("Origin", "http://example.com")
    28  
    29  	// Create a response recorder
    30  	res := httptest.NewRecorder()
    31  
    32  	// Create an instance of the middleware with DevMode enabled
    33  	middlewareInstance := middleware{config: config.AppConfig{DevMode: true}}
    34  
    35  	// Call the EnableCORS middleware with the mock handler
    36  	corsHandler := middlewareInstance.EnableCORS(mock)
    37  
    38  	// Process the request using the CORS middleware
    39  	corsHandler.ServeHTTP(res, req)
    40  
    41  	// Verify that the CORS headers are set correctly when DevMode is enabled
    42  	expectedHeadersDevMode := map[string]string{
    43  		"Access-Control-Allow-Origin":      "http://example.com",
    44  		"Access-Control-Allow-Credentials": "true",
    45  		"Access-Control-Expose-Headers":    "*",
    46  		"Access-Control-Allow-Methods":     "GET,POST,PUT,PATCH,DELETE,OPTIONS",
    47  		"Access-Control-Allow-Headers":     "Accept, Content-Type, X-CSRF-Token, Authorization",
    48  	}
    49  
    50  	for header, expectedValue := range expectedHeadersDevMode {
    51  		actualValue := res.Header().Get(header)
    52  		if actualValue != expectedValue {
    53  			t.Errorf("Expected %s header to be '%s', but got '%s'", header, expectedValue, actualValue)
    54  		}
    55  	}
    56  
    57  	// Reset the recorder for the next test
    58  	res = httptest.NewRecorder()
    59  
    60  	// Create an instance of the middleware with DevMode disabled
    61  	middlewareInstance.config.DevMode = false
    62  
    63  	// Call the EnableCORS middleware with the mock handler
    64  	corsHandler = middlewareInstance.EnableCORS(mock)
    65  
    66  	// Process the request using the CORS middleware
    67  	corsHandler.ServeHTTP(res, req)
    68  
    69  	// Verify that the CORS headers are not set when DevMode is disabled
    70  	for header := range expectedHeadersDevMode {
    71  		actualValue := res.Header().Get(header)
    72  		if actualValue != "" {
    73  			t.Errorf("Expected %s header to be empty, but got '%s'", header, actualValue)
    74  		}
    75  	}
    76  
    77  	if !mock.called {
    78  		t.Error("Expected ServeHTTP be called")
    79  	}
    80  }