github.com/franciscocpg/up@v0.1.10/http/cors/cors_test.go (about)

     1  package cors
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  
     9  	"github.com/apex/up"
    10  	"github.com/tj/assert"
    11  )
    12  
    13  var hello = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    14  	fmt.Fprint(w, "Hello World")
    15  })
    16  
    17  func TestCORS_disabled(t *testing.T) {
    18  	c, err := up.ParseConfigString(`{}`)
    19  
    20  	assert.NoError(t, err, "config")
    21  
    22  	h := New(c, hello)
    23  
    24  	t.Run("GET", func(t *testing.T) {
    25  		res := httptest.NewRecorder()
    26  		req := httptest.NewRequest("GET", "/", nil)
    27  
    28  		req.Header.Set("Origin", "https://example.com")
    29  
    30  		h.ServeHTTP(res, req)
    31  
    32  		header := make(http.Header)
    33  		header.Add("Content-Type", "text/plain; charset=utf-8")
    34  
    35  		assert.Equal(t, 200, res.Code)
    36  		assert.Equal(t, header, res.HeaderMap)
    37  		assert.Equal(t, "Hello World", res.Body.String())
    38  	})
    39  }
    40  
    41  func TestCORS_defaults(t *testing.T) {
    42  	c, err := up.ParseConfigString(`{
    43  		"cors": {}
    44  	}`)
    45  
    46  	assert.NoError(t, err, "config")
    47  
    48  	h := New(c, hello)
    49  
    50  	t.Run("GET", func(t *testing.T) {
    51  		res := httptest.NewRecorder()
    52  		req := httptest.NewRequest("GET", "/", nil)
    53  
    54  		req.Header.Set("Origin", "https://example.com")
    55  
    56  		h.ServeHTTP(res, req)
    57  
    58  		header := make(http.Header)
    59  		header.Add("Content-Type", "text/plain; charset=utf-8")
    60  		header.Add("Vary", "Origin")
    61  		header.Add("Access-Control-Allow-Origin", "*")
    62  
    63  		assert.Equal(t, 200, res.Code)
    64  		assert.Equal(t, header, res.HeaderMap)
    65  		assert.Equal(t, "Hello World", res.Body.String())
    66  	})
    67  
    68  	t.Run("OPTIONS", func(t *testing.T) {
    69  		res := httptest.NewRecorder()
    70  		req := httptest.NewRequest("OPTIONS", "/", nil)
    71  
    72  		req.Header.Set("Access-Control-Request-Method", "POST")
    73  		req.Header.Set("Origin", "https://example.com")
    74  		req.Header.Set("Access-Control-Request-Headers", "Content-Type")
    75  
    76  		h.ServeHTTP(res, req)
    77  
    78  		header := make(http.Header)
    79  		header.Add("Vary", "Origin")
    80  		header.Add("Vary", "Access-Control-Request-Method")
    81  		header.Add("Vary", "Access-Control-Request-Headers")
    82  		header.Add("Access-Control-Allow-Methods", "POST")
    83  		header.Add("Access-Control-Allow-Headers", "Content-Type")
    84  		header.Add("Access-Control-Allow-Origin", "*")
    85  
    86  		assert.Equal(t, 200, res.Code)
    87  		assert.Equal(t, header, res.HeaderMap)
    88  		assert.Equal(t, "", res.Body.String())
    89  	})
    90  }
    91  
    92  func TestCORS_options(t *testing.T) {
    93  	c := up.MustParseConfigString(`{
    94  		"cors": {
    95  			"allowed_origins": ["https://apex.sh"],
    96  			"allowed_methods": ["GET"],
    97  			"allow_credentials": true,
    98  			"max_age": 86400
    99  		}
   100  	}`)
   101  
   102  	h := New(c, hello)
   103  
   104  	t.Run("GET", func(t *testing.T) {
   105  		res := httptest.NewRecorder()
   106  		req := httptest.NewRequest("GET", "/", nil)
   107  
   108  		req.Header.Set("Origin", "https://example.com")
   109  
   110  		h.ServeHTTP(res, req)
   111  
   112  		header := make(http.Header)
   113  		header.Add("Content-Type", "text/plain; charset=utf-8")
   114  		header.Add("Vary", "Origin")
   115  
   116  		assert.Equal(t, 200, res.Code)
   117  		assert.Equal(t, header, res.HeaderMap)
   118  		assert.Equal(t, "Hello World", res.Body.String())
   119  	})
   120  
   121  	t.Run("OPTIONS", func(t *testing.T) {
   122  		res := httptest.NewRecorder()
   123  		req := httptest.NewRequest("OPTIONS", "/", nil)
   124  
   125  		req.Header.Set("Access-Control-Request-Method", "POST")
   126  		req.Header.Set("Origin", "https://example.com")
   127  		req.Header.Set("Access-Control-Request-Headers", "Content-Type")
   128  
   129  		h.ServeHTTP(res, req)
   130  
   131  		header := make(http.Header)
   132  		header.Add("Vary", "Origin")
   133  		header.Add("Vary", "Access-Control-Request-Method")
   134  		header.Add("Vary", "Access-Control-Request-Headers")
   135  
   136  		assert.Equal(t, 200, res.Code)
   137  		assert.Equal(t, header, res.HeaderMap)
   138  		assert.Equal(t, "", res.Body.String())
   139  	})
   140  }