github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/handlers/cors_test.go (about)

     1  package handlers
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  
     7  	http "github.com/hxx258456/ccgo/gmhttp"
     8  	"github.com/hxx258456/ccgo/gmhttp/httptest"
     9  )
    10  
    11  func TestDefaultCORSHandlerReturnsOk(t *testing.T) {
    12  	r := newRequest("GET", "http://www.example.com/")
    13  	rr := httptest.NewRecorder()
    14  
    15  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    16  
    17  	CORS()(testHandler).ServeHTTP(rr, r)
    18  
    19  	if got, want := rr.Code, http.StatusOK; got != want {
    20  		t.Fatalf("bad status: got %v want %v", got, want)
    21  	}
    22  }
    23  
    24  func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) {
    25  	r := newRequest("GET", "http://www.example.com/")
    26  	r.Header.Set("Origin", r.URL.String())
    27  
    28  	rr := httptest.NewRecorder()
    29  
    30  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    31  
    32  	CORS()(testHandler).ServeHTTP(rr, r)
    33  
    34  	if got, want := rr.Code, http.StatusOK; got != want {
    35  		t.Fatalf("bad status: got %v want %v", got, want)
    36  	}
    37  }
    38  
    39  func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) {
    40  	r := newRequest("OPTIONS", "http://www.example.com/")
    41  	r.Header.Set("Origin", r.URL.String())
    42  
    43  	rr := httptest.NewRecorder()
    44  
    45  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    46  		w.WriteHeader(http.StatusTeapot)
    47  	})
    48  
    49  	CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r)
    50  
    51  	if got, want := rr.Code, http.StatusTeapot; got != want {
    52  		t.Fatalf("bad status: got %v want %v", got, want)
    53  	}
    54  }
    55  
    56  func TestCORSHandlerSetsExposedHeaders(t *testing.T) {
    57  	// Test default configuration.
    58  	r := newRequest("GET", "http://www.example.com/")
    59  	r.Header.Set("Origin", r.URL.String())
    60  
    61  	rr := httptest.NewRecorder()
    62  
    63  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    64  
    65  	CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r)
    66  
    67  	if got, want := rr.Code, http.StatusOK; got != want {
    68  		t.Fatalf("bad status: got %v want %v", got, want)
    69  	}
    70  
    71  	header := rr.HeaderMap.Get(corsExposeHeadersHeader)
    72  	if got, want := header, "X-Cors-Test"; got != want {
    73  		t.Fatalf("bad header: expected %q header, got empty header for method.", want)
    74  	}
    75  }
    76  
    77  func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) {
    78  	r := newRequest("OPTIONS", "http://www.example.com/")
    79  	r.Header.Set("Origin", r.URL.String())
    80  
    81  	rr := httptest.NewRecorder()
    82  
    83  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    84  
    85  	CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
    86  
    87  	if got, want := rr.Code, http.StatusBadRequest; got != want {
    88  		t.Fatalf("bad status: got %v want %v", got, want)
    89  	}
    90  }
    91  
    92  func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) {
    93  	r := newRequest("OPTIONS", "http://www.example.com/")
    94  	r.Header.Set("Origin", r.URL.String())
    95  	r.Header.Set(corsRequestMethodHeader, "DELETE")
    96  
    97  	rr := httptest.NewRecorder()
    98  
    99  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   100  
   101  	CORS()(testHandler).ServeHTTP(rr, r)
   102  
   103  	if got, want := rr.Code, http.StatusMethodNotAllowed; got != want {
   104  		t.Fatalf("bad status: got %v want %v", got, want)
   105  	}
   106  }
   107  
   108  func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) {
   109  	r := newRequest("OPTIONS", "http://www.example.com/")
   110  	r.Header.Set("Origin", r.URL.String())
   111  	r.Header.Set(corsRequestMethodHeader, "GET")
   112  
   113  	rr := httptest.NewRecorder()
   114  
   115  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   116  		t.Fatal("Options request must not be passed to next handler")
   117  	})
   118  
   119  	CORS()(testHandler).ServeHTTP(rr, r)
   120  
   121  	if got, want := rr.Code, http.StatusOK; got != want {
   122  		t.Fatalf("bad status: got %v want %v", got, want)
   123  	}
   124  }
   125  
   126  func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) {
   127  	statusCode := http.StatusNoContent
   128  	r := newRequest("OPTIONS", "http://www.example.com/")
   129  	r.Header.Set("Origin", r.URL.String())
   130  	r.Header.Set(corsRequestMethodHeader, "GET")
   131  
   132  	rr := httptest.NewRecorder()
   133  
   134  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   135  		t.Fatal("Options request must not be passed to next handler")
   136  	})
   137  
   138  	CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r)
   139  
   140  	if got, want := rr.Code, statusCode; got != want {
   141  		t.Fatalf("bad status: got %v want %v", got, want)
   142  	}
   143  }
   144  
   145  func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) {
   146  	r := newRequest("OPTIONS", "http://www.example.com/")
   147  	r.Header.Set("Origin", r.URL.String())
   148  	r.Header.Set(corsRequestMethodHeader, "GET")
   149  
   150  	rr := httptest.NewRecorder()
   151  
   152  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   153  		t.Fatal("Options request must not be passed to next handler")
   154  	})
   155  
   156  	CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r)
   157  
   158  	if got, want := rr.Code, http.StatusOK; got != want {
   159  		t.Fatalf("bad status: got %v want %v", got, want)
   160  	}
   161  }
   162  
   163  func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) {
   164  	r := newRequest("OPTIONS", "http://www.example.com/")
   165  	r.Header.Set("Origin", r.URL.String())
   166  	r.Header.Set(corsRequestMethodHeader, "DELETE")
   167  
   168  	rr := httptest.NewRecorder()
   169  
   170  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   171  
   172  	CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
   173  
   174  	if got, want := rr.Code, http.StatusOK; got != want {
   175  		t.Fatalf("bad status: got %v want %v", got, want)
   176  	}
   177  
   178  	header := rr.HeaderMap.Get(corsAllowMethodsHeader)
   179  	if got, want := header, "DELETE"; got != want {
   180  		t.Fatalf("bad header: expected %q method header, got %q header.", want, got)
   181  	}
   182  }
   183  
   184  func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) {
   185  	for _, method := range defaultCorsMethods {
   186  		r := newRequest("OPTIONS", "http://www.example.com/")
   187  		r.Header.Set("Origin", r.URL.String())
   188  		r.Header.Set(corsRequestMethodHeader, method)
   189  
   190  		rr := httptest.NewRecorder()
   191  
   192  		testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   193  
   194  		CORS()(testHandler).ServeHTTP(rr, r)
   195  
   196  		if got, want := rr.Code, http.StatusOK; got != want {
   197  			t.Fatalf("bad status: got %v want %v", got, want)
   198  		}
   199  
   200  		header := rr.HeaderMap.Get(corsAllowMethodsHeader)
   201  		if got, want := header, ""; got != want {
   202  			t.Fatalf("bad header: expected %q method header, got %q.", want, got)
   203  		}
   204  	}
   205  }
   206  
   207  func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) {
   208  	for _, simpleHeader := range defaultCorsHeaders {
   209  		r := newRequest("OPTIONS", "http://www.example.com/")
   210  		r.Header.Set("Origin", r.URL.String())
   211  		r.Header.Set(corsRequestMethodHeader, "GET")
   212  		r.Header.Set(corsRequestHeadersHeader, simpleHeader)
   213  
   214  		rr := httptest.NewRecorder()
   215  
   216  		testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   217  
   218  		CORS()(testHandler).ServeHTTP(rr, r)
   219  
   220  		if got, want := rr.Code, http.StatusOK; got != want {
   221  			t.Fatalf("bad status: got %v want %v", got, want)
   222  		}
   223  
   224  		header := rr.HeaderMap.Get(corsAllowHeadersHeader)
   225  		if got, want := header, ""; got != want {
   226  			t.Fatalf("bad header: expected %q header, got %q.", want, got)
   227  		}
   228  	}
   229  }
   230  
   231  func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) {
   232  	r := newRequest("OPTIONS", "http://www.example.com/")
   233  	r.Header.Set("Origin", r.URL.String())
   234  	r.Header.Set(corsRequestMethodHeader, "POST")
   235  	r.Header.Set(corsRequestHeadersHeader, "Content-Type")
   236  
   237  	rr := httptest.NewRecorder()
   238  
   239  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   240  
   241  	CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r)
   242  
   243  	if got, want := rr.Code, http.StatusOK; got != want {
   244  		t.Fatalf("bad status: got %v want %v", got, want)
   245  	}
   246  
   247  	header := rr.HeaderMap.Get(corsAllowHeadersHeader)
   248  	if got, want := header, "Content-Type"; got != want {
   249  		t.Fatalf("bad header: expected %q header, got %q header.", want, got)
   250  	}
   251  }
   252  
   253  func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) {
   254  	r := newRequest("OPTIONS", "http://www.example.com/")
   255  	r.Header.Set("Origin", r.URL.String())
   256  	r.Header.Set(corsRequestMethodHeader, "POST")
   257  	r.Header.Set(corsRequestHeadersHeader, "Content-Type")
   258  
   259  	rr := httptest.NewRecorder()
   260  
   261  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   262  
   263  	CORS()(testHandler).ServeHTTP(rr, r)
   264  
   265  	if got, want := rr.Code, http.StatusForbidden; got != want {
   266  		t.Fatalf("bad status: got %v want %v", got, want)
   267  	}
   268  }
   269  
   270  func TestCORSHandlerMaxAgeForPreflight(t *testing.T) {
   271  	r := newRequest("OPTIONS", "http://www.example.com/")
   272  	r.Header.Set("Origin", r.URL.String())
   273  	r.Header.Set(corsRequestMethodHeader, "POST")
   274  
   275  	rr := httptest.NewRecorder()
   276  
   277  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   278  
   279  	CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r)
   280  
   281  	if got, want := rr.Code, http.StatusOK; got != want {
   282  		t.Fatalf("bad status: got %v want %v", got, want)
   283  	}
   284  
   285  	header := rr.HeaderMap.Get(corsMaxAgeHeader)
   286  	if got, want := header, "600"; got != want {
   287  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got)
   288  	}
   289  }
   290  
   291  func TestCORSHandlerAllowedCredentials(t *testing.T) {
   292  	r := newRequest("GET", "http://www.example.com/")
   293  	r.Header.Set("Origin", r.URL.String())
   294  
   295  	rr := httptest.NewRecorder()
   296  
   297  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   298  
   299  	CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r)
   300  
   301  	if status := rr.Code; status != http.StatusOK {
   302  		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
   303  	}
   304  
   305  	header := rr.HeaderMap.Get(corsAllowCredentialsHeader)
   306  	if got, want := header, "true"; got != want {
   307  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowCredentialsHeader, want, got)
   308  	}
   309  }
   310  
   311  func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) {
   312  	r := newRequest("GET", "http://www.example.com/")
   313  	r.Header.Set("Origin", r.URL.String())
   314  
   315  	rr := httptest.NewRecorder()
   316  
   317  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   318  
   319  	CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r)
   320  
   321  	if status := rr.Code; status != http.StatusOK {
   322  		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
   323  	}
   324  
   325  	header := rr.HeaderMap.Get(corsVaryHeader)
   326  	if got, want := header, corsOriginHeader; got != want {
   327  		t.Fatalf("bad header: expected %s to be %q, got %q.", corsVaryHeader, want, got)
   328  	}
   329  }
   330  
   331  func TestCORSWithMultipleHandlers(t *testing.T) {
   332  	var lastHandledBy string
   333  	corsMiddleware := CORS()
   334  
   335  	testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   336  		lastHandledBy = "testHandler1"
   337  	})
   338  	testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   339  		lastHandledBy = "testHandler2"
   340  	})
   341  
   342  	r1 := newRequest("GET", "http://www.example.com/")
   343  	rr1 := httptest.NewRecorder()
   344  	handler1 := corsMiddleware(testHandler1)
   345  
   346  	corsMiddleware(testHandler2)
   347  
   348  	handler1.ServeHTTP(rr1, r1)
   349  	if lastHandledBy != "testHandler1" {
   350  		t.Fatalf("bad CORS() registration: Handler served should be Handler registered")
   351  	}
   352  }
   353  
   354  func TestCORSOriginValidatorWithImplicitStar(t *testing.T) {
   355  	r := newRequest("GET", "http://a.example.com")
   356  	r.Header.Set("Origin", r.URL.String())
   357  	rr := httptest.NewRecorder()
   358  
   359  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   360  
   361  	originValidator := func(origin string) bool {
   362  		if strings.HasSuffix(origin, ".example.com") {
   363  			return true
   364  		}
   365  		return false
   366  	}
   367  
   368  	CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r)
   369  	header := rr.HeaderMap.Get(corsAllowOriginHeader)
   370  	if got, want := header, r.URL.String(); got != want {
   371  		t.Fatalf("bad header: expected %s to be %q, got %q.", corsAllowOriginHeader, want, got)
   372  	}
   373  }
   374  
   375  func TestCORSOriginValidatorWithExplicitStar(t *testing.T) {
   376  	r := newRequest("GET", "http://a.example.com")
   377  	r.Header.Set("Origin", r.URL.String())
   378  	rr := httptest.NewRecorder()
   379  
   380  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   381  
   382  	originValidator := func(origin string) bool {
   383  		if strings.HasSuffix(origin, ".example.com") {
   384  			return true
   385  		}
   386  		return false
   387  	}
   388  
   389  	CORS(
   390  		AllowedOriginValidator(originValidator),
   391  		AllowedOrigins([]string{"*"}),
   392  	)(testHandler).ServeHTTP(rr, r)
   393  	header := rr.HeaderMap.Get(corsAllowOriginHeader)
   394  	if got, want := header, "*"; got != want {
   395  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got)
   396  	}
   397  }
   398  
   399  func TestCORSAllowStar(t *testing.T) {
   400  	r := newRequest("GET", "http://a.example.com")
   401  	r.Header.Set("Origin", r.URL.String())
   402  	rr := httptest.NewRecorder()
   403  
   404  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
   405  
   406  	CORS()(testHandler).ServeHTTP(rr, r)
   407  	header := rr.HeaderMap.Get(corsAllowOriginHeader)
   408  	if got, want := header, "*"; got != want {
   409  		t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got)
   410  	}
   411  }