github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/examples/sample-application/secure/auth/auth_test.go (about) 1 // Copyright 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package auth 16 17 import ( 18 "testing" 19 20 "github.com/google/go-safeweb/examples/sample-application/storage" 21 "github.com/google/go-safeweb/safehttp" 22 "github.com/google/go-safeweb/safehttp/safehttptest" 23 ) 24 25 const TEST_USER = "test" 26 27 func TestInterceptorBefore(t *testing.T) { 28 tests := []struct { 29 name string 30 cfg safehttp.InterceptorConfig 31 hasAuth bool 32 want safehttp.StatusCode 33 }{ 34 { 35 name: "base case, no error", 36 hasAuth: true, 37 cfg: nil, 38 want: safehttp.StatusOK, 39 }, 40 { 41 name: "force skip using config", 42 hasAuth: true, 43 cfg: Skip{}, 44 want: safehttp.StatusOK, 45 }, 46 { 47 name: "missing auth, error", 48 hasAuth: false, 49 cfg: nil, 50 want: safehttp.StatusUnauthorized, 51 }, 52 } 53 for _, tt := range tests { 54 t.Run(tt.name, func(t *testing.T) { 55 withUserDB, token := addTestUser(storage.NewDB()) 56 ip := newTestInterceptor(withUserDB) 57 58 req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil) 59 rw, r := safehttptest.NewFakeResponseWriter() 60 61 if tt.hasAuth { 62 addTestUserCookie(req, token) 63 } 64 65 // Note: "Before" return value is not significant 66 ip.Before(rw, req, tt.cfg) 67 68 if got := r.Code; got != int(tt.want) { 69 t.Errorf("status code got: %d, want %d", got, tt.want) 70 } 71 }) 72 } 73 } 74 75 func TestInterceptorCommit(t *testing.T) { 76 tests := []struct { 77 name string 78 action sessionAction 79 hasCookie bool 80 }{ 81 { 82 name: "clear session, no error", 83 action: clearSess, 84 hasCookie: false, 85 }, { 86 name: "set session, no error", 87 action: setSess, 88 hasCookie: true, 89 }, 90 { 91 name: "unexpected action, skip", 92 action: sessionAction("unexpected"), 93 hasCookie: false, 94 }, 95 } 96 for _, tt := range tests { 97 t.Run(tt.name, func(t *testing.T) { 98 withUserDB, _ := addTestUser(storage.NewDB()) 99 ip := newTestInterceptor(withUserDB) 100 101 req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil) 102 rw, r := safehttptest.NewFakeResponseWriter() 103 104 safehttp.FlightValues(req.Context()).Put(userKey, "user") 105 safehttp.FlightValues(req.Context()).Put(changeSessKey, tt.action) 106 107 ip.Commit(rw, req, r.Result, nil) 108 109 var token string 110 for _, c := range rw.Cookies { 111 if c.Name() == sessionCookie { 112 token = c.Value() 113 } 114 } 115 116 if tt.hasCookie == (token == "") { 117 t.Errorf("token = %q, want %v", token, tt.hasCookie) 118 } 119 }) 120 } 121 } 122 123 func TestInterceptorMatch(t *testing.T) { 124 tests := []struct { 125 name string 126 cfg safehttp.InterceptorConfig 127 want bool 128 }{ 129 { 130 name: "basic case, no error", 131 cfg: Skip{}, 132 want: true, 133 }, 134 { 135 name: "no Skip{}, error", 136 cfg: nil, 137 want: false, 138 }, 139 } 140 for _, tt := range tests { 141 t.Run(tt.name, func(t *testing.T) { 142 ip := newTestInterceptor(nil) 143 if got := ip.Match(tt.cfg); got != tt.want { 144 t.Errorf("Interceptor.Match() = %v, want %v", got, tt.want) 145 } 146 }) 147 } 148 } 149 150 func TestInterceptorUserFromCookie(t *testing.T) { 151 withUserDB, validToken := addTestUser(storage.NewDB()) 152 ip := newTestInterceptor(withUserDB) 153 154 tests := []struct { 155 name string 156 token string 157 want string 158 }{ 159 { 160 name: "basic case, no error", 161 token: validToken, 162 want: TEST_USER, 163 }, 164 { 165 name: "empty cookie, error", 166 token: "", 167 want: "", 168 }, 169 { 170 name: "invalid token, error", 171 token: "not_a_valid_token", 172 want: "", 173 }, 174 } 175 for _, tt := range tests { 176 t.Run(tt.name, func(t *testing.T) { 177 req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil) 178 addTestUserCookie(req, tt.token) 179 180 if got := ip.userFromCookie(req); got != tt.want { 181 t.Errorf("Interceptor.userFromCookie() = %v, want %v", got, tt.want) 182 } 183 }) 184 } 185 } 186 187 func TestSessionManagement(t *testing.T) { 188 want := "wanted" 189 r := safehttptest.NewRequest(safehttp.MethodGet, "/", nil) 190 191 CreateSession(r, want) 192 if got := User(r); got != want { 193 t.Errorf("user id got: %q, want %q", got, want) 194 } 195 196 ClearSession(r) 197 // Note: `ctxSessionAction` already tested inside ctx_test.go 198 if got := ctxSessionAction(r.Context()); got != clearSess { 199 t.Errorf("no clearSess action found in context after ClearSession") 200 } 201 } 202 203 func addTestUserCookie(r *safehttp.IncomingRequest, v string) { 204 r.Header.Add("Cookie", safehttp.NewCookie(sessionCookie, v).String()) 205 } 206 207 func newTestInterceptor(db *storage.DB) Interceptor { 208 if db == nil { 209 db = storage.NewDB() 210 } 211 return Interceptor{ 212 DB: db, 213 } 214 } 215 216 func addTestUser(db *storage.DB) (*storage.DB, string) { 217 token := (*db).GetToken(TEST_USER) 218 return db, token 219 }