github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/form_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp
    16  
    17  import (
    18  	"math"
    19  	"mime/multipart"
    20  	"strconv"
    21  	"testing"
    22  
    23  	"github.com/google/go-cmp/cmp"
    24  )
    25  
    26  func TestFormValidInt64(t *testing.T) {
    27  	tests := []struct {
    28  		val  string
    29  		want int64
    30  	}{
    31  		{val: "123", want: 123},
    32  		{val: "9223372036854775807", want: math.MaxInt64},
    33  		{val: "-1", want: -1},
    34  	}
    35  
    36  	for _, tt := range tests {
    37  		t.Run(tt.val, func(t *testing.T) {
    38  			values := map[string][]string{"a": {tt.val}}
    39  			f := Form{values: values}
    40  
    41  			if got := f.Int64("a", 0); got != tt.want {
    42  				t.Errorf(`f.Int64("a", 0) got: %v want: %v`, got, tt.want)
    43  			}
    44  
    45  			if err := f.Err(); err != nil {
    46  				t.Errorf("f.Err() got: %v want: nil", err)
    47  			}
    48  		})
    49  	}
    50  }
    51  
    52  func TestFormInvalidInt64(t *testing.T) {
    53  	tests := []struct {
    54  		name string
    55  		val  string
    56  	}{
    57  		{name: "Overflow", val: "9223372036854775810"},
    58  		{name: "Not a number", val: "abc"},
    59  	}
    60  
    61  	for _, tt := range tests {
    62  		t.Run(tt.name, func(t *testing.T) {
    63  			values := map[string][]string{"a": {tt.val}}
    64  			f := Form{values: values}
    65  
    66  			if got, want := f.Int64("a", 0), int64(0); got != want {
    67  				t.Errorf(`f.Int64("a", 0) got: %v want: %v`, got, want)
    68  			}
    69  
    70  			if err := f.Err(); err == nil {
    71  				t.Error("f.Err() got: nil want: error")
    72  			}
    73  		})
    74  	}
    75  }
    76  
    77  func TestFormValidUint64(t *testing.T) {
    78  	tests := []struct {
    79  		val  string
    80  		want uint64
    81  	}{
    82  		{val: "123", want: 123},
    83  		{val: "18446744073709551615", want: math.MaxUint64},
    84  	}
    85  
    86  	for _, tt := range tests {
    87  		t.Run(tt.val, func(t *testing.T) {
    88  			values := map[string][]string{"a": {tt.val}}
    89  			f := Form{values: values}
    90  
    91  			if got := f.Uint64("a", 0); got != tt.want {
    92  				t.Errorf(`f.Uint64("a", 0) got: %v want: %v`, got, tt.want)
    93  			}
    94  
    95  			if err := f.Err(); err != nil {
    96  				t.Errorf("f.Err() got: %v want: nil", err)
    97  			}
    98  		})
    99  	}
   100  }
   101  
   102  func TestFormInvalidUint(t *testing.T) {
   103  	tests := []struct {
   104  		name string
   105  		val  string
   106  	}{
   107  		{name: "Negative", val: "-1"},
   108  		{name: "Overflow", val: "18446744073709551630"},
   109  		{name: "Not a number", val: "abc"},
   110  	}
   111  
   112  	for _, tt := range tests {
   113  		t.Run(tt.name, func(t *testing.T) {
   114  			values := map[string][]string{"a": {tt.val}}
   115  			f := Form{values: values}
   116  
   117  			if got, want := f.Uint64("a", 0), uint64(0); got != want {
   118  				t.Errorf(`f.Uint64("a", 0) got: %v want: %v`, got, want)
   119  			}
   120  
   121  			if err := f.Err(); err == nil {
   122  				t.Error("f.Err() got: nil want: error")
   123  			}
   124  		})
   125  	}
   126  }
   127  
   128  func TestFormValidString(t *testing.T) {
   129  	tests := []string{
   130  		"b",
   131  		"diavola",
   132  		"ăȚâȘî",
   133  		"\x64\x69\x61\x76\x6f\x6c\x61",
   134  	}
   135  
   136  	for _, val := range tests {
   137  		t.Run(val, func(t *testing.T) {
   138  			values := map[string][]string{"a": {val}}
   139  			f := Form{values: values}
   140  
   141  			if got := f.String("a", ""); got != val {
   142  				t.Errorf(`f.String("a", 0) got: %v want: %v`, got, val)
   143  			}
   144  
   145  			if err := f.Err(); err != nil {
   146  				t.Errorf("f.Err() got: %v want: nil", err)
   147  			}
   148  		})
   149  	}
   150  }
   151  
   152  func TestFormValidFloat64(t *testing.T) {
   153  	tests := []struct {
   154  		val  string
   155  		want float64
   156  	}{
   157  		{
   158  			val:  "1.234",
   159  			want: 1.234,
   160  		},
   161  		{
   162  			val:  strconv.FormatFloat(math.MaxFloat64, 'f', 6, 64),
   163  			want: math.MaxFloat64,
   164  		},
   165  		{
   166  			val:  strconv.FormatFloat(-math.SmallestNonzeroFloat64, 'f', 324, 64),
   167  			want: -math.SmallestNonzeroFloat64,
   168  		},
   169  	}
   170  
   171  	for _, tt := range tests {
   172  		t.Run(tt.val, func(t *testing.T) {
   173  			values := map[string][]string{"a": {tt.val}}
   174  			f := Form{values: values}
   175  
   176  			if got := f.Float64("a", 0); got != tt.want {
   177  				t.Errorf(`f.Float64("a", 0) got: %v want: %v`, got, tt.want)
   178  			}
   179  
   180  			if err := f.Err(); err != nil {
   181  				t.Errorf("f.Err() got: %v want: nil", err)
   182  			}
   183  		})
   184  	}
   185  }
   186  
   187  func TestFormInvalidFloat64(t *testing.T) {
   188  	tests := []struct {
   189  		name string
   190  		val  string
   191  	}{
   192  		{name: "Not a float", val: "abc"},
   193  		{name: "Overflow", val: "1.797693134862315708145274237317043567981e309"},
   194  	}
   195  
   196  	for _, tt := range tests {
   197  		t.Run(tt.name, func(t *testing.T) {
   198  			values := map[string][]string{"a": {tt.val}}
   199  			f := Form{values: values}
   200  
   201  			if got, want := f.Float64("a", 0), float64(0); got != want {
   202  				t.Errorf(`f.Float64("a", 0) got: %v want: %v`, got, want)
   203  			}
   204  
   205  			if err := f.Err(); err == nil {
   206  				t.Error("f.Err() got: nil want: error")
   207  			}
   208  		})
   209  	}
   210  }
   211  
   212  func TestFormValidBool(t *testing.T) {
   213  	tests := []struct {
   214  		val  string
   215  		want bool
   216  	}{
   217  		{val: "true", want: true},
   218  		{val: "True", want: true},
   219  		{val: "TRUE", want: true},
   220  		{val: "t", want: true},
   221  		{val: "T", want: true},
   222  		{val: "1", want: true},
   223  		{val: "false", want: false},
   224  		{val: "False", want: false},
   225  		{val: "FALSE", want: false},
   226  		{val: "f", want: false},
   227  		{val: "F", want: false},
   228  		{val: "0", want: false},
   229  	}
   230  
   231  	for _, tt := range tests {
   232  		t.Run(tt.val, func(t *testing.T) {
   233  			values := map[string][]string{"a": {tt.val}}
   234  			f := Form{values: values}
   235  
   236  			if got := f.Bool("a", false); got != tt.want {
   237  				t.Errorf(`f.Bool("a", 0) got: %v want: %v`, got, tt.want)
   238  			}
   239  
   240  			if err := f.Err(); err != nil {
   241  				t.Errorf("f.Err() got: %v want: nil", err)
   242  			}
   243  		})
   244  	}
   245  }
   246  
   247  func TestFormInvalidBool(t *testing.T) {
   248  	tests := []struct {
   249  		name string
   250  		val  string
   251  	}{
   252  		{name: "Invalid casing", val: "TRuE"},
   253  		{name: "Not a bool", val: "potato"},
   254  	}
   255  
   256  	for _, tt := range tests {
   257  		t.Run(tt.name, func(t *testing.T) {
   258  			values := map[string][]string{"a": {tt.val}}
   259  			f := Form{values: values}
   260  
   261  			if got, want := f.Bool("a", false), false; got != want {
   262  				t.Errorf(`f.Bool("a", 0) got: %v want: %v`, got, want)
   263  			}
   264  
   265  			if err := f.Err(); err == nil {
   266  				t.Error("f.Err() got: nil want: error")
   267  			}
   268  		})
   269  	}
   270  }
   271  
   272  func TestFormValidSlice(t *testing.T) {
   273  	tests := []struct {
   274  		name        string
   275  		values      []string
   276  		placeholder interface{}
   277  		want        interface{}
   278  	}{
   279  		{
   280  			name:        "Int64",
   281  			values:      []string{"-8", "9", "-100"},
   282  			placeholder: &[]int64{},
   283  			want:        &[]int64{-8, 9, -100},
   284  		},
   285  		{
   286  			name:        "Uint64",
   287  			values:      []string{"8", "9", "10"},
   288  			placeholder: &[]uint64{},
   289  			want:        &[]uint64{8, 9, 10},
   290  		},
   291  		{
   292  			name:        "String",
   293  			values:      []string{"margeritta", "diavola", "calzone"},
   294  			placeholder: &[]string{},
   295  			want:        &[]string{"margeritta", "diavola", "calzone"},
   296  		},
   297  		{
   298  			name:        "Float64",
   299  			values:      []string{"1.3", "8.9", "-4.1"},
   300  			placeholder: &[]float64{},
   301  			want:        &[]float64{1.3, 8.9, -4.1},
   302  		},
   303  		{
   304  			name:        "Bool",
   305  			values:      []string{"t", "0", "TRUE"},
   306  			placeholder: &[]bool{},
   307  			want:        &[]bool{true, false, true},
   308  		},
   309  	}
   310  
   311  	for _, tt := range tests {
   312  		t.Run(tt.name, func(t *testing.T) {
   313  			f := Form{values: map[string][]string{"x": tt.values}}
   314  
   315  			f.Slice("x", tt.placeholder)
   316  			if diff := cmp.Diff(tt.want, tt.placeholder); diff != "" {
   317  				t.Errorf("f.Slice: diff (-want +got): \n%v", diff)
   318  			}
   319  
   320  			if err := f.Err(); err != nil {
   321  				t.Errorf("f.Err() got: %v want: nil", err)
   322  			}
   323  		})
   324  	}
   325  }
   326  
   327  func TestFormInvalidSlice(t *testing.T) {
   328  	tests := []struct {
   329  		name        string
   330  		values      []string
   331  		placeholder interface{}
   332  	}{
   333  		{
   334  			name:        "Int64",
   335  			values:      []string{"1", "abc", "1"},
   336  			placeholder: &[]int64{},
   337  		},
   338  		{
   339  			name:        "Uint64",
   340  			values:      []string{"1", "abc", "-1"},
   341  			placeholder: &[]uint64{},
   342  		},
   343  		{
   344  			name:        "Float64",
   345  			values:      []string{"1.3", "abc", "-4.1"},
   346  			placeholder: &[]float64{},
   347  		},
   348  		{
   349  			name:        "Bool",
   350  			values:      []string{"t", "abc", "TRUE"},
   351  			placeholder: &[]bool{},
   352  		},
   353  	}
   354  
   355  	for _, tt := range tests {
   356  		t.Run(tt.name, func(t *testing.T) {
   357  			f := Form{values: map[string][]string{"x": tt.values}}
   358  
   359  			f.Slice("x", tt.placeholder)
   360  
   361  			// TODO: add a check here that tt.placeholder is nil. I (grenfeldt@)
   362  			// can't come up with a way of testing this in a table test.
   363  
   364  			if err := f.Err(); err == nil {
   365  				t.Error("f.Err() got: nil want: error")
   366  			}
   367  		})
   368  	}
   369  }
   370  
   371  func TestFormSliceInvalidPointerType(t *testing.T) {
   372  	f := Form{values: map[string][]string{"x": {"x"}}}
   373  	x := []int16{1}
   374  	f.Slice("x", &x)
   375  	if diff := cmp.Diff([]int16{1}, x); diff != "" {
   376  		t.Errorf("f.Slice: diff (-want +got): \n%v", diff)
   377  	}
   378  	if err := f.Err(); err == nil {
   379  		t.Error("f.Err() got: nil want: error")
   380  	}
   381  }
   382  
   383  func TestFormUnknownParam(t *testing.T) {
   384  	tests := []struct {
   385  		name         string
   386  		getFormValue func(f Form) interface{}
   387  		want         interface{}
   388  	}{
   389  		{
   390  			name: "Int64",
   391  			getFormValue: func(f Form) interface{} {
   392  				return f.Int64("x", -15)
   393  			},
   394  			want: int64(-15),
   395  		},
   396  		{
   397  			name: "Uint64",
   398  			getFormValue: func(f Form) interface{} {
   399  				return f.Uint64("x", 15)
   400  			},
   401  			want: uint64(15),
   402  		},
   403  		{
   404  			name: "String",
   405  			getFormValue: func(f Form) interface{} {
   406  				return f.String("x", "missing")
   407  			},
   408  			want: "missing",
   409  		},
   410  		{
   411  			name: "Float64",
   412  			getFormValue: func(f Form) interface{} {
   413  				return f.Float64("x", 3.14)
   414  			},
   415  			want: 3.14,
   416  		},
   417  		{
   418  			name: "Bool",
   419  			getFormValue: func(f Form) interface{} {
   420  				return f.Bool("x", true)
   421  			},
   422  			want: true,
   423  		},
   424  	}
   425  
   426  	for _, tt := range tests {
   427  		t.Run(tt.name, func(t *testing.T) {
   428  			f := Form{}
   429  			got := tt.getFormValue(f)
   430  			if diff := cmp.Diff(tt.want, got); diff != "" {
   431  				t.Errorf("tt.getFormValue(f) diff (-want +got): \n%v", diff)
   432  			}
   433  		})
   434  	}
   435  }
   436  
   437  func TestFormSliceUnknownParam(t *testing.T) {
   438  	f := Form{}
   439  	p := []int64{1, 2}
   440  	f.Slice("x", &p)
   441  	if p != nil {
   442  		t.Errorf(`f.Slice("x", &p) got p: %v want: nil`, p)
   443  	}
   444  	if err := f.Err(); err != nil {
   445  		t.Errorf("f.Err() got: %v want: nil", err)
   446  	}
   447  }
   448  
   449  func TestFormClearSliceString(t *testing.T) {
   450  	x := []string{"xyz"}
   451  	if err := clearSlice(&x); err != nil {
   452  		t.Errorf("clearSlice(&x) got err: %v want: nil", err)
   453  	}
   454  	if x != nil {
   455  		t.Errorf("clearSlice(&x) got x: %v want: nil", x)
   456  	}
   457  }
   458  
   459  func TestFormClearSliceInt64(t *testing.T) {
   460  	x := []int64{-1}
   461  	if err := clearSlice(&x); err != nil {
   462  		t.Errorf("clearSlice(&x) got err: %v want: nil", err)
   463  	}
   464  	if x != nil {
   465  		t.Errorf("clearSlice(&x) got x: %v want: nil", x)
   466  	}
   467  }
   468  
   469  func TestFormClearSliceUint64(t *testing.T) {
   470  	x := []uint64{1}
   471  	if err := clearSlice(&x); err != nil {
   472  		t.Errorf("clearSlice(&x) got err: %v want: nil", err)
   473  	}
   474  	if x != nil {
   475  		t.Errorf("clearSlice(&x) got x: %v want: nil", x)
   476  	}
   477  }
   478  
   479  func TestFormClearSliceFloat64(t *testing.T) {
   480  	x := []float64{3.14}
   481  	if err := clearSlice(&x); err != nil {
   482  		t.Errorf("clearSlice(&x) got err: %v want: nil", err)
   483  	}
   484  	if x != nil {
   485  		t.Errorf("clearSlice(&x) got x: %v want: nil", x)
   486  	}
   487  }
   488  
   489  func TestFormClearSliceBool(t *testing.T) {
   490  	x := []bool{true}
   491  	if err := clearSlice(&x); err != nil {
   492  		t.Errorf("clearSlice(&x) got err: %v want: nil", err)
   493  	}
   494  	if x != nil {
   495  		t.Errorf("clearSlice(&x) got x: %v want: nil", x)
   496  	}
   497  }
   498  
   499  func TestFormClearSliceUnknownType(t *testing.T) {
   500  	x := []int16{-1}
   501  	if err := clearSlice(&x); err == nil {
   502  		t.Error("clearSlice(&x) got: nil want: error")
   503  	}
   504  }
   505  
   506  func TestMultipartFormValidFile(t *testing.T) {
   507  	fh := &multipart.FileHeader{Filename: "bar"}
   508  	mf := &multipart.Form{
   509  		File: map[string][]*multipart.FileHeader{
   510  			"foo": {fh},
   511  		}}
   512  	f := newMulipartForm(mf)
   513  	fhs := f.File("foo")
   514  	if fhs == nil {
   515  		t.Errorf(`m.File("foo"): got nil, want file headers`)
   516  	}
   517  	if diff := cmp.Diff(fh, fhs[0], cmp.AllowUnexported(multipart.FileHeader{})); diff != "" {
   518  		t.Errorf("file headers mismatch (-want +got):\n%s", diff)
   519  	}
   520  }
   521  
   522  func TestMultipartFormValidFileAndVals(t *testing.T) {
   523  	fh := &multipart.FileHeader{Filename: "bar"}
   524  	mf := &multipart.Form{
   525  		Value: map[string][]string{"number": {"1"}},
   526  		File: map[string][]*multipart.FileHeader{
   527  			"foo": {fh},
   528  		}}
   529  	f := newMulipartForm(mf)
   530  
   531  	if want, got := int64(1), f.Int64("number", 0); want != got {
   532  		t.Errorf(`f.Int64("number"): got %d, want %d`, got, want)
   533  	}
   534  	if err := f.Err(); err != nil {
   535  		t.Errorf(`f.Err(): got err %v`, err)
   536  	}
   537  
   538  	fhs := f.File("foo")
   539  	if fhs == nil {
   540  		t.Errorf(`m.File("foo"): got nil, want file headers`)
   541  	}
   542  	if diff := cmp.Diff(fh, fhs[0], cmp.AllowUnexported(multipart.FileHeader{})); diff != "" {
   543  		t.Errorf("file headers mismatch (-want +got):\n%s", diff)
   544  	}
   545  }
   546  
   547  func TestMultipartFormFileWithPathInName(t *testing.T) {
   548  	fh := &multipart.FileHeader{Filename: "../tmp/myfile.txt"}
   549  	mf := &multipart.Form{
   550  		File: map[string][]*multipart.FileHeader{
   551  			"foo": {fh},
   552  		}}
   553  	f := newMulipartForm(mf)
   554  
   555  	filename := f.File("foo")[0].Filename
   556  	if want, got := "myfile.txt", filename; want != got {
   557  		t.Errorf(`f.File("foo").Filename: got %q, want %q`, got, want)
   558  	}
   559  	if err := f.Err(); err != nil {
   560  		t.Errorf(`f.Err(): got err %v, want nil`, err)
   561  	}
   562  }
   563  
   564  func TestMultipartFormMissingFile(t *testing.T) {
   565  	f := newMulipartForm(&multipart.Form{})
   566  	fhs := f.File("x")
   567  	if fhs != nil {
   568  		t.Errorf(`m.File("x"): got file headers, want nil`)
   569  	}
   570  }