gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/handlers/cors_test.go (about)

     1  package handlers
     2  
     3  import (
     4  	http "gitee.com/ks-custle/core-gm/gmhttp"
     5  	"gitee.com/ks-custle/core-gm/gmhttp/httptest"
     6  	"strings"
     7  	"testing"
     8  )
     9  
    10  func TestDefaultCORSHandlerReturnsOk(t *testing.T) {
    11  	r := newRequest("GET", "http://www.example.com/")
    12  	rr := httptest.NewRecorder()
    13  
    14  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    15  
    16  	CORS()(testHandler).ServeHTTP(rr, r)
    17  
    18  	if got, want := rr.Code, http.StatusOK; got != want {
    19  		t.Fatalf("bad status: got %v want %v", got, want)
    20  	}
    21  }
    22  
    23  func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) {
    24  	r := newRequest("GET", "http://www.example.com/")
    25  	r.Header.Set("Origin", r.URL.String())
    26  
    27  	rr := httptest.NewRecorder()
    28  
    29  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    30  
    31  	CORS()(testHandler).ServeHTTP(rr, r)
    32  
    33  	if got, want := rr.Code, http.StatusOK; got != want {
    34  		t.Fatalf("bad status: got %v want %v", got, want)
    35  	}
    36  }
    37  
    38  func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) {
    39  	r := newRequest("OPTIONS", "http://www.example.com/")
    40  	r.Header.Set("Origin", r.URL.String())
    41  
    42  	rr := httptest.NewRecorder()
    43  
    44  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    45  		w.WriteHeader(http.StatusTeapot)
    46  	})
    47  
    48  	CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r)
    49  
    50  	if got, want := rr.Code, http.StatusTeapot; got != want {
    51  		t.Fatalf("bad status: got %v want %v", got, want)
    52  	}
    53  }
    54  
    55  func TestCORSHandlerSetsExposedHeaders(t *testing.T) {
    56  	// Test default configuration.
    57  	r := newRequest("GET", "http://www.example.com/")
    58  	r.Header.Set("Origin", r.URL.String())
    59  
    60  	rr := httptest.NewRecorder()
    61  
    62  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    63  
    64  	CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r)
    65  
    66  	if got, want := rr.Code, http.StatusOK; got != want {
    67  		t.Fatalf("bad status: got %v want %v", got, want)
    68  	}
    69  
    70  	header := rr.HeaderMap.Get(corsExposeHeadersHeader)
    71  	if got, want := header, "X-Cors-Test"; got != want {
    72  		t.Fatalf("bad header: expected %q header, got empty header for method.", want)
    73  	}
    74  }
    75  
    76  func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) {
    77  	r := newRequest("OPTIONS", "http://www.example.com/")
    78  	r.Header.Set("Origin", r.URL.String())
    79  
    80  	rr := httptest.NewRecorder()
    81  
    82  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    83  
    84  	CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
    85  
    86  	if got, want := rr.Code, http.StatusBadRequest; got != want {
    87  		t.Fatalf("bad status: got %v want %v", got, want)
    88  	}
    89  }
    90  
    91  func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) {
    92  	r := newRequest("OPTIONS", "http://www.example.com/")
    93  	r.Header.Set("Origin", r.URL.String())
    94  	r.Header.Set(corsRequestMethodHeader, "DELETE")
    95  
    96  	rr := httptest.NewRecorder()
    97  
    98  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    99  
   100  	CORS()(testHandler).ServeHTTP(rr, r)
   101  
   102  	if got, want := rr.Code, http.StatusMethodNotAllowed; got != want {
   103  		t.Fatalf("bad status: got %v want %v", got, want)
   104  	}
   105  }
   106  
   107  func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) {
   108  	r := newRequest("OPTIONS", "http://www.example.com/")
   109  	r.Header.Set("Origin", r.URL.String())
   110  	r.Header.Set(corsRequestMethodHeader, "GET")
   111  
   112  	rr := httptest.NewRecorder()
   113  
   114  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   115  		t.Fatal("Options request must not be passed to next handler")
   116  	})
   117  
   118  	CORS()(testHandler).ServeHTTP(rr, r)
   119  
   120  	if got, want := rr.Code, http.StatusOK; got != want {
   121  		t.Fatalf("bad status: got %v want %v", got, want)
   122  	}
   123  }
   124  
   125  func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) {
   126  	statusCode := http.StatusNoContent
   127  	r := newRequest("OPTIONS", "http://www.example.com/")
   128  	r.Header.Set("Origin", r.URL.String())
   129  	r.Header.Set(corsRequestMethodHeader, "GET")
   130  
   131  	rr := httptest.NewRecorder()
   132  
   133  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   134  		t.Fatal("Options request must not be passed to next handler")
   135  	})
   136  
   137  	CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r)
   138  
   139  	if got, want := rr.Code, statusCode; got != want {
   140  		t.Fatalf("bad status: got %v want %v", got, want)
   141  	}
   142  }
   143  
   144  func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) {
   145  	r := newRequest("OPTIONS", "http://www.example.com/")
   146  	r.Header.Set("Origin", r.URL.String())
   147  	r.Header.Set(corsRequestMethodHeader, "GET")
   148  
   149  	rr := httptest.NewRecorder()
   150  
   151  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   152  		t.Fatal("Options request must not be passed to next handler")
   153  	})
   154  
   155  	CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r)
   156  
   157  	if got, want := rr.Code, http.StatusOK; got != want {
   158  		t.Fatalf("bad status: got %v want %v", got, want)
   159  	}
   160  }
   161  
   162  func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) {
   163  	r := newRequest("OPTIONS", "http://www.example.com/")
   164  	r.Header.Set("Origin", r.URL.String())
   165  	r.Header.Set(corsRequestMethodHeader, "DELETE")
   166  
   167  	rr := httptest.NewRecorder()
   168  
   169  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   170  
   171  	CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
   172  
   173  	if got, want := rr.Code, http.StatusOK; got != want {
   174  		t.Fatalf("bad status: got %v want %v", got, want)
   175  	}
   176  
   177  	header := rr.HeaderMap.Get(corsAllowMethodsHeader)
   178  	if got, want := header, "DELETE"; got != want {
   179  		t.Fatalf("bad header: expected %q method header, got %q header.", want, got)
   180  	}
   181  }
   182  
   183  func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) {
   184  	for _, method := range defaultCorsMethods {
   185  		r := newRequest("OPTIONS", "http://www.example.com/")
   186  		r.Header.Set("Origin", r.URL.String())
   187  		r.Header.Set(corsRequestMethodHeader, method)
   188  
   189  		rr := httptest.NewRecorder()
   190  
   191  		testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   192  
   193  		CORS()(testHandler).ServeHTTP(rr, r)
   194  
   195  		if got, want := rr.Code, http.StatusOK; got != want {
   196  			t.Fatalf("bad status: got %v want %v", got, want)
   197  		}
   198  
   199  		header := rr.HeaderMap.Get(corsAllowMethodsHeader)
   200  		if got, want := header, ""; got != want {
   201  			t.Fatalf("bad header: expected %q method header, got %q.", want, got)
   202  		}
   203  	}
   204  }
   205  
   206  func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) {
   207  	for _, simpleHeader := range defaultCorsHeaders {
   208  		r := newRequest("OPTIONS", "http://www.example.com/")
   209  		r.Header.Set("Origin", r.URL.String())
   210  		r.Header.Set(corsRequestMethodHeader, "GET")
   211  		r.Header.Set(corsRequestHeadersHeader, simpleHeader)
   212  
   213  		rr := httptest.NewRecorder()
   214  
   215  		testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   216  
   217  		CORS()(testHandler).ServeHTTP(rr, r)
   218  
   219  		if got, want := rr.Code, http.StatusOK; got != want {
   220  			t.Fatalf("bad status: got %v want %v", got, want)
   221  		}
   222  
   223  		header := rr.HeaderMap.Get(corsAllowHeadersHeader)
   224  		if got, want := header, ""; got != want {
   225  			t.Fatalf("bad header: expected %q header, got %q.", want, got)
   226  		}
   227  	}
   228  }
   229  
   230  func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) {
   231  	r := newRequest("OPTIONS", "http://www.example.com/")
   232  	r.Header.Set("Origin", r.URL.String())
   233  	r.Header.Set(corsRequestMethodHeader, "POST")
   234  	r.Header.Set(corsRequestHeadersHeader, "Content-Type")
   235  
   236  	rr := httptest.NewRecorder()
   237  
   238  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   239  
   240  	CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r)
   241  
   242  	if got, want := rr.Code, http.StatusOK; got != want {
   243  		t.Fatalf("bad status: got %v want %v", got, want)
   244  	}
   245  
   246  	header := rr.HeaderMap.Get(corsAllowHeadersHeader)
   247  	if got, want := header, "Content-Type"; got != want {
   248  		t.Fatalf("bad header: expected %q header, got %q header.", want, got)
   249  	}
   250  }
   251  
   252  func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) {
   253  	r := newRequest("OPTIONS", "http://www.example.com/")
   254  	r.Header.Set("Origin", r.URL.String())
   255  	r.Header.Set(corsRequestMethodHeader, "POST")
   256  	r.Header.Set(corsRequestHeadersHeader, "Content-Type")
   257  
   258  	rr := httptest.NewRecorder()
   259  
   260  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   261  
   262  	CORS()(testHandler).ServeHTTP(rr, r)
   263  
   264  	if got, want := rr.Code, http.StatusForbidden; got != want {
   265  		t.Fatalf("bad status: got %v want %v", got, want)
   266  	}
   267  }
   268  
   269  func TestCORSHandlerMaxAgeForPreflight(t *testing.T) {
   270  	r := newRequest("OPTIONS", "http://www.example.com/")
   271  	r.Header.Set("Origin", r.URL.String())
   272  	r.Header.Set(corsRequestMethodHeader, "POST")
   273  
   274  	rr := httptest.NewRecorder()
   275  
   276  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   277  
   278  	CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r)
   279  
   280  	if got, want := rr.Code, http.StatusOK; got != want {
   281  		t.Fatalf("bad status: got %v want %v", got, want)
   282  	}
   283  
   284  	header := rr.HeaderMap.Get(corsMaxAgeHeader)
   285  	if got, want := header, "600"; got != want {
   286  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got)
   287  	}
   288  }
   289  
   290  func TestCORSHandlerAllowedCredentials(t *testing.T) {
   291  	r := newRequest("GET", "http://www.example.com/")
   292  	r.Header.Set("Origin", r.URL.String())
   293  
   294  	rr := httptest.NewRecorder()
   295  
   296  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   297  
   298  	CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r)
   299  
   300  	if status := rr.Code; status != http.StatusOK {
   301  		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
   302  	}
   303  
   304  	header := rr.HeaderMap.Get(corsAllowCredentialsHeader)
   305  	if got, want := header, "true"; got != want {
   306  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowCredentialsHeader, want, got)
   307  	}
   308  }
   309  
   310  func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) {
   311  	r := newRequest("GET", "http://www.example.com/")
   312  	r.Header.Set("Origin", r.URL.String())
   313  
   314  	rr := httptest.NewRecorder()
   315  
   316  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   317  
   318  	CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r)
   319  
   320  	if status := rr.Code; status != http.StatusOK {
   321  		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
   322  	}
   323  
   324  	header := rr.HeaderMap.Get(corsVaryHeader)
   325  	if got, want := header, corsOriginHeader; got != want {
   326  		t.Fatalf("bad header: expected %s to be %q, got %q.", corsVaryHeader, want, got)
   327  	}
   328  }
   329  
   330  func TestCORSWithMultipleHandlers(t *testing.T) {
   331  	var lastHandledBy string
   332  	corsMiddleware := CORS()
   333  
   334  	testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   335  		lastHandledBy = "testHandler1"
   336  	})
   337  	testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   338  		lastHandledBy = "testHandler2"
   339  	})
   340  
   341  	r1 := newRequest("GET", "http://www.example.com/")
   342  	rr1 := httptest.NewRecorder()
   343  	handler1 := corsMiddleware(testHandler1)
   344  
   345  	corsMiddleware(testHandler2)
   346  
   347  	handler1.ServeHTTP(rr1, r1)
   348  	if lastHandledBy != "testHandler1" {
   349  		t.Fatalf("bad CORS() registration: Handler served should be Handler registered")
   350  	}
   351  }
   352  
   353  func TestCORSOriginValidatorWithImplicitStar(t *testing.T) {
   354  	r := newRequest("GET", "http://a.example.com")
   355  	r.Header.Set("Origin", r.URL.String())
   356  	rr := httptest.NewRecorder()
   357  
   358  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   359  
   360  	originValidator := func(origin string) bool {
   361  		if strings.HasSuffix(origin, ".example.com") {
   362  			return true
   363  		}
   364  		return false
   365  	}
   366  
   367  	CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r)
   368  	header := rr.HeaderMap.Get(corsAllowOriginHeader)
   369  	if got, want := header, r.URL.String(); got != want {
   370  		t.Fatalf("bad header: expected %s to be %q, got %q.", corsAllowOriginHeader, want, got)
   371  	}
   372  }
   373  
   374  func TestCORSOriginValidatorWithExplicitStar(t *testing.T) {
   375  	r := newRequest("GET", "http://a.example.com")
   376  	r.Header.Set("Origin", r.URL.String())
   377  	rr := httptest.NewRecorder()
   378  
   379  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   380  
   381  	originValidator := func(origin string) bool {
   382  		if strings.HasSuffix(origin, ".example.com") {
   383  			return true
   384  		}
   385  		return false
   386  	}
   387  
   388  	CORS(
   389  		AllowedOriginValidator(originValidator),
   390  		AllowedOrigins([]string{"*"}),
   391  	)(testHandler).ServeHTTP(rr, r)
   392  	header := rr.HeaderMap.Get(corsAllowOriginHeader)
   393  	if got, want := header, "*"; got != want {
   394  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got)
   395  	}
   396  }
   397  
   398  func TestCORSAllowStar(t *testing.T) {
   399  	r := newRequest("GET", "http://a.example.com")
   400  	r.Header.Set("Origin", r.URL.String())
   401  	rr := httptest.NewRecorder()
   402  
   403  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   404  
   405  	CORS()(testHandler).ServeHTTP(rr, r)
   406  	header := rr.HeaderMap.Get(corsAllowOriginHeader)
   407  	if got, want := header, "*"; got != want {
   408  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got)
   409  	}
   410  }