go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/encryptedcookies/method_test.go (about)

     1  // Copyright 2021 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package encryptedcookies
    16  
    17  import (
    18  	"context"
    19  	"crypto/rsa"
    20  	"encoding/base64"
    21  	"encoding/binary"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"net/http"
    26  	"net/http/httptest"
    27  	"net/url"
    28  	"strings"
    29  	"sync/atomic"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/google/tink/go/aead"
    34  	"github.com/google/tink/go/keyset"
    35  	"github.com/google/tink/go/tink"
    36  
    37  	"go.chromium.org/luci/auth/identity"
    38  	"go.chromium.org/luci/common/clock"
    39  	"go.chromium.org/luci/common/clock/testclock"
    40  	"go.chromium.org/luci/common/logging/gologger"
    41  	"go.chromium.org/luci/common/retry/transient"
    42  	"go.chromium.org/luci/gae/impl/memory"
    43  	"go.chromium.org/luci/server/auth"
    44  	"go.chromium.org/luci/server/auth/authtest"
    45  	"go.chromium.org/luci/server/auth/openid"
    46  	"go.chromium.org/luci/server/auth/signing/signingtest"
    47  	"go.chromium.org/luci/server/caching"
    48  	"go.chromium.org/luci/server/encryptedcookies/internal"
    49  	"go.chromium.org/luci/server/encryptedcookies/session/datastore"
    50  	"go.chromium.org/luci/server/router"
    51  
    52  	. "github.com/smartystreets/goconvey/convey"
    53  )
    54  
    55  func TestMethod(t *testing.T) {
    56  	t.Parallel()
    57  
    58  	Convey("With mocks", t, func(c C) {
    59  		// Instantiate ID provider with mocked time only. It doesn't need other
    60  		// context features.
    61  		ctx := context.Background()
    62  		ctx, tc := testclock.UseTime(ctx, testclock.TestRecentTimeUTC)
    63  		provider := openIDProviderFake{
    64  			ExpectedClientID:     "client_id",
    65  			ExpectedClientSecret: "client_secret",
    66  			ExpectedRedirectURI:  "https://primary.example.com/auth/openid/callback",
    67  			UserSub:              "user-sub",
    68  			UserEmail:            "someone@example.com",
    69  			UserName:             "Someone Example",
    70  			UserPicture:          "https://example.com/picture",
    71  			RefreshToken:         "good_refresh_token",
    72  		}
    73  		provider.Init(ctx, c)
    74  		defer provider.Close()
    75  
    76  		// Primary encryption keys used to encrypt cookies.
    77  		kh, err := keyset.NewHandle(aead.AES256GCMKeyTemplate())
    78  		So(err, ShouldBeNil)
    79  		ae, err := aead.New(kh)
    80  		So(err, ShouldBeNil)
    81  
    82  		// Install the rest of the context stuff used by AuthMethod.
    83  		ctx = memory.Use(ctx)
    84  		ctx = gologger.StdConfig.Use(ctx)
    85  		ctx = caching.WithEmptyProcessCache(ctx)
    86  		ctx = authtest.MockAuthConfig(ctx)
    87  
    88  		sessionStore := &datastore.Store{}
    89  
    90  		makeMethod := func(requiredScopes, optionalScopes []string) *AuthMethod {
    91  			return &AuthMethod{
    92  				OpenIDConfig: func(context.Context) (*OpenIDConfig, error) {
    93  					return &OpenIDConfig{
    94  						DiscoveryURL: provider.DiscoveryURL(),
    95  						ClientID:     provider.ExpectedClientID,
    96  						ClientSecret: provider.ExpectedClientSecret,
    97  						RedirectURI:  provider.ExpectedRedirectURI,
    98  					}, nil
    99  				},
   100  				AEADProvider:        func(context.Context) tink.AEAD { return ae },
   101  				Sessions:            sessionStore,
   102  				RequiredScopes:      requiredScopes,
   103  				OptionalScopes:      optionalScopes,
   104  				ExposeStateEndpoint: true,
   105  			}
   106  		}
   107  
   108  		methodV1 := makeMethod([]string{"scope1"}, []string{"scope2"})
   109  
   110  		call := func(h router.Handler, host string, url *url.URL, header http.Header) *http.Response {
   111  			rw := httptest.NewRecorder()
   112  			h(&router.Context{
   113  				Request: (&http.Request{
   114  					URL:    url,
   115  					Host:   host,
   116  					Header: header,
   117  				}).WithContext(ctx),
   118  				Writer: rw,
   119  			})
   120  			return rw.Result()
   121  		}
   122  
   123  		performLogin := func(method *AuthMethod) (callbackRawQuery string) {
   124  			So(method.Warmup(ctx), ShouldBeNil)
   125  
   126  			loginURL, err := method.LoginURL(ctx, "/some/dest")
   127  			So(err, ShouldBeNil)
   128  			So(loginURL, ShouldEqual, "/auth/openid/login?r=%2Fsome%2Fdest")
   129  			parsed, _ := url.Parse(loginURL)
   130  
   131  			// Hitting the generated login URL generates a redirect to the provider.
   132  			resp := call(method.loginHandler, "dest.example.com", parsed, nil)
   133  			So(resp.StatusCode, ShouldEqual, http.StatusFound)
   134  			authURL, _ := url.Parse(resp.Header.Get("Location"))
   135  			So(authURL.Host, ShouldEqual, provider.Host())
   136  			So(authURL.Path, ShouldEqual, "/authorization")
   137  
   138  			// After the user logs in, the provider generates a redirect to the
   139  			// callback URI with some query parameters.
   140  			return provider.CallbackRawQuery(authURL.Query())
   141  		}
   142  
   143  		performCallback := func(method *AuthMethod, callbackRawQuery string) (cookie *http.Cookie, deleted []*http.Cookie) {
   144  			// Provider calls us back on our primary host (not "dest.example.com").
   145  			resp := call(method.callbackHandler, "primary.example.com", &url.URL{
   146  				RawQuery: callbackRawQuery,
   147  			}, nil)
   148  
   149  			// We got a redirect to "dest.example.com".
   150  			So(resp.StatusCode, ShouldEqual, http.StatusFound)
   151  			So(resp.Header.Get("Location"), ShouldEqual, "https://dest.example.com?"+callbackRawQuery)
   152  
   153  			// Now hitting the same callback on "dest.example.com".
   154  			resp = call(method.callbackHandler, "dest.example.com", &url.URL{
   155  				RawQuery: callbackRawQuery,
   156  			}, nil)
   157  
   158  			// Got a redirect to the final destination URL.
   159  			So(resp.StatusCode, ShouldEqual, http.StatusFound)
   160  			So(resp.Header.Get("Location"), ShouldEqual, "/some/dest")
   161  
   162  			// And we've got some session cookie!
   163  			for _, c := range resp.Cookies() {
   164  				if c.MaxAge == -1 {
   165  					deleted = append(deleted, c)
   166  				} else {
   167  					// Should have at most one new cookie.
   168  					So(cookie, ShouldBeNil)
   169  					cookie = c
   170  				}
   171  			}
   172  			So(cookie, ShouldNotBeNil)
   173  			return
   174  		}
   175  
   176  		phonyRequest := func(cookie *http.Cookie) auth.RequestMetadata {
   177  			req, _ := http.NewRequest("GET", "https://dest.example.com/phony", nil)
   178  			if cookie != nil {
   179  				req.Header.Add("Cookie", fmt.Sprintf("%s=%s", cookie.Name, cookie.Value))
   180  			}
   181  			return auth.RequestMetadataForHTTP(req)
   182  		}
   183  
   184  		Convey("Full flow", func() {
   185  			callbackRawQueryV1 := performLogin(methodV1)
   186  			cookieV1, deleted := performCallback(methodV1, callbackRawQueryV1)
   187  
   188  			// Set the cookie on the correct path.
   189  			So(cookieV1.Path, ShouldEqual, internal.UnlimitedCookiePath)
   190  
   191  			// Removed a potentially stale cookie on a different path.
   192  			So(deleted, ShouldHaveLength, 1)
   193  			So(deleted[0].Path, ShouldEqual, internal.LimitedCookiePath)
   194  
   195  			// Handed out 1 access token thus far.
   196  			So(provider.AccessTokensMinted(), ShouldEqual, 1)
   197  
   198  			Convey("Code reuse is forbidden", func() {
   199  				// Trying to use the authorization code again fails.
   200  				resp := call(methodV1.callbackHandler, "dest.example.com", &url.URL{
   201  					RawQuery: callbackRawQueryV1,
   202  				}, nil)
   203  				So(resp.StatusCode, ShouldEqual, http.StatusBadRequest)
   204  			})
   205  
   206  			Convey("No cookies => method is skipped", func() {
   207  				user, session, err := methodV1.Authenticate(ctx, phonyRequest(nil))
   208  				So(err, ShouldBeNil)
   209  				So(user, ShouldBeNil)
   210  				So(session, ShouldBeNil)
   211  			})
   212  
   213  			Convey("Good cookie works", func() {
   214  				user, session, err := methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   215  				So(err, ShouldBeNil)
   216  				So(user, ShouldResemble, &auth.User{
   217  					Identity: identity.Identity("user:" + provider.UserEmail),
   218  					Email:    provider.UserEmail,
   219  					Name:     provider.UserName,
   220  					Picture:  provider.UserPicture,
   221  				})
   222  				So(session, ShouldNotBeNil)
   223  
   224  				// Can grab the stored access token.
   225  				tok, err := session.AccessToken(ctx)
   226  				So(err, ShouldBeNil)
   227  				So(tok.AccessToken, ShouldEqual, "access_token_1")
   228  				So(tok.Expiry.Sub(testclock.TestRecentTimeUTC), ShouldEqual, time.Hour)
   229  
   230  				// Can grab the stored ID token.
   231  				tok, err = session.IDToken(ctx)
   232  				So(err, ShouldBeNil)
   233  				So(tok.AccessToken, ShouldStartWith, "eyJhbG") // JWT header
   234  				So(tok.Expiry.Sub(testclock.TestRecentTimeUTC), ShouldEqual, time.Hour)
   235  			})
   236  
   237  			Convey("Malformed cookie is ignored", func() {
   238  				user, session, err := methodV1.Authenticate(ctx, phonyRequest(&http.Cookie{
   239  					Name:  cookieV1.Name,
   240  					Value: cookieV1.Value[:20],
   241  				}))
   242  				So(err, ShouldBeNil)
   243  				So(user, ShouldBeNil)
   244  				So(session, ShouldBeNil)
   245  			})
   246  
   247  			Convey("Missing datastore session", func() {
   248  				methodV1.Sessions.(*datastore.Store).Namespace = "another"
   249  				user, session, err := methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   250  				So(err, ShouldBeNil)
   251  				So(user, ShouldBeNil)
   252  				So(session, ShouldBeNil)
   253  			})
   254  
   255  			Convey("After short-lived tokens expire", func() {
   256  				tc.Add(2 * time.Hour)
   257  
   258  				Convey("Session refresh OK", func() {
   259  					// The session is still valid.
   260  					user, session, err := methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   261  					So(err, ShouldBeNil)
   262  					So(user, ShouldNotBeNil)
   263  					So(session, ShouldNotBeNil)
   264  
   265  					// Tokens have been refreshed.
   266  					So(provider.AccessTokensMinted(), ShouldEqual, 2)
   267  
   268  					// auth.Session returns the refreshed token.
   269  					tok, err := session.AccessToken(ctx)
   270  					So(err, ShouldBeNil)
   271  					So(tok.AccessToken, ShouldEqual, "access_token_2")
   272  					So(tok.Expiry.Sub(testclock.TestRecentTimeUTC), ShouldEqual, 3*time.Hour)
   273  
   274  					// No need to refresh anymore.
   275  					user, session, err = methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   276  					So(err, ShouldBeNil)
   277  					So(user, ShouldNotBeNil)
   278  					So(session, ShouldNotBeNil)
   279  					So(provider.AccessTokensMinted(), ShouldEqual, 2)
   280  				})
   281  
   282  				Convey("Session refresh transient fail", func() {
   283  					provider.TransientErr = errors.New("boom")
   284  
   285  					_, _, err := methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   286  					So(err, ShouldNotBeNil)
   287  					So(transient.Tag.In(err), ShouldBeTrue)
   288  				})
   289  
   290  				Convey("Session refresh fatal fail", func() {
   291  					provider.RefreshToken = "another-token"
   292  
   293  					// Refresh fails and closes the session.
   294  					user, session, err := methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   295  					So(err, ShouldBeNil)
   296  					So(user, ShouldBeNil)
   297  					So(session, ShouldBeNil)
   298  
   299  					// Using the closed session is unsuccessful.
   300  					user, session, err = methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   301  					So(err, ShouldBeNil)
   302  					So(user, ShouldBeNil)
   303  					So(session, ShouldBeNil)
   304  				})
   305  			})
   306  
   307  			Convey("Logout works", func() {
   308  				logoutURL, err := methodV1.LogoutURL(ctx, "/some/dest")
   309  				So(err, ShouldBeNil)
   310  				So(logoutURL, ShouldEqual, "/auth/openid/logout?r=%2Fsome%2Fdest")
   311  				parsed, _ := url.Parse(logoutURL)
   312  
   313  				resp := call(methodV1.logoutHandler, "primary.example.com", parsed, http.Header{
   314  					"Cookie": {fmt.Sprintf("%s=%s", cookieV1.Name, cookieV1.Value)},
   315  				})
   316  
   317  				// Got a redirect to the final destination URL.
   318  				So(resp.StatusCode, ShouldEqual, http.StatusFound)
   319  				So(resp.Header.Get("Location"), ShouldEqual, "/some/dest")
   320  
   321  				// Cookies are removed.
   322  				cookies := resp.Cookies()
   323  				paths := []string{}
   324  				for _, c := range cookies {
   325  					So(c.Name, ShouldEqual, internal.SessionCookieName)
   326  					So(c.Value, ShouldEqual, "deleted")
   327  					So(c.MaxAge, ShouldEqual, -1)
   328  					paths = append(paths, c.Path)
   329  				}
   330  				So(paths, ShouldResemble, []string{internal.UnlimitedCookiePath, internal.LimitedCookiePath})
   331  
   332  				// It also no longer works.
   333  				user, session, err := methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   334  				So(err, ShouldBeNil)
   335  				So(user, ShouldBeNil)
   336  				So(session, ShouldBeNil)
   337  
   338  				// The refresh token was not revoked.
   339  				So(provider.Revoked, ShouldBeNil)
   340  
   341  				// Hitting logout again (resending the cookie) succeeds.
   342  				resp = call(methodV1.logoutHandler, "primary.example.com", parsed, http.Header{
   343  					"Cookie": {fmt.Sprintf("%s=%s", cookieV1.Name, cookieV1.Value)},
   344  				})
   345  				So(resp.StatusCode, ShouldEqual, http.StatusFound)
   346  				So(resp.Header.Get("Location"), ShouldEqual, "/some/dest")
   347  				So(provider.Revoked, ShouldBeNil)
   348  			})
   349  
   350  			Convey("Add additional optional scope works", func() {
   351  				methodV2 := makeMethod([]string{"scope1"}, []string{"scope2", "scope3"})
   352  
   353  				user, session, err := methodV2.Authenticate(ctx, phonyRequest(cookieV1))
   354  				So(err, ShouldBeNil)
   355  				So(user, ShouldResemble, &auth.User{
   356  					Identity: identity.Identity("user:" + provider.UserEmail),
   357  					Email:    provider.UserEmail,
   358  					Name:     provider.UserName,
   359  					Picture:  provider.UserPicture,
   360  				})
   361  				So(session, ShouldNotBeNil)
   362  			})
   363  
   364  			Convey("Promote optional scope works", func() {
   365  				methodV2 := makeMethod([]string{"scope1", "scope2"}, nil)
   366  
   367  				user, session, err := methodV2.Authenticate(ctx, phonyRequest(cookieV1))
   368  				So(err, ShouldBeNil)
   369  				So(user, ShouldResemble, &auth.User{
   370  					Identity: identity.Identity("user:" + provider.UserEmail),
   371  					Email:    provider.UserEmail,
   372  					Name:     provider.UserName,
   373  					Picture:  provider.UserPicture,
   374  				})
   375  				So(session, ShouldNotBeNil)
   376  			})
   377  
   378  			Convey("Add additional required scope invalidates old sessions", func() {
   379  				methodV2 := makeMethod([]string{"scope1", "scope3"}, []string{"scope2"})
   380  
   381  				user, session, err := methodV2.Authenticate(ctx, phonyRequest(cookieV1))
   382  				So(err, ShouldBeNil)
   383  				So(user, ShouldBeNil)
   384  				So(session, ShouldBeNil)
   385  
   386  				// The cookie no longer works with the old method.
   387  				user, session, err = methodV1.Authenticate(ctx, phonyRequest(cookieV1))
   388  				So(err, ShouldBeNil)
   389  				So(user, ShouldBeNil)
   390  				So(session, ShouldBeNil)
   391  			})
   392  
   393  			Convey("Additional scopes are decided during login not callback", func() {
   394  				methodV2 := makeMethod([]string{"scope1"}, []string{"scope2", "scope3"})
   395  				methodV3 := makeMethod([]string{"scope1", "scope3"}, []string{"scope2"})
   396  
   397  				// User hit the login handle in the v1 but the callback is handled by
   398  				// v2.
   399  				callbackRawQueryV1 := performLogin(methodV1)
   400  				cookieV1, _ := performCallback(methodV2, callbackRawQueryV1)
   401  
   402  				// Cookies produced by login requests in methodV1 does not have the
   403  				// added scope.
   404  				user, session, err := methodV3.Authenticate(ctx, phonyRequest(cookieV1))
   405  				So(err, ShouldBeNil)
   406  				So(user, ShouldBeNil)
   407  				So(session, ShouldBeNil)
   408  
   409  				// User hit the login handle in the v2 but the callback is handled by
   410  				// v1.
   411  				callbackRawQueryV2 := performLogin(methodV2)
   412  				cookieV2, _ := performCallback(methodV1, callbackRawQueryV2)
   413  
   414  				// Cookies produced by login requests in methodV2 does have the added
   415  				// scope.
   416  				user, session, err = methodV3.Authenticate(ctx, phonyRequest(cookieV2))
   417  				So(err, ShouldBeNil)
   418  				So(user, ShouldResemble, &auth.User{
   419  					Identity: identity.Identity("user:" + provider.UserEmail),
   420  					Email:    provider.UserEmail,
   421  					Name:     provider.UserName,
   422  					Picture:  provider.UserPicture,
   423  				})
   424  				So(session, ShouldNotBeNil)
   425  			})
   426  		})
   427  
   428  		Convey("LimitCookieExposure cookie works", func() {
   429  			method := makeMethod([]string{"scope1"}, []string{"scope2"})
   430  			method.LimitCookieExposure = true
   431  
   432  			cookie, deleted := performCallback(method, performLogin(method))
   433  
   434  			// Set the cookie on the correct path.
   435  			So(cookie.Path, ShouldEqual, internal.LimitedCookiePath)
   436  			So(cookie.SameSite, ShouldEqual, http.SameSiteStrictMode)
   437  
   438  			// Removed a potentially stale cookie on a different path.
   439  			So(deleted, ShouldHaveLength, 1)
   440  			So(deleted[0].Path, ShouldEqual, internal.UnlimitedCookiePath)
   441  
   442  			user, _, _ := method.Authenticate(ctx, phonyRequest(cookie))
   443  			So(user, ShouldResemble, &auth.User{
   444  				Identity: identity.Identity("user:" + provider.UserEmail),
   445  				Email:    provider.UserEmail,
   446  				Name:     provider.UserName,
   447  				Picture:  provider.UserPicture,
   448  			})
   449  		})
   450  
   451  		Convey("State endpoint works", func() {
   452  			r := router.New()
   453  			r.Use(router.MiddlewareChain{
   454  				func(rc *router.Context, next router.Handler) {
   455  					rc.Request = rc.Request.WithContext(ctx)
   456  					next(rc)
   457  				},
   458  			})
   459  			methodV1.InstallHandlers(r, nil)
   460  
   461  			callState := func(cookie *http.Cookie) (code int, state *auth.StateEndpointResponse) {
   462  				rw := httptest.NewRecorder()
   463  				req := httptest.NewRequest("GET", stateURL, nil)
   464  				if cookie != nil {
   465  					req.Header.Add("Cookie", fmt.Sprintf("%s=%s", cookie.Name, cookie.Value))
   466  				}
   467  				req.Header.Add("Sec-Fetch-Site", "same-origin")
   468  				r.ServeHTTP(rw, req)
   469  				res := rw.Result()
   470  				code = res.StatusCode
   471  				if code == 200 {
   472  					state = &auth.StateEndpointResponse{}
   473  					So(json.NewDecoder(res.Body).Decode(&state), ShouldBeNil)
   474  				}
   475  				return
   476  			}
   477  
   478  			goodCookie, _ := performCallback(methodV1, performLogin(methodV1))
   479  			So(provider.AccessTokensMinted(), ShouldEqual, 1)
   480  			tc.Add(30 * time.Minute)
   481  
   482  			Convey("No cookie", func() {
   483  				code, state := callState(nil)
   484  				So(code, ShouldEqual, 200)
   485  				So(state, ShouldResemble, &auth.StateEndpointResponse{Identity: "anonymous:anonymous"})
   486  			})
   487  
   488  			Convey("Valid cookie", func() {
   489  				code, state := callState(goodCookie)
   490  				So(code, ShouldEqual, 200)
   491  				So(state, ShouldResemble, &auth.StateEndpointResponse{
   492  					Identity:             "user:someone@example.com",
   493  					Email:                "someone@example.com",
   494  					Picture:              "https://example.com/picture",
   495  					AccessToken:          "access_token_1",
   496  					AccessTokenExpiry:    testclock.TestRecentTimeUTC.Add(time.Hour).Unix(),
   497  					AccessTokenExpiresIn: 1800,
   498  					IDToken:              state.IDToken, // checked separately
   499  					IDTokenExpiry:        testclock.TestRecentTimeUTC.Add(time.Hour).Unix(),
   500  					IDTokenExpiresIn:     1800,
   501  				})
   502  				So(state.IDToken, ShouldStartWith, "eyJhbG") // JWT header
   503  
   504  				// Still only 1 token minted overall.
   505  				So(provider.AccessTokensMinted(), ShouldEqual, 1)
   506  			})
   507  
   508  			Convey("Refreshes tokens", func() {
   509  				tc.Add(time.Hour) // make sure existing tokens expire
   510  
   511  				code, state := callState(goodCookie)
   512  				So(code, ShouldEqual, 200)
   513  				So(state, ShouldResemble, &auth.StateEndpointResponse{
   514  					Identity:             "user:someone@example.com",
   515  					Email:                "someone@example.com",
   516  					Picture:              "https://example.com/picture",
   517  					AccessToken:          "access_token_2",
   518  					AccessTokenExpiry:    testclock.TestRecentTimeUTC.Add(2*time.Hour + 30*time.Minute).Unix(),
   519  					AccessTokenExpiresIn: 3600,
   520  					IDToken:              state.IDToken, // checked separately
   521  					IDTokenExpiry:        testclock.TestRecentTimeUTC.Add(2*time.Hour + 30*time.Minute).Unix(),
   522  					IDTokenExpiresIn:     3600,
   523  				})
   524  				So(state.IDToken, ShouldStartWith, "eyJhbG") // JWT header
   525  
   526  				// Minted a new token.
   527  				So(provider.AccessTokensMinted(), ShouldEqual, 2)
   528  			})
   529  		})
   530  	})
   531  }
   532  
   533  ////////////////////////////////////////////////////////////////////////////////
   534  
   535  const (
   536  	fakeSigningKeyID = "signing-key"
   537  	fakeIssuer       = "https://issuer.example.com"
   538  )
   539  
   540  type openIDProviderFake struct {
   541  	ExpectedClientID     string
   542  	ExpectedClientSecret string
   543  	ExpectedRedirectURI  string
   544  
   545  	UserSub     string
   546  	UserEmail   string
   547  	UserName    string
   548  	UserPicture string
   549  
   550  	RefreshToken string
   551  	Revoked      []string
   552  
   553  	TransientErr error
   554  
   555  	c             C
   556  	srv           *httptest.Server
   557  	signer        *signingtest.Signer
   558  	nextAccessTok int64
   559  }
   560  
   561  func (f *openIDProviderFake) Init(ctx context.Context, c C) {
   562  	r := router.New()
   563  	r.GET("/discovery", nil, f.discoveryHandler)
   564  	r.GET("/jwks", nil, f.jwksHandler)
   565  	r.POST("/token", nil, f.tokenHandler)
   566  	r.POST("/revocation", nil, f.revocationHandler)
   567  	f.c = c
   568  	f.srv = httptest.NewServer(r)
   569  	f.signer = signingtest.NewSigner(nil)
   570  }
   571  
   572  func (f *openIDProviderFake) Close() {
   573  	f.srv.Close()
   574  }
   575  
   576  func (f *openIDProviderFake) Host() string {
   577  	return strings.TrimPrefix(f.srv.URL, "http://")
   578  }
   579  
   580  func (f *openIDProviderFake) DiscoveryURL() string {
   581  	return f.srv.URL + "/discovery"
   582  }
   583  
   584  func (f *openIDProviderFake) CallbackRawQuery(q url.Values) string {
   585  	f.c.So(q.Get("client_id"), ShouldEqual, f.ExpectedClientID)
   586  	f.c.So(q.Get("redirect_uri"), ShouldEqual, f.ExpectedRedirectURI)
   587  	f.c.So(q.Get("nonce"), ShouldNotEqual, "")
   588  	f.c.So(q.Get("code_challenge_method"), ShouldEqual, "S256")
   589  	f.c.So(q.Get("code_challenge"), ShouldNotEqual, "")
   590  
   591  	// Remember the nonce and the code verifier challenge just by encoding them
   592  	// in the resulting authorization code. This code is opaque to the ID
   593  	// provider clients. It will be passed back to us in /token handler.
   594  	blob, _ := json.Marshal(map[string]string{
   595  		"nonce":          q.Get("nonce"),
   596  		"code_challenge": q.Get("code_challenge"),
   597  	})
   598  	authCode := base64.RawStdEncoding.EncodeToString(blob)
   599  
   600  	return fmt.Sprintf("code=%s&state=%s", authCode, q.Get("state"))
   601  }
   602  
   603  func (f *openIDProviderFake) AccessTokensMinted() int64 {
   604  	return atomic.LoadInt64(&f.nextAccessTok)
   605  }
   606  
   607  func (f *openIDProviderFake) discoveryHandler(ctx *router.Context) {
   608  	json.NewEncoder(ctx.Writer).Encode(map[string]string{
   609  		"issuer":                 fakeIssuer,
   610  		"authorization_endpoint": f.srv.URL + "/authorization", // not actually called in test
   611  		"token_endpoint":         f.srv.URL + "/token",
   612  		"revocation_endpoint":    f.srv.URL + "/revocation",
   613  		"jwks_uri":               f.srv.URL + "/jwks",
   614  	})
   615  }
   616  
   617  func (f *openIDProviderFake) jwksHandler(ctx *router.Context) {
   618  	keys := jwksForTest(fakeSigningKeyID, &f.signer.KeyForTest().PublicKey)
   619  	json.NewEncoder(ctx.Writer).Encode(keys)
   620  }
   621  
   622  func (f *openIDProviderFake) tokenHandler(ctx *router.Context) {
   623  	if f.TransientErr != nil {
   624  		http.Error(ctx.Writer, f.TransientErr.Error(), 500)
   625  		return
   626  	}
   627  	if !f.checkClient(ctx) {
   628  		return
   629  	}
   630  	if ctx.Request.FormValue("redirect_uri") != f.ExpectedRedirectURI {
   631  		http.Error(ctx.Writer, "bad redirect URI", 400)
   632  		return
   633  	}
   634  
   635  	switch ctx.Request.FormValue("grant_type") {
   636  	case "authorization_code":
   637  		code := ctx.Request.FormValue("code")
   638  		codeVerifier := ctx.Request.FormValue("code_verifier")
   639  
   640  		// Parse the code generated in CallbackRawQuery.
   641  		blob, err := base64.RawStdEncoding.DecodeString(code)
   642  		if err != nil {
   643  			http.Error(ctx.Writer, "bad code base64", 400)
   644  			return
   645  		}
   646  		var encodedParams struct {
   647  			Nonce         string `json:"nonce"`
   648  			CodeChallenge string `json:"code_challenge"`
   649  		}
   650  		if err := json.Unmarshal(blob, &encodedParams); err != nil {
   651  			http.Error(ctx.Writer, "bad code JSON", 400)
   652  			return
   653  		}
   654  
   655  		// Verify the given `code_verifier` matches `code_challenge`.
   656  		if internal.DeriveCodeChallenge(codeVerifier) != encodedParams.CodeChallenge {
   657  			http.Error(ctx.Writer, "bad code_challenge", 400)
   658  			return
   659  		}
   660  
   661  		// All is good! Generate tokens.
   662  		json.NewEncoder(ctx.Writer).Encode(map[string]any{
   663  			"expires_in":    3600,
   664  			"refresh_token": f.RefreshToken,
   665  			"access_token":  f.genAccessToken(),
   666  			"id_token":      f.genIDToken(ctx.Request.Context(), encodedParams.Nonce),
   667  		})
   668  
   669  	case "refresh_token":
   670  		refreshToken := ctx.Request.FormValue("refresh_token")
   671  		if refreshToken != f.RefreshToken {
   672  			http.Error(ctx.Writer, "bad refresh token", 400)
   673  			return
   674  		}
   675  		json.NewEncoder(ctx.Writer).Encode(map[string]any{
   676  			"expires_in":   3600,
   677  			"access_token": f.genAccessToken(),
   678  			"id_token":     f.genIDToken(ctx.Request.Context(), ""),
   679  		})
   680  
   681  	default:
   682  		http.Error(ctx.Writer, "unknown grant_type", 400)
   683  	}
   684  }
   685  
   686  func (f *openIDProviderFake) revocationHandler(ctx *router.Context) {
   687  	if !f.checkClient(ctx) {
   688  		return
   689  	}
   690  	f.Revoked = append(f.Revoked, ctx.Request.FormValue("token"))
   691  }
   692  
   693  func (f *openIDProviderFake) checkClient(ctx *router.Context) bool {
   694  	if ctx.Request.FormValue("client_id") != f.ExpectedClientID {
   695  		http.Error(ctx.Writer, "bad client ID", 400)
   696  		return false
   697  	}
   698  	if ctx.Request.FormValue("client_secret") != f.ExpectedClientSecret {
   699  		http.Error(ctx.Writer, "bad client secret", 400)
   700  		return false
   701  	}
   702  	return true
   703  }
   704  
   705  func (f *openIDProviderFake) genAccessToken() string {
   706  	count := atomic.AddInt64(&f.nextAccessTok, 1)
   707  	return fmt.Sprintf("access_token_%d", count)
   708  }
   709  
   710  func (f *openIDProviderFake) genIDToken(ctx context.Context, nonce string) string {
   711  	body, err := json.Marshal(openid.IDToken{
   712  		Iss:           fakeIssuer,
   713  		EmailVerified: true,
   714  		Sub:           f.UserSub,
   715  		Email:         f.UserEmail,
   716  		Name:          f.UserName,
   717  		Picture:       f.UserPicture,
   718  		Aud:           f.ExpectedClientID,
   719  		Iat:           clock.Now(ctx).Unix(),
   720  		Exp:           clock.Now(ctx).Add(time.Hour).Unix(),
   721  		Nonce:         nonce,
   722  	})
   723  	if err != nil {
   724  		panic(err)
   725  	}
   726  	return jwtForTest(ctx, body, fakeSigningKeyID, f.signer)
   727  }
   728  
   729  ////////////////////////////////////////////////////////////////////////////////
   730  
   731  func jwksForTest(keyID string, pubKey *rsa.PublicKey) *openid.JSONWebKeySetStruct {
   732  	modulus := pubKey.N.Bytes()
   733  	exp := []byte{0, 0, 0, 0}
   734  	binary.BigEndian.PutUint32(exp, uint32(pubKey.E))
   735  	return &openid.JSONWebKeySetStruct{
   736  		Keys: []openid.JSONWebKeyStruct{
   737  			{
   738  				Kty: "RSA",
   739  				Alg: "RS256",
   740  				Use: "sig",
   741  				Kid: keyID,
   742  				N:   base64.RawURLEncoding.EncodeToString(modulus),
   743  				E:   base64.RawURLEncoding.EncodeToString(exp),
   744  			},
   745  		},
   746  	}
   747  }
   748  
   749  func jwtForTest(ctx context.Context, body []byte, keyID string, signer *signingtest.Signer) string {
   750  	b64hdr := base64.RawURLEncoding.EncodeToString([]byte(
   751  		fmt.Sprintf(`{"alg": "RS256","kid": "%s"}`, keyID)))
   752  	b64bdy := base64.RawURLEncoding.EncodeToString(body)
   753  	_, sig, err := signer.SignBytes(ctx, []byte(b64hdr+"."+b64bdy))
   754  	if err != nil {
   755  		panic(err)
   756  	}
   757  	return b64hdr + "." + b64bdy + "." + base64.RawURLEncoding.EncodeToString(sig)
   758  }