github.com/dhax/go-base@v0.0.0-20231004214136-8be7e5c1972b/auth/pwdless/api_test.go (about)

     1  package pwdless
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"os"
    13  	"strings"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/go-chi/chi/v5"
    18  	"github.com/spf13/viper"
    19  
    20  	"github.com/dhax/go-base/auth/jwt"
    21  	"github.com/dhax/go-base/email"
    22  	"github.com/dhax/go-base/logging"
    23  )
    24  
    25  var (
    26  	auth      *Resource
    27  	authStore MockAuthStore
    28  	mailer    email.MockMailer
    29  	ts        *httptest.Server
    30  )
    31  
    32  func TestMain(m *testing.M) {
    33  	viper.SetDefault("auth_login_token_length", 8)
    34  	viper.SetDefault("auth_login_token_expiry", "11m")
    35  	viper.SetDefault("auth_jwt_secret", "random")
    36  	viper.SetDefault("log_level", "error")
    37  
    38  	var err error
    39  	auth, err = NewResource(&authStore, &mailer)
    40  	if err != nil {
    41  		fmt.Println(err)
    42  		os.Exit(1)
    43  	}
    44  
    45  	r := chi.NewRouter()
    46  	r.Use(logging.NewStructuredLogger(logging.NewLogger()))
    47  	r.Mount("/", auth.Router())
    48  
    49  	ts = httptest.NewServer(r)
    50  
    51  	code := m.Run()
    52  	ts.Close()
    53  	os.Exit(code)
    54  }
    55  
    56  func TestAuthResource_login(t *testing.T) {
    57  	authStore.GetAccountByEmailFn = func(email string) (*Account, error) {
    58  		var err error
    59  		a := Account{
    60  			ID:    1,
    61  			Email: email,
    62  			Name:  "test",
    63  		}
    64  
    65  		switch email {
    66  		case "not@exists.io":
    67  			err = errors.New("sql no row")
    68  		case "disabled@account.io":
    69  			a.Active = false
    70  		case "valid@account.io":
    71  			a.Active = true
    72  		}
    73  		return &a, err
    74  	}
    75  
    76  	mailer.LoginTokenFn = func(n, e string, c email.ContentLoginToken) error {
    77  		return nil
    78  	}
    79  
    80  	tests := []struct {
    81  		name   string
    82  		email  string
    83  		status int
    84  		err    error
    85  	}{
    86  		{"missing", "", http.StatusUnauthorized, ErrInvalidLogin},
    87  		{"inexistent", "not@exists.io", http.StatusUnauthorized, ErrUnknownLogin},
    88  		{"disabled", "disabled@account.io", http.StatusUnauthorized, ErrLoginDisabled},
    89  		{"valid", "valid@account.io", http.StatusOK, nil},
    90  	}
    91  
    92  	for _, tc := range tests {
    93  		t.Run(tc.name, func(t *testing.T) {
    94  			req, err := encode(&loginRequest{Email: tc.email})
    95  			if err != nil {
    96  				t.Fatal("failed to encode request body")
    97  			}
    98  			res, body := testRequest(t, ts, "POST", "/login", req, "")
    99  
   100  			if res.StatusCode != tc.status {
   101  				t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
   102  			}
   103  			if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
   104  				t.Errorf(" got: %s, expected to contain: %s", body, tc.err.Error())
   105  			}
   106  			if tc.err == ErrInvalidLogin && authStore.GetAccountByEmailInvoked {
   107  				t.Error("GetByLoginToken invoked for invalid email")
   108  			}
   109  			if tc.err == nil && !mailer.LoginTokenInvoked {
   110  				t.Error("emailService.LoginToken not invoked")
   111  			}
   112  			authStore.GetAccountByEmailInvoked = false
   113  			mailer.LoginTokenInvoked = false
   114  		})
   115  	}
   116  }
   117  
   118  func TestAuthResource_token(t *testing.T) {
   119  	authStore.GetAccountFn = func(id int) (*Account, error) {
   120  		var err error
   121  		a := Account{
   122  			ID:     id,
   123  			Active: true,
   124  			Name:   "test",
   125  		}
   126  		switch id {
   127  		case 2:
   128  			a.Active = false
   129  		case 3:
   130  			// unmodified
   131  		default:
   132  			err = errors.New("sql no rows")
   133  		}
   134  		return &a, err
   135  	}
   136  	authStore.UpdateAccountFn = func(a *Account) error {
   137  		a.LastLogin = time.Now()
   138  		return nil
   139  	}
   140  	authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error {
   141  		return nil
   142  	}
   143  
   144  	tests := []struct {
   145  		name   string
   146  		token  string
   147  		id     int
   148  		status int
   149  		err    error
   150  	}{
   151  		{"invalid", "#ยง$%", 0, http.StatusUnauthorized, ErrLoginToken},
   152  		{"expired", "12345678", 0, http.StatusUnauthorized, ErrLoginToken},
   153  		{"deleted_account", "", 1, http.StatusUnauthorized, ErrUnknownLogin},
   154  		{"disabled", "", 2, http.StatusUnauthorized, ErrLoginDisabled},
   155  		{"valid", "", 3, http.StatusOK, nil},
   156  	}
   157  
   158  	for _, tc := range tests {
   159  		t.Run(tc.name, func(t *testing.T) {
   160  			token := auth.LoginAuth.CreateToken(tc.id)
   161  			if tc.token != "" {
   162  				token.Token = tc.token
   163  			}
   164  
   165  			req, err := encode(tokenRequest{Token: token.Token})
   166  			if err != nil {
   167  				t.Fatal("failed to encode request body")
   168  			}
   169  			res, body := testRequest(t, ts, "POST", "/token", req, "")
   170  
   171  			if res.StatusCode != tc.status {
   172  				t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
   173  			}
   174  			if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
   175  				t.Errorf("got: %s, expected to contain: %s", body, tc.err.Error())
   176  			}
   177  			if tc.err == ErrLoginToken && authStore.CreateOrUpdateTokenInvoked {
   178  				t.Errorf("CreateOrUpdate invoked despite error %s", tc.err.Error())
   179  			}
   180  			if tc.err == nil && !authStore.CreateOrUpdateTokenInvoked {
   181  				t.Error("CreateOrUpdate not invoked")
   182  			}
   183  			authStore.CreateOrUpdateTokenInvoked = false
   184  		})
   185  	}
   186  }
   187  
   188  func TestAuthResource_refresh(t *testing.T) {
   189  	authStore.GetAccountFn = func(id int) (*Account, error) {
   190  		a := Account{
   191  			Active: true,
   192  			Name:   "Test",
   193  		}
   194  		switch id {
   195  		case 999:
   196  			a.Active = false
   197  		}
   198  		return &a, nil
   199  	}
   200  	authStore.UpdateAccountFn = func(a *Account) error {
   201  		a.LastLogin = time.Now()
   202  		return nil
   203  	}
   204  
   205  	authStore.GetTokenFn = func(token string) (*jwt.Token, error) {
   206  		var err error
   207  		var t jwt.Token
   208  		t.Expiry = time.Now().Add(1 * time.Minute)
   209  
   210  		switch token {
   211  		case "not_found":
   212  			err = errors.New("sql no rows")
   213  		case "expired":
   214  			t.Expiry = time.Now().Add(-1 * time.Minute)
   215  		case "disabled":
   216  			t.AccountID = 999
   217  		}
   218  		return &t, err
   219  	}
   220  	authStore.CreateOrUpdateTokenFn = func(a *jwt.Token) error {
   221  		return nil
   222  	}
   223  	authStore.DeleteTokenFn = func(t *jwt.Token) error {
   224  		return nil
   225  	}
   226  
   227  	tests := []struct {
   228  		name   string
   229  		token  string
   230  		exp    time.Duration
   231  		status int
   232  		err    error
   233  	}{
   234  		{"not_found", "not_found", 1, http.StatusUnauthorized, jwt.ErrTokenExpired},
   235  		{"expired", "expired", -1, http.StatusUnauthorized, jwt.ErrTokenExpired},
   236  		{"disabled", "disabled", 1, http.StatusUnauthorized, ErrLoginDisabled},
   237  		{"valid", "valid", 1, http.StatusOK, nil},
   238  	}
   239  
   240  	for _, tc := range tests {
   241  		t.Run(tc.name, func(t *testing.T) {
   242  			// refreshJWT, err := auth.TokenAuth.CreateRefreshJWT(jwt.RefreshClaims{Token: tc.token})
   243  			// if err != nil {
   244  			// 	t.Errorf("failed to create refresh jwt")
   245  			// }
   246  			refreshJWT := genRefreshJWT(jwt.RefreshClaims{
   247  				Token: tc.token,
   248  				CommonClaims: jwt.CommonClaims{
   249  					ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(),
   250  				},
   251  			})
   252  
   253  			res, body := testRequest(t, ts, "POST", "/refresh", nil, refreshJWT)
   254  			if res.StatusCode != tc.status {
   255  				t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
   256  			}
   257  			if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
   258  				t.Errorf("got: %s, expected error to contain: %s", body, tc.err.Error())
   259  			}
   260  			if tc.status == http.StatusUnauthorized && authStore.CreateOrUpdateTokenInvoked {
   261  				t.Errorf("CreateOrUpdate invoked for status %d", tc.status)
   262  			}
   263  			if tc.status == http.StatusOK {
   264  				if !authStore.GetTokenInvoked {
   265  					t.Errorf("GetByToken not invoked")
   266  				}
   267  				if !authStore.CreateOrUpdateTokenInvoked {
   268  					t.Errorf("CreateOrUpdate not invoked")
   269  				}
   270  				if authStore.DeleteTokenInvoked {
   271  					t.Errorf("Delete should not be invoked")
   272  				}
   273  			}
   274  			authStore.GetTokenInvoked = false
   275  			authStore.CreateOrUpdateTokenInvoked = false
   276  			authStore.DeleteTokenInvoked = false
   277  		})
   278  	}
   279  }
   280  
   281  func TestAuthResource_logout(t *testing.T) {
   282  	authStore.GetTokenFn = func(token string) (*jwt.Token, error) {
   283  		var err error
   284  		t := jwt.Token{
   285  			Expiry: time.Now().Add(1 * time.Minute),
   286  		}
   287  
   288  		switch token {
   289  		case "notfound":
   290  			err = errors.New("sql no rows")
   291  		}
   292  		return &t, err
   293  	}
   294  	authStore.DeleteTokenFn = func(a *jwt.Token) error {
   295  		return nil
   296  	}
   297  
   298  	tests := []struct {
   299  		name   string
   300  		token  string
   301  		exp    time.Duration
   302  		status int
   303  		err    error
   304  	}{
   305  		{"notfound", "notfound", 1, http.StatusUnauthorized, jwt.ErrTokenExpired},
   306  		{"expired", "valid", -1, http.StatusOK, nil},
   307  		{"valid", "valid", 1, http.StatusOK, nil},
   308  	}
   309  
   310  	for _, tc := range tests {
   311  		t.Run(tc.name, func(t *testing.T) {
   312  			refreshJWT := genRefreshJWT(jwt.RefreshClaims{
   313  				Token: tc.token,
   314  				CommonClaims: jwt.CommonClaims{
   315  					ExpiresAt: time.Now().Add(time.Minute * tc.exp).UnixNano(),
   316  				},
   317  			})
   318  
   319  			res, body := testRequest(t, ts, "POST", "/logout", nil, refreshJWT)
   320  			if res.StatusCode != tc.status {
   321  				t.Errorf("got http status %d, want: %d", res.StatusCode, tc.status)
   322  			}
   323  			if tc.err != nil && !strings.Contains(body, tc.err.Error()) {
   324  				t.Errorf("got: %x, expected error to contain %s", body, tc.err.Error())
   325  			}
   326  			if tc.status == http.StatusUnauthorized && authStore.DeleteTokenInvoked {
   327  				t.Errorf("Delete invoked for status %d", tc.status)
   328  			}
   329  			authStore.DeleteTokenInvoked = false
   330  		})
   331  	}
   332  }
   333  
   334  func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader, token string) (*http.Response, string) {
   335  	req, err := http.NewRequest(method, ts.URL+path, body)
   336  	if err != nil {
   337  		t.Fatal(err)
   338  		return nil, ""
   339  	}
   340  	req.Header.Set("Content-Type", "application/json")
   341  	if token != "" {
   342  		req.Header.Set("Authorization", fmt.Sprintf("BEARER %s", token))
   343  	}
   344  
   345  	resp, err := http.DefaultClient.Do(req)
   346  	if err != nil {
   347  		t.Fatal(err)
   348  		return nil, ""
   349  	}
   350  	defer resp.Body.Close()
   351  
   352  	respBody, err := ioutil.ReadAll(resp.Body)
   353  	if err != nil {
   354  		t.Fatal(err)
   355  		return nil, ""
   356  	}
   357  
   358  	return resp, string(respBody)
   359  }
   360  
   361  // func genJWT(c jwt.AppClaims) string {
   362  // 	claims, _ := jwt.ParseStructToMap(c)
   363  
   364  // 	_, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(claims)
   365  // 	return tokenString
   366  // }
   367  
   368  func genRefreshJWT(c jwt.RefreshClaims) string {
   369  	claims, _ := jwt.ParseStructToMap(c)
   370  
   371  	_, tokenString, _ := auth.TokenAuth.JwtAuth.Encode(claims)
   372  	return tokenString
   373  }
   374  
   375  func encode(v interface{}) (*bytes.Buffer, error) {
   376  	data := new(bytes.Buffer)
   377  	err := json.NewEncoder(data).Encode(v)
   378  	return data, err
   379  }