github.com/grafviktor/keep-my-secret@v0.9.10-0.20230908165355-19f35cce90e5/internal/api/auth/auth_test.go (about) 1 package auth 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 "time" 8 9 "github.com/golang-jwt/jwt/v4" 10 "github.com/stretchr/testify/require" 11 12 "github.com/grafviktor/keep-my-secret/internal/config" 13 ) 14 15 const cookieName = "refresh_cookie" 16 17 func TestGetRefreshCookie(t *testing.T) { 18 // Initialize my Auth object with relevant settings 19 auth := &Auth{ 20 CookieName: cookieName, 21 CookiePath: "/", 22 RefreshExpiry: 24 * time.Hour, 23 } 24 25 // Call the GetRefreshCookie function 26 refreshToken := "your_refresh_token" 27 cookie := auth.GetRefreshCookie(refreshToken) 28 29 // Assert that the cookie has the expected values 30 if cookie.Name != cookieName { 31 t.Errorf("Expected cookie name '%s', got '%s'", cookieName, cookie.Name) 32 } 33 34 if cookie.Path != "/" { 35 t.Errorf("Expected cookie path '/', got '%s'", cookie.Path) 36 } 37 38 if cookie.Value != refreshToken { 39 t.Errorf("Expected cookie value '%s', got '%s'", refreshToken, cookie.Value) 40 } 41 42 if cookie.MaxAge != int(24*60*60) { 43 t.Errorf("Expected cookie MaxAge '%d', got '%d'", int(24*60*60), cookie.MaxAge) 44 } 45 46 if cookie.SameSite != siteMode { 47 t.Errorf("Expected SameSite mode '%v', got '%v'", siteMode, cookie.SameSite) 48 } 49 50 if !cookie.HttpOnly { 51 t.Error("Expected HttpOnly to be true") 52 } 53 54 if !cookie.Secure { 55 t.Error("Expected Secure to be true") 56 } 57 } 58 59 func TestGetExpiredRefreshCookie(t *testing.T) { 60 // Initialize my Auth object with relevant settings 61 auth := &Auth{ 62 CookieName: "refresh_cookie", 63 CookiePath: "/", 64 } 65 66 // Call the GetExpiredRefreshCookie function 67 cookie := auth.GetExpiredRefreshCookie() 68 69 // Assert that the cookie has the expected values 70 if cookie.Name != "refresh_cookie" { 71 t.Errorf("Expected cookie name 'refresh_cookie', got '%s'", cookie.Name) 72 } 73 74 if cookie.Path != "/" { 75 t.Errorf("Expected cookie path '/', got '%s'", cookie.Path) 76 } 77 78 if cookie.Value != "" { 79 t.Errorf("Expected empty cookie value, got '%s'", cookie.Value) 80 } 81 82 if !cookie.Expires.Equal(time.Unix(0, 0)) { 83 t.Errorf("Expected cookie expiry '1970-01-01 00:00:00 UTC', got '%s'", cookie.Expires) 84 } 85 86 if cookie.MaxAge != -1 { 87 t.Errorf("Expected MaxAge '-1', got '%d'", cookie.MaxAge) 88 } 89 90 if cookie.SameSite != siteMode { 91 t.Errorf("Expected SameSite mode '%v', got '%v'", siteMode, cookie.SameSite) 92 } 93 94 if !cookie.HttpOnly { 95 t.Error("Expected HttpOnly to be true") 96 } 97 98 if !cookie.Secure { 99 t.Error("Expected Secure to be true") 100 } 101 } 102 103 func TestVerifyAuthHeader(t *testing.T) { 104 envConfig := config.EnvConfig{ 105 Secret: "romeo romeo whiskey", 106 Domain: "localhost", 107 } 108 ac := config.New(envConfig) 109 110 // Create a test HTTP request with an Authorization header 111 req, err := http.NewRequest("GET", "/", nil) 112 if err != nil { 113 t.Fatal(err) 114 } 115 116 jwtUser := JWTUser{ID: "user@localhost"} 117 newAuth := New(ac) 118 tokenPair, err := newAuth.GenerateTokenPair(&jwtUser) 119 if err != nil { 120 t.Errorf("Unexpected error: %v", err) 121 } 122 tokenString := tokenPair.AccessToken 123 // tokenString := GenerateTokenPair(Secret, JWTIssuer, time.Now().Add(1*time.Hour)) 124 req.Header.Add("Authorization", "Bearer "+tokenString) 125 126 // Create a test ResponseRecorder 127 rr := httptest.NewRecorder() 128 129 // Call the VerifyAuthHeader function 130 verifier := JWTVerifier{} 131 token, claims, err := verifier.VerifyAuthHeader(ac, rr, req) 132 // Check for expected results 133 if err != nil { 134 t.Errorf("Unexpected error: %v", err) 135 } 136 137 if token != tokenString { 138 t.Errorf("Expected token '%s', got '%s'", tokenString, token) 139 } 140 141 require.NotNil(t, claims) 142 143 // if claims == nil { 144 // t.Error("Expected non-nil claims") 145 // } 146 147 if claims.Issuer != ac.JWTIssuer { 148 t.Errorf("Expected issuer '%s', got '%s'", ac.JWTIssuer, claims.Issuer) 149 } 150 151 // Test cases with invalid headers 152 testCases := []struct { 153 headerValue string 154 expectedErr string 155 }{ 156 {"", "no auth header"}, 157 {"InvalidHeader", "invalid auth header"}, 158 {"Bearer InvalidToken", "token contains an invalid number of segments"}, 159 {"Bearer " + generateJWTToken("WrongSecret", ac.JWTIssuer, time.Now().Add(1*time.Hour)), "signature is invalid"}, 160 {"Bearer " + generateJWTToken(ac.Secret, ac.JWTIssuer, time.Now().Add(-1*time.Hour)), "expired token"}, 161 {"Bearer " + generateJWTToken(ac.Secret, "WrongIssuer", time.Now().Add(1*time.Hour)), "invalid issuer"}, 162 } 163 164 for _, tc := range testCases { 165 req.Header.Set("Authorization", tc.headerValue) 166 rr = httptest.NewRecorder() 167 _, _, err := verifier.VerifyAuthHeader(ac, rr, req) 168 if err == nil || err.Error() != tc.expectedErr { 169 t.Errorf("Expected error: '%s', got: '%v'", tc.expectedErr, err) 170 } 171 } 172 } 173 174 func generateJWTToken(secretKey, issuer string, expiration time.Time) string { 175 claims := jwt.MapClaims{ 176 "iss": issuer, 177 "exp": expiration.Unix(), 178 } 179 180 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 181 tokenString, _ := token.SignedString([]byte(secretKey)) 182 return tokenString 183 }