github.com/franciscocpg/up@v0.1.10/http/cors/cors_test.go (about) 1 package cors 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httptest" 7 "testing" 8 9 "github.com/apex/up" 10 "github.com/tj/assert" 11 ) 12 13 var hello = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 fmt.Fprint(w, "Hello World") 15 }) 16 17 func TestCORS_disabled(t *testing.T) { 18 c, err := up.ParseConfigString(`{}`) 19 20 assert.NoError(t, err, "config") 21 22 h := New(c, hello) 23 24 t.Run("GET", func(t *testing.T) { 25 res := httptest.NewRecorder() 26 req := httptest.NewRequest("GET", "/", nil) 27 28 req.Header.Set("Origin", "https://example.com") 29 30 h.ServeHTTP(res, req) 31 32 header := make(http.Header) 33 header.Add("Content-Type", "text/plain; charset=utf-8") 34 35 assert.Equal(t, 200, res.Code) 36 assert.Equal(t, header, res.HeaderMap) 37 assert.Equal(t, "Hello World", res.Body.String()) 38 }) 39 } 40 41 func TestCORS_defaults(t *testing.T) { 42 c, err := up.ParseConfigString(`{ 43 "cors": {} 44 }`) 45 46 assert.NoError(t, err, "config") 47 48 h := New(c, hello) 49 50 t.Run("GET", func(t *testing.T) { 51 res := httptest.NewRecorder() 52 req := httptest.NewRequest("GET", "/", nil) 53 54 req.Header.Set("Origin", "https://example.com") 55 56 h.ServeHTTP(res, req) 57 58 header := make(http.Header) 59 header.Add("Content-Type", "text/plain; charset=utf-8") 60 header.Add("Vary", "Origin") 61 header.Add("Access-Control-Allow-Origin", "*") 62 63 assert.Equal(t, 200, res.Code) 64 assert.Equal(t, header, res.HeaderMap) 65 assert.Equal(t, "Hello World", res.Body.String()) 66 }) 67 68 t.Run("OPTIONS", func(t *testing.T) { 69 res := httptest.NewRecorder() 70 req := httptest.NewRequest("OPTIONS", "/", nil) 71 72 req.Header.Set("Access-Control-Request-Method", "POST") 73 req.Header.Set("Origin", "https://example.com") 74 req.Header.Set("Access-Control-Request-Headers", "Content-Type") 75 76 h.ServeHTTP(res, req) 77 78 header := make(http.Header) 79 header.Add("Vary", "Origin") 80 header.Add("Vary", "Access-Control-Request-Method") 81 header.Add("Vary", "Access-Control-Request-Headers") 82 header.Add("Access-Control-Allow-Methods", "POST") 83 header.Add("Access-Control-Allow-Headers", "Content-Type") 84 header.Add("Access-Control-Allow-Origin", "*") 85 86 assert.Equal(t, 200, res.Code) 87 assert.Equal(t, header, res.HeaderMap) 88 assert.Equal(t, "", res.Body.String()) 89 }) 90 } 91 92 func TestCORS_options(t *testing.T) { 93 c := up.MustParseConfigString(`{ 94 "cors": { 95 "allowed_origins": ["https://apex.sh"], 96 "allowed_methods": ["GET"], 97 "allow_credentials": true, 98 "max_age": 86400 99 } 100 }`) 101 102 h := New(c, hello) 103 104 t.Run("GET", func(t *testing.T) { 105 res := httptest.NewRecorder() 106 req := httptest.NewRequest("GET", "/", nil) 107 108 req.Header.Set("Origin", "https://example.com") 109 110 h.ServeHTTP(res, req) 111 112 header := make(http.Header) 113 header.Add("Content-Type", "text/plain; charset=utf-8") 114 header.Add("Vary", "Origin") 115 116 assert.Equal(t, 200, res.Code) 117 assert.Equal(t, header, res.HeaderMap) 118 assert.Equal(t, "Hello World", res.Body.String()) 119 }) 120 121 t.Run("OPTIONS", func(t *testing.T) { 122 res := httptest.NewRecorder() 123 req := httptest.NewRequest("OPTIONS", "/", nil) 124 125 req.Header.Set("Access-Control-Request-Method", "POST") 126 req.Header.Set("Origin", "https://example.com") 127 req.Header.Set("Access-Control-Request-Headers", "Content-Type") 128 129 h.ServeHTTP(res, req) 130 131 header := make(http.Header) 132 header.Add("Vary", "Origin") 133 header.Add("Vary", "Access-Control-Request-Method") 134 header.Add("Vary", "Access-Control-Request-Headers") 135 136 assert.Equal(t, 200, res.Code) 137 assert.Equal(t, header, res.HeaderMap) 138 assert.Equal(t, "", res.Body.String()) 139 }) 140 }