github.com/Jeffail/benthos/v3@v3.65.0/internal/http/docs/cors_test.go (about)

     1  package docs
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  )
    11  
    12  func TestAPIEnableCORS(t *testing.T) {
    13  	conf := NewServerCORS()
    14  	conf.Enabled = true
    15  	conf.AllowedOrigins = []string{"*"}
    16  
    17  	tmpHandler := http.NewServeMux()
    18  	tmpHandler.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
    19  		_, _ = w.Write([]byte("1.2.3"))
    20  	})
    21  
    22  	handler, err := conf.WrapHandler(tmpHandler)
    23  	require.NoError(t, err)
    24  
    25  	request, _ := http.NewRequest("OPTIONS", "/version", http.NoBody)
    26  	request.Header.Add("Origin", "meow")
    27  	request.Header.Add("Access-Control-Request-Method", "POST")
    28  
    29  	response := httptest.NewRecorder()
    30  	handler.ServeHTTP(response, request)
    31  
    32  	assert.Equal(t, http.StatusOK, response.Code)
    33  	assert.Equal(t, "*", response.Header().Get("Access-Control-Allow-Origin"))
    34  }
    35  
    36  func TestAPIEnableCORSOrigins(t *testing.T) {
    37  	conf := NewServerCORS()
    38  	conf.Enabled = true
    39  	conf.AllowedOrigins = []string{"foo", "bar"}
    40  
    41  	tmpHandler := http.NewServeMux()
    42  	tmpHandler.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) {
    43  		_, _ = w.Write([]byte("1.2.3"))
    44  	})
    45  
    46  	handler, err := conf.WrapHandler(tmpHandler)
    47  	require.NoError(t, err)
    48  
    49  	request, _ := http.NewRequest("OPTIONS", "/version", http.NoBody)
    50  	request.Header.Add("Origin", "foo")
    51  	request.Header.Add("Access-Control-Request-Method", "POST")
    52  
    53  	response := httptest.NewRecorder()
    54  	handler.ServeHTTP(response, request)
    55  
    56  	assert.Equal(t, http.StatusOK, response.Code)
    57  	assert.Equal(t, "foo", response.Header().Get("Access-Control-Allow-Origin"))
    58  
    59  	request, _ = http.NewRequest("OPTIONS", "/version", http.NoBody)
    60  	request.Header.Add("Origin", "bar")
    61  	request.Header.Add("Access-Control-Request-Method", "POST")
    62  
    63  	response = httptest.NewRecorder()
    64  	handler.ServeHTTP(response, request)
    65  
    66  	assert.Equal(t, http.StatusOK, response.Code)
    67  	assert.Equal(t, "bar", response.Header().Get("Access-Control-Allow-Origin"))
    68  
    69  	request, _ = http.NewRequest("OPTIONS", "/version", http.NoBody)
    70  	request.Header.Add("Origin", "baz")
    71  	request.Header.Add("Access-Control-Request-Method", "POST")
    72  
    73  	response = httptest.NewRecorder()
    74  	handler.ServeHTTP(response, request)
    75  
    76  	assert.Equal(t, http.StatusOK, response.Code)
    77  	assert.Equal(t, "", response.Header().Get("Access-Control-Allow-Origin"))
    78  }
    79  
    80  func TestAPIEnableCORSNoHeaders(t *testing.T) {
    81  	conf := NewServerCORS()
    82  	conf.Enabled = true
    83  
    84  	_, err := conf.WrapHandler(http.NewServeMux())
    85  	require.Error(t, err)
    86  	assert.Contains(t, err.Error(), "must specify at least one allowed origin")
    87  }