github.com/google/martian/v3@v3.3.3/cors/cors_test.go (about) 1 // Copyright 2015 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package cors 16 17 import ( 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 ) 22 23 func TestServeHTTPSameOrigin(t *testing.T) { 24 var handlerRun bool 25 26 h := NewHandler(http.HandlerFunc( 27 func(rw http.ResponseWriter, req *http.Request) { 28 handlerRun = true 29 })) 30 31 req, err := http.NewRequest("GET", "http://example.com", nil) 32 if err != nil { 33 t.Fatalf("http.NewRequest(): got %v, want no error", err) 34 } 35 rw := httptest.NewRecorder() 36 37 h.ServeHTTP(rw, req) 38 39 if !handlerRun { 40 t.Error("handlerRun: got false, want true") 41 } 42 } 43 44 func TestServeHTTPPreflight(t *testing.T) { 45 var handlerRun bool 46 47 h := NewHandler(http.HandlerFunc( 48 func(rw http.ResponseWriter, req *http.Request) { 49 handlerRun = true 50 })) 51 h.AllowCredentials(true) 52 53 req, err := http.NewRequest("OPTIONS", "http://example.com", nil) 54 if err != nil { 55 t.Fatalf("http.NewRequest(): got %v, want no error", err) 56 } 57 req.Header.Set("Origin", "http://google.com") 58 req.Header.Set("Access-Control-Request-Method", "PUT") 59 req.Header.Set("Access-Control-Request-Headers", "Cors-Test") 60 61 rw := httptest.NewRecorder() 62 63 h.ServeHTTP(rw, req) 64 65 if got, want := rw.Header().Get("Access-Control-Allow-Origin"), "*"; got != want { 66 t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Origin", got, want) 67 } 68 if got, want := rw.Header().Get("Access-Control-Allow-Methods"), "PUT"; got != want { 69 t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Methods", got, want) 70 } 71 if got, want := rw.Header().Get("Access-Control-Allow-Headers"), "Cors-Test"; got != want { 72 t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Headers", got, want) 73 } 74 if got, want := rw.Header().Get("Access-Control-Allow-Credentials"), "true"; got != want { 75 t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Credentials", got, want) 76 } 77 78 if handlerRun { 79 t.Error("handlerRun: got true, want false") 80 } 81 } 82 83 func TestServeHTTPSimple(t *testing.T) { 84 var handlerRun bool 85 86 h := NewHandler(http.HandlerFunc( 87 func(rw http.ResponseWriter, req *http.Request) { 88 handlerRun = true 89 })) 90 h.SetOrigin("http://martian.local") 91 92 req, err := http.NewRequest("GET", "http://example.com", nil) 93 if err != nil { 94 t.Fatalf("http.NewRequest(): got %v, want no error", err) 95 } 96 req.Header.Set("Origin", "http://google.com") 97 req.Header.Set("Access-Control-Request-Method", "GET") 98 req.Header.Set("Access-Control-Request-Headers", "Cors-Test") 99 100 rw := httptest.NewRecorder() 101 102 h.ServeHTTP(rw, req) 103 104 if got, want := rw.Header().Get("Access-Control-Allow-Origin"), "http://martian.local"; got != want { 105 t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Origin", got, want) 106 } 107 if got, want := rw.Header().Get("Access-Control-Allow-Methods"), "GET"; got != want { 108 t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Methods", got, want) 109 } 110 if got, want := rw.Header().Get("Access-Control-Allow-Headers"), "Cors-Test"; got != want { 111 t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Headers", got, want) 112 } 113 114 if !handlerRun { 115 t.Error("handlerRun: got false, want true") 116 } 117 }