github.com/jxgolibs/go-oauth2-server@v1.0.1/oauth/authenticate_test.go (about)

     1  package oauth_test
     2  
     3  import (
     4  	"time"
     5  
     6  	"github.com/RichardKnop/go-oauth2-server/models"
     7  	"github.com/RichardKnop/go-oauth2-server/oauth"
     8  	"github.com/RichardKnop/go-oauth2-server/session"
     9  	"github.com/RichardKnop/uuid"
    10  	"github.com/jinzhu/gorm"
    11  	"github.com/stretchr/testify/assert"
    12  )
    13  
    14  func (suite *OauthTestSuite) TestAuthenticate() {
    15  	var (
    16  		accessToken *models.OauthAccessToken
    17  		err         error
    18  	)
    19  
    20  	// Insert some test access tokens
    21  	testAccessTokens := []*models.OauthAccessToken{
    22  		// Expired access token
    23  		{
    24  			MyGormModel: models.MyGormModel{
    25  				ID:        uuid.New(),
    26  				CreatedAt: time.Now().UTC(),
    27  			},
    28  			Token:     "test_expired_token",
    29  			ExpiresAt: time.Now().UTC().Add(-10 * time.Second),
    30  			Client:    suite.clients[0],
    31  			User:      suite.users[0],
    32  		},
    33  		// Access token without a user
    34  		{
    35  			MyGormModel: models.MyGormModel{
    36  				ID:        uuid.New(),
    37  				CreatedAt: time.Now().UTC(),
    38  			},
    39  			Token:     "test_client_token",
    40  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
    41  			Client:    suite.clients[0],
    42  		},
    43  		// Access token with a user
    44  		{
    45  			MyGormModel: models.MyGormModel{
    46  				ID:        uuid.New(),
    47  				CreatedAt: time.Now().UTC(),
    48  			},
    49  			Token:     "test_user_token",
    50  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
    51  			Client:    suite.clients[0],
    52  			User:      suite.users[0],
    53  		},
    54  	}
    55  	for _, testAccessToken := range testAccessTokens {
    56  		err := suite.db.Create(testAccessToken).Error
    57  		assert.NoError(suite.T(), err, "Inserting test data failed")
    58  	}
    59  
    60  	// Test passing an empty token
    61  	accessToken, err = suite.service.Authenticate("")
    62  
    63  	// Access token should be nil
    64  	assert.Nil(suite.T(), accessToken)
    65  
    66  	// Correct error should be returned
    67  	if assert.NotNil(suite.T(), err) {
    68  		assert.Equal(suite.T(), oauth.ErrAccessTokenNotFound, err)
    69  	}
    70  
    71  	// Test passing a bogus token
    72  	accessToken, err = suite.service.Authenticate("bogus")
    73  
    74  	// Access token should be nil
    75  	assert.Nil(suite.T(), accessToken)
    76  
    77  	// Correct error should be returned
    78  	if assert.NotNil(suite.T(), err) {
    79  		assert.Equal(suite.T(), oauth.ErrAccessTokenNotFound, err)
    80  	}
    81  
    82  	// Test passing an expired token
    83  	accessToken, err = suite.service.Authenticate("test_expired_token")
    84  
    85  	// Access token should be nil
    86  	assert.Nil(suite.T(), accessToken)
    87  
    88  	// Correct error should be returned
    89  	if assert.NotNil(suite.T(), err) {
    90  		assert.Equal(suite.T(), oauth.ErrAccessTokenExpired, err)
    91  	}
    92  
    93  	// Test passing a valid client token
    94  	accessToken, err = suite.service.Authenticate("test_client_token")
    95  
    96  	// Correct access token should be returned
    97  	if assert.NotNil(suite.T(), accessToken) {
    98  		assert.Equal(suite.T(), "test_client_token", accessToken.Token)
    99  		assert.EqualValues(suite.T(), suite.clients[0].ID, accessToken.ClientID.String)
   100  		assert.False(suite.T(), accessToken.UserID.Valid)
   101  	}
   102  
   103  	// Error should be nil
   104  	assert.Nil(suite.T(), err)
   105  
   106  	// Test passing a valid user token
   107  	accessToken, err = suite.service.Authenticate("test_user_token")
   108  
   109  	// Correct access token should be returned
   110  	if assert.NotNil(suite.T(), accessToken) {
   111  		assert.Equal(suite.T(), "test_user_token", accessToken.Token)
   112  		assert.EqualValues(suite.T(), suite.clients[0].ID, accessToken.ClientID.String)
   113  		assert.EqualValues(suite.T(), suite.users[0].ID, accessToken.UserID.String)
   114  	}
   115  
   116  	// Error should be nil
   117  	assert.Nil(suite.T(), err)
   118  }
   119  
   120  func (suite *OauthTestSuite) TestAuthenticateRollingRefreshToken() {
   121  	var (
   122  		testAccessTokens  []*models.OauthAccessToken
   123  		testRefreshTokens []*models.OauthRefreshToken
   124  		accessToken       *models.OauthAccessToken
   125  		err               error
   126  		refreshTokens     []*models.OauthRefreshToken
   127  	)
   128  
   129  	// Insert some test access tokens
   130  	testAccessTokens = []*models.OauthAccessToken{
   131  		{
   132  			MyGormModel: models.MyGormModel{
   133  				ID:        uuid.New(),
   134  				CreatedAt: time.Now().UTC(),
   135  			},
   136  			Token:     "test_token_1",
   137  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   138  			Client:    suite.clients[0],
   139  			User:      suite.users[0],
   140  		},
   141  		{
   142  			MyGormModel: models.MyGormModel{
   143  				ID:        uuid.New(),
   144  				CreatedAt: time.Now().UTC(),
   145  			},
   146  			Token:     "test_token_2",
   147  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   148  			Client:    suite.clients[0],
   149  		},
   150  		{
   151  			MyGormModel: models.MyGormModel{
   152  				ID:        uuid.New(),
   153  				CreatedAt: time.Now().UTC(),
   154  			},
   155  			Token:     "test_token_3",
   156  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   157  			Client:    suite.clients[0],
   158  			User:      suite.users[1],
   159  		},
   160  	}
   161  	for _, testAccessToken := range testAccessTokens {
   162  		err = suite.db.Create(testAccessToken).Error
   163  		assert.NoError(suite.T(), err, "Inserting test data failed")
   164  	}
   165  
   166  	// Insert some test access tokens
   167  	testRefreshTokens = []*models.OauthRefreshToken{
   168  		{
   169  			MyGormModel: models.MyGormModel{
   170  				ID:        uuid.New(),
   171  				CreatedAt: time.Now().UTC(),
   172  			},
   173  			Token:     "test_token_1",
   174  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   175  			Client:    suite.clients[0],
   176  			User:      suite.users[0],
   177  		},
   178  		{
   179  			MyGormModel: models.MyGormModel{
   180  				ID:        uuid.New(),
   181  				CreatedAt: time.Now().UTC(),
   182  			},
   183  			Token:     "test_token_2",
   184  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   185  			Client:    suite.clients[0],
   186  		},
   187  		{
   188  			MyGormModel: models.MyGormModel{
   189  				ID:        uuid.New(),
   190  				CreatedAt: time.Now().UTC(),
   191  			},
   192  			Token:     "test_token_3",
   193  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   194  			Client:    suite.clients[0],
   195  			User:      suite.users[1],
   196  		},
   197  	}
   198  	for _, testRefreshToken := range testRefreshTokens {
   199  		err = suite.db.Create(testRefreshToken).Error
   200  		assert.NoError(suite.T(), err, "Inserting test data failed")
   201  	}
   202  
   203  	// Authenticate with the first access token
   204  	now1 := time.Now().UTC()
   205  	gorm.NowFunc = func() time.Time {
   206  		return now1
   207  	}
   208  	accessToken, err = suite.service.Authenticate("test_token_1")
   209  	assert.Nil(suite.T(), err)
   210  	assert.Equal(suite.T(), "test_token_1", accessToken.Token)
   211  	assert.EqualValues(suite.T(), suite.clients[0].ID, accessToken.ClientID.String)
   212  	assert.EqualValues(suite.T(), suite.users[0].ID, accessToken.UserID.String)
   213  
   214  	// First refresh token expiration date should be extended
   215  	refreshTokens = make([]*models.OauthRefreshToken, len(testRefreshTokens))
   216  	err = suite.db.Where(
   217  		"token IN ('test_token_1', 'test_token_2', 'test_token_3')",
   218  	).Order("created_at").Find(&refreshTokens).Error
   219  	assert.Nil(suite.T(), err)
   220  	assert.Equal(suite.T(), "test_token_1", refreshTokens[0].Token)
   221  	assert.Equal(
   222  		suite.T(),
   223  		now1.Unix()+int64(suite.cnf.Oauth.RefreshTokenLifetime),
   224  		refreshTokens[0].ExpiresAt.Unix(),
   225  	)
   226  	assert.Equal(suite.T(), "test_token_2", refreshTokens[1].Token)
   227  	assert.Equal(
   228  		suite.T(),
   229  		testRefreshTokens[1].ExpiresAt.Unix(),
   230  		refreshTokens[1].ExpiresAt.Unix(),
   231  	)
   232  	assert.Equal(suite.T(), "test_token_3", refreshTokens[2].Token)
   233  	assert.Equal(
   234  		suite.T(),
   235  		testRefreshTokens[2].ExpiresAt.Unix(),
   236  		refreshTokens[2].ExpiresAt.Unix(),
   237  	)
   238  
   239  	// Authenticate with the second access token
   240  	now2 := time.Now().UTC()
   241  	gorm.NowFunc = func() time.Time {
   242  		return now2
   243  	}
   244  	accessToken, err = suite.service.Authenticate("test_token_2")
   245  	assert.Nil(suite.T(), err)
   246  	assert.Equal(suite.T(), "test_token_2", accessToken.Token)
   247  	assert.EqualValues(suite.T(), suite.clients[0].ID, accessToken.ClientID.String)
   248  	assert.False(suite.T(), accessToken.UserID.Valid)
   249  
   250  	// Second refresh token expiration date should be extended
   251  	refreshTokens = make([]*models.OauthRefreshToken, len(testRefreshTokens))
   252  	err = suite.db.Where(
   253  		"token IN ('test_token_1', 'test_token_2', 'test_token_3')",
   254  	).Order("created_at").Find(&refreshTokens).Error
   255  	assert.Nil(suite.T(), err)
   256  	assert.Equal(suite.T(), "test_token_1", refreshTokens[0].Token)
   257  	assert.Equal(
   258  		suite.T(),
   259  		now1.Unix()+int64(suite.cnf.Oauth.RefreshTokenLifetime),
   260  		refreshTokens[0].ExpiresAt.Unix(),
   261  	)
   262  	assert.Equal(suite.T(), "test_token_2", refreshTokens[1].Token)
   263  	assert.Equal(
   264  		suite.T(),
   265  		now2.Unix()+int64(suite.cnf.Oauth.RefreshTokenLifetime),
   266  		refreshTokens[1].ExpiresAt.Unix(),
   267  	)
   268  	assert.Equal(suite.T(), "test_token_3", refreshTokens[2].Token)
   269  	assert.Equal(
   270  		suite.T(),
   271  		testRefreshTokens[2].ExpiresAt.Unix(),
   272  		refreshTokens[2].ExpiresAt.Unix(),
   273  	)
   274  
   275  	// Authenticate with the third access token
   276  	now3 := time.Now().UTC()
   277  	gorm.NowFunc = func() time.Time {
   278  		return now3
   279  	}
   280  	accessToken, err = suite.service.Authenticate("test_token_3")
   281  	assert.Nil(suite.T(), err)
   282  	assert.Equal(suite.T(), "test_token_3", accessToken.Token)
   283  	assert.EqualValues(suite.T(), suite.clients[0].ID, accessToken.ClientID.String)
   284  	assert.EqualValues(suite.T(), suite.users[1].ID, accessToken.UserID.String)
   285  
   286  	// First refresh token expiration date should be extended
   287  	refreshTokens = make([]*models.OauthRefreshToken, len(testRefreshTokens))
   288  	err = suite.db.Where(
   289  		"token IN ('test_token_1', 'test_token_2', 'test_token_3')",
   290  	).Order("created_at").Find(&refreshTokens).Error
   291  	assert.Nil(suite.T(), err)
   292  	assert.Equal(suite.T(), "test_token_1", refreshTokens[0].Token)
   293  	assert.Equal(
   294  		suite.T(),
   295  		now1.Unix()+int64(suite.cnf.Oauth.RefreshTokenLifetime),
   296  		refreshTokens[0].ExpiresAt.Unix(),
   297  	)
   298  	assert.Equal(suite.T(), "test_token_2", refreshTokens[1].Token)
   299  	assert.Equal(
   300  		suite.T(),
   301  		now2.Unix()+int64(suite.cnf.Oauth.RefreshTokenLifetime),
   302  		refreshTokens[1].ExpiresAt.Unix(),
   303  	)
   304  	assert.Equal(suite.T(), "test_token_3", refreshTokens[2].Token)
   305  	assert.Equal(
   306  		suite.T(),
   307  		now3.Unix()+int64(suite.cnf.Oauth.RefreshTokenLifetime),
   308  		refreshTokens[2].ExpiresAt.Unix(),
   309  	)
   310  }
   311  
   312  func (suite *OauthTestSuite) TestClearUserTokens() {
   313  	var (
   314  		testAccessTokens  []*models.OauthAccessToken
   315  		testRefreshTokens []*models.OauthRefreshToken
   316  		err               error
   317  		testUserSession   *session.UserSession
   318  	)
   319  
   320  	// Insert some test access tokens
   321  	testAccessTokens = []*models.OauthAccessToken{
   322  		{
   323  			MyGormModel: models.MyGormModel{
   324  				ID:        uuid.New(),
   325  				CreatedAt: time.Now().UTC(),
   326  			},
   327  			Token:     "test_token_1",
   328  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   329  			Client:    suite.clients[0],
   330  			User:      suite.users[0],
   331  		},
   332  		{
   333  			MyGormModel: models.MyGormModel{
   334  				ID:        uuid.New(),
   335  				CreatedAt: time.Now().UTC(),
   336  			},
   337  			Token:     "test_token_2",
   338  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   339  			Client:    suite.clients[1],
   340  			User:      suite.users[0],
   341  		},
   342  		{
   343  			MyGormModel: models.MyGormModel{
   344  				ID:        uuid.New(),
   345  				CreatedAt: time.Now().UTC(),
   346  			},
   347  			Token:     "test_token_3",
   348  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   349  			Client:    suite.clients[0],
   350  			User:      suite.users[1],
   351  		},
   352  	}
   353  	for _, testAccessToken := range testAccessTokens {
   354  		err = suite.db.Create(testAccessToken).Error
   355  		assert.NoError(suite.T(), err, "Inserting test data failed")
   356  	}
   357  
   358  	// Insert some test access tokens
   359  	testRefreshTokens = []*models.OauthRefreshToken{
   360  		{
   361  			MyGormModel: models.MyGormModel{
   362  				ID:        uuid.New(),
   363  				CreatedAt: time.Now().UTC(),
   364  			},
   365  			Token:     "test_token_1",
   366  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   367  			Client:    suite.clients[0],
   368  			User:      suite.users[0],
   369  		},
   370  		{
   371  			MyGormModel: models.MyGormModel{
   372  				ID:        uuid.New(),
   373  				CreatedAt: time.Now().UTC(),
   374  			},
   375  			Token:     "test_token_2",
   376  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   377  			Client:    suite.clients[1],
   378  			User:      suite.users[0],
   379  		},
   380  		{
   381  			MyGormModel: models.MyGormModel{
   382  				ID:        uuid.New(),
   383  				CreatedAt: time.Now().UTC(),
   384  			},
   385  			Token:     "test_token_3",
   386  			ExpiresAt: time.Now().UTC().Add(+10 * time.Second),
   387  			Client:    suite.clients[0],
   388  			User:      suite.users[1],
   389  		},
   390  	}
   391  	for _, testRefreshToken := range testRefreshTokens {
   392  		err = suite.db.Create(testRefreshToken).Error
   393  		assert.NoError(suite.T(), err, "Inserting test data failed")
   394  	}
   395  
   396  	testUserSession = &session.UserSession{
   397  		ClientID:     suite.clients[0].Key,
   398  		Username:     suite.users[0].Username,
   399  		AccessToken:  "test_token_1",
   400  		RefreshToken: "test_token_1",
   401  	}
   402  
   403  	// Remove test_token_1 from accress and refresh token tables
   404  	suite.service.ClearUserTokens(testUserSession)
   405  
   406  	// Assert that the refresh token was removed
   407  	found := !models.OauthRefreshTokenPreload(suite.db).Where("token = ?", testUserSession.RefreshToken).First(&models.OauthRefreshToken{}).RecordNotFound()
   408  	assert.Equal(suite.T(), false, found)
   409  
   410  	// Assert that the access token was removed
   411  	found = !models.OauthAccessTokenPreload(suite.db).Where("token = ?", testUserSession.AccessToken).First(&models.OauthAccessToken{}).RecordNotFound()
   412  	assert.Equal(suite.T(), false, found)
   413  
   414  	// Assert that the other two tokens are still there
   415  	// Refresh tokens
   416  	found = !models.OauthRefreshTokenPreload(suite.db).Where("token = ?", "test_token_2").First(&models.OauthRefreshToken{}).RecordNotFound()
   417  	assert.Equal(suite.T(), true, found)
   418  	found = !models.OauthRefreshTokenPreload(suite.db).Where("token = ?", "test_token_3").First(&models.OauthRefreshToken{}).RecordNotFound()
   419  	assert.Equal(suite.T(), true, found)
   420  
   421  	// Access tokens
   422  	found = !models.OauthAccessTokenPreload(suite.db).Where("token = ?", "test_token_2").First(&models.OauthAccessToken{}).RecordNotFound()
   423  	assert.Equal(suite.T(), true, found)
   424  	found = !models.OauthAccessTokenPreload(suite.db).Where("token = ?", "test_token_3").First(&models.OauthAccessToken{}).RecordNotFound()
   425  	assert.Equal(suite.T(), true, found)
   426  
   427  }