github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/plugins/xsrf/xsrfangular/xsrf_test.go (about) 1 // Copyright 2020 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package xsrfangular 16 17 import ( 18 "strings" 19 "testing" 20 21 "github.com/google/go-safeweb/safehttp" 22 "github.com/google/go-safeweb/safehttp/safehttptest" 23 ) 24 25 const ( 26 cookieName = "XSRF-TOKEN" 27 headerName = "X-XSRF-TOKEN" 28 ) 29 30 func TestAddCookie(t *testing.T) { 31 tests := []struct { 32 name, cookie string 33 it *Interceptor 34 }{ 35 { 36 name: "Default interceptor", 37 it: Default(), 38 cookie: cookieName, 39 }, 40 { 41 name: "Custom interceptor", 42 it: &Interceptor{ 43 TokenCookieName: "FOO-TOKEN", 44 TokenHeaderName: "X-FOO-TOKEN", 45 }, 46 cookie: "FOO-TOKEN", 47 }, 48 } 49 50 for _, test := range tests { 51 t.Run(test.name, func(t *testing.T) { 52 req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil) 53 fakeRW, _ := safehttptest.NewFakeResponseWriter() 54 test.it.Commit(fakeRW, req, nil, nil) 55 56 if len(fakeRW.Cookies) != 1 { 57 t.Errorf("len(Cookies) = %v, want 1", len(fakeRW.Cookies)) 58 } 59 60 if got, want := fakeRW.Cookies[0].String(), "Path=/; Max-Age=86400; Secure; SameSite=Strict"; !strings.Contains(got, want) { 61 t.Errorf("XSRF cookie got %q, want to contain %q", got, want) 62 } 63 }) 64 } 65 } 66 67 func TestAddCookieFail(t *testing.T) { 68 req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil) 69 fakeRW, _ := safehttptest.NewFakeResponseWriter() 70 it := &Interceptor{} 71 72 defer func() { 73 if r := recover(); r == nil { 74 t.Fatal("expected panic") 75 } 76 }() 77 it.Commit(fakeRW, req, nil, nil) 78 } 79 80 func TestPostProtection(t *testing.T) { 81 tests := []struct { 82 name string 83 req *safehttp.IncomingRequest 84 wantStatus safehttp.StatusCode 85 wantHeader map[string][]string 86 wantBody string 87 }{ 88 { 89 name: "Same cookie and header", 90 req: func() *safehttp.IncomingRequest { 91 req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil) 92 req.Header.Set("Cookie", cookieName+"="+"1234") 93 req.Header.Set(headerName, "1234") 94 return req 95 }(), 96 wantStatus: safehttp.StatusOK, 97 wantHeader: map[string][]string{}, 98 wantBody: "", 99 }, 100 { 101 name: "Different cookie and header", 102 req: func() *safehttp.IncomingRequest { 103 req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil) 104 req.Header.Set("Cookie", cookieName+"="+"5768") 105 req.Header.Set(headerName, "1234") 106 return req 107 }(), 108 wantStatus: safehttp.StatusUnauthorized, 109 wantHeader: map[string][]string{ 110 "Content-Type": {"text/plain; charset=utf-8"}, 111 "X-Content-Type-Options": {"nosniff"}, 112 }, 113 wantBody: "Unauthorized\n", 114 }, 115 { 116 name: "Missing header", 117 req: func() *safehttp.IncomingRequest { 118 req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil) 119 req.Header.Set("Cookie", cookieName+"="+"1234") 120 return req 121 }(), 122 wantStatus: safehttp.StatusUnauthorized, 123 wantHeader: map[string][]string{ 124 "Content-Type": {"text/plain; charset=utf-8"}, 125 "X-Content-Type-Options": {"nosniff"}, 126 }, 127 wantBody: "Unauthorized\n", 128 }, 129 { 130 name: "Missing cookie", 131 req: func() *safehttp.IncomingRequest { 132 req := safehttptest.NewRequest(safehttp.MethodPost, "/", nil) 133 req.Header.Set(headerName, "1234") 134 return req 135 }(), 136 wantStatus: safehttp.StatusForbidden, 137 wantHeader: map[string][]string{ 138 "Content-Type": {"text/plain; charset=utf-8"}, 139 "X-Content-Type-Options": {"nosniff"}, 140 }, 141 wantBody: "Forbidden\n", 142 }, 143 } 144 145 for _, test := range tests { 146 t.Run(test.name, func(t *testing.T) { 147 fakeRW, rr := safehttptest.NewFakeResponseWriter() 148 i := Default() 149 i.Before(fakeRW, test.req, nil) 150 151 if got := rr.Code; got != int(test.wantStatus) { 152 t.Errorf("rr.Code: got %v, want %v", got, test.wantStatus) 153 } 154 }) 155 } 156 }