github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/model/oauth/client_test.go (about)

     1  package oauth_test
     2  
     3  import (
     4  	"encoding/json"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/cozy/cozy-stack/model/instance"
     9  	"github.com/cozy/cozy-stack/model/job"
    10  	"github.com/cozy/cozy-stack/model/oauth"
    11  	"github.com/cozy/cozy-stack/pkg/config/config"
    12  	"github.com/cozy/cozy-stack/pkg/consts"
    13  	"github.com/cozy/cozy-stack/pkg/couchdb"
    14  	"github.com/cozy/cozy-stack/pkg/couchdb/mango"
    15  	"github.com/cozy/cozy-stack/pkg/metadata"
    16  	"github.com/cozy/cozy-stack/tests/testutils"
    17  	jwt "github.com/golang-jwt/jwt/v5"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  
    21  	_ "github.com/cozy/cozy-stack/model/notification/center"
    22  	_ "github.com/cozy/cozy-stack/worker/mails"
    23  )
    24  
    25  var c = &oauth.Client{
    26  	CouchID: "my-client-id",
    27  }
    28  
    29  func TestClient(t *testing.T) {
    30  	if testing.Short() {
    31  		t.Skip("an instance is required for this test: test skipped due to the use of --short flag")
    32  	}
    33  
    34  	config.UseTestFile(t)
    35  	conf := config.GetConfig()
    36  	conf.Contexts[config.DefaultInstanceContext] = map[string]interface{}{"manager_url": "http://manager.example.org"}
    37  	setup := testutils.NewSetup(t, t.Name())
    38  	testInstance := setup.GetTestInstance()
    39  
    40  	t.Run("CreateJWT", func(t *testing.T) {
    41  		tokenString, err := c.CreateJWT(testInstance, "test", "foo:read")
    42  		assert.NoError(t, err)
    43  
    44  		token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
    45  			_, ok := token.Method.(*jwt.SigningMethodHMAC)
    46  			assert.True(t, ok, "The signing method should be HMAC")
    47  			return testInstance.OAuthSecret, nil
    48  		})
    49  		assert.NoError(t, err)
    50  		assert.True(t, token.Valid)
    51  
    52  		claims, ok := token.Claims.(jwt.MapClaims)
    53  		assert.True(t, ok, "Claims can be parsed as standard claims")
    54  		assert.Equal(t, []interface{}{"test"}, claims["aud"])
    55  		assert.Equal(t, testInstance.Domain, claims["iss"])
    56  		assert.Equal(t, "my-client-id", claims["sub"])
    57  		assert.Equal(t, "foo:read", claims["scope"])
    58  	})
    59  
    60  	t.Run("ParseJWT", func(t *testing.T) {
    61  		tokenString, err := c.CreateJWT(testInstance, "refresh", "foo:read")
    62  		assert.NoError(t, err)
    63  
    64  		claims, ok := c.ValidToken(testInstance, consts.RefreshTokenAudience, tokenString)
    65  		assert.True(t, ok, "The token must be valid")
    66  		assert.Equal(t, jwt.ClaimStrings{"refresh"}, claims.Audience)
    67  		assert.Equal(t, testInstance.Domain, claims.Issuer)
    68  		assert.Equal(t, "my-client-id", claims.Subject)
    69  		assert.Equal(t, "foo:read", claims.Scope)
    70  	})
    71  
    72  	t.Run("ParseJWTInvalidAudience", func(t *testing.T) {
    73  		tokenString, err := c.CreateJWT(testInstance, "access", "foo:read")
    74  		assert.NoError(t, err)
    75  		_, ok := c.ValidToken(testInstance, consts.RefreshTokenAudience, tokenString)
    76  		assert.False(t, ok, "The token should be invalid")
    77  	})
    78  
    79  	t.Run("CreateClient", func(t *testing.T) {
    80  		client := &oauth.Client{
    81  			ClientName:   "foo",
    82  			RedirectURIs: []string{"https://foobar"},
    83  			SoftwareID:   "bar",
    84  
    85  			NotificationPlatform:    "android",
    86  			NotificationDeviceToken: "foobar",
    87  		}
    88  		assert.Nil(t, client.Create(testInstance))
    89  
    90  		client2 := &oauth.Client{
    91  			ClientName:   "foo",
    92  			RedirectURIs: []string{"https://foobar"},
    93  			SoftwareID:   "bar",
    94  
    95  			NotificationPlatform:    "ios",
    96  			NotificationDeviceToken: "foobar",
    97  		}
    98  		assert.Nil(t, client2.Create(testInstance))
    99  
   100  		client3 := &oauth.Client{
   101  			ClientName:   "foo",
   102  			RedirectURIs: []string{"https://foobar"},
   103  			SoftwareID:   "bar",
   104  		}
   105  		assert.Nil(t, client3.Create(testInstance))
   106  
   107  		client4 := &oauth.Client{
   108  			ClientName:   "foo-2",
   109  			RedirectURIs: []string{"https://foobar"},
   110  			SoftwareID:   "bar",
   111  		}
   112  		assert.Nil(t, client4.Create(testInstance))
   113  
   114  		assert.Equal(t, "foo", client.ClientName)
   115  		assert.Equal(t, "foo-2", client2.ClientName)
   116  		assert.Equal(t, "foo-3", client3.ClientName)
   117  		assert.Equal(t, "foo-2-2", client4.ClientName)
   118  	})
   119  
   120  	t.Run("CreateClientWithNotifications", func(t *testing.T) {
   121  		goodClient := &oauth.Client{
   122  			ClientName:   "client-5",
   123  			RedirectURIs: []string{"https://foobar"},
   124  			SoftwareID:   "bar",
   125  		}
   126  		if !assert.Nil(t, goodClient.Create(testInstance)) {
   127  			return
   128  		}
   129  
   130  		{
   131  			var err error
   132  			goodClient, err = oauth.FindClient(testInstance, goodClient.ClientID)
   133  			require.NoError(t, err)
   134  		}
   135  
   136  		{
   137  			client := goodClient.Clone().(*oauth.Client)
   138  			client.NotificationPlatform = "android"
   139  			assert.Nil(t, client.Update(testInstance, goodClient))
   140  		}
   141  
   142  		{
   143  			client := goodClient.Clone().(*oauth.Client)
   144  			client.NotificationPlatform = "unknown"
   145  			assert.NotNil(t, client.Update(testInstance, goodClient))
   146  		}
   147  	})
   148  
   149  	t.Run("CreateClientWithClientsLimit", func(t *testing.T) {
   150  		var pending, notPending, notificationWithoutPremium, notificationWithPremium *oauth.Client
   151  		t.Cleanup(func() {
   152  			// Delete created clients
   153  			pending, err := oauth.FindClient(testInstance, pending.ClientID)
   154  			require.NoError(t, err)
   155  			require.Nil(t, pending.Delete(testInstance))
   156  
   157  			notPending, err := oauth.FindClient(testInstance, notPending.ClientID)
   158  			require.NoError(t, err)
   159  			require.Nil(t, notPending.Delete(testInstance))
   160  
   161  			notificationWithoutPremium, err := oauth.FindClient(testInstance, notificationWithoutPremium.ClientID)
   162  			require.NoError(t, err)
   163  			require.Nil(t, notificationWithoutPremium.Delete(testInstance))
   164  
   165  			notificationWithPremium, err := oauth.FindClient(testInstance, notificationWithPremium.ClientID)
   166  			require.NoError(t, err)
   167  			require.Nil(t, notificationWithPremium.Delete(testInstance))
   168  		})
   169  
   170  		pending = &oauth.Client{
   171  			ClientName:   "pending",
   172  			ClientKind:   "mobile",
   173  			RedirectURIs: []string{"https://foobar"},
   174  			SoftwareID:   "bar",
   175  		}
   176  		require.Nil(t, pending.Create(testInstance))
   177  		assertClientsLimitAlertMailWasNotSent(t, testInstance)
   178  
   179  		notPending = &oauth.Client{
   180  			ClientName:   "notPending",
   181  			ClientKind:   "mobile",
   182  			RedirectURIs: []string{"https://foobar"},
   183  			SoftwareID:   "bar",
   184  		}
   185  		require.Nil(t, notPending.Create(testInstance, oauth.NotPending))
   186  		assertClientsLimitAlertMailWasNotSent(t, testInstance)
   187  
   188  		testutils.WithFlag(t, testInstance, "cozy.oauthclients.max", float64(1))
   189  
   190  		notificationWithoutPremium = &oauth.Client{
   191  			ClientName:   "notificationWithoutPremium",
   192  			ClientKind:   "mobile",
   193  			RedirectURIs: []string{"https://foobar"},
   194  			SoftwareID:   "bar",
   195  		}
   196  		require.Nil(t, notificationWithoutPremium.Create(testInstance, oauth.NotPending))
   197  		premiumLink := assertClientsLimitAlertMailWasSent(t, testInstance, "notificationWithoutPremium", 1)
   198  		assert.Empty(t, premiumLink)
   199  
   200  		testutils.WithManager(t, testInstance)
   201  
   202  		notificationWithPremium = &oauth.Client{
   203  			ClientName:   "notificationWithPremium",
   204  			ClientKind:   "mobile",
   205  			RedirectURIs: []string{"https://foobar"},
   206  			SoftwareID:   "bar",
   207  		}
   208  		require.Nil(t, notificationWithPremium.Create(testInstance, oauth.NotPending))
   209  		premiumLink = assertClientsLimitAlertMailWasSent(t, testInstance, "notificationWithPremium", 1)
   210  		assert.NotEmpty(t, premiumLink)
   211  	})
   212  
   213  	t.Run("GetConnectedUserClients", func(t *testing.T) {
   214  		browser := &oauth.Client{
   215  			ClientName:   "browser",
   216  			ClientKind:   "browser",
   217  			RedirectURIs: []string{"https://foobar"},
   218  			SoftwareID:   "bar",
   219  		}
   220  		require.Nil(t, browser.Create(testInstance, oauth.NotPending))
   221  
   222  		desktop := &oauth.Client{
   223  			ClientName:   "desktop",
   224  			ClientKind:   "desktop",
   225  			RedirectURIs: []string{"https://foobar"},
   226  			SoftwareID:   "bar",
   227  		}
   228  		require.Nil(t, desktop.Create(testInstance, oauth.NotPending))
   229  
   230  		mobile := &oauth.Client{
   231  			ClientName:   "mobile",
   232  			ClientKind:   "mobile",
   233  			RedirectURIs: []string{"https://foobar"},
   234  			SoftwareID:   "bar",
   235  		}
   236  		require.Nil(t, mobile.Create(testInstance, oauth.NotPending))
   237  
   238  		pending := &oauth.Client{
   239  			ClientName:   "pending",
   240  			ClientKind:   "desktop",
   241  			RedirectURIs: []string{"https://foobar"},
   242  			SoftwareID:   "bar",
   243  		}
   244  		require.Nil(t, pending.Create(testInstance))
   245  
   246  		sharing := &oauth.Client{
   247  			ClientName:   "sharing",
   248  			ClientKind:   "sharing",
   249  			RedirectURIs: []string{"https://foobar"},
   250  			SoftwareID:   "bar",
   251  		}
   252  		require.Nil(t, sharing.Create(testInstance, oauth.NotPending))
   253  
   254  		incomplete := &oauth.Client{
   255  			ClientName:   "incomplete",
   256  			RedirectURIs: []string{"https://foobar"},
   257  			SoftwareID:   "bar",
   258  		}
   259  		require.Nil(t, incomplete.Create(testInstance, oauth.NotPending))
   260  
   261  		clients, _, err := oauth.GetConnectedUserClients(testInstance, 100, "")
   262  		require.NoError(t, err)
   263  
   264  		assert.Len(t, clients, 3)
   265  		assert.Equal(t, clients[0].ClientName, "browser")
   266  		assert.Equal(t, clients[1].ClientName, "desktop")
   267  		assert.Equal(t, clients[2].ClientName, "mobile")
   268  	})
   269  
   270  	t.Run("ParseJWTInvalidIssuer", func(t *testing.T) {
   271  		other := &instance.Instance{
   272  			OAuthSecret: testInstance.OAuthSecret,
   273  			Domain:      "other.example.com",
   274  		}
   275  		tokenString, err := c.CreateJWT(other, "refresh", "foo:read")
   276  		assert.NoError(t, err)
   277  		_, ok := c.ValidToken(testInstance, consts.RefreshTokenAudience, tokenString)
   278  		assert.False(t, ok, "The token should be invalid")
   279  	})
   280  
   281  	t.Run("ParseJWTInvalidSubject", func(t *testing.T) {
   282  		other := &oauth.Client{
   283  			CouchID: "my-other-client",
   284  		}
   285  		tokenString, err := other.CreateJWT(testInstance, "refresh", "foo:read")
   286  		assert.NoError(t, err)
   287  		_, ok := c.ValidToken(testInstance, consts.RefreshTokenAudience, tokenString)
   288  		assert.False(t, ok, "The token should be invalid")
   289  	})
   290  
   291  	t.Run("ParseGoodSoftwareID", func(t *testing.T) {
   292  		goodClient := &oauth.Client{
   293  			ClientName:   "client-5",
   294  			RedirectURIs: []string{"https://foobar"},
   295  			SoftwareID:   "registry://drive",
   296  		}
   297  		err := goodClient.CheckSoftwareID(testInstance)
   298  		assert.Nil(t, err)
   299  	})
   300  
   301  	t.Run("ParseHttpSoftwareID", func(t *testing.T) {
   302  		goodClient := &oauth.Client{
   303  			ClientName:   "client-5",
   304  			RedirectURIs: []string{"https://foobar"},
   305  			SoftwareID:   "https://github.com/cozy-labs/cozy-desktop",
   306  		}
   307  		err := goodClient.CheckSoftwareID(testInstance)
   308  		assert.Nil(t, err)
   309  	})
   310  
   311  	t.Run("SortCLientsByCreatedAtDesc", func(t *testing.T) {
   312  		t0 := time.Now().Add(-1 * time.Minute)
   313  		t1 := t0.Add(10 * time.Second)
   314  		t2 := t1.Add(10 * time.Second)
   315  		clients := []*oauth.Client{
   316  			{CouchID: "a", Metadata: &metadata.CozyMetadata{CreatedAt: t2}},
   317  			{CouchID: "d"},
   318  			{CouchID: "c", Metadata: &metadata.CozyMetadata{CreatedAt: t0}},
   319  			{CouchID: "e"},
   320  			{CouchID: "b", Metadata: &metadata.CozyMetadata{CreatedAt: t1}},
   321  		}
   322  		oauth.SortClientsByCreatedAtDesc(clients)
   323  		require.Len(t, clients, 5)
   324  		assert.Equal(t, "a", clients[0].CouchID)
   325  		assert.Equal(t, "b", clients[1].CouchID)
   326  		assert.Equal(t, "c", clients[2].CouchID)
   327  		assert.Equal(t, "d", clients[3].CouchID)
   328  		assert.Equal(t, "e", clients[4].CouchID)
   329  	})
   330  
   331  	t.Run("CheckOAuthClientsLimitReached", func(t *testing.T) {
   332  		require.NoError(t, couchdb.ResetDB(testInstance, consts.OAuthClients))
   333  
   334  		// Create the OAuth client for the flagship app
   335  		flagship := oauth.Client{
   336  			RedirectURIs: []string{"cozy://flagship"},
   337  			ClientName:   "flagship-app",
   338  			ClientKind:   "mobile",
   339  			SoftwareID:   "github.com/cozy/cozy-stack/testing/flagship",
   340  			Flagship:     true,
   341  		}
   342  		require.Nil(t, flagship.Create(testInstance, oauth.NotPending))
   343  
   344  		clients, _, err := oauth.GetConnectedUserClients(testInstance, 100, "")
   345  		require.NoError(t, err)
   346  		require.Equal(t, len(clients), 1)
   347  
   348  		var reached, exceeded bool
   349  
   350  		reached, exceeded = oauth.CheckOAuthClientsLimitReached(testInstance, 0)
   351  		require.True(t, reached)
   352  		require.True(t, exceeded)
   353  
   354  		reached, exceeded = oauth.CheckOAuthClientsLimitReached(testInstance, 1)
   355  		require.True(t, reached)
   356  		require.False(t, exceeded)
   357  
   358  		reached, exceeded = oauth.CheckOAuthClientsLimitReached(testInstance, 2)
   359  		require.False(t, reached)
   360  		require.False(t, exceeded)
   361  
   362  		reached, exceeded = oauth.CheckOAuthClientsLimitReached(testInstance, -1)
   363  		require.False(t, reached)
   364  		require.False(t, exceeded)
   365  	})
   366  
   367  	t.Run("checkPlayIntegrityAttestation", func(t *testing.T) {
   368  		config := config.GetConfig()
   369  		config.Flagship.PlayIntegrityDecryptionKeys = []string{"bVcBAv0eO64NKIvDoRHpnTOZVxAkhMuFwRHrTEMr23U="}
   370  		config.Flagship.PlayIntegrityVerificationKeys = []string{"MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAElTF2uARN7oxfoDWyERYMe6QutI2NqS+CAtVmsPDIRjBBxF96fYojFVXRRsMb86PjkE21Ol+sO1YuspY+YuDRMw=="}
   371  
   372  		req := oauth.AttestationRequest{
   373  			Platform:    "android",
   374  			Issuer:      "playintegrity",
   375  			Challenge:   "testtesttesttesttesttesttest",
   376  			Attestation: "eyJhbGciOiJBMjU2S1ciLCJlbmMiOiJBMjU2R0NNIn0.MZIbC3rTckzCtg4rdAatQObifb3hkgJSq7-_XYTLItiCjkOyEjORlQ.-Z-6QJyEx4Bf4fNp.4vFq2XQvgQESouc5fF-oSixpYwWL2FBDzHfw1ay8nHmCXAgYfJ1yRPJm09dvJWJ5Iez4-HvfRWkwstZ4gtGYr4SX42h7L0vWkcv8yJ-12X9kUAFM_7ylpBLWiDEHnd0SeqpSeiAut_XXD81A_SncaenicMzDi0QKqeD6bdAkY67h46hnuyektYU4AsK9nVRPStaEfNiREJ017PuRVP3JQZVk4vAvg0jMfdY3BnaQ3AiEMb6uredrgP29gIIs0mGwcvc7ONyVRZ4_gSDSmfqKBjG-7HuC_rmC9CL2cUoz_JRxY0njvJi7isyfoTVZMyI4TKbUQckTKvv1Ysv11FxlTVsQqmOkVKtHOemS-G9ji23rq-LcGHG1DyriNqd3aFjMD6s1p5tFpxg7Eyc3pEm4f1Ig4S-sOC6BsTjqM_cNyqCuNbfwtQSE1pnh7yI7pcsfLPRisoODng0wTYXAqA4mvATf60eKSrPGb6vD47owlV-CbxLkG3PpVhjIpLIGknFSJnkzeIdgTR5XWUsQKVJ6ppW4mq8tO_C4KNHNISKimUhmFekG1w1rZ_suAvaC5Oz6NKn4iVMXpNm3N8nuBCkwbenN_A7334rSMHS12Ye1QRiH54VuUksUmzeUiFxaubkEJGVHwxYDN_lwQZ7bzSZbMfW46_-rK98SC3JNkif4Ucdl52fWY8Mpaf41PYGv6H7QAnY94wkAZGJPmaCzicDs5UbAiCI.fqVFSJEaY7GiqCga4-CMuw",
   377  		}
   378  		require.NoError(t, oauth.CheckPlayIntegrityAttestationForTestingPurpose(req))
   379  	})
   380  }
   381  
   382  func assertClientsLimitAlertMailWasNotSent(t *testing.T, instance *instance.Instance) {
   383  	var jobs []job.Job
   384  	couchReq := &couchdb.FindRequest{
   385  		UseIndex: "by-worker-and-state",
   386  		Selector: mango.And(
   387  			mango.Equal("worker", "sendmail"),
   388  			mango.Exists("state"),
   389  		),
   390  		Sort: mango.SortBy{
   391  			mango.SortByField{Field: "worker", Direction: "desc"},
   392  		},
   393  		Limit: 1,
   394  	}
   395  	err := couchdb.FindDocs(instance, consts.Jobs, couchReq, &jobs)
   396  	assert.NoError(t, err)
   397  
   398  	// Mail sent for the device connection
   399  	assert.Len(t, jobs, 0)
   400  }
   401  
   402  func assertClientsLimitAlertMailWasSent(t *testing.T, instance *instance.Instance, clientName string, clientsLimit int) string {
   403  	var jobs []job.Job
   404  	couchReq := &couchdb.FindRequest{
   405  		UseIndex: "by-worker-and-state",
   406  		Selector: mango.And(
   407  			mango.Equal("worker", "sendmail"),
   408  			mango.Exists("state"),
   409  		),
   410  		Sort: mango.SortBy{
   411  			mango.SortByField{Field: "worker", Direction: "desc"},
   412  		},
   413  		Limit: 1,
   414  	}
   415  	err := couchdb.FindDocs(instance, consts.Jobs, couchReq, &jobs)
   416  	assert.NoError(t, err)
   417  	assert.Len(t, jobs, 1)
   418  
   419  	var msg map[string]interface{}
   420  	err = json.Unmarshal(jobs[0].Message, &msg)
   421  	assert.NoError(t, err)
   422  
   423  	assert.Equal(t, msg["mode"], "noreply")
   424  	assert.Equal(t, msg["template_name"], "notifications_oauthclients")
   425  
   426  	values := msg["template_values"].(map[string]interface{})
   427  	assert.Equal(t, values["ClientName"], clientName)
   428  	assert.Equal(t, values["ClientsLimit"], float64(clientsLimit))
   429  
   430  	return values["OffersLink"].(string)
   431  }