github.com/adharshmk96/stk@v1.2.3/pkg/middleware/cors_test.go (about)

     1  package middleware_test
     2  
     3  import (
     4  	"net/http"
     5  	"testing"
     6  
     7  	"github.com/adharshmk96/stk/gsk"
     8  	"github.com/adharshmk96/stk/pkg/middleware"
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  func TestCORSDefault(t *testing.T) {
    13  	// Create a new server instance
    14  	config := &gsk.ServerConfig{
    15  		Port: "8888",
    16  	}
    17  	s := gsk.New(config)
    18  
    19  	s.Use(middleware.CORS())
    20  
    21  	// Register a test route and handler
    22  	s.Get("/", func(c *gsk.Context) {
    23  		c.Status(http.StatusOK).JSONResponse("OK")
    24  	})
    25  
    26  	t.Run("Non-preflight request", func(t *testing.T) {
    27  		// Run the test request
    28  		testParams := gsk.TestParams{
    29  			Headers: map[string]string{
    30  				"Origin": "example.com",
    31  			},
    32  		}
    33  		rr, _ := s.Test("GET", "/", nil, testParams)
    34  
    35  		expectedHeaders := map[string]string{
    36  			"Access-Control-Allow-Origin":  "example.com",
    37  			"Access-Control-Allow-Methods": "POST, GET, OPTIONS, PUT, DELETE, PATCH",
    38  			"Access-Control-Allow-Headers": "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization",
    39  		}
    40  
    41  		// expect http.StatusOK
    42  		if rr.Code != http.StatusOK {
    43  			t.Errorf("Expected response code %d, but got %d", http.StatusOK, rr.Code)
    44  		}
    45  
    46  		for header, expectedValue := range expectedHeaders {
    47  			if value := rr.Header().Get(header); value != expectedValue {
    48  				t.Errorf("Expected %s header to be %q, but got %q", header, expectedValue, value)
    49  			}
    50  		}
    51  	})
    52  
    53  }
    54  
    55  func TestCORSAllowedOrigin(t *testing.T) {
    56  	// Create a new server instance
    57  	config := &gsk.ServerConfig{
    58  		Port: "8888",
    59  	}
    60  
    61  	AllowedOrigins := []string{
    62  		"example.com",
    63  	}
    64  	s := gsk.New(config)
    65  
    66  	s.Use(middleware.CORS(middleware.CORSConfig{
    67  		AllowedOrigins: AllowedOrigins,
    68  	}))
    69  
    70  	// Register a test route and handler
    71  	s.Get("/", func(c *gsk.Context) {
    72  		c.Status(http.StatusOK).JSONResponse("OK")
    73  	})
    74  
    75  	t.Run("Non-preflight request from example.com", func(t *testing.T) {
    76  
    77  		// Run the test request
    78  		testParams := gsk.TestParams{
    79  			Headers: map[string]string{
    80  				"Origin": "example.com",
    81  			},
    82  		}
    83  		rr, _ := s.Test("GET", "/", nil, testParams)
    84  
    85  		expectedHeaders := map[string]string{
    86  			"Access-Control-Allow-Origin":  "example.com",
    87  			"Access-Control-Allow-Methods": "POST, GET, OPTIONS, PUT, DELETE, PATCH",
    88  			"Access-Control-Allow-Headers": "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization",
    89  		}
    90  
    91  		assert.Equal(t, http.StatusOK, rr.Code)
    92  
    93  		for header, expectedValue := range expectedHeaders {
    94  			value := rr.Header().Get(header)
    95  			assert.Equal(t, expectedValue, value)
    96  		}
    97  	})
    98  
    99  	t.Run("Non-preflight request with invalid origin", func(t *testing.T) {
   100  
   101  		// Run the test request
   102  		testParams := gsk.TestParams{
   103  			Headers: map[string]string{
   104  				"Origin": "invalid.com",
   105  			},
   106  		}
   107  		rr, _ := s.Test("GET", "/", nil, testParams)
   108  
   109  		expectedHeaders := map[string]string{
   110  			"Access-Control-Allow-Origin":  "",
   111  			"Access-Control-Allow-Methods": "",
   112  			"Access-Control-Allow-Headers": "",
   113  		}
   114  
   115  		assert.Equal(t, http.StatusForbidden, rr.Code)
   116  
   117  		for header, expectedValue := range expectedHeaders {
   118  			value := rr.Header().Get(header)
   119  			assert.Equal(t, expectedValue, value)
   120  		}
   121  	})
   122  
   123  	t.Run("Preflight request with example.com", func(t *testing.T) {
   124  
   125  		// Run the test request
   126  		testParams := gsk.TestParams{
   127  			Headers: map[string]string{
   128  				"Origin":                        "example.com",
   129  				"Access-Control-Request-Method": "POST",
   130  			},
   131  		}
   132  		rr, _ := s.Test("OPTIONS", "/", nil, testParams)
   133  
   134  		// NOTE: thie is behaviour from the router package
   135  		// change this if we are chaning the router
   136  		expectedHeaders := map[string]string{
   137  			"Access-Control-Allow-Origin":  "example.com",
   138  			"Access-Control-Allow-Methods": "POST, GET, OPTIONS, PUT, DELETE, PATCH",
   139  			"Access-Control-Allow-Headers": "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization",
   140  		}
   141  
   142  		assert.Equal(t, http.StatusNoContent, rr.Code)
   143  
   144  		for header, expectedValue := range expectedHeaders {
   145  			value := rr.Header().Get(header)
   146  			assert.Equal(t, expectedValue, value)
   147  		}
   148  	})
   149  
   150  	t.Run("Preflight request with invalid origin", func(t *testing.T) {
   151  
   152  		// Run the test request
   153  		testParams := gsk.TestParams{
   154  			Headers: map[string]string{
   155  				"Origin":                        "invalid.com",
   156  				"Access-Control-Request-Method": "POST",
   157  			},
   158  		}
   159  		rr, _ := s.Test("OPTIONS", "/", nil, testParams)
   160  
   161  		expectedHeaders := map[string]string{
   162  			"Access-Control-Allow-Origin":  "",
   163  			"Access-Control-Allow-Methods": "",
   164  			"Access-Control-Allow-Headers": "",
   165  		}
   166  
   167  		// TODO - this should be checked later on
   168  		assert.Equal(t, http.StatusForbidden, rr.Code)
   169  
   170  		for header, expectedValue := range expectedHeaders {
   171  			value := rr.Header().Get(header)
   172  			assert.Equal(t, expectedValue, value)
   173  		}
   174  	})
   175  
   176  }