github.com/dhax/go-base@v0.0.0-20231004214136-8be7e5c1972b/auth/pwdless/api_test.go (about) 1 package pwdless 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "net/http" 11 "net/http/httptest" 12 "os" 13 "strings" 14 "testing" 15 "time" 16 17 "github.com/go-chi/chi/v5" 18 "github.com/spf13/viper" 19 20 "github.com/dhax/go-base/auth/jwt" 21 "github.com/dhax/go-base/email" 22 "github.com/dhax/go-base/logging" 23 ) 24 25 var ( 26 auth *Resource 27 authStore MockAuthStore 28 mailer email.MockMailer 29 ts *httptest.Server 30 ) 31 32 func TestMain(m *testing.M) { 33 viper.SetDefault("auth_login_token_length", 8) 34 viper.SetDefault("auth_login_token_expiry", "11m") 35 viper.SetDefault("auth_jwt_secret", "random") 36 viper.SetDefault("log_level", "error") 37 38 var err error 39 auth, err = NewResource(&authStore, &mailer) 40 if err != nil { 41 fmt.Println(err) 42 os.Exit(1) 43 } 44 45 r := chi.NewRouter() 46 r.Use(logging.NewStructuredLogger(logging.NewLogger())) 47 r.Mount("/", auth.Router()) 48 49 ts = httptest.NewServer(r) 50 51 code := m.Run() 52 ts.Close() 53 os.Exit(code) 54 } 55 56 func TestAuthResource_login(t *testing.T) { 57 authStore.GetAccountByEmailFn = func(email string) (*Account, error) { 58 var err error 59 a := Account{ 60 ID: 1, 61 Email: email, 62 Name: "test", 63 } 64 65 switch email { 66 case "not@exists.io": 67 err = errors.New("sql no row") 68 case "disabled@account.io": 69 a.Active = false 70 case "valid@account.io": 71 a.Active = true 72 } 73 return &a, err 74 } 75 76 mailer.LoginTokenFn = func(n, e string, c email.ContentLoginToken) error { 77 return nil 78 } 79 80 tests := []struct { 81 name string 82 email string 83 status int 84 err error 85 }{ 86 {"missing", "", http.StatusUnauthorized, ErrInvalidLogin}, 87 {"inexistent", "not@exists.io", http.StatusUnauthorized, ErrUnknownLogin}, 88 {"disabled", "disabled@account.io", http.StatusUnauthorized, ErrLoginDisabled}, 89 {"valid", "valid@account.io", http.StatusOK, nil}, 90 } 91 92 for _, tc := range tests { 93 t.Run(tc.name, func(t *testing.T) { 94 req, err := encode(&loginRequest{Email: tc.email}) 95 if err != nil { 96 t.Fatal("failed to encode request body") 97 } 98 res, body := testRequest(t, ts, "POST", "/login", req, "") 99 100 if res.StatusCode != tc.status { 101 t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status) 102 } 103 if tc.err != nil && !strings.Contains(body, tc.err.Error()) { 104 t.Errorf(" got: %s, expected to contain: %s", body, tc.err.Error()) 105 } 106 if tc.err == ErrInvalidLogin && authStore.GetAccountByEmailInvoked { 107 t.Error("GetByLoginToken invoked for invalid email") 108 } 109 if tc.err == nil && !mailer.LoginTokenInvoked { 110 t.Error("emailService.LoginToken not invoked") 111 } 112 authStore.GetAccountByEmailInvoked = false 113 mailer.LoginTokenInvoked = false 114 }) 115 } 116 } 117 118 func TestAuthResource_token(t *testing.T) { 119 authStore.GetAccountFn = func(id int) (*Account, error) { 120 var err error 121 a := Account{ 122 ID: id, 123 Active: true, 124 Name: "test", 125 } 126 switch id { 127 case 2: 128 a.Active = false 129 case 3: 130 // unmodified 131 default: 132 err = errors.New("sql no rows") 133 } 134 return &a, err 135 } 136 authStore.UpdateAccountFn = func(a *Account) error { 137 a.LastLogin = time.Now() 138 return nil 139 } 140 authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error { 141 return nil 142 } 143 144 tests := []struct { 145 name string 146 token string 147 id int 148 status int 149 err error 150 }{ 151 {"invalid", "#ยง$%", 0, http.StatusUnauthorized, ErrLoginToken}, 152 {"expired", "12345678", 0, http.StatusUnauthorized, ErrLoginToken}, 153 {"deleted_account", "", 1, http.StatusUnauthorized, ErrUnknownLogin}, 154 {"disabled", "", 2, http.StatusUnauthorized, ErrLoginDisabled}, 155 {"valid", "", 3, http.StatusOK, nil}, 156 } 157 158 for _, tc := range tests { 159 t.Run(tc.name, func(t *testing.T) { 160 token := auth.LoginAuth.CreateToken(tc.id) 161 if tc.token != "" { 162 token.Token = tc.token 163 } 164 165 req, err := encode(tokenRequest{Token: token.Token}) 166 if err != nil { 167 t.Fatal("failed to encode request body") 168 } 169 res, body := testRequest(t, ts, "POST", "/token", req, "") 170 171 if res.StatusCode != tc.status { 172 t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status) 173 } 174 if tc.err != nil && !strings.Contains(body, tc.err.Error()) { 175 t.Errorf("got: %s, expected to contain: %s", body, tc.err.Error()) 176 } 177 if tc.err == ErrLoginToken && authStore.CreateOrUpdateTokenInvoked { 178 t.Errorf("CreateOrUpdate invoked despite error %s", tc.err.Error()) 179 } 180 if tc.err == nil && !authStore.CreateOrUpdateTokenInvoked { 181 t.Error("CreateOrUpdate not invoked") 182 } 183 authStore.CreateOrUpdateTokenInvoked = false 184 }) 185 } 186 } 187 188 func TestAuthResource_refresh(t *testing.T) { 189 authStore.GetAccountFn = func(id int) (*Account, error) { 190 a := Account{ 191 Active: true, 192 Name: "Test", 193 } 194 switch id { 195 case 999: 196 a.Active = false 197 } 198 return &a, nil 199 } 200 authStore.UpdateAccountFn = func(a *Account) error { 201 a.LastLogin = time.Now() 202 return nil 203 } 204 205 authStore.GetTokenFn = func(token string) (*jwt.Token, error) { 206 var err error 207 var t jwt.Token 208 t.Expiry = time.Now().Add(1 * time.Minute) 209 210 switch token { 211 case "not_found": 212 err = errors.New("sql no rows") 213 case "expired": 214 t.Expiry = time.Now().Add(-1 * time.Minute) 215 case "disabled": 216 t.AccountID = 999 217 } 218 return &t, err 219 } 220 authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error { 221 return nil 222 } 223 authStore.DeleteTokenFn = func(t *jwt.Token) error { 224 return nil 225 } 226 227 tests := []struct { 228 name string 229 token string 230 exp time.Duration 231 status int 232 err error 233 }{ 234 {"not_found", "not_found", 1, http.StatusUnauthorized, jwt.ErrTokenExpired}, 235 {"expired", "expired", -1, http.StatusUnauthorized, jwt.ErrTokenExpired}, 236 {"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled}, 237 {"valid", "valid", 1, http.StatusOK, nil}, 238 } 239 240 for _, tc := range tests { 241 t.Run(tc.name, func(t *testing.T) { 242 // refreshJWT, err := auth.TokenAuth.CreateRefreshJWT(jwt.RefreshClaims{Token: tc.token}) 243 // if err != nil { 244 // t.Errorf("failed to create refresh jwt") 245 // } 246 refreshJWT := genRefreshJWT(jwt.RefreshClaims{ 247 Token: tc.token, 248 CommonClaims: jwt.CommonClaims{ 249 ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(), 250 }, 251 }) 252 253 res, body := testRequest(t, ts, "POST", "/refresh", nil, refreshJWT) 254 if res.StatusCode != tc.status { 255 t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status) 256 } 257 if tc.err != nil && !strings.Contains(body, tc.err.Error()) { 258 t.Errorf("got: %s, expected error to contain: %s", body, tc.err.Error()) 259 } 260 if tc.status == http.StatusUnauthorized && authStore.CreateOrUpdateTokenInvoked { 261 t.Errorf("CreateOrUpdate invoked for status %d", tc.status) 262 } 263 if tc.status == http.StatusOK { 264 if !authStore.GetTokenInvoked { 265 t.Errorf("GetByToken not invoked") 266 } 267 if !authStore.CreateOrUpdateTokenInvoked { 268 t.Errorf("CreateOrUpdate not invoked") 269 } 270 if authStore.DeleteTokenInvoked { 271 t.Errorf("Delete should not be invoked") 272 } 273 } 274 authStore.GetTokenInvoked = false 275 authStore.CreateOrUpdateTokenInvoked = false 276 authStore.DeleteTokenInvoked = false 277 }) 278 } 279 } 280 281 func TestAuthResource_logout(t *testing.T) { 282 authStore.GetTokenFn = func(token string) (*jwt.Token, error) { 283 var err error 284 t := jwt.Token{ 285 Expiry: time.Now().Add(1 * time.Minute), 286 } 287 288 switch token { 289 case "notfound": 290 err = errors.New("sql no rows") 291 } 292 return &t, err 293 } 294 authStore.DeleteTokenFn = func(a *jwt.Token) error { 295 return nil 296 } 297 298 tests := []struct { 299 name string 300 token string 301 exp time.Duration 302 status int 303 err error 304 }{ 305 {"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired}, 306 {"expired", "valid", -1, http.StatusOK, nil}, 307 {"valid", "valid", 1, http.StatusOK, nil}, 308 } 309 310 for _, tc := range tests { 311 t.Run(tc.name, func(t *testing.T) { 312 refreshJWT := genRefreshJWT(jwt.RefreshClaims{ 313 Token: tc.token, 314 CommonClaims: jwt.CommonClaims{ 315 ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(), 316 }, 317 }) 318 319 res, body := testRequest(t, ts, "POST", "/logout", nil, refreshJWT) 320 if res.StatusCode != tc.status { 321 t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status) 322 } 323 if tc.err != nil && !strings.Contains(body, tc.err.Error()) { 324 t.Errorf("got: %x, expected error to contain %s", body, tc.err.Error()) 325 } 326 if tc.status == http.StatusUnauthorized && authStore.DeleteTokenInvoked { 327 t.Errorf("Delete invoked for status %d", tc.status) 328 } 329 authStore.DeleteTokenInvoked = false 330 }) 331 } 332 } 333 334 func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader, token string) (*http.Response, string) { 335 req, err := http.NewRequest(method, ts.URL+path, body) 336 if err != nil { 337 t.Fatal(err) 338 return nil, "" 339 } 340 req.Header.Set("Content-Type", "application/json") 341 if token != "" { 342 req.Header.Set("Authorization", fmt.Sprintf("BEARER %s", token)) 343 } 344 345 resp, err := http.DefaultClient.Do(req) 346 if err != nil { 347 t.Fatal(err) 348 return nil, "" 349 } 350 defer resp.Body.Close() 351 352 respBody, err := ioutil.ReadAll(resp.Body) 353 if err != nil { 354 t.Fatal(err) 355 return nil, "" 356 } 357 358 return resp, string(respBody) 359 } 360 361 // func genJWT(c jwt.AppClaims) string { 362 // claims, _ := jwt.ParseStructToMap(c) 363 364 // _, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(claims) 365 // return tokenString 366 // } 367 368 func genRefreshJWT(c jwt.RefreshClaims) string { 369 claims, _ := jwt.ParseStructToMap(c) 370 371 _, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(claims) 372 return tokenString 373 } 374 375 func encode(v interface{}) (*bytes.Buffer, error) { 376 data := new(bytes.Buffer) 377 err := json.NewEncoder(data).Encode(v) 378 return data, err 379 }