github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/internal/requesttesting/headers/referer_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package headers 16 17 import ( 18 "context" 19 "net/http" 20 "testing" 21 22 "github.com/google/go-cmp/cmp" 23 24 "github.com/google/go-safeweb/internal/requesttesting" 25 ) 26 27 func TestReferer(t *testing.T) { 28 type testWant struct { 29 headers map[string][]string 30 referer string 31 } 32 33 var tests = []struct { 34 name string 35 request []byte 36 want testWant 37 }{ 38 { 39 name: "Basic", 40 request: []byte("GET / HTTP/1.1\r\n" + 41 "Host: localhost:8080\r\n" + 42 "Referer: http://example.com\r\n" + 43 "\r\n"), 44 want: testWant{ 45 headers: map[string][]string{"Referer": {"http://example.com"}}, 46 referer: "http://example.com", 47 }, 48 }, 49 { 50 name: "CasingOrdering1", 51 request: []byte("GET / HTTP/1.1\r\n" + 52 "Host: localhost:8080\r\n" + 53 "referer: http://example.com\r\n" + 54 "Referer: http://evil.com\r\n" + 55 "\r\n"), 56 want: testWant{ 57 headers: map[string][]string{"Referer": {"http://example.com", "http://evil.com"}}, 58 referer: "http://example.com", 59 }, 60 }, 61 { 62 name: "CasingOrdering2", 63 request: []byte("GET / HTTP/1.1\r\n" + 64 "Host: localhost:8080\r\n" + 65 "Referer: http://example.com\r\n" + 66 "referer: http://evil.com\r\n" + 67 "\r\n"), 68 want: testWant{ 69 headers: map[string][]string{"Referer": {"http://example.com", "http://evil.com"}}, 70 referer: "http://example.com", 71 }, 72 }, 73 } 74 75 for _, tt := range tests { 76 t.Run(tt.name, func(t *testing.T) { 77 resp, err := requesttesting.MakeRequest(context.Background(), tt.request, func(r *http.Request) { 78 if diff := cmp.Diff(tt.want.headers, map[string][]string(r.Header)); diff != "" { 79 t.Errorf("r.Header mismatch (-want +got):\n%s", diff) 80 } 81 82 if r.Referer() != tt.want.referer { 83 t.Errorf("r.Referer() got: %q want: %q", r.Referer(), tt.want.referer) 84 } 85 }) 86 if err != nil { 87 t.Fatalf("MakeRequest() got err: %v", err) 88 } 89 90 if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) { 91 t.Errorf("status code got: %q want: %q", got, want) 92 } 93 }) 94 } 95 } 96 97 func TestRefererOrdering(t *testing.T) { 98 // The documentation of http.Request.Referer() doesn't clearly specify 99 // that only the first Referer header is used and that the other ones 100 // are ignored. This could potentially lead to security issues if two 101 // HTTP servers that look at different headers are chained together. 102 // 103 // The desired behavior would be to respond with 400 (Bad Request) 104 // when there is more than one Referer header. 105 106 request := []byte("GET / HTTP/1.1\r\n" + 107 "Host: localhost:8080\r\n" + 108 "Referer: http://example.com\r\n" + 109 "Referer: http://evil.com\r\n" + 110 "\r\n") 111 112 t.Run("Current behavior", func(t *testing.T) { 113 resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) { 114 wantHeaders := map[string][]string{"Referer": {"http://example.com", "http://evil.com"}} 115 if diff := cmp.Diff(wantHeaders, map[string][]string(r.Header)); diff != "" { 116 t.Errorf("r.Header mismatch (-want +got):\n%s", diff) 117 } 118 119 if want := "http://example.com"; r.Referer() != want { 120 t.Errorf("r.Referer() got: %q want: %q", r.Referer(), want) 121 } 122 }) 123 if err != nil { 124 t.Fatalf("MakeRequest() got err: %v want: nil", err) 125 } 126 127 if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) { 128 t.Errorf("status code got: %q want: %q", got, want) 129 } 130 }) 131 132 t.Run("Desired behavior", func(t *testing.T) { 133 t.Skip() 134 resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) { 135 t.Error("Expected handler to not be called!") 136 }) 137 if err != nil { 138 t.Fatalf("MakeRequest() got err: %v want: nil", err) 139 } 140 141 if got, want := extractStatus(resp), statusBadRequestPrefix; !matchStatus(got, want) { 142 t.Errorf("status code got: %q want: %q", got, want) 143 } 144 }) 145 }