github.com/lingyao2333/mo-zero@v1.4.1/rest/internal/cors/handlers_test.go (about)

     1  package cors
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  func TestCorsHandlerWithOrigins(t *testing.T) {
    12  	tests := []struct {
    13  		name      string
    14  		origins   []string
    15  		reqOrigin string
    16  		expect    string
    17  	}{
    18  		{
    19  			name:   "allow all origins",
    20  			expect: allOrigins,
    21  		},
    22  		{
    23  			name:      "allow one origin",
    24  			origins:   []string{"http://local"},
    25  			reqOrigin: "http://local",
    26  			expect:    "http://local",
    27  		},
    28  		{
    29  			name:      "allow many origins",
    30  			origins:   []string{"http://local", "http://remote"},
    31  			reqOrigin: "http://local",
    32  			expect:    "http://local",
    33  		},
    34  		{
    35  			name:      "allow sub origins",
    36  			origins:   []string{"local", "remote"},
    37  			reqOrigin: "sub.local",
    38  			expect:    "sub.local",
    39  		},
    40  		{
    41  			name:      "allow all origins",
    42  			reqOrigin: "http://local",
    43  			expect:    "*",
    44  		},
    45  		{
    46  			name:      "allow many origins with all mark",
    47  			origins:   []string{"http://local", "http://remote", "*"},
    48  			reqOrigin: "http://another",
    49  			expect:    "http://another",
    50  		},
    51  		{
    52  			name:      "not allow origin",
    53  			origins:   []string{"http://local", "http://remote"},
    54  			reqOrigin: "http://another",
    55  		},
    56  	}
    57  
    58  	methods := []string{
    59  		http.MethodOptions,
    60  		http.MethodGet,
    61  		http.MethodPost,
    62  	}
    63  
    64  	for _, test := range tests {
    65  		for _, method := range methods {
    66  			test := test
    67  			t.Run(test.name+"-handler", func(t *testing.T) {
    68  				r := httptest.NewRequest(method, "http://localhost", http.NoBody)
    69  				r.Header.Set(originHeader, test.reqOrigin)
    70  				w := httptest.NewRecorder()
    71  				handler := NotAllowedHandler(nil, test.origins...)
    72  				handler.ServeHTTP(w, r)
    73  				if method == http.MethodOptions {
    74  					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
    75  				} else {
    76  					assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
    77  				}
    78  				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
    79  			})
    80  			t.Run(test.name+"-handler-custom", func(t *testing.T) {
    81  				r := httptest.NewRequest(method, "http://localhost", http.NoBody)
    82  				r.Header.Set(originHeader, test.reqOrigin)
    83  				w := httptest.NewRecorder()
    84  				handler := NotAllowedHandler(func(w http.ResponseWriter) {
    85  					w.Header().Set("foo", "bar")
    86  				}, test.origins...)
    87  				handler.ServeHTTP(w, r)
    88  				if method == http.MethodOptions {
    89  					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
    90  				} else {
    91  					assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
    92  				}
    93  				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
    94  				assert.Equal(t, "bar", w.Header().Get("foo"))
    95  			})
    96  		}
    97  	}
    98  
    99  	for _, test := range tests {
   100  		for _, method := range methods {
   101  			test := test
   102  			t.Run(test.name+"-middleware", func(t *testing.T) {
   103  				r := httptest.NewRequest(method, "http://localhost", http.NoBody)
   104  				r.Header.Set(originHeader, test.reqOrigin)
   105  				w := httptest.NewRecorder()
   106  				handler := Middleware(nil, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
   107  					w.WriteHeader(http.StatusOK)
   108  				})
   109  				handler.ServeHTTP(w, r)
   110  				if method == http.MethodOptions {
   111  					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
   112  				} else {
   113  					assert.Equal(t, http.StatusOK, w.Result().StatusCode)
   114  				}
   115  				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
   116  			})
   117  			t.Run(test.name+"-middleware-custom", func(t *testing.T) {
   118  				r := httptest.NewRequest(method, "http://localhost", http.NoBody)
   119  				r.Header.Set(originHeader, test.reqOrigin)
   120  				w := httptest.NewRecorder()
   121  				handler := Middleware(func(header http.Header) {
   122  					header.Set("foo", "bar")
   123  				}, test.origins...)(func(w http.ResponseWriter, r *http.Request) {
   124  					w.WriteHeader(http.StatusOK)
   125  				})
   126  				handler.ServeHTTP(w, r)
   127  				if method == http.MethodOptions {
   128  					assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
   129  				} else {
   130  					assert.Equal(t, http.StatusOK, w.Result().StatusCode)
   131  				}
   132  				assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
   133  				assert.Equal(t, "bar", w.Header().Get("foo"))
   134  			})
   135  		}
   136  	}
   137  }