go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/auth/oauth_test.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package auth
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"testing"
    23  	"time"
    24  
    25  	"golang.org/x/oauth2"
    26  
    27  	"go.chromium.org/luci/common/clock"
    28  	"go.chromium.org/luci/common/clock/testclock"
    29  
    30  	"go.chromium.org/luci/server/caching"
    31  	"go.chromium.org/luci/server/caching/cachingtest"
    32  
    33  	. "github.com/smartystreets/goconvey/convey"
    34  	. "go.chromium.org/luci/common/testing/assertions"
    35  )
    36  
    37  const testScope = "https://example.com/scopes/user.email"
    38  
    39  type tokenInfo struct {
    40  	Audience      string `json:"aud"`
    41  	Email         string `json:"email"`
    42  	EmailVerified string `json:"email_verified"`
    43  	Error         string `json:"error_description"`
    44  	ExpiresIn     string `json:"expires_in"`
    45  	Scope         string `json:"scope"`
    46  }
    47  
    48  func TestGoogleOAuth2Method(t *testing.T) {
    49  	t.Parallel()
    50  
    51  	Convey("with mock backend", t, func(c C) {
    52  		ctx := caching.WithEmptyProcessCache(context.Background())
    53  		ctx = cachingtest.WithGlobalCache(ctx, map[string]caching.BlobCache{
    54  			oauthValidationCache.Parameters().GlobalNamespace: cachingtest.NewBlobCache(),
    55  		})
    56  		ctx, tc := testclock.UseTime(ctx, testclock.TestRecentTimeUTC)
    57  
    58  		tc.SetTimerCallback(func(d time.Duration, t clock.Timer) {
    59  			if testclock.HasTags(t, "oauth-tokeninfo-retry") {
    60  				tc.Add(d)
    61  			}
    62  		})
    63  
    64  		goodUser := &User{
    65  			Identity: "user:abc@example.com",
    66  			Email:    "abc@example.com",
    67  			ClientID: "client_id",
    68  		}
    69  
    70  		checks := 0
    71  		info := tokenInfo{
    72  			Audience:      "client_id",
    73  			Email:         "abc@example.com",
    74  			EmailVerified: "true",
    75  			ExpiresIn:     "3600",
    76  			Scope:         testScope + " other stuff",
    77  		}
    78  		status := http.StatusOK
    79  		ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    80  			checks++
    81  			w.WriteHeader(status)
    82  			c.So(json.NewEncoder(w).Encode(&info), ShouldBeNil)
    83  		}))
    84  		defer ts.Close()
    85  
    86  		ctx = ModifyConfig(ctx, func(cfg Config) Config {
    87  			cfg.AnonymousTransport = func(context.Context) http.RoundTripper {
    88  				return http.DefaultTransport
    89  			}
    90  			return cfg
    91  		})
    92  
    93  		call := func(header string) (*User, error) {
    94  			m := GoogleOAuth2Method{
    95  				Scopes:            []string{testScope},
    96  				tokenInfoEndpoint: ts.URL,
    97  			}
    98  			req := makeRequest()
    99  			req.FakeHeader.Set("Authorization", header)
   100  			u, _, err := m.Authenticate(ctx, req)
   101  			return u, err
   102  		}
   103  
   104  		Convey("Works with end users", func() {
   105  			u, err := call("Bearer access_token")
   106  			So(err, ShouldBeNil)
   107  			So(u, ShouldResemble, goodUser)
   108  		})
   109  
   110  		Convey("Works with service accounts", func() {
   111  			info.Email = "something@example.gserviceaccount.com"
   112  			info.Audience = "ignored"
   113  			u, err := call("Bearer access_token")
   114  			So(err, ShouldBeNil)
   115  			So(u, ShouldResemble, &User{
   116  				Identity: "user:something@example.gserviceaccount.com",
   117  				Email:    "something@example.gserviceaccount.com",
   118  			})
   119  		})
   120  
   121  		Convey("Valid tokens are cached", func() {
   122  			u, err := call("Bearer access_token")
   123  			So(err, ShouldBeNil)
   124  			So(u, ShouldResemble, goodUser)
   125  			So(checks, ShouldEqual, 1)
   126  
   127  			// Hit the process cache.
   128  			u, err = call("Bearer access_token")
   129  			So(err, ShouldBeNil)
   130  			So(u, ShouldResemble, goodUser)
   131  
   132  			// Hit the global cache by clearing the local one.
   133  			ctx = caching.WithEmptyProcessCache(ctx)
   134  			u, err = call("Bearer access_token")
   135  			So(err, ShouldBeNil)
   136  			So(u, ShouldResemble, goodUser)
   137  
   138  			// No new calls to the token endpoints.
   139  			So(checks, ShouldEqual, 1)
   140  
   141  			// Advance time until the token expires, but the validation outcome is
   142  			// still cached.
   143  			tc.Add(time.Hour + time.Second)
   144  			status = http.StatusBadRequest
   145  
   146  			// Correctly identified as expired, no new calls to the token endpoint.
   147  			_, err = call("Bearer access_token")
   148  			So(err, ShouldEqual, ErrBadOAuthToken)
   149  			So(checks, ShouldEqual, 1)
   150  
   151  			// Advance time a bit more until the token is evicted from the cache.
   152  			tc.Add(15 * time.Minute)
   153  
   154  			// Correctly identified as expired, via a call to the token endpoint.
   155  			_, err = call("Bearer access_token")
   156  			So(err, ShouldEqual, ErrBadOAuthToken)
   157  			So(checks, ShouldEqual, 2)
   158  		})
   159  
   160  		Convey("Bad tokens are cached", func() {
   161  			status = http.StatusBadRequest
   162  
   163  			_, err := call("Bearer access_token")
   164  			So(err, ShouldEqual, ErrBadOAuthToken)
   165  			So(checks, ShouldEqual, 1)
   166  
   167  			// Hit the process cache.
   168  			_, err = call("Bearer access_token")
   169  			So(err, ShouldEqual, ErrBadOAuthToken)
   170  
   171  			// Hit the global cache by clearing the local one.
   172  			ctx = caching.WithEmptyProcessCache(ctx)
   173  			_, err = call("Bearer access_token")
   174  			So(err, ShouldEqual, ErrBadOAuthToken)
   175  
   176  			// Advance time a little bit, the outcome is still cached.
   177  			tc.Add(5 * time.Minute)
   178  			_, err = call("Bearer access_token")
   179  			So(err, ShouldEqual, ErrBadOAuthToken)
   180  			So(checks, ShouldEqual, 1)
   181  
   182  			// Advance time until the cache entry expire, the token is rechecked.
   183  			tc.Add(15 * time.Minute)
   184  			_, err = call("Bearer access_token")
   185  			So(err, ShouldEqual, ErrBadOAuthToken)
   186  			So(checks, ShouldEqual, 2)
   187  		})
   188  
   189  		Convey("Bad header", func() {
   190  			_, err := call("broken")
   191  			So(err, ShouldErrLike, "oauth: bad Authorization header")
   192  		})
   193  
   194  		Convey("HTTP 500", func() {
   195  			status = http.StatusInternalServerError
   196  			_, err := call("Bearer access_token")
   197  			So(err, ShouldErrLike, "transient error")
   198  		})
   199  
   200  		Convey("Error response", func() {
   201  			status = http.StatusBadRequest
   202  			info.Error = "OMG, error"
   203  			_, err := call("Bearer access_token")
   204  			So(err, ShouldEqual, ErrBadOAuthToken)
   205  		})
   206  
   207  		Convey("No email", func() {
   208  			info.Email = ""
   209  			_, err := call("Bearer access_token")
   210  			So(err, ShouldEqual, ErrBadOAuthToken)
   211  		})
   212  
   213  		Convey("Email not verified", func() {
   214  			info.EmailVerified = "false"
   215  			_, err := call("Bearer access_token")
   216  			So(err, ShouldEqual, ErrBadOAuthToken)
   217  		})
   218  
   219  		Convey("Bad expires_in", func() {
   220  			info.ExpiresIn = "not a number"
   221  			_, err := call("Bearer access_token")
   222  			So(err, ShouldErrLike, "transient error") // see the comment in GetTokenInfo
   223  		})
   224  
   225  		Convey("Zero expires_in", func() {
   226  			info.ExpiresIn = "0"
   227  			_, err := call("Bearer access_token")
   228  			So(err, ShouldEqual, ErrBadOAuthToken)
   229  		})
   230  
   231  		Convey("No audience", func() {
   232  			info.Audience = ""
   233  			_, err := call("Bearer access_token")
   234  			So(err, ShouldEqual, ErrBadOAuthToken)
   235  		})
   236  
   237  		Convey("No scope", func() {
   238  			info.Scope = ""
   239  			_, err := call("Bearer access_token")
   240  			So(err, ShouldEqual, ErrBadOAuthToken)
   241  		})
   242  
   243  		Convey("Bad email", func() {
   244  			info.Email = "@@@@"
   245  			_, err := call("Bearer access_token")
   246  			So(err, ShouldEqual, ErrBadOAuthToken)
   247  		})
   248  
   249  		Convey("Missing required scope", func() {
   250  			info.Scope = "some other scopes"
   251  			_, err := call("Bearer access_token")
   252  			So(err, ShouldEqual, ErrBadOAuthToken)
   253  		})
   254  	})
   255  }
   256  
   257  func TestGetUserCredentials(t *testing.T) {
   258  	t.Parallel()
   259  
   260  	ctx := context.Background()
   261  	m := GoogleOAuth2Method{}
   262  
   263  	call := func(hdr string) (*oauth2.Token, error) {
   264  		req := makeRequest()
   265  		req.FakeHeader.Set("Authorization", hdr)
   266  		return m.GetUserCredentials(ctx, req)
   267  	}
   268  
   269  	Convey("Works", t, func() {
   270  		tok, err := call("Bearer abc.def")
   271  		So(err, ShouldBeNil)
   272  		So(tok, ShouldResemble, &oauth2.Token{
   273  			AccessToken: "abc.def",
   274  			TokenType:   "Bearer",
   275  		})
   276  	})
   277  
   278  	Convey("Normalizes header", t, func() {
   279  		tok, err := call("  bearer    abc.def")
   280  		So(err, ShouldBeNil)
   281  		So(tok, ShouldResemble, &oauth2.Token{
   282  			AccessToken: "abc.def",
   283  			TokenType:   "Bearer",
   284  		})
   285  	})
   286  
   287  	Convey("Bad headers", t, func() {
   288  		_, err := call("")
   289  		So(err, ShouldEqual, ErrBadAuthorizationHeader)
   290  		_, err = call("abc.def")
   291  		So(err, ShouldEqual, ErrBadAuthorizationHeader)
   292  		_, err = call("Basic abc.def")
   293  		So(err, ShouldEqual, ErrBadAuthorizationHeader)
   294  	})
   295  }