github.com/grafviktor/keep-my-secret@v0.9.10-0.20230908165355-19f35cce90e5/internal/api/auth/auth_test.go (about)

     1  package auth
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/golang-jwt/jwt/v4"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/grafviktor/keep-my-secret/internal/config"
    13  )
    14  
    15  const cookieName = "refresh_cookie"
    16  
    17  func TestGetRefreshCookie(t *testing.T) {
    18  	// Initialize my Auth object with relevant settings
    19  	auth := &Auth{
    20  		CookieName:    cookieName,
    21  		CookiePath:    "/",
    22  		RefreshExpiry: 24 * time.Hour,
    23  	}
    24  
    25  	// Call the GetRefreshCookie function
    26  	refreshToken := "your_refresh_token"
    27  	cookie := auth.GetRefreshCookie(refreshToken)
    28  
    29  	// Assert that the cookie has the expected values
    30  	if cookie.Name != cookieName {
    31  		t.Errorf("Expected cookie name '%s', got '%s'", cookieName, cookie.Name)
    32  	}
    33  
    34  	if cookie.Path != "/" {
    35  		t.Errorf("Expected cookie path '/', got '%s'", cookie.Path)
    36  	}
    37  
    38  	if cookie.Value != refreshToken {
    39  		t.Errorf("Expected cookie value '%s', got '%s'", refreshToken, cookie.Value)
    40  	}
    41  
    42  	if cookie.MaxAge != int(24*60*60) {
    43  		t.Errorf("Expected cookie MaxAge '%d', got '%d'", int(24*60*60), cookie.MaxAge)
    44  	}
    45  
    46  	if cookie.SameSite != siteMode {
    47  		t.Errorf("Expected SameSite mode '%v', got '%v'", siteMode, cookie.SameSite)
    48  	}
    49  
    50  	if !cookie.HttpOnly {
    51  		t.Error("Expected HttpOnly to be true")
    52  	}
    53  
    54  	if !cookie.Secure {
    55  		t.Error("Expected Secure to be true")
    56  	}
    57  }
    58  
    59  func TestGetExpiredRefreshCookie(t *testing.T) {
    60  	// Initialize my Auth object with relevant settings
    61  	auth := &Auth{
    62  		CookieName: "refresh_cookie",
    63  		CookiePath: "/",
    64  	}
    65  
    66  	// Call the GetExpiredRefreshCookie function
    67  	cookie := auth.GetExpiredRefreshCookie()
    68  
    69  	// Assert that the cookie has the expected values
    70  	if cookie.Name != "refresh_cookie" {
    71  		t.Errorf("Expected cookie name 'refresh_cookie', got '%s'", cookie.Name)
    72  	}
    73  
    74  	if cookie.Path != "/" {
    75  		t.Errorf("Expected cookie path '/', got '%s'", cookie.Path)
    76  	}
    77  
    78  	if cookie.Value != "" {
    79  		t.Errorf("Expected empty cookie value, got '%s'", cookie.Value)
    80  	}
    81  
    82  	if !cookie.Expires.Equal(time.Unix(0, 0)) {
    83  		t.Errorf("Expected cookie expiry '1970-01-01 00:00:00 UTC', got '%s'", cookie.Expires)
    84  	}
    85  
    86  	if cookie.MaxAge != -1 {
    87  		t.Errorf("Expected MaxAge '-1', got '%d'", cookie.MaxAge)
    88  	}
    89  
    90  	if cookie.SameSite != siteMode {
    91  		t.Errorf("Expected SameSite mode '%v', got '%v'", siteMode, cookie.SameSite)
    92  	}
    93  
    94  	if !cookie.HttpOnly {
    95  		t.Error("Expected HttpOnly to be true")
    96  	}
    97  
    98  	if !cookie.Secure {
    99  		t.Error("Expected Secure to be true")
   100  	}
   101  }
   102  
   103  func TestVerifyAuthHeader(t *testing.T) {
   104  	envConfig := config.EnvConfig{
   105  		Secret: "romeo romeo whiskey",
   106  		Domain: "localhost",
   107  	}
   108  	ac := config.New(envConfig)
   109  
   110  	// Create a test HTTP request with an Authorization header
   111  	req, err := http.NewRequest("GET", "/", nil)
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  
   116  	jwtUser := JWTUser{ID: "user@localhost"}
   117  	newAuth := New(ac)
   118  	tokenPair, err := newAuth.GenerateTokenPair(&jwtUser)
   119  	if err != nil {
   120  		t.Errorf("Unexpected error: %v", err)
   121  	}
   122  	tokenString := tokenPair.AccessToken
   123  	// tokenString := GenerateTokenPair(Secret, JWTIssuer, time.Now().Add(1*time.Hour))
   124  	req.Header.Add("Authorization", "Bearer "+tokenString)
   125  
   126  	// Create a test ResponseRecorder
   127  	rr := httptest.NewRecorder()
   128  
   129  	// Call the VerifyAuthHeader function
   130  	verifier := JWTVerifier{}
   131  	token, claims, err := verifier.VerifyAuthHeader(ac, rr, req)
   132  	// Check for expected results
   133  	if err != nil {
   134  		t.Errorf("Unexpected error: %v", err)
   135  	}
   136  
   137  	if token != tokenString {
   138  		t.Errorf("Expected token '%s', got '%s'", tokenString, token)
   139  	}
   140  
   141  	require.NotNil(t, claims)
   142  
   143  	// if claims == nil {
   144  	// 	t.Error("Expected non-nil claims")
   145  	// }
   146  
   147  	if claims.Issuer != ac.JWTIssuer {
   148  		t.Errorf("Expected issuer '%s', got '%s'", ac.JWTIssuer, claims.Issuer)
   149  	}
   150  
   151  	// Test cases with invalid headers
   152  	testCases := []struct {
   153  		headerValue string
   154  		expectedErr string
   155  	}{
   156  		{"", "no auth header"},
   157  		{"InvalidHeader", "invalid auth header"},
   158  		{"Bearer InvalidToken", "token contains an invalid number of segments"},
   159  		{"Bearer " + generateJWTToken("WrongSecret", ac.JWTIssuer, time.Now().Add(1*time.Hour)), "signature is invalid"},
   160  		{"Bearer " + generateJWTToken(ac.Secret, ac.JWTIssuer, time.Now().Add(-1*time.Hour)), "expired token"},
   161  		{"Bearer " + generateJWTToken(ac.Secret, "WrongIssuer", time.Now().Add(1*time.Hour)), "invalid issuer"},
   162  	}
   163  
   164  	for _, tc := range testCases {
   165  		req.Header.Set("Authorization", tc.headerValue)
   166  		rr = httptest.NewRecorder()
   167  		_, _, err := verifier.VerifyAuthHeader(ac, rr, req)
   168  		if err == nil || err.Error() != tc.expectedErr {
   169  			t.Errorf("Expected error: '%s', got: '%v'", tc.expectedErr, err)
   170  		}
   171  	}
   172  }
   173  
   174  func generateJWTToken(secretKey, issuer string, expiration time.Time) string {
   175  	claims := jwt.MapClaims{
   176  		"iss": issuer,
   177  		"exp": expiration.Unix(),
   178  	}
   179  
   180  	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
   181  	tokenString, _ := token.SignedString([]byte(secretKey))
   182  	return tokenString
   183  }