github.com/orangenpresse/up@v0.6.0/http/cors/cors_test.go (about)

     1  package cors
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  
     9  	"github.com/apex/up"
    10  	"github.com/tj/assert"
    11  )
    12  
    13  var hello = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    14  	fmt.Fprint(w, "Hello World")
    15  })
    16  
    17  func TestCORS_disabled(t *testing.T) {
    18  	c, err := up.ParseConfigString(`{
    19  		"name": "app"
    20  	}`)
    21  
    22  	assert.NoError(t, err, "config")
    23  
    24  	h := New(c, hello)
    25  
    26  	t.Run("GET", func(t *testing.T) {
    27  		res := httptest.NewRecorder()
    28  		req := httptest.NewRequest("GET", "/", nil)
    29  
    30  		req.Header.Set("Origin", "https://example.com")
    31  
    32  		h.ServeHTTP(res, req)
    33  
    34  		header := make(http.Header)
    35  		header.Add("Content-Type", "text/plain; charset=utf-8")
    36  
    37  		assert.Equal(t, 200, res.Code)
    38  		assert.Equal(t, header, res.HeaderMap)
    39  		assert.Equal(t, "Hello World", res.Body.String())
    40  	})
    41  }
    42  
    43  func TestCORS_defaults(t *testing.T) {
    44  	c, err := up.ParseConfigString(`{
    45  		"name": "app",
    46  		"cors": {}
    47  	}`)
    48  
    49  	assert.NoError(t, err, "config")
    50  
    51  	h := New(c, hello)
    52  
    53  	t.Run("GET", func(t *testing.T) {
    54  		res := httptest.NewRecorder()
    55  		req := httptest.NewRequest("GET", "/", nil)
    56  
    57  		req.Header.Set("Origin", "https://example.com")
    58  
    59  		h.ServeHTTP(res, req)
    60  
    61  		header := make(http.Header)
    62  		header.Add("Content-Type", "text/plain; charset=utf-8")
    63  		header.Add("Vary", "Origin")
    64  		header.Add("Access-Control-Allow-Origin", "*")
    65  
    66  		assert.Equal(t, 200, res.Code)
    67  		assert.Equal(t, header, res.HeaderMap)
    68  		assert.Equal(t, "Hello World", res.Body.String())
    69  	})
    70  
    71  	t.Run("OPTIONS", func(t *testing.T) {
    72  		res := httptest.NewRecorder()
    73  		req := httptest.NewRequest("OPTIONS", "/", nil)
    74  
    75  		req.Header.Set("Access-Control-Request-Method", "POST")
    76  		req.Header.Set("Origin", "https://example.com")
    77  		req.Header.Set("Access-Control-Request-Headers", "Content-Type")
    78  
    79  		h.ServeHTTP(res, req)
    80  
    81  		header := make(http.Header)
    82  		header.Add("Vary", "Origin")
    83  		header.Add("Vary", "Access-Control-Request-Method")
    84  		header.Add("Vary", "Access-Control-Request-Headers")
    85  		header.Add("Access-Control-Allow-Methods", "POST")
    86  		header.Add("Access-Control-Allow-Headers", "Content-Type")
    87  		header.Add("Access-Control-Allow-Origin", "*")
    88  
    89  		assert.Equal(t, 200, res.Code)
    90  		assert.Equal(t, header, res.HeaderMap)
    91  		assert.Equal(t, "", res.Body.String())
    92  	})
    93  }
    94  
    95  func TestCORS_options(t *testing.T) {
    96  	c := up.MustParseConfigString(`{
    97  		"name": "app",
    98  		"cors": {
    99  			"allowed_origins": ["https://apex.sh"],
   100  			"allowed_methods": ["GET"],
   101  			"allow_credentials": true,
   102  			"max_age": 86400
   103  		}
   104  	}`)
   105  
   106  	h := New(c, hello)
   107  
   108  	t.Run("GET", func(t *testing.T) {
   109  		res := httptest.NewRecorder()
   110  		req := httptest.NewRequest("GET", "/", nil)
   111  
   112  		req.Header.Set("Origin", "https://example.com")
   113  
   114  		h.ServeHTTP(res, req)
   115  
   116  		header := make(http.Header)
   117  		header.Add("Content-Type", "text/plain; charset=utf-8")
   118  		header.Add("Vary", "Origin")
   119  
   120  		assert.Equal(t, 200, res.Code)
   121  		assert.Equal(t, header, res.HeaderMap)
   122  		assert.Equal(t, "Hello World", res.Body.String())
   123  	})
   124  
   125  	t.Run("OPTIONS", func(t *testing.T) {
   126  		res := httptest.NewRecorder()
   127  		req := httptest.NewRequest("OPTIONS", "/", nil)
   128  
   129  		req.Header.Set("Access-Control-Request-Method", "POST")
   130  		req.Header.Set("Origin", "https://example.com")
   131  		req.Header.Set("Access-Control-Request-Headers", "Content-Type")
   132  
   133  		h.ServeHTTP(res, req)
   134  
   135  		header := make(http.Header)
   136  		header.Add("Vary", "Origin")
   137  		header.Add("Vary", "Access-Control-Request-Method")
   138  		header.Add("Vary", "Access-Control-Request-Headers")
   139  
   140  		assert.Equal(t, 200, res.Code)
   141  		assert.Equal(t, header, res.HeaderMap)
   142  		assert.Equal(t, "", res.Body.String())
   143  	})
   144  }