github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/internal/requesttesting/headers/useragent_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package headers 16 17 import ( 18 "context" 19 "net/http" 20 "testing" 21 22 "github.com/google/go-cmp/cmp" 23 "github.com/google/go-safeweb/internal/requesttesting" 24 ) 25 26 func TestUserAgent(t *testing.T) { 27 type testWant struct { 28 headers map[string][]string 29 useragent string 30 } 31 32 var tests = []struct { 33 name string 34 request []byte 35 want testWant 36 }{ 37 { 38 name: "Basic", 39 request: []byte("GET / HTTP/1.1\r\n" + 40 "Host: localhost:8080\r\n" + 41 "User-Agent: BlahBlah\r\n" + 42 "\r\n"), 43 want: testWant{ 44 headers: map[string][]string{"User-Agent": {"BlahBlah"}}, 45 useragent: "BlahBlah", 46 }, 47 }, 48 { 49 name: "CasingOrdering1", 50 request: []byte("GET / HTTP/1.1\r\n" + 51 "Host: localhost:8080\r\n" + 52 "user-Agent: BlahBlah\r\n" + 53 "User-Agent: FooFoo\r\n" + 54 "\r\n"), 55 want: testWant{ 56 headers: map[string][]string{"User-Agent": {"BlahBlah", "FooFoo"}}, 57 useragent: "BlahBlah", 58 }, 59 }, 60 { 61 name: "CasingOrdering1", 62 request: []byte("GET / HTTP/1.1\r\n" + 63 "Host: localhost:8080\r\n" + 64 "User-Agent: BlahBlah\r\n" + 65 "user-Agent: FooFoo\r\n" + 66 "\r\n"), 67 want: testWant{ 68 headers: map[string][]string{"User-Agent": {"BlahBlah", "FooFoo"}}, 69 useragent: "BlahBlah", 70 }, 71 }, 72 } 73 74 for _, tt := range tests { 75 t.Run(tt.name, func(t *testing.T) { 76 resp, err := requesttesting.MakeRequest(context.Background(), tt.request, func(r *http.Request) { 77 if diff := cmp.Diff(tt.want.headers, map[string][]string(r.Header)); diff != "" { 78 t.Errorf("r.Header mismatch (-want +got):\n%s", diff) 79 } 80 81 if r.UserAgent() != tt.want.useragent { 82 t.Errorf("r.UserAgent() got: %q want: %q", r.UserAgent(), tt.want.useragent) 83 } 84 }) 85 if err != nil { 86 t.Fatalf("MakeRequest() got err: %v", err) 87 } 88 89 if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) { 90 t.Errorf("status code got: %q want: %q", got, want) 91 } 92 }) 93 } 94 } 95 96 func TestUserAgentOrdering(t *testing.T) { 97 // The documentation of http.Request.UserAgent() doesn't clearly specify 98 // that only the first User-Agent header is used and that the other ones 99 // are ignored. This could potentially lead to security issues if two 100 // HTTP servers that look at different headers are chained together. 101 // 102 // The desired behavior would be to respond with 400 (Bad Request) 103 // when there is more than one User-Agent header. 104 105 request := []byte("GET / HTTP/1.1\r\n" + 106 "Host: localhost:8080\r\n" + 107 "User-Agent: BlahBlah\r\n" + 108 "User-Agent: FooFoo\r\n" + 109 "\r\n") 110 111 t.Run("Current behavior", func(t *testing.T) { 112 resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) { 113 wantHeaders := map[string][]string{"User-Agent": {"BlahBlah", "FooFoo"}} 114 if diff := cmp.Diff(wantHeaders, map[string][]string(r.Header)); diff != "" { 115 t.Errorf("r.Header mismatch (-want +got):\n%s", diff) 116 } 117 118 if want := "BlahBlah"; r.UserAgent() != want { 119 t.Errorf("r.UserAgent() got: %q want: %q", r.UserAgent(), want) 120 } 121 }) 122 if err != nil { 123 t.Fatalf("MakeRequest() got err: %v want: nil", err) 124 } 125 126 if got, want := extractStatus(resp), statusOK; !matchStatus(got, want) { 127 t.Errorf("status code got: %q want: %q", got, want) 128 } 129 }) 130 131 t.Run("Desired behavior", func(t *testing.T) { 132 t.Skip() 133 resp, err := requesttesting.MakeRequest(context.Background(), request, func(r *http.Request) { 134 t.Error("Expected handler to not be called!") 135 }) 136 if err != nil { 137 t.Fatalf("MakeRequest() got err: %v want: nil", err) 138 } 139 140 if got, want := extractStatus(resp), statusBadRequestPrefix; !matchStatus(got, want) { 141 t.Errorf("status code got: %q want: %q", got, want) 142 } 143 }) 144 }