github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/xsrf/xsrfangular/xsrf_test.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package xsrfangular
    16  
    17  import (
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/google/go-safeweb/safehttp"
    22  	"github.com/google/go-safeweb/safehttp/safehttptest"
    23  )
    24  
    25  const (
    26  	cookieName = "XSRF-TOKEN"
    27  	headerName = "X-XSRF-TOKEN"
    28  )
    29  
    30  func TestAddCookie(t *testing.T) {
    31  	tests := []struct {
    32  		name, cookie string
    33  		it           *Interceptor
    34  	}{
    35  		{
    36  			name:   "Default interceptor",
    37  			it:     Default(),
    38  			cookie: cookieName,
    39  		},
    40  		{
    41  			name: "Custom interceptor",
    42  			it: &Interceptor{
    43  				TokenCookieName: "FOO-TOKEN",
    44  				TokenHeaderName: "X-FOO-TOKEN",
    45  			},
    46  			cookie: "FOO-TOKEN",
    47  		},
    48  	}
    49  
    50  	for _, test := range tests {
    51  		t.Run(test.name, func(t *testing.T) {
    52  			req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
    53  			fakeRW, _ := safehttptest.NewFakeResponseWriter()
    54  			test.it.Commit(fakeRW, req, nil, nil)
    55  
    56  			if len(fakeRW.Cookies) != 1 {
    57  				t.Errorf("len(Cookies) = %v, want 1", len(fakeRW.Cookies))
    58  			}
    59  
    60  			if got, want := fakeRW.Cookies[0].String(), "Path=/; Max-Age=86400; Secure; SameSite=Strict"; !strings.Contains(got, want) {
    61  				t.Errorf("XSRF cookie got %q, want to contain %q", got, want)
    62  			}
    63  		})
    64  	}
    65  }
    66  
    67  func TestAddCookieFail(t *testing.T) {
    68  	req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
    69  	fakeRW, _ := safehttptest.NewFakeResponseWriter()
    70  	it := &Interceptor{}
    71  
    72  	defer func() {
    73  		if r := recover(); r == nil {
    74  			t.Fatal("expected panic")
    75  		}
    76  	}()
    77  	it.Commit(fakeRW, req, nil, nil)
    78  }
    79  
    80  func TestPostProtection(t *testing.T) {
    81  	tests := []struct {
    82  		name       string
    83  		req        *safehttp.IncomingRequest
    84  		wantStatus safehttp.StatusCode
    85  		wantHeader map[string][]string
    86  		wantBody   string
    87  	}{
    88  		{
    89  			name: "Same cookie and header",
    90  			req: func() *safehttp.IncomingRequest {
    91  				req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil)
    92  				req.Header.Set("Cookie", cookieName+"="+"1234")
    93  				req.Header.Set(headerName, "1234")
    94  				return req
    95  			}(),
    96  			wantStatus: safehttp.StatusOK,
    97  			wantHeader: map[string][]string{},
    98  			wantBody:   "",
    99  		},
   100  		{
   101  			name: "Different cookie and header",
   102  			req: func() *safehttp.IncomingRequest {
   103  				req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil)
   104  				req.Header.Set("Cookie", cookieName+"="+"5768")
   105  				req.Header.Set(headerName, "1234")
   106  				return req
   107  			}(),
   108  			wantStatus: safehttp.StatusUnauthorized,
   109  			wantHeader: map[string][]string{
   110  				"Content-Type":           {"text/plain; charset=utf-8"},
   111  				"X-Content-Type-Options": {"nosniff"},
   112  			},
   113  			wantBody: "Unauthorized\n",
   114  		},
   115  		{
   116  			name: "Missing header",
   117  			req: func() *safehttp.IncomingRequest {
   118  				req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil)
   119  				req.Header.Set("Cookie", cookieName+"="+"1234")
   120  				return req
   121  			}(),
   122  			wantStatus: safehttp.StatusUnauthorized,
   123  			wantHeader: map[string][]string{
   124  				"Content-Type":           {"text/plain; charset=utf-8"},
   125  				"X-Content-Type-Options": {"nosniff"},
   126  			},
   127  			wantBody: "Unauthorized\n",
   128  		},
   129  		{
   130  			name: "Missing cookie",
   131  			req: func() *safehttp.IncomingRequest {
   132  				req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil)
   133  				req.Header.Set(headerName, "1234")
   134  				return req
   135  			}(),
   136  			wantStatus: safehttp.StatusForbidden,
   137  			wantHeader: map[string][]string{
   138  				"Content-Type":           {"text/plain; charset=utf-8"},
   139  				"X-Content-Type-Options": {"nosniff"},
   140  			},
   141  			wantBody: "Forbidden\n",
   142  		},
   143  	}
   144  
   145  	for _, test := range tests {
   146  		t.Run(test.name, func(t *testing.T) {
   147  			fakeRW, rr := safehttptest.NewFakeResponseWriter()
   148  			i := Default()
   149  			i.Before(fakeRW, test.req, nil)
   150  
   151  			if got := rr.Code; got != int(test.wantStatus) {
   152  				t.Errorf("rr.Code: got %v, want %v", got, test.wantStatus)
   153  			}
   154  		})
   155  	}
   156  }