github.com/orangenpresse/up@v0.6.0/http/cors/cors_test.go (about) 1 package cors 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httptest" 7 "testing" 8 9 "github.com/apex/up" 10 "github.com/tj/assert" 11 ) 12 13 var hello = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 fmt.Fprint(w, "Hello World") 15 }) 16 17 func TestCORS_disabled(t *testing.T) { 18 c, err := up.ParseConfigString(`{ 19 "name": "app" 20 }`) 21 22 assert.NoError(t, err, "config") 23 24 h := New(c, hello) 25 26 t.Run("GET", func(t *testing.T) { 27 res := httptest.NewRecorder() 28 req := httptest.NewRequest("GET", "/", nil) 29 30 req.Header.Set("Origin", "https://example.com") 31 32 h.ServeHTTP(res, req) 33 34 header := make(http.Header) 35 header.Add("Content-Type", "text/plain; charset=utf-8") 36 37 assert.Equal(t, 200, res.Code) 38 assert.Equal(t, header, res.HeaderMap) 39 assert.Equal(t, "Hello World", res.Body.String()) 40 }) 41 } 42 43 func TestCORS_defaults(t *testing.T) { 44 c, err := up.ParseConfigString(`{ 45 "name": "app", 46 "cors": {} 47 }`) 48 49 assert.NoError(t, err, "config") 50 51 h := New(c, hello) 52 53 t.Run("GET", func(t *testing.T) { 54 res := httptest.NewRecorder() 55 req := httptest.NewRequest("GET", "/", nil) 56 57 req.Header.Set("Origin", "https://example.com") 58 59 h.ServeHTTP(res, req) 60 61 header := make(http.Header) 62 header.Add("Content-Type", "text/plain; charset=utf-8") 63 header.Add("Vary", "Origin") 64 header.Add("Access-Control-Allow-Origin", "*") 65 66 assert.Equal(t, 200, res.Code) 67 assert.Equal(t, header, res.HeaderMap) 68 assert.Equal(t, "Hello World", res.Body.String()) 69 }) 70 71 t.Run("OPTIONS", func(t *testing.T) { 72 res := httptest.NewRecorder() 73 req := httptest.NewRequest("OPTIONS", "/", nil) 74 75 req.Header.Set("Access-Control-Request-Method", "POST") 76 req.Header.Set("Origin", "https://example.com") 77 req.Header.Set("Access-Control-Request-Headers", "Content-Type") 78 79 h.ServeHTTP(res, req) 80 81 header := make(http.Header) 82 header.Add("Vary", "Origin") 83 header.Add("Vary", "Access-Control-Request-Method") 84 header.Add("Vary", "Access-Control-Request-Headers") 85 header.Add("Access-Control-Allow-Methods", "POST") 86 header.Add("Access-Control-Allow-Headers", "Content-Type") 87 header.Add("Access-Control-Allow-Origin", "*") 88 89 assert.Equal(t, 200, res.Code) 90 assert.Equal(t, header, res.HeaderMap) 91 assert.Equal(t, "", res.Body.String()) 92 }) 93 } 94 95 func TestCORS_options(t *testing.T) { 96 c := up.MustParseConfigString(`{ 97 "name": "app", 98 "cors": { 99 "allowed_origins": ["https://apex.sh"], 100 "allowed_methods": ["GET"], 101 "allow_credentials": true, 102 "max_age": 86400 103 } 104 }`) 105 106 h := New(c, hello) 107 108 t.Run("GET", func(t *testing.T) { 109 res := httptest.NewRecorder() 110 req := httptest.NewRequest("GET", "/", nil) 111 112 req.Header.Set("Origin", "https://example.com") 113 114 h.ServeHTTP(res, req) 115 116 header := make(http.Header) 117 header.Add("Content-Type", "text/plain; charset=utf-8") 118 header.Add("Vary", "Origin") 119 120 assert.Equal(t, 200, res.Code) 121 assert.Equal(t, header, res.HeaderMap) 122 assert.Equal(t, "Hello World", res.Body.String()) 123 }) 124 125 t.Run("OPTIONS", func(t *testing.T) { 126 res := httptest.NewRecorder() 127 req := httptest.NewRequest("OPTIONS", "/", nil) 128 129 req.Header.Set("Access-Control-Request-Method", "POST") 130 req.Header.Set("Origin", "https://example.com") 131 req.Header.Set("Access-Control-Request-Headers", "Content-Type") 132 133 h.ServeHTTP(res, req) 134 135 header := make(http.Header) 136 header.Add("Vary", "Origin") 137 header.Add("Vary", "Access-Control-Request-Method") 138 header.Add("Vary", "Access-Control-Request-Headers") 139 140 assert.Equal(t, 200, res.Code) 141 assert.Equal(t, header, res.HeaderMap) 142 assert.Equal(t, "", res.Body.String()) 143 }) 144 }