github.com/Jeffail/benthos/v3@v3.65.0/internal/http/docs/cors_test.go (about) 1 package docs 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/require" 10 ) 11 12 func TestAPIEnableCORS(t *testing.T) { 13 conf := NewServerCORS() 14 conf.Enabled = true 15 conf.AllowedOrigins = []string{"*"} 16 17 tmpHandler := http.NewServeMux() 18 tmpHandler.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) { 19 _, _ = w.Write([]byte("1.2.3")) 20 }) 21 22 handler, err := conf.WrapHandler(tmpHandler) 23 require.NoError(t, err) 24 25 request, _ := http.NewRequest("OPTIONS", "/version", http.NoBody) 26 request.Header.Add("Origin", "meow") 27 request.Header.Add("Access-Control-Request-Method", "POST") 28 29 response := httptest.NewRecorder() 30 handler.ServeHTTP(response, request) 31 32 assert.Equal(t, http.StatusOK, response.Code) 33 assert.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin")) 34 } 35 36 func TestAPIEnableCORSOrigins(t *testing.T) { 37 conf := NewServerCORS() 38 conf.Enabled = true 39 conf.AllowedOrigins = []string{"foo", "bar"} 40 41 tmpHandler := http.NewServeMux() 42 tmpHandler.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) { 43 _, _ = w.Write([]byte("1.2.3")) 44 }) 45 46 handler, err := conf.WrapHandler(tmpHandler) 47 require.NoError(t, err) 48 49 request, _ := http.NewRequest("OPTIONS", "/version", http.NoBody) 50 request.Header.Add("Origin", "foo") 51 request.Header.Add("Access-Control-Request-Method", "POST") 52 53 response := httptest.NewRecorder() 54 handler.ServeHTTP(response, request) 55 56 assert.Equal(t, http.StatusOK, response.Code) 57 assert.Equal(t, "foo", response.Header().Get("Access-Control-Allow-Origin")) 58 59 request, _ = http.NewRequest("OPTIONS", "/version", http.NoBody) 60 request.Header.Add("Origin", "bar") 61 request.Header.Add("Access-Control-Request-Method", "POST") 62 63 response = httptest.NewRecorder() 64 handler.ServeHTTP(response, request) 65 66 assert.Equal(t, http.StatusOK, response.Code) 67 assert.Equal(t, "bar", response.Header().Get("Access-Control-Allow-Origin")) 68 69 request, _ = http.NewRequest("OPTIONS", "/version", http.NoBody) 70 request.Header.Add("Origin", "baz") 71 request.Header.Add("Access-Control-Request-Method", "POST") 72 73 response = httptest.NewRecorder() 74 handler.ServeHTTP(response, request) 75 76 assert.Equal(t, http.StatusOK, response.Code) 77 assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin")) 78 } 79 80 func TestAPIEnableCORSNoHeaders(t *testing.T) { 81 conf := NewServerCORS() 82 conf.Enabled = true 83 84 _, err := conf.WrapHandler(http.NewServeMux()) 85 require.Error(t, err) 86 assert.Contains(t, err.Error(), "must specify at least one allowed origin") 87 }