github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/incoming_request_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp_test
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"github.com/google/go-safeweb/safehttp"
    26  	"github.com/google/go-safeweb/safehttp/safehttptest"
    27  )
    28  
    29  func TestIncomingRequestCookie(t *testing.T) {
    30  	var tests = []struct {
    31  		name      string
    32  		req       *http.Request
    33  		wantName  string
    34  		wantValue string
    35  	}{
    36  		{
    37  			name: "Basic",
    38  			req: func() *http.Request {
    39  				r := httptest.NewRequest(http.MethodGet, "/", nil)
    40  				r.Header.Set("Cookie", "foo=bar")
    41  				return r
    42  			}(),
    43  			wantName:  "foo",
    44  			wantValue: "bar",
    45  		},
    46  		{
    47  			name: "Multiple cookies with the same name",
    48  			req: func() *http.Request {
    49  				r := httptest.NewRequest(http.MethodGet, "/", nil)
    50  				r.Header.Add("Cookie", "foo=bar; foo=xyz")
    51  				r.Header.Add("Cookie", "foo=pizza")
    52  				return r
    53  			}(),
    54  			wantName:  "foo",
    55  			wantValue: "bar",
    56  		},
    57  	}
    58  
    59  	for _, tt := range tests {
    60  		t.Run(tt.name, func(t *testing.T) {
    61  			ir := safehttp.NewIncomingRequest(tt.req)
    62  			c, err := ir.Cookie(tt.wantName)
    63  			if err != nil {
    64  				t.Errorf(`ir.Cookie(tt.wantName) got: %v want: nil`, err)
    65  			}
    66  
    67  			if got := c.Name(); got != tt.wantName {
    68  				t.Errorf("c.Name() got: %v want: %v", got, tt.wantName)
    69  			}
    70  
    71  			if got := c.Value(); got != tt.wantValue {
    72  				t.Errorf(`c.Value() got: %v want: %v`, got, tt.wantValue)
    73  			}
    74  		})
    75  	}
    76  }
    77  
    78  func TestIncomingRequestCookieNotFound(t *testing.T) {
    79  	r := httptest.NewRequest(http.MethodGet, "/", nil)
    80  	ir := safehttp.NewIncomingRequest(r)
    81  	if _, err := ir.Cookie("foo"); err == nil {
    82  		t.Error(`ir.Cookie("foo") got: nil want: error`)
    83  	}
    84  }
    85  
    86  func TestIncomingRequestCookies(t *testing.T) {
    87  	var tests = []struct {
    88  		name       string
    89  		req        *http.Request
    90  		wantNames  []string
    91  		wantValues []string
    92  	}{
    93  		{
    94  			name: "One",
    95  			req: func() *http.Request {
    96  				r := httptest.NewRequest(http.MethodGet, "/", nil)
    97  				r.Header.Set("Cookie", "foo=bar")
    98  				return r
    99  			}(),
   100  			wantNames:  []string{"foo"},
   101  			wantValues: []string{"bar"},
   102  		},
   103  		{
   104  			name: "Multiple",
   105  			req: func() *http.Request {
   106  				r := httptest.NewRequest(http.MethodGet, "/", nil)
   107  				r.Header.Add("Cookie", "foo=bar; a=b")
   108  				r.Header.Add("Cookie", "pizza=hawaii")
   109  				return r
   110  			}(),
   111  			wantNames:  []string{"foo", "a", "pizza"},
   112  			wantValues: []string{"bar", "b", "hawaii"},
   113  		},
   114  		{
   115  			name:       "None",
   116  			req:        httptest.NewRequest(http.MethodGet, "/", nil),
   117  			wantNames:  []string{},
   118  			wantValues: []string{},
   119  		},
   120  	}
   121  
   122  	for _, tt := range tests {
   123  		t.Run(tt.name, func(t *testing.T) {
   124  			ir := safehttp.NewIncomingRequest(tt.req)
   125  			cl := ir.Cookies()
   126  
   127  			if got, want := len(cl), len(tt.wantNames); got != want {
   128  				t.Errorf("len(cl) got: %v want: %v", got, want)
   129  			}
   130  
   131  			for i, c := range cl {
   132  				if got := c.Name(); got != tt.wantNames[i] {
   133  					t.Errorf("c.Name() got: %v want: %v", got, tt.wantNames[i])
   134  				}
   135  
   136  				if got := c.Value(); got != tt.wantValues[i] {
   137  					t.Errorf(`c.Value() got: %v want: %v`, got, tt.wantValues[i])
   138  				}
   139  			}
   140  		})
   141  
   142  	}
   143  }
   144  
   145  type pizza struct {
   146  	val string
   147  }
   148  
   149  type pizzaKey string
   150  
   151  func TestRequestWithContext(t *testing.T) {
   152  	tests := []struct {
   153  		name    string
   154  		key     pizzaKey
   155  		wantVal *pizza
   156  		wantOk  bool
   157  	}{
   158  		{
   159  			name:    "Value set for key",
   160  			key:     pizzaKey("1234"),
   161  			wantOk:  true,
   162  			wantVal: &pizza{val: "margeritta"},
   163  		},
   164  		{
   165  			name:    "Value not set for key",
   166  			key:     pizzaKey("5678"),
   167  			wantOk:  false,
   168  			wantVal: nil,
   169  		},
   170  	}
   171  	for _, test := range tests {
   172  		req := httptest.NewRequest(safehttp.MethodGet, "/", nil)
   173  		ir := safehttp.NewIncomingRequest(req)
   174  		ir = ir.WithContext(context.WithValue(ir.Context(), pizzaKey("1234"), &pizza{val: "margeritta"}))
   175  
   176  		got, ok := ir.Context().Value(test.key).(*pizza)
   177  		if ok != test.wantOk {
   178  			t.Errorf("type match: got %v, want %v", ok, test.wantOk)
   179  		}
   180  		if diff := cmp.Diff(test.wantVal, got, cmp.AllowUnexported(pizza{})); diff != "" {
   181  			t.Errorf("ir.Context().Value(test.key): mismatch (-want +got): \n%s", diff)
   182  		}
   183  	}
   184  }
   185  
   186  func TestRequestSetNilContext(t *testing.T) {
   187  	req := httptest.NewRequest(safehttp.MethodGet, "/", nil)
   188  	ir := safehttp.NewIncomingRequest(req)
   189  
   190  	defer func() {
   191  		if r := recover(); r != nil {
   192  			return
   193  		}
   194  		t.Errorf(`ir.SetContext(nil): expected panic`)
   195  	}()
   196  
   197  	// Avoids a linter complaint about a nil context being passed as argument.
   198  	// In this case, we explicitly want to test that a nil context results in an error.
   199  	var nilContext context.Context
   200  	ir.WithContext(nilContext)
   201  }
   202  
   203  func TestIncomingRequestPostForm(t *testing.T) {
   204  	methods := []string{
   205  		safehttp.MethodPost,
   206  		safehttp.MethodPut,
   207  		safehttp.MethodPatch,
   208  	}
   209  
   210  	for _, m := range methods {
   211  		t.Run(m, func(t *testing.T) {
   212  			r := safehttptest.NewRequest(m, "/", strings.NewReader("a=b"))
   213  			r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   214  
   215  			f, err := r.PostForm()
   216  			if err != nil {
   217  				t.Errorf("r.PostForm() got: %v want: nil", err)
   218  			}
   219  
   220  			if got, want := f.String("a", ""), "b"; got != want {
   221  				t.Errorf(`f.String("a", "") got: %q want: %q`, got, want)
   222  			}
   223  
   224  			if err := f.Err(); err != nil {
   225  				t.Errorf("f.Err() got: %v want: nil", err)
   226  			}
   227  		})
   228  	}
   229  }
   230  
   231  func TestIncomingRequestInvalidPostForm(t *testing.T) {
   232  	tests := []struct {
   233  		name string
   234  		req  *safehttp.IncomingRequest
   235  	}{
   236  		{
   237  			name: "GET method",
   238  			req:  safehttptest.NewRequest(safehttp.MethodGet, "/", nil),
   239  		},
   240  		{
   241  			name: "Wrong content type",
   242  			req: func() *safehttp.IncomingRequest {
   243  				r := safehttptest.NewRequest(safehttp.MethodPost, "/", nil)
   244  				r.Header.Set("Content-Type", "blah/blah")
   245  				return r
   246  			}(),
   247  		},
   248  		{
   249  			// Note that net/http.Request.ParseForm also parses url parameters and
   250  			// the errors that occur are returned.
   251  			name: "Invalid url parameter",
   252  			req: func() *safehttp.IncomingRequest {
   253  				r := safehttptest.NewRequest(safehttp.MethodPost, "http://foo.com/asdf?%xx=yy", nil)
   254  				r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   255  				return r
   256  			}(),
   257  		},
   258  	}
   259  
   260  	for _, tt := range tests {
   261  		t.Run(tt.name, func(t *testing.T) {
   262  			if _, err := tt.req.PostForm(); err == nil {
   263  				t.Error("tt.req.PostForm() got: nil want: error")
   264  			}
   265  		})
   266  	}
   267  }
   268  
   269  func TestIncomingRequestMultipartForm(t *testing.T) {
   270  	methods := []string{
   271  		safehttp.MethodPost,
   272  		safehttp.MethodPut,
   273  		safehttp.MethodPatch,
   274  	}
   275  
   276  	for _, m := range methods {
   277  		t.Run(m, func(t *testing.T) {
   278  			body := "--123\r\n" +
   279  				"Content-Disposition: form-data; name=\"a\"\r\n" +
   280  				"\r\n" +
   281  				"b\r\n" +
   282  				"--123--\r\n"
   283  			r := safehttptest.NewRequest(m, "/", strings.NewReader(body))
   284  			r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
   285  
   286  			f, err := r.MultipartForm(1000)
   287  			if err != nil {
   288  				t.Errorf("r.MultipartForm(1000) got: %v want: nil", err)
   289  			}
   290  
   291  			if got, want := f.String("a", ""), "b"; got != want {
   292  				t.Errorf(`f.String("a", "") got: %q want: %q`, got, want)
   293  			}
   294  
   295  			if err := f.Err(); err != nil {
   296  				t.Errorf("f.Err() got: %v want: nil", err)
   297  			}
   298  		})
   299  	}
   300  }
   301  
   302  func TestIncomingRequestMultipartFormNegativeMemory(t *testing.T) {
   303  	body := "--123\r\n" +
   304  		"Content-Disposition: form-data; name=\"a\"\r\n" +
   305  		"\r\n" +
   306  		"b\r\n" +
   307  		"--123--\r\n"
   308  	r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body))
   309  	r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
   310  
   311  	f, err := r.MultipartForm(-1)
   312  	if err != nil {
   313  		t.Errorf("r.MultipartForm(-1) got: %v want: nil", err)
   314  	}
   315  
   316  	if got, want := f.String("a", ""), "b"; got != want {
   317  		t.Errorf(`f.String("a", "") got: %q want: %q`, got, want)
   318  	}
   319  
   320  	if err := f.Err(); err != nil {
   321  		t.Errorf("f.Err() got: %v want: nil", err)
   322  	}
   323  }
   324  
   325  func TestIncomingRequestInvalidMultipartForm(t *testing.T) {
   326  	tests := []struct {
   327  		name string
   328  		req  *safehttp.IncomingRequest
   329  	}{
   330  		{
   331  			name: "GET method",
   332  			req:  safehttptest.NewRequest(safehttp.MethodGet, "/", nil),
   333  		},
   334  		{
   335  			name: "Wrong content type",
   336  			req: func() *safehttp.IncomingRequest {
   337  				r := safehttptest.NewRequest(safehttp.MethodPost, "/", nil)
   338  				r.Header.Set("Content-Type", "blah/blah")
   339  				return r
   340  			}(),
   341  		},
   342  		{
   343  			// Note that net/http.Request.ParseMultipartForm also parses url parameters
   344  			// and the errors that occur are returned.
   345  			name: "Invalid url parameter",
   346  			req: func() *safehttp.IncomingRequest {
   347  				r := safehttptest.NewRequest(safehttp.MethodPost, "http://foo.com/asdf?%xx=yy", nil)
   348  				r.Header.Set("Content-Type", "multipart/form-data")
   349  				return r
   350  			}(),
   351  		},
   352  	}
   353  
   354  	for _, tt := range tests {
   355  		t.Run(tt.name, func(t *testing.T) {
   356  			_, err := tt.req.MultipartForm(1000)
   357  			if err == nil {
   358  				t.Error("tt.req.ir.MultipartForm(1000) got: nil want: error")
   359  			}
   360  		})
   361  	}
   362  }
   363  
   364  func TestIncomingRequestMultipartFileUpload(t *testing.T) {
   365  	body := "--123\r\n" +
   366  		"Content-Disposition: form-data; name=\"file\"; filename=\"myfile\"\r\n" +
   367  		"\r\n" +
   368  		"file content\r\n" +
   369  		"--123--\r\n"
   370  	r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body))
   371  	r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
   372  
   373  	f, err := r.MultipartForm(1024)
   374  	if err != nil {
   375  		t.Errorf("r.MultipartForm(1024): got err %v", err)
   376  	}
   377  
   378  	fhs := f.File("file")
   379  	if fhs == nil {
   380  		t.Error(`f.File("file"): got nil, want file header`)
   381  	}
   382  	defer f.RemoveFiles()
   383  
   384  	file, err := fhs[0].Open()
   385  	if err != nil {
   386  		t.Fatalf("fhs[0].Open(): got err %v, want nil", err)
   387  	}
   388  
   389  	content := make([]byte, 12)
   390  	file.Read(content)
   391  	if want, got := "file content", string(content); want != got {
   392  		t.Errorf("file.Read(content): got %s, want %s", got, want)
   393  	}
   394  }
   395  
   396  func TestIncomingRequestMultipartFormAndFileUpload(t *testing.T) {
   397  	body := "--123\r\n" +
   398  		"Content-Disposition: form-data; name=\"key\"\r\n" +
   399  		"\r\n" +
   400  		"12\r\n" +
   401  		"--123\r\n" +
   402  		"Content-Disposition: form-data; name=\"file\"; filename=\"myfile\"\r\n" +
   403  		"\r\n" +
   404  		"file content\r\n" +
   405  		"--123--\r\n"
   406  	r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body))
   407  	r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
   408  
   409  	f, err := r.MultipartForm(1024)
   410  	if err != nil {
   411  		t.Errorf("r.MultipartForm(1024): got err %v", err)
   412  	}
   413  
   414  	if want, got := int64(12), f.Int64("key", 0); want != got {
   415  		t.Errorf(`f.Int64("key", 0): got %d, want %d`, got, want)
   416  	}
   417  	if err := f.Err(); err != nil {
   418  		t.Errorf("f.Err(): got err %v", err)
   419  	}
   420  
   421  	fhs := f.File("file")
   422  	if fhs == nil {
   423  		t.Error(`f.File("file"): got nil, want file header`)
   424  	}
   425  	defer f.RemoveFiles()
   426  
   427  	file, err := fhs[0].Open()
   428  	if err != nil {
   429  		t.Fatalf("fhs[0].Open(): got err %v, want nil", err)
   430  	}
   431  
   432  	content := make([]byte, 12)
   433  	file.Read(content)
   434  	if want, got := "file content", string(content); want != got {
   435  		t.Errorf("file.Read(content): got %s, want %s", got, want)
   436  	}
   437  }
   438  
   439  func TestIncomingRequestFileUploadMissingContent(t *testing.T) {
   440  	body := "--123\r\n" +
   441  		"Content-Disposition: form-data; name=\"file\"; filename=\"myfile\"\r\n" +
   442  		"\r\n" +
   443  		"--123--\r\n"
   444  	r := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(body))
   445  	r.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
   446  
   447  	f, err := r.MultipartForm(1024)
   448  	if err != nil {
   449  		t.Errorf("r.MultipartForm(1024): got err %v", err)
   450  	}
   451  
   452  	fhs := f.File("file")
   453  	if fhs == nil {
   454  		t.Error(`f.File("file"): got nil, want file header`)
   455  	}
   456  	defer f.RemoveFiles()
   457  
   458  	file, err := fhs[0].Open()
   459  	if err != nil {
   460  		t.Fatalf("fhs[0].Open(): got err %v, want nil", err)
   461  	}
   462  
   463  	content := make([]byte, 0)
   464  	file.Read(content)
   465  	if want, got := "", string(content); want != got {
   466  		t.Errorf("file.Read(content): got %s, want %s", got, want)
   467  	}
   468  }