github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/xsrf/xsrfhtml/xsrf_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package xsrfhtml
    16  
    17  import (
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/google/go-safeweb/safehttp"
    22  	"github.com/google/go-safeweb/safehttp/safehttptest"
    23  	"golang.org/x/net/xsrftoken"
    24  )
    25  
    26  var (
    27  	formTokenTests = []struct {
    28  		name, cookieVal, host string
    29  		wantStatus            safehttp.StatusCode
    30  	}{
    31  		{
    32  			name:       "Valid token",
    33  			cookieVal:  "abcdef",
    34  			host:       "go.dev",
    35  			wantStatus: safehttp.StatusOK,
    36  		},
    37  		{
    38  			name:       "Invalid host in token generation",
    39  			cookieVal:  "abcdef",
    40  			host:       "google.com",
    41  			wantStatus: safehttp.StatusForbidden,
    42  		},
    43  		{
    44  			name:       "Invalid cookie value in token generation",
    45  			cookieVal:  "evilvalue",
    46  			host:       "go.dev",
    47  			wantStatus: safehttp.StatusForbidden,
    48  		},
    49  	}
    50  )
    51  
    52  func TestTokenPost(t *testing.T) {
    53  	for _, test := range formTokenTests {
    54  		t.Run(test.name, func(t *testing.T) {
    55  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
    56  			tok := xsrftoken.Generate("testSecretAppKey", test.cookieVal, test.host)
    57  			req := safehttptest.NewRequest(safehttp.MethodPost, "https://go.dev/", strings.NewReader(TokenKey+"="+tok))
    58  			req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
    59  			req.Header.Set("Cookie", cookieIDKey+"=abcdef")
    60  
    61  			i := Interceptor{SecretAppKey: "testSecretAppKey"}
    62  			i.Before(fakeRW, req, nil)
    63  
    64  			if got := rr.Code; got != int(test.wantStatus) {
    65  				t.Errorf("rr.Code: got %v, want %v", got, test.wantStatus)
    66  			}
    67  		})
    68  	}
    69  }
    70  
    71  func TestTokenMultipart(t *testing.T) {
    72  	for _, test := range formTokenTests {
    73  		t.Run(test.name, func(t *testing.T) {
    74  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
    75  			tok := xsrftoken.Generate("testSecretAppKey", test.cookieVal, test.host)
    76  			b := "--123\r\n" +
    77  				"Content-Disposition: form-data; name=\"xsrf-token\"\r\n" +
    78  				"\r\n" +
    79  				tok + "\r\n" +
    80  				"--123--\r\n"
    81  			req := safehttptest.NewRequest(safehttp.MethodPost, "https://go.dev/", strings.NewReader(b))
    82  			req.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
    83  			req.Header.Set("Cookie", cookieIDKey+"=abcdef")
    84  
    85  			i := Interceptor{SecretAppKey: "testSecretAppKey"}
    86  			i.Before(fakeRW, req, nil)
    87  
    88  			if got := rr.Code; got != int(test.wantStatus) {
    89  				t.Errorf("rr.Code: got %v, want %v", got, test.wantStatus)
    90  			}
    91  		})
    92  	}
    93  }
    94  
    95  func TestMalformedForm(t *testing.T) {
    96  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
    97  	req := safehttptest.NewRequest(safehttp.MethodPost, "https://foo.com/pizza", nil)
    98  	req.Header.Set("Content-Type", "wrong")
    99  	req.Header.Set("Cookie", cookieIDKey+"=abcdef")
   100  
   101  	i := Interceptor{SecretAppKey: "testSecretAppKey"}
   102  	i.Before(fakeRW, req, nil)
   103  
   104  	if want, got := int(safehttp.StatusBadRequest), rr.Code; got != want {
   105  		t.Errorf("rr.Code: got %v, want %v", got, want)
   106  	}
   107  }
   108  
   109  func TestMissingTokenInBody(t *testing.T) {
   110  	tests := []struct {
   111  		name string
   112  		req  *safehttp.IncomingRequest
   113  	}{
   114  		{
   115  			name: "Missing token in POST request with form",
   116  			req: func() *safehttp.IncomingRequest {
   117  				req := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader("foo=bar"))
   118  				req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   119  				req.Header.Set("Cookie", cookieIDKey+"=abcdef")
   120  				return req
   121  			}(),
   122  		},
   123  		{
   124  			name: "Missing token in PATCH request with form",
   125  			req: func() *safehttp.IncomingRequest {
   126  				req := safehttptest.NewRequest(safehttp.MethodPatch, "/", strings.NewReader("foo=bar"))
   127  				req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   128  				req.Header.Set("Cookie", cookieIDKey+"=abcdef")
   129  				return req
   130  			}(),
   131  		},
   132  		{
   133  			name: "Missing token in POST request with multipart form",
   134  			req: func() *safehttp.IncomingRequest {
   135  				b := "--123\r\n" +
   136  					"Content-Disposition: form-data; name=\"foo\"\r\n" +
   137  					"\r\n" +
   138  					"bar\r\n" +
   139  					"--123--\r\n"
   140  				req := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader(b))
   141  				req.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
   142  				req.Header.Set("Cookie", cookieIDKey+"=abcdef")
   143  				return req
   144  			}(),
   145  		},
   146  		{
   147  			name: "Missing token in PATCH request with multipart form",
   148  			req: func() *safehttp.IncomingRequest {
   149  				b := "--123\r\n" +
   150  					"Content-Disposition: form-data; name=\"foo\"\r\n" +
   151  					"\r\n" +
   152  					"bar\r\n" +
   153  					"--123--\r\n"
   154  				req := safehttptest.NewRequest(safehttp.MethodPatch, "/", strings.NewReader(b))
   155  				req.Header.Set("Content-Type", `multipart/form-data; boundary="123"`)
   156  				req.Header.Set("Cookie", cookieIDKey+"=abcdef")
   157  				return req
   158  			}(),
   159  		},
   160  	}
   161  	for _, test := range tests {
   162  		t.Run(test.name, func(t *testing.T) {
   163  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   164  
   165  			i := Interceptor{SecretAppKey: "testSecretAppKey"}
   166  			i.Before(fakeRW, test.req, nil)
   167  
   168  			if want, got := safehttp.StatusUnauthorized, rr.Code; got != int(want) {
   169  				t.Errorf("rr.Code: got %v, want %v", got, want)
   170  			}
   171  		})
   172  	}
   173  }
   174  
   175  func TestMissingCookieInGetRequest(t *testing.T) {
   176  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   177  	req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
   178  
   179  	i := Interceptor{SecretAppKey: "testSecretAppKey"}
   180  	i.Commit(fakeRW, req, nil, nil)
   181  
   182  	if want, got := safehttp.StatusOK, rr.Code; got != int(want) {
   183  		t.Errorf("rr.Code: got %v, want %v", got, want)
   184  	}
   185  
   186  	if len(fakeRW.Cookies) != 1 {
   187  		t.Errorf("len(Cookies) = %v, want 1", len(fakeRW.Cookies))
   188  	}
   189  	if got, want := fakeRW.Cookies[0].String(), "HttpOnly; Secure; SameSite=Strict"; !strings.Contains(got, want) {
   190  		t.Errorf("XSRF cookie got %q, want to contain %q", got, want)
   191  	}
   192  }
   193  
   194  func TestMissingCookieInPostRequest(t *testing.T) {
   195  	tests := []struct {
   196  		name       string
   197  		stage      func(it *Interceptor, rw safehttp.ResponseWriter, req *safehttp.IncomingRequest)
   198  		wantStatus safehttp.StatusCode
   199  	}{
   200  		{
   201  			name: "In Before stage",
   202  			stage: func(it *Interceptor, rw safehttp.ResponseWriter, req *safehttp.IncomingRequest) {
   203  
   204  				it.Before(rw, req, nil)
   205  			},
   206  			wantStatus: safehttp.StatusForbidden,
   207  		},
   208  		{
   209  			name: "In Commit stage",
   210  			stage: func(it *Interceptor, rw safehttp.ResponseWriter, req *safehttp.IncomingRequest) {
   211  				it.Commit(rw, req, nil, nil)
   212  			},
   213  			wantStatus: safehttp.StatusOK,
   214  		},
   215  	}
   216  
   217  	for _, test := range tests {
   218  		t.Run(test.name, func(t *testing.T) {
   219  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   220  			req := safehttptest.NewRequest(safehttp.MethodPost, "/", strings.NewReader("foo=bar"))
   221  			req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   222  
   223  			test.stage(&Interceptor{SecretAppKey: "testSecretAppKey"}, fakeRW, req)
   224  
   225  			if gotStatus := rr.Code; gotStatus != int(test.wantStatus) {
   226  				t.Errorf("rr.Code: got %v, want %v", gotStatus, test.wantStatus)
   227  			}
   228  		})
   229  	}
   230  
   231  }
   232  
   233  func TestCommitTokenInResponse(t *testing.T) {
   234  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   235  	req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
   236  
   237  	i := Interceptor{SecretAppKey: "testSecretAppKey"}
   238  	tr := &safehttp.TemplateResponse{}
   239  	i.Commit(fakeRW, req, tr, nil)
   240  
   241  	tok, ok := tr.FuncMap["XSRFToken"]
   242  	if !ok {
   243  		t.Fatal(`tr.FuncMap["XSRFToken"] not found`)
   244  	}
   245  
   246  	fn, ok := tok.(func() string)
   247  	if !ok {
   248  		t.Fatalf(`tr.FuncMap["XSRFToken"]: got %T, want "func() string"`, fn)
   249  	}
   250  	if got := fn(); got == "" {
   251  		t.Error(`tr.FuncMap["XSRFToken"](): got empty token`, got)
   252  	}
   253  
   254  	if want, got := safehttp.StatusOK, rr.Code; got != int(want) {
   255  		t.Errorf("rr.Code: got %v, want %v", got, want)
   256  	}
   257  
   258  	if want, got := "", rr.Body.String(); got != want {
   259  		t.Errorf("rr.Body.String(): got %q want %q", got, want)
   260  	}
   261  }
   262  
   263  func TestCommitNotTemplateResponse(t *testing.T) {
   264  	fakeRW, rr := safehttptest.NewFakeResponseWriter()
   265  	req := safehttptest.NewRequest(safehttp.MethodGet, "https://foo.com/pizza", nil)
   266  
   267  	i := Interceptor{SecretAppKey: "testSecretAppKey"}
   268  	i.Commit(fakeRW, req, safehttp.NoContentResponse{}, nil)
   269  
   270  	if want, got := safehttp.StatusOK, rr.Code; got != int(want) {
   271  		t.Errorf("rr.Code: got %v, want %v", got, want)
   272  	}
   273  
   274  	if want, got := "", rr.Body.String(); got != want {
   275  		t.Errorf("rr.Body.String(): got %q want %q", got, want)
   276  	}
   277  }