github.com/volatiletech/authboss@v2.4.1+incompatible/otp/twofactor/sms2fa/sms_test.go (about)

     1  package sms2fa
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  
     9  	"github.com/volatiletech/authboss/otp/twofactor"
    10  	"golang.org/x/crypto/bcrypt"
    11  
    12  	"github.com/volatiletech/authboss"
    13  	"github.com/volatiletech/authboss/mocks"
    14  )
    15  
    16  type smsHolderSender string
    17  
    18  func (s *smsHolderSender) Send(ctx context.Context, number, text string) error {
    19  	*s = smsHolderSender(text)
    20  	return nil
    21  }
    22  
    23  func TestSMSSetup(t *testing.T) {
    24  	t.Parallel()
    25  
    26  	ab := authboss.New()
    27  	router := &mocks.Router{}
    28  	renderer := &mocks.Renderer{}
    29  	errHandler := &mocks.ErrorHandler{}
    30  
    31  	ab.Config.Core.Router = router
    32  	ab.Config.Core.ViewRenderer = renderer
    33  	ab.Config.Core.ErrorHandler = errHandler
    34  
    35  	sms := &SMS{Authboss: ab, Sender: new(smsHolderSender)}
    36  	if err := sms.Setup(); err != nil {
    37  		t.Fatal(err)
    38  	}
    39  
    40  	gets := []string{"/2fa/sms/setup", "/2fa/sms/confirm", "/2fa/sms/remove", "/2fa/sms/validate"}
    41  	posts := []string{"/2fa/sms/setup", "/2fa/sms/confirm", "/2fa/sms/remove", "/2fa/sms/validate"}
    42  	if err := router.HasGets(gets...); err != nil {
    43  		t.Error(err)
    44  	}
    45  	if err := router.HasPosts(posts...); err != nil {
    46  		t.Error(err)
    47  	}
    48  }
    49  
    50  type testHarness struct {
    51  	sms    *SMS
    52  	ab     *authboss.Authboss
    53  	sender *smsHolderSender
    54  
    55  	bodyReader *mocks.BodyReader
    56  	responder  *mocks.Responder
    57  	redirector *mocks.Redirector
    58  	session    *mocks.ClientStateRW
    59  	storer     *mocks.ServerStorer
    60  }
    61  
    62  func testSetup() *testHarness {
    63  	harness := &testHarness{}
    64  
    65  	harness.ab = authboss.New()
    66  	harness.bodyReader = &mocks.BodyReader{}
    67  	harness.redirector = &mocks.Redirector{}
    68  	harness.responder = &mocks.Responder{}
    69  	harness.session = mocks.NewClientRW()
    70  	harness.storer = mocks.NewServerStorer()
    71  
    72  	harness.ab.Config.Paths.AuthLoginOK = "/login/ok"
    73  
    74  	harness.ab.Config.Core.BodyReader = harness.bodyReader
    75  	harness.ab.Config.Core.Logger = mocks.Logger{}
    76  	harness.ab.Config.Core.Responder = harness.responder
    77  	harness.ab.Config.Core.Redirector = harness.redirector
    78  	harness.ab.Config.Storage.SessionState = harness.session
    79  	harness.ab.Config.Storage.Server = harness.storer
    80  
    81  	harness.sender = new(smsHolderSender)
    82  	harness.sms = &SMS{Authboss: harness.ab, Sender: harness.sender}
    83  
    84  	return harness
    85  }
    86  
    87  func (h *testHarness) loadClientState(w http.ResponseWriter, r **http.Request) {
    88  	req, err := h.ab.LoadClientState(w, *r)
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  
    93  	*r = req
    94  }
    95  
    96  func (h *testHarness) putUserInCtx(u *mocks.User, r **http.Request) {
    97  	req := (*r).WithContext(context.WithValue((*r).Context(), authboss.CTXKeyUser, u))
    98  	*r = req
    99  }
   100  
   101  func (h *testHarness) newHTTP(method string, bodyArgs ...string) (*http.Request, *authboss.ClientStateResponseWriter, *httptest.ResponseRecorder) {
   102  	r := mocks.Request(method, bodyArgs...)
   103  	wr := httptest.NewRecorder()
   104  	w := h.ab.NewResponse(wr)
   105  
   106  	return r, w, wr
   107  }
   108  
   109  func (h *testHarness) setSession(key, value string) {
   110  	h.session.ClientValues[key] = value
   111  }
   112  
   113  func TestHijackAuth(t *testing.T) {
   114  	t.Parallel()
   115  
   116  	t.Run("Handled", func(t *testing.T) {
   117  		harness := testSetup()
   118  
   119  		handled, err := harness.sms.HijackAuth(nil, nil, true)
   120  		if handled {
   121  			t.Error("should not be handled")
   122  		}
   123  		if err != nil {
   124  			t.Error(err)
   125  		}
   126  	})
   127  
   128  	t.Run("UserNoSMS", func(t *testing.T) {
   129  		harness := testSetup()
   130  
   131  		r, w, _ := harness.newHTTP("POST")
   132  		r.URL.RawQuery = "test=query"
   133  
   134  		user := &mocks.User{Email: "test@test.com"}
   135  		harness.putUserInCtx(user, &r)
   136  
   137  		harness.loadClientState(w, &r)
   138  		handled, err := harness.sms.HijackAuth(w, r, false)
   139  		if handled {
   140  			t.Error("should not be handled")
   141  		}
   142  		if err != nil {
   143  			t.Error(err)
   144  		}
   145  	})
   146  
   147  	t.Run("Ok", func(t *testing.T) {
   148  		harness := testSetup()
   149  
   150  		handled, err := harness.sms.HijackAuth(nil, nil, true)
   151  		if handled {
   152  			t.Error("should not be handled")
   153  		}
   154  		if err != nil {
   155  			t.Error(err)
   156  		}
   157  
   158  		r, w, _ := harness.newHTTP("POST")
   159  		r.URL.RawQuery = "test=query"
   160  
   161  		user := &mocks.User{Email: "test@test.com", SMSPhoneNumber: "number"}
   162  		harness.putUserInCtx(user, &r)
   163  		harness.loadClientState(w, &r)
   164  
   165  		handled, err = harness.sms.HijackAuth(w, r, false)
   166  		if !handled {
   167  			t.Error("should be handled")
   168  		}
   169  		if err != nil {
   170  			t.Error(err)
   171  		}
   172  
   173  		if len(*harness.sender) == 0 {
   174  			t.Error("a code should have been sent via sms")
   175  		}
   176  
   177  		if _, ok := harness.session.ClientValues[SessionSMSLast]; !ok {
   178  			t.Error("it should record the time it was last sent at")
   179  		}
   180  		if _, ok := harness.session.ClientValues[SessionSMSSecret]; !ok {
   181  			t.Error("there should be a code")
   182  		}
   183  
   184  		opts := harness.redirector.Options
   185  		if opts.Code != http.StatusTemporaryRedirect {
   186  			t.Error("status wrong:", opts.Code)
   187  		}
   188  
   189  		if opts.RedirectPath != "/auth/2fa/sms/validate?test=query" {
   190  			t.Error("redir path wrong:", opts.RedirectPath)
   191  		}
   192  	})
   193  }
   194  
   195  func TestSendCodeSuppression(t *testing.T) {
   196  	t.Parallel()
   197  
   198  	h := testSetup()
   199  	r, w, _ := h.newHTTP("POST")
   200  
   201  	if err := h.sms.SendCodeToUser(w, r, "pid", "phonenumber"); err != nil {
   202  		t.Error(err)
   203  	}
   204  
   205  	// Flush the session sets, reload the client state
   206  	w.WriteHeader(http.StatusOK)
   207  	h.loadClientState(w, &r)
   208  
   209  	// Send again within 10s, hopefully Go can execute that fast :D
   210  	if err := h.sms.SendCodeToUser(w, r, "pid", "phonenumber"); err == nil {
   211  		t.Error("should have errored")
   212  	} else if err != errSMSRateLimit {
   213  		t.Error("it should have blocked the second send")
   214  	}
   215  }
   216  
   217  func TestGetSetup(t *testing.T) {
   218  	t.Parallel()
   219  
   220  	h := testSetup()
   221  	r, w, _ := h.newHTTP("GET")
   222  
   223  	user := &mocks.User{Email: "test@test.com", SMSPhoneNumberSeed: "seednumber"}
   224  	h.storer.Users[user.Email] = user
   225  
   226  	h.setSession(authboss.SessionKey, user.Email)
   227  	h.setSession(SessionSMSSecret, "secret")
   228  	h.setSession(SessionSMSNumber, "number")
   229  	h.loadClientState(w, &r)
   230  
   231  	if err := h.sms.GetSetup(w, r); err != nil {
   232  		t.Error(err)
   233  	}
   234  
   235  	// Flush ClientState
   236  	w.WriteHeader(http.StatusOK)
   237  
   238  	if h.session.ClientValues[SessionSMSSecret] != "" {
   239  		t.Error("session sms secret should be cleared")
   240  	}
   241  	if h.session.ClientValues[SessionSMSNumber] != "" {
   242  		t.Error("session sms number should be cleared")
   243  	}
   244  
   245  	if h.responder.Page != PageSMSSetup {
   246  		t.Error("page wrong:", h.responder.Page)
   247  	}
   248  	if got := h.responder.Data[DataSMSPhoneNumber]; got != "seednumber" {
   249  		t.Error("data wrong:", got)
   250  	}
   251  }
   252  
   253  func TestPostSetup(t *testing.T) {
   254  	t.Parallel()
   255  
   256  	t.Run("NoPhoneNumber", func(t *testing.T) {
   257  		h := testSetup()
   258  		r, w, _ := h.newHTTP("POST")
   259  
   260  		user := &mocks.User{Email: "test@test.com"}
   261  		h.storer.Users[user.Email] = user
   262  		h.setSession(authboss.SessionKey, user.Email)
   263  		h.loadClientState(w, &r)
   264  
   265  		h.bodyReader.Return = mocks.Values{PhoneNumber: ""}
   266  
   267  		if err := h.sms.PostSetup(w, r); err != nil {
   268  			t.Error(err)
   269  		}
   270  
   271  		if h.responder.Page != PageSMSSetup {
   272  			t.Error("page wrong:", h.responder.Page)
   273  		}
   274  		validation := h.responder.Data[authboss.DataValidation].(map[string][]string)
   275  		if got := validation[FormValuePhoneNumber][0]; got != "must provide a phone number" {
   276  			t.Error("data wrong:", got)
   277  		}
   278  	})
   279  
   280  	t.Run("Ok", func(t *testing.T) {
   281  		h := testSetup()
   282  		r, w, _ := h.newHTTP("POST")
   283  
   284  		user := &mocks.User{Email: "test@test.com"}
   285  		h.storer.Users[user.Email] = user
   286  		h.setSession(authboss.SessionKey, user.Email)
   287  		h.loadClientState(w, &r)
   288  
   289  		h.bodyReader.Return = mocks.Values{PhoneNumber: "number"}
   290  
   291  		if err := h.sms.PostSetup(w, r); err != nil {
   292  			t.Error(err)
   293  		}
   294  
   295  		// Flush ClientState
   296  		w.WriteHeader(http.StatusOK)
   297  
   298  		if val := h.session.ClientValues[SessionSMSNumber]; val != "number" {
   299  			t.Error("session value wrong:", val)
   300  		}
   301  		if val := h.session.ClientValues[SessionSMSLast]; len(val) == 0 {
   302  			t.Error("session sms last should be set by send")
   303  		}
   304  
   305  		code := string(*h.sender)
   306  		if val := h.session.ClientValues[SessionSMSSecret]; val != code {
   307  			t.Error("the code should be stored in the session")
   308  		}
   309  
   310  		opts := h.redirector.Options
   311  		if opts.Code != http.StatusTemporaryRedirect {
   312  			t.Error("code was wrong:", opts.Code)
   313  		}
   314  		if opts.RedirectPath != "/auth/2fa/sms/confirm" {
   315  			t.Error("redirect path was wrong:", opts.RedirectPath)
   316  		}
   317  	})
   318  }
   319  
   320  func TestValidatorGet(t *testing.T) {
   321  	t.Parallel()
   322  
   323  	h := testSetup()
   324  	validator := &SMSValidator{SMS: h.sms, Page: PageSMSConfirm}
   325  
   326  	r, w, _ := h.newHTTP("GET")
   327  	if err := validator.Get(w, r); err != nil {
   328  		t.Fatal(err)
   329  	}
   330  
   331  	if h.responder.Page != PageSMSConfirm {
   332  		t.Error("page wrong:", h.responder.Page)
   333  	}
   334  }
   335  
   336  func TestValidatorPostSend(t *testing.T) {
   337  	t.Parallel()
   338  
   339  	h := testSetup()
   340  	validator := &SMSValidator{SMS: h.sms, Page: PageSMSValidate}
   341  
   342  	r, w, _ := h.newHTTP("POST")
   343  
   344  	user := &mocks.User{Email: "test@test.com", SMSPhoneNumber: "number"}
   345  	h.storer.Users[user.Email] = user
   346  	h.setSession(authboss.SessionKey, user.Email)
   347  	h.loadClientState(w, &r)
   348  	h.bodyReader.Return = mocks.Values{}
   349  
   350  	if err := validator.Post(w, r); err != nil {
   351  		t.Fatal(err)
   352  	}
   353  
   354  	if code := string(*h.sender); len(code) == 0 {
   355  		t.Error("should have sent a code")
   356  	}
   357  
   358  	*h.sender = ""
   359  
   360  	// When action is confirm, it retrieves the phone number from
   361  	// the session, not the user.
   362  	validator.Page = PageSMSConfirm
   363  	user.SMSPhoneNumber = ""
   364  	h.setSession(SessionSMSNumber, "number")
   365  	h.loadClientState(w, &r)
   366  
   367  	if err := validator.Post(w, r); err != nil {
   368  		t.Fatal(err)
   369  	}
   370  
   371  	if code := string(*h.sender); len(code) == 0 {
   372  		t.Error("should have sent a code")
   373  	}
   374  }
   375  
   376  func TestValidatorPostOk(t *testing.T) {
   377  	t.Parallel()
   378  
   379  	t.Run("OkConfirm", func(t *testing.T) {
   380  		h := testSetup()
   381  		r, w, _ := h.newHTTP("POST")
   382  		v := &SMSValidator{SMS: h.sms, Page: PageSMSConfirm}
   383  
   384  		user := &mocks.User{Email: "test@test.com"}
   385  		h.storer.Users[user.Email] = user
   386  		h.setSession(authboss.SessionKey, user.Email)
   387  
   388  		code := "code"
   389  		h.setSession(SessionSMSSecret, code)
   390  		h.setSession(SessionSMSNumber, "number")
   391  		h.bodyReader.Return = mocks.Values{Code: code}
   392  
   393  		h.loadClientState(w, &r)
   394  
   395  		if err := v.Post(w, r); err != nil {
   396  			t.Fatal(err)
   397  		}
   398  
   399  		// Flush client state
   400  		w.WriteHeader(http.StatusOK)
   401  
   402  		if h.responder.Page != PageSMSConfirmSuccess {
   403  			t.Error("page wrong:", h.responder.Page)
   404  		}
   405  		if got := h.responder.Data[twofactor.DataRecoveryCodes].([]string); len(got) == 0 {
   406  			t.Error("recovery codes should have been returned")
   407  		}
   408  
   409  		if h.session.ClientValues[SessionSMSNumber] != "" {
   410  			t.Error("session sms number should be cleared")
   411  		}
   412  		if h.session.ClientValues[SessionSMSSecret] != "" {
   413  			t.Error("session sms secret should be cleared")
   414  		}
   415  
   416  		if got := user.GetSMSPhoneNumber(); got != "number" {
   417  			t.Error("sms phone number was wrong:", got)
   418  		}
   419  		if len(user.GetRecoveryCodes()) == 0 {
   420  			t.Error("recovery codes should have been saved")
   421  		}
   422  	})
   423  
   424  	t.Run("OkRemoveWithRecovery", func(t *testing.T) {
   425  		h := testSetup()
   426  		r, w, _ := h.newHTTP("POST")
   427  		v := &SMSValidator{SMS: h.sms, Page: PageSMSRemove}
   428  
   429  		user := &mocks.User{Email: "test@test.com", SMSPhoneNumber: "number"}
   430  		h.storer.Users[user.Email] = user
   431  		h.setSession(authboss.SessionKey, user.Email)
   432  
   433  		codes, err := twofactor.GenerateRecoveryCodes()
   434  		if err != nil {
   435  			t.Fatal(err)
   436  		}
   437  		b, err := bcrypt.GenerateFromPassword([]byte(codes[0]), bcrypt.DefaultCost)
   438  		if err != nil {
   439  			t.Fatal(err)
   440  		}
   441  		user.RecoveryCodes = string(b)
   442  
   443  		h.setSession(SessionSMSSecret, "code-user-never-got")
   444  		h.bodyReader.Return = mocks.Values{Recovery: codes[0]}
   445  
   446  		h.loadClientState(w, &r)
   447  
   448  		if err := v.Post(w, r); err != nil {
   449  			t.Fatal(err)
   450  		}
   451  
   452  		// Flush client state
   453  		w.WriteHeader(http.StatusOK)
   454  
   455  		if h.responder.Page != PageSMSRemoveSuccess {
   456  			t.Error("page wrong:", h.responder.Page)
   457  		}
   458  
   459  		if h.session.ClientValues[authboss.Session2FA] != "" {
   460  			t.Error("session 2fa should be cleared")
   461  		}
   462  
   463  		if len(user.GetSMSPhoneNumber()) != 0 {
   464  			t.Error("sms phone number should be cleared")
   465  		}
   466  		if len(user.GetRecoveryCodes()) != 0 {
   467  			t.Error("last recovery code should have been used")
   468  		}
   469  	})
   470  
   471  	t.Run("OkValidateWithCode", func(t *testing.T) {
   472  		h := testSetup()
   473  		r, w, _ := h.newHTTP("POST")
   474  		v := &SMSValidator{SMS: h.sms, Page: PageSMSValidate}
   475  
   476  		user := &mocks.User{Email: "test@test.com", SMSPhoneNumber: "number"}
   477  		h.storer.Users[user.Email] = user
   478  		h.setSession(authboss.SessionKey, user.Email)
   479  
   480  		codes, err := twofactor.GenerateRecoveryCodes()
   481  		if err != nil {
   482  			t.Fatal(err)
   483  		}
   484  		b, err := bcrypt.GenerateFromPassword([]byte(codes[0]), bcrypt.DefaultCost)
   485  		if err != nil {
   486  			t.Fatal(err)
   487  		}
   488  		user.RecoveryCodes = string(b)
   489  
   490  		h.setSession(SessionSMSSecret, "code-user-never-got")
   491  		h.bodyReader.Return = mocks.Values{Recovery: codes[0]}
   492  
   493  		h.loadClientState(w, &r)
   494  
   495  		if err := v.Post(w, r); err != nil {
   496  			t.Fatal(err)
   497  		}
   498  
   499  		// Flush client state
   500  		w.WriteHeader(http.StatusOK)
   501  
   502  		opts := h.redirector.Options
   503  		if opts.Code != http.StatusTemporaryRedirect {
   504  			t.Error("code was wrong:", opts.Code)
   505  		}
   506  		if opts.RedirectPath != v.Paths.AuthLoginOK {
   507  			t.Error("path was wrong:", opts.RedirectPath)
   508  		}
   509  		if !opts.FollowRedirParam {
   510  			t.Error("redir param is not set")
   511  		}
   512  		if opts.Success != "Successfully Authenticated" {
   513  			t.Error("should have had a success message")
   514  		}
   515  
   516  		if pid := h.session.ClientValues[authboss.SessionKey]; pid != user.Email {
   517  			t.Error("session pid should be set:", pid)
   518  		}
   519  		if twofa := h.session.ClientValues[authboss.Session2FA]; twofa != "sms" {
   520  			t.Error("session 2fa should be sms:", twofa)
   521  		}
   522  
   523  		cleared := []string{SessionSMSSecret, SessionSMSPendingPID, authboss.SessionHalfAuthKey}
   524  		for _, c := range cleared {
   525  			if _, ok := h.session.ClientValues[c]; ok {
   526  				t.Error(c, "was not cleared")
   527  			}
   528  		}
   529  	})
   530  
   531  	t.Run("FailRemoveCode", func(t *testing.T) {
   532  		h := testSetup()
   533  		r, w, _ := h.newHTTP("POST")
   534  		v := &SMSValidator{SMS: h.sms, Page: PageSMSRemove}
   535  
   536  		user := &mocks.User{Email: "test@test.com"}
   537  		h.storer.Users[user.Email] = user
   538  		h.setSession(authboss.SessionKey, user.Email)
   539  
   540  		h.setSession(SessionSMSSecret, "code")
   541  		h.bodyReader.Return = mocks.Values{Code: "badcode"}
   542  
   543  		h.loadClientState(w, &r)
   544  
   545  		if err := v.Post(w, r); err != nil {
   546  			t.Fatal(err)
   547  		}
   548  
   549  		if h.responder.Page != PageSMSRemove {
   550  			t.Error("page wrong:", h.responder.Page)
   551  		}
   552  		validation := h.responder.Data[authboss.DataValidation].(map[string][]string)
   553  		if got := validation[FormValueCode][0]; got != "2fa code was invalid" {
   554  			t.Error("data wrong:", got)
   555  		}
   556  	})
   557  }