github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/incoming_request_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safehttp_test 16 17 import ( 18 "context" 19 "net/http" 20 "net/http/httptest" 21 "strings" 22 "testing" 23 24 "github.com/google/go-cmp/cmp" 25 "github.com/google/go-safeweb/safehttp" 26 "github.com/google/go-safeweb/safehttp/safehttptest" 27 ) 28 29 func TestIncomingRequestCookie(t *testing.T) { 30 var tests = []struct { 31 name string 32 req *http.Request 33 wantName string 34 wantValue string 35 }{ 36 { 37 name: "Basic", 38 req: func() *http.Request { 39 r := httptest.NewRequest(http.MethodGet, "/", nil) 40 r.Header.Set("Cookie", "foo=bar") 41 return r 42 }(), 43 wantName: "foo", 44 wantValue: "bar", 45 }, 46 { 47 name: "Multiple cookies with the same name", 48 req: func() *http.Request { 49 r := httptest.NewRequest(http.MethodGet, "/", nil) 50 r.Header.Add("Cookie", "foo=bar; foo=xyz") 51 r.Header.Add("Cookie", "foo=pizza") 52 return r 53 }(), 54 wantName: "foo", 55 wantValue: "bar", 56 }, 57 } 58 59 for _, tt := range tests { 60 t.Run(tt.name, func(t *testing.T) { 61 ir := safehttp.NewIncomingRequest(tt.req) 62 c, err := ir.Cookie(tt.wantName) 63 if err != nil { 64 t.Errorf(`ir.Cookie(tt.wantName) got: %v want: nil`, err) 65 } 66 67 if got := c.Name(); got != tt.wantName { 68 t.Errorf("c.Name() got: %v want: %v", got, tt.wantName) 69 } 70 71 if got := c.Value(); got != tt.wantValue { 72 t.Errorf(`c.Value() got: %v want: %v`, got, tt.wantValue) 73 } 74 }) 75 } 76 } 77 78 func TestIncomingRequestCookieNotFound(t *testing.T) { 79 r := httptest.NewRequest(http.MethodGet, "/", nil) 80 ir := safehttp.NewIncomingRequest(r) 81 if _, err := ir.Cookie("foo"); err == nil { 82 t.Error(`ir.Cookie("foo") got: nil want: error`) 83 } 84 } 85 86 func TestIncomingRequestCookies(t *testing.T) { 87 var tests = []struct { 88 name string 89 req *http.Request 90 wantNames []string 91 wantValues []string 92 }{ 93 { 94 name: "One", 95 req: func() *http.Request { 96 r := httptest.NewRequest(http.MethodGet, "/", nil) 97 r.Header.Set("Cookie", "foo=bar") 98 return r 99 }(), 100 wantNames: []string{"foo"}, 101 wantValues: []string{"bar"}, 102 }, 103 { 104 name: "Multiple", 105 req: func() *http.Request { 106 r := httptest.NewRequest(http.MethodGet, "/", nil) 107 r.Header.Add("Cookie", "foo=bar; a=b") 108 r.Header.Add("Cookie", "pizza=hawaii") 109 return r 110 }(), 111 wantNames: []string{"foo", "a", "pizza"}, 112 wantValues: []string{"bar", "b", "hawaii"}, 113 }, 114 { 115 name: "None", 116 req: httptest.NewRequest(http.MethodGet, "/", nil), 117 wantNames: []string{}, 118 wantValues: []string{}, 119 }, 120 } 121 122 for _, tt := range tests { 123 t.Run(tt.name, func(t *testing.T) { 124 ir := safehttp.NewIncomingRequest(tt.req) 125 cl := ir.Cookies() 126 127 if got, want := len(cl), len(tt.wantNames); got != want { 128 t.Errorf("len(cl) got: %v want: %v", got, want) 129 } 130 131 for i, c := range cl { 132 if got := c.Name(); got != tt.wantNames[i] { 133 t.Errorf("c.Name() got: %v want: %v", got, tt.wantNames[i]) 134 } 135 136 if got := c.Value(); got != tt.wantValues[i] { 137 t.Errorf(`c.Value() got: %v want: %v`, got, tt.wantValues[i]) 138 } 139 } 140 }) 141 142 } 143 } 144 145 type pizza struct { 146 val string 147 } 148 149 type pizzaKey string 150 151 func TestRequestWithContext(t *testing.T) { 152 tests := []struct { 153 name string 154 key pizzaKey 155 wantVal *pizza 156 wantOk bool 157 }{ 158 { 159 name: "Value set for key", 160 key: pizzaKey("1234"), 161 wantOk: true, 162 wantVal: &pizza{val: "margeritta"}, 163 }, 164 { 165 name: "Value not set for key", 166 key: pizzaKey("5678"), 167 wantOk: false, 168 wantVal: nil, 169 }, 170 } 171 for _, test := range tests { 172 req := httptest.NewRequest(safehttp.MethodGet, "/", nil) 173 ir := safehttp.NewIncomingRequest(req) 174 ir = ir.WithContext(context.WithValue(ir.Context(), pizzaKey("1234"), &pizza{val: "margeritta"})) 175 176 got, ok := ir.Context().Value(test.key).(*pizza) 177 if ok != test.wantOk { 178 t.Errorf("type match: got %v, want %v", ok, test.wantOk) 179 } 180 if diff := cmp.Diff(test.wantVal, got, cmp.AllowUnexported(pizza{})); diff != "" { 181 t.Errorf("ir.Context().Value(test.key): mismatch (-want +got): \n%s", diff) 182 } 183 } 184 } 185 186 func TestRequestSetNilContext(t *testing.T) { 187 req := httptest.NewRequest(safehttp.MethodGet, "/", nil) 188 ir := safehttp.NewIncomingRequest(req) 189 190 defer func() { 191 if r := recover(); r != nil { 192 return 193 } 194 t.Errorf(`ir.SetContext(nil): expected panic`) 195 }() 196 197 // Avoids a linter complaint about a nil context being passed as argument. 198 // In this case, we explicitly want to test that a nil context results in an error. 199 var nilContext context.Context 200 ir.WithContext(nilContext) 201 } 202 203 func TestIncomingRequestPostForm(t *testing.T) { 204 methods := []string{ 205 safehttp.MethodPost, 206 safehttp.MethodPut, 207 safehttp.MethodPatch, 208 } 209 210 for _, m := range methods { 211 t.Run(m, func(t *testing.T) { 212 r := safehttptest.NewRequest(m, "/", strings.NewReader("a=b")) 213 r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 214 215 f, err := r.PostForm() 216 if err != nil { 217 t.Errorf("r.PostForm() got: %v want: nil", err) 218 } 219 220 if got, want := f.String("a", ""), "b"; got != want { 221 t.Errorf(`f.String("a", "") got: %q want: %q`, got, want) 222 } 223 224 if err := f.Err(); err != nil { 225 t.Errorf("f.Err() got: %v want: nil", err) 226 } 227 }) 228 } 229 } 230 231 func TestIncomingRequestInvalidPostForm(t *testing.T) { 232 tests := []struct { 233 name string 234 req *safehttp.IncomingRequest 235 }{ 236 { 237 name: "GET method", 238 req: safehttptest.NewRequest(safehttp.MethodGet, "/", nil), 239 }, 240 { 241 name: "Wrong content type", 242 req: func() *safehttp.IncomingRequest { 243 r := safehttptest.NewRequest(safehttp.MethodPost, "/", nil) 244 r.Header.Set("Content-Type", "blah/blah") 245 return r 246 }(), 247 }, 248 { 249 // Note that net/http.Request.ParseForm also parses url parameters and 250 // the errors that occur are returned. 251 name: "Invalid url parameter", 252 req: func() *safehttp.IncomingRequest { 253 r := safehttptest.NewRequest(safehttp.MethodPost, "http://foo.com/asdf?%xx=yy", nil) 254 r.Header.Set("Content-Type", "application/x-www-form-urlencoded") 255 return r 256 }(), 257 }, 258 } 259 260 for _, tt := range tests { 261 t.Run(tt.name, func(t *testing.T) { 262 if _, err := tt.req.PostForm(); err == nil { 263 t.Error("tt.req.PostForm() got: nil want: error") 264 } 265 }) 266 } 267 } 268 269 func TestIncomingRequestMultipartForm(t *testing.T) { 270 methods := []string{ 271 safehttp.MethodPost, 272 safehttp.MethodPut, 273 safehttp.MethodPatch, 274 } 275 276 for _, m := range methods { 277 t.Run(m, func(t *testing.T) { 278 body := "--123\r\n" + 279 "Content-Disposition: form-data; name=\"a\"\r\n" + 280 "\r\n" + 281 "b\r\n" + 282 "--123--\r\n" 283 r := safehttptest.NewRequest(m, "/", strings.NewReader(body)) 284 r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`) 285 286 f, err := r.MultipartForm(1000) 287 if err != nil { 288 t.Errorf("r.MultipartForm(1000) got: %v want: nil", err) 289 } 290 291 if got, want := f.String("a", ""), "b"; got != want { 292 t.Errorf(`f.String("a", "") got: %q want: %q`, got, want) 293 } 294 295 if err := f.Err(); err != nil { 296 t.Errorf("f.Err() got: %v want: nil", err) 297 } 298 }) 299 } 300 } 301 302 func TestIncomingRequestMultipartFormNegativeMemory(t *testing.T) { 303 body := "--123\r\n" + 304 "Content-Disposition: form-data; name=\"a\"\r\n" + 305 "\r\n" + 306 "b\r\n" + 307 "--123--\r\n" 308 r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body)) 309 r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`) 310 311 f, err := r.MultipartForm(-1) 312 if err != nil { 313 t.Errorf("r.MultipartForm(-1) got: %v want: nil", err) 314 } 315 316 if got, want := f.String("a", ""), "b"; got != want { 317 t.Errorf(`f.String("a", "") got: %q want: %q`, got, want) 318 } 319 320 if err := f.Err(); err != nil { 321 t.Errorf("f.Err() got: %v want: nil", err) 322 } 323 } 324 325 func TestIncomingRequestInvalidMultipartForm(t *testing.T) { 326 tests := []struct { 327 name string 328 req *safehttp.IncomingRequest 329 }{ 330 { 331 name: "GET method", 332 req: safehttptest.NewRequest(safehttp.MethodGet, "/", nil), 333 }, 334 { 335 name: "Wrong content type", 336 req: func() *safehttp.IncomingRequest { 337 r := safehttptest.NewRequest(safehttp.MethodPost, "/", nil) 338 r.Header.Set("Content-Type", "blah/blah") 339 return r 340 }(), 341 }, 342 { 343 // Note that net/http.Request.ParseMultipartForm also parses url parameters 344 // and the errors that occur are returned. 345 name: "Invalid url parameter", 346 req: func() *safehttp.IncomingRequest { 347 r := safehttptest.NewRequest(safehttp.MethodPost, "http://foo.com/asdf?%xx=yy", nil) 348 r.Header.Set("Content-Type", "multipart/form-data") 349 return r 350 }(), 351 }, 352 } 353 354 for _, tt := range tests { 355 t.Run(tt.name, func(t *testing.T) { 356 _, err := tt.req.MultipartForm(1000) 357 if err == nil { 358 t.Error("tt.req.ir.MultipartForm(1000) got: nil want: error") 359 } 360 }) 361 } 362 } 363 364 func TestIncomingRequestMultipartFileUpload(t *testing.T) { 365 body := "--123\r\n" + 366 "Content-Disposition: form-data; name=\"file\"; filename=\"myfile\"\r\n" + 367 "\r\n" + 368 "file content\r\n" + 369 "--123--\r\n" 370 r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body)) 371 r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`) 372 373 f, err := r.MultipartForm(1024) 374 if err != nil { 375 t.Errorf("r.MultipartForm(1024): got err %v", err) 376 } 377 378 fhs := f.File("file") 379 if fhs == nil { 380 t.Error(`f.File("file"): got nil, want file header`) 381 } 382 defer f.RemoveFiles() 383 384 file, err := fhs[0].Open() 385 if err != nil { 386 t.Fatalf("fhs[0].Open(): got err %v, want nil", err) 387 } 388 389 content := make([]byte, 12) 390 file.Read(content) 391 if want, got := "file content", string(content); want != got { 392 t.Errorf("file.Read(content): got %s, want %s", got, want) 393 } 394 } 395 396 func TestIncomingRequestMultipartFormAndFileUpload(t *testing.T) { 397 body := "--123\r\n" + 398 "Content-Disposition: form-data; name=\"key\"\r\n" + 399 "\r\n" + 400 "12\r\n" + 401 "--123\r\n" + 402 "Content-Disposition: form-data; name=\"file\"; filename=\"myfile\"\r\n" + 403 "\r\n" + 404 "file content\r\n" + 405 "--123--\r\n" 406 r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body)) 407 r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`) 408 409 f, err := r.MultipartForm(1024) 410 if err != nil { 411 t.Errorf("r.MultipartForm(1024): got err %v", err) 412 } 413 414 if want, got := int64(12), f.Int64("key", 0); want != got { 415 t.Errorf(`f.Int64("key", 0): got %d, want %d`, got, want) 416 } 417 if err := f.Err(); err != nil { 418 t.Errorf("f.Err(): got err %v", err) 419 } 420 421 fhs := f.File("file") 422 if fhs == nil { 423 t.Error(`f.File("file"): got nil, want file header`) 424 } 425 defer f.RemoveFiles() 426 427 file, err := fhs[0].Open() 428 if err != nil { 429 t.Fatalf("fhs[0].Open(): got err %v, want nil", err) 430 } 431 432 content := make([]byte, 12) 433 file.Read(content) 434 if want, got := "file content", string(content); want != got { 435 t.Errorf("file.Read(content): got %s, want %s", got, want) 436 } 437 } 438 439 func TestIncomingRequestFileUploadMissingContent(t *testing.T) { 440 body := "--123\r\n" + 441 "Content-Disposition: form-data; name=\"file\"; filename=\"myfile\"\r\n" + 442 "\r\n" + 443 "--123--\r\n" 444 r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body)) 445 r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`) 446 447 f, err := r.MultipartForm(1024) 448 if err != nil { 449 t.Errorf("r.MultipartForm(1024): got err %v", err) 450 } 451 452 fhs := f.File("file") 453 if fhs == nil { 454 t.Error(`f.File("file"): got nil, want file header`) 455 } 456 defer f.RemoveFiles() 457 458 file, err := fhs[0].Open() 459 if err != nil { 460 t.Fatalf("fhs[0].Open(): got err %v, want nil", err) 461 } 462 463 content := make([]byte, 0) 464 file.Read(content) 465 if want, got := "", string(content); want != got { 466 t.Errorf("file.Read(content): got %s, want %s", got, want) 467 } 468 }