go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/auth/auth_test.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package auth
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"net"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"testing"
    24  
    25  	"golang.org/x/oauth2"
    26  
    27  	"go.chromium.org/luci/auth/identity"
    28  	"go.chromium.org/luci/common/errors"
    29  	"go.chromium.org/luci/common/retry/transient"
    30  
    31  	"go.chromium.org/luci/server/auth/authdb"
    32  	"go.chromium.org/luci/server/auth/realms"
    33  	"go.chromium.org/luci/server/auth/service/protocol"
    34  	"go.chromium.org/luci/server/auth/signing"
    35  	"go.chromium.org/luci/server/router"
    36  
    37  	. "github.com/smartystreets/goconvey/convey"
    38  	. "go.chromium.org/luci/common/testing/assertions"
    39  )
    40  
    41  func TestAuthenticate(t *testing.T) {
    42  	t.Parallel()
    43  
    44  	Convey("Happy path", t, func() {
    45  		c := injectTestDB(context.Background(), &fakeDB{
    46  			allowedClientID: "some_client_id",
    47  		})
    48  		auth := Authenticator{
    49  			Methods: []Method{fakeAuthMethod{clientID: "some_client_id"}},
    50  		}
    51  		req := makeRequest()
    52  		req.FakeRemoteAddr = "1.2.3.4"
    53  		c, err := auth.Authenticate(c, req)
    54  		So(err, ShouldBeNil)
    55  
    56  		So(CurrentUser(c), ShouldResemble, &User{
    57  			Identity: "user:abc@example.com",
    58  			Email:    "abc@example.com",
    59  			ClientID: "some_client_id",
    60  		})
    61  
    62  		So(GetState(c).PeerIP().String(), ShouldEqual, "1.2.3.4")
    63  
    64  		url, err := LoginURL(c, "login")
    65  		So(err, ShouldBeNil)
    66  		So(url, ShouldEqual, "http://fake.login.url/login")
    67  
    68  		url, err = LogoutURL(c, "logout")
    69  		So(err, ShouldBeNil)
    70  		So(url, ShouldEqual, "http://fake.logout.url/logout")
    71  
    72  		tok, extra, err := GetState(c).UserCredentials()
    73  		So(err, ShouldBeNil)
    74  		So(tok, ShouldResemble, &oauth2.Token{AccessToken: "token-abc@example.com"})
    75  		So(extra, ShouldHaveLength, 0)
    76  	})
    77  
    78  	Convey("Custom EndUserIP implementation", t, func() {
    79  		req := makeRequest()
    80  		req.FakeHeader.Add("X-Custom-IP", "4.5.6.7")
    81  
    82  		c := injectTestDB(context.Background(), &fakeDB{})
    83  		c = ModifyConfig(c, func(cfg Config) Config {
    84  			cfg.EndUserIP = func(r RequestMetadata) string { return r.Header("X-Custom-IP") }
    85  			return cfg
    86  		})
    87  
    88  		auth := Authenticator{
    89  			Methods: []Method{fakeAuthMethod{email: "zzz@example.com"}},
    90  		}
    91  		c, err := auth.Authenticate(c, req)
    92  		So(err, ShouldBeNil)
    93  		So(GetState(c).PeerIP().String(), ShouldEqual, "4.5.6.7")
    94  	})
    95  
    96  	Convey("No methods given", t, func() {
    97  		c := injectTestDB(context.Background(), &fakeDB{
    98  			allowedClientID: "some_client_id",
    99  		})
   100  		auth := Authenticator{}
   101  		_, err := auth.Authenticate(c, makeRequest())
   102  		So(err, ShouldEqual, ErrNotConfigured)
   103  	})
   104  
   105  	Convey("IsAllowedOAuthClientID on default DB", t, func() {
   106  		c := context.Background()
   107  		auth := Authenticator{
   108  			Methods: []Method{fakeAuthMethod{clientID: "some_client_id"}},
   109  		}
   110  		_, err := auth.Authenticate(c, makeRequest())
   111  		So(err, ShouldErrLike, "the library is not properly configured")
   112  	})
   113  
   114  	Convey("IsAllowedOAuthClientID with invalid client_id", t, func() {
   115  		c := injectTestDB(context.Background(), &fakeDB{
   116  			allowedClientID: "some_client_id",
   117  		})
   118  		c = injectFrontendClientID(c, "frontend_client_id")
   119  		auth := Authenticator{
   120  			Methods: []Method{fakeAuthMethod{clientID: "another_client_id"}},
   121  		}
   122  		_, err := auth.Authenticate(c, makeRequest())
   123  		So(err, ShouldEqual, ErrBadClientID)
   124  	})
   125  
   126  	Convey("IsAllowedOAuthClientID with frontend client_id", t, func() {
   127  		c := injectTestDB(context.Background(), &fakeDB{
   128  			allowedClientID: "some_client_id",
   129  		})
   130  		c = injectFrontendClientID(c, "frontend_client_id")
   131  		auth := Authenticator{
   132  			Methods: []Method{fakeAuthMethod{clientID: "frontend_client_id"}},
   133  		}
   134  		_, err := auth.Authenticate(c, makeRequest())
   135  		So(err, ShouldBeNil) // success!
   136  	})
   137  
   138  	Convey("IP allowlist restriction works", t, func() {
   139  		db, err := authdb.NewSnapshotDB(&protocol.AuthDB{
   140  			IpWhitelistAssignments: []*protocol.AuthIPWhitelistAssignment{
   141  				{
   142  					Identity:    "user:abc@example.com",
   143  					IpWhitelist: "allowlist",
   144  				},
   145  			},
   146  			IpWhitelists: []*protocol.AuthIPWhitelist{
   147  				{
   148  					Name: "allowlist",
   149  					Subnets: []string{
   150  						"1.2.3.4/32",
   151  					},
   152  				},
   153  			},
   154  		}, "http://auth-service", 1234, false)
   155  		So(err, ShouldBeNil)
   156  
   157  		c := injectTestDB(context.Background(), db)
   158  
   159  		Convey("User is using IP allowlist and IP is in the allowlist.", func() {
   160  			auth := Authenticator{
   161  				Methods: []Method{fakeAuthMethod{email: "abc@example.com"}},
   162  			}
   163  			req := makeRequest()
   164  			req.FakeRemoteAddr = "1.2.3.4"
   165  			c, err := auth.Authenticate(c, req)
   166  			So(err, ShouldBeNil)
   167  			So(CurrentIdentity(c), ShouldEqual, identity.Identity("user:abc@example.com"))
   168  		})
   169  
   170  		Convey("User is using IP allowlist and IP is NOT in the allowlist.", func() {
   171  			auth := Authenticator{
   172  				Methods: []Method{fakeAuthMethod{email: "abc@example.com"}},
   173  			}
   174  			req := makeRequest()
   175  			req.FakeRemoteAddr = "1.2.3.5"
   176  			_, err := auth.Authenticate(c, req)
   177  			So(err, ShouldEqual, ErrForbiddenIP)
   178  		})
   179  
   180  		Convey("User is not using IP allowlist.", func() {
   181  			auth := Authenticator{
   182  				Methods: []Method{fakeAuthMethod{email: "def@example.com"}},
   183  			}
   184  			req := makeRequest()
   185  			req.FakeRemoteAddr = "1.2.3.5"
   186  			c, err := auth.Authenticate(c, req)
   187  			So(err, ShouldBeNil)
   188  			So(CurrentIdentity(c), ShouldEqual, identity.Identity("user:def@example.com"))
   189  		})
   190  	})
   191  
   192  	Convey("X-Luci-Project works", t, func() {
   193  		c := injectTestDB(context.Background(), &fakeDB{
   194  			groups: map[string][]identity.Identity{
   195  				InternalServicesGroup: {"user:allowed@example.com"},
   196  			},
   197  		})
   198  
   199  		Convey("Allowed", func() {
   200  			auth := Authenticator{
   201  				Methods: []Method{fakeAuthMethod{email: "allowed@example.com"}},
   202  			}
   203  			req := makeRequest()
   204  			req.FakeHeader.Set(XLUCIProjectHeader, "test-proj")
   205  			c, err := auth.Authenticate(c, req)
   206  			So(err, ShouldBeNil)
   207  			So(CurrentIdentity(c), ShouldEqual, identity.Identity("project:test-proj"))
   208  
   209  			tok, extra, err := GetState(c).UserCredentials()
   210  			So(err, ShouldBeNil)
   211  			So(tok, ShouldResemble, &oauth2.Token{AccessToken: "token-allowed@example.com"})
   212  			So(extra, ShouldResemble, map[string]string{XLUCIProjectHeader: "test-proj"})
   213  		})
   214  
   215  		Convey("Forbidden", func() {
   216  			auth := Authenticator{
   217  				Methods: []Method{fakeAuthMethod{email: "unknown@example.com"}},
   218  			}
   219  			req := makeRequest()
   220  			req.FakeHeader.Set(XLUCIProjectHeader, "test-proj")
   221  			_, err := auth.Authenticate(c, req)
   222  			So(err, ShouldEqual, ErrProjectHeaderForbidden)
   223  		})
   224  
   225  		Convey("Bad project ID", func() {
   226  			auth := Authenticator{
   227  				Methods: []Method{fakeAuthMethod{email: "allowed@example.com"}},
   228  			}
   229  			req := makeRequest()
   230  			req.FakeHeader.Set(XLUCIProjectHeader, "?????")
   231  			_, err := auth.Authenticate(c, req)
   232  			So(err, ShouldErrLike, "bad value")
   233  		})
   234  	})
   235  }
   236  
   237  func TestMiddleware(t *testing.T) {
   238  	t.Parallel()
   239  
   240  	handler := func(c *router.Context) {
   241  		fmt.Fprintf(c.Writer, "%s", CurrentIdentity(c.Request.Context()))
   242  	}
   243  
   244  	call := func(a *Authenticator) *httptest.ResponseRecorder {
   245  		req, err := http.NewRequest("GET", "http://example.com/foo", nil)
   246  		So(err, ShouldBeNil)
   247  		w := httptest.NewRecorder()
   248  		router.RunMiddleware(&router.Context{
   249  			Writer: w,
   250  			Request: req.WithContext(injectTestDB(context.Background(), &fakeDB{
   251  				allowedClientID: "some_client_id",
   252  			})),
   253  		}, router.NewMiddlewareChain(a.GetMiddleware()), handler)
   254  		return w
   255  	}
   256  
   257  	Convey("Happy path", t, func() {
   258  		rr := call(&Authenticator{
   259  			Methods: []Method{fakeAuthMethod{clientID: "some_client_id"}},
   260  		})
   261  		So(rr.Code, ShouldEqual, 200)
   262  		So(rr.Body.String(), ShouldEqual, "user:abc@example.com")
   263  	})
   264  
   265  	Convey("Fatal error", t, func() {
   266  		rr := call(&Authenticator{
   267  			Methods: []Method{fakeAuthMethod{clientID: "another_client_id"}},
   268  		})
   269  		So(rr.Code, ShouldEqual, 403)
   270  		So(rr.Body.String(), ShouldEqual, ErrBadClientID.Error()+"\n")
   271  	})
   272  
   273  	Convey("Transient error", t, func() {
   274  		rr := call(&Authenticator{
   275  			Methods: []Method{fakeAuthMethod{err: errors.New("boo", transient.Tag)}},
   276  		})
   277  		So(rr.Code, ShouldEqual, 500)
   278  		So(rr.Body.String(), ShouldEqual, "Internal Server Error\n")
   279  	})
   280  }
   281  
   282  ///
   283  
   284  type fakeRequest struct {
   285  	FakeRemoteAddr string
   286  	FakeHost       string
   287  	FakeHeader     http.Header
   288  }
   289  
   290  func (r *fakeRequest) Header(key string) string                { return r.FakeHeader.Get(key) }
   291  func (r *fakeRequest) Cookie(key string) (*http.Cookie, error) { return nil, fmt.Errorf("no cookie") }
   292  func (r *fakeRequest) RemoteAddr() string                      { return r.FakeRemoteAddr }
   293  func (r *fakeRequest) Host() string                            { return r.FakeHost }
   294  
   295  func makeRequest() *fakeRequest {
   296  	return &fakeRequest{
   297  		FakeRemoteAddr: "127.0.0.1",
   298  		FakeHost:       "some-url",
   299  		FakeHeader:     map[string][]string{},
   300  	}
   301  }
   302  
   303  ///
   304  
   305  // fakeAuthMethod implements Method.
   306  type fakeAuthMethod struct {
   307  	err      error
   308  	clientID string
   309  	email    string
   310  	observe  func(RequestMetadata)
   311  }
   312  
   313  func (m fakeAuthMethod) Authenticate(_ context.Context, r RequestMetadata) (*User, Session, error) {
   314  	if m.observe != nil {
   315  		m.observe(r)
   316  	}
   317  	if m.err != nil {
   318  		return nil, nil, m.err
   319  	}
   320  	email := m.email
   321  	if email == "" {
   322  		email = "abc@example.com"
   323  	}
   324  	return &User{
   325  		Identity: identity.Identity("user:" + email),
   326  		Email:    email,
   327  		ClientID: m.clientID,
   328  	}, nil, nil
   329  }
   330  
   331  func (m fakeAuthMethod) LoginURL(ctx context.Context, dest string) (string, error) {
   332  	return "http://fake.login.url/" + dest, nil
   333  }
   334  
   335  func (m fakeAuthMethod) LogoutURL(ctx context.Context, dest string) (string, error) {
   336  	return "http://fake.logout.url/" + dest, nil
   337  }
   338  
   339  func (m fakeAuthMethod) GetUserCredentials(context.Context, RequestMetadata) (*oauth2.Token, error) {
   340  	email := m.email
   341  	if email == "" {
   342  		email = "abc@example.com"
   343  	}
   344  	return &oauth2.Token{AccessToken: "token-" + email}, nil
   345  }
   346  
   347  func injectTestDB(ctx context.Context, d authdb.DB) context.Context {
   348  	return ModifyConfig(ctx, func(cfg Config) Config {
   349  		cfg.DBProvider = func(ctx context.Context) (authdb.DB, error) {
   350  			return d, nil
   351  		}
   352  		return cfg
   353  	})
   354  }
   355  
   356  func injectFrontendClientID(ctx context.Context, clientID string) context.Context {
   357  	return ModifyConfig(ctx, func(cfg Config) Config {
   358  		cfg.FrontendClientID = func(context.Context) (string, error) {
   359  			return clientID, nil
   360  		}
   361  		return cfg
   362  	})
   363  }
   364  
   365  ///
   366  
   367  // fakeDB implements DB.
   368  type fakeDB struct {
   369  	allowedClientID string
   370  	internalService string
   371  	authServiceURL  string
   372  	tokenServiceURL string
   373  	groups          map[string][]identity.Identity
   374  	realmData       map[string]*protocol.RealmData
   375  }
   376  
   377  func (db *fakeDB) IsAllowedOAuthClientID(ctx context.Context, email, clientID string) (bool, error) {
   378  	return clientID == db.allowedClientID, nil
   379  }
   380  
   381  func (db *fakeDB) IsInternalService(ctx context.Context, hostname string) (bool, error) {
   382  	return hostname == db.internalService, nil
   383  }
   384  
   385  func (db *fakeDB) IsMember(ctx context.Context, id identity.Identity, groups []string) (bool, error) {
   386  	for _, g := range groups {
   387  		for _, member := range db.groups[g] {
   388  			if id == member {
   389  				return true, nil
   390  			}
   391  		}
   392  	}
   393  	return false, nil
   394  }
   395  
   396  func (db *fakeDB) CheckMembership(ctx context.Context, id identity.Identity, groups []string) ([]string, error) {
   397  	panic("not implemented")
   398  }
   399  
   400  func (db *fakeDB) HasPermission(ctx context.Context, id identity.Identity, perm realms.Permission, realm string, attrs realms.Attrs) (bool, error) {
   401  	return false, errors.New("fakeDB: HasPermission is not implemented")
   402  }
   403  
   404  func (db *fakeDB) QueryRealms(ctx context.Context, id identity.Identity, perm realms.Permission, project string, attrs realms.Attrs) ([]string, error) {
   405  	return nil, errors.New("fakeDB: QueryRealms is not implemented")
   406  }
   407  
   408  func (db *fakeDB) FilterKnownGroups(ctx context.Context, groups []string) ([]string, error) {
   409  	return nil, errors.New("fakeDB: FilterKnownGroups is not implemented")
   410  }
   411  
   412  func (db *fakeDB) GetCertificates(ctx context.Context, id identity.Identity) (*signing.PublicCertificates, error) {
   413  	return nil, errors.New("fakeDB: GetCertificates is not implemented")
   414  }
   415  
   416  func (db *fakeDB) GetAllowlistForIdentity(ctx context.Context, ident identity.Identity) (string, error) {
   417  	return "", nil
   418  }
   419  
   420  func (db *fakeDB) IsAllowedIP(ctx context.Context, ip net.IP, allowlist string) (bool, error) {
   421  	return allowlist == "bots" && ip.String() == "1.2.3.4", nil
   422  }
   423  
   424  func (db *fakeDB) GetAuthServiceURL(ctx context.Context) (string, error) {
   425  	if db.authServiceURL == "" {
   426  		return "", errors.New("fakeDB: GetAuthServiceURL is not configured")
   427  	}
   428  	return db.authServiceURL, nil
   429  }
   430  
   431  func (db *fakeDB) GetTokenServiceURL(ctx context.Context) (string, error) {
   432  	if db.tokenServiceURL == "" {
   433  		return "", errors.New("fakeDB: GetTokenServiceURL is not configured")
   434  	}
   435  	return db.tokenServiceURL, nil
   436  }
   437  
   438  func (db *fakeDB) GetRealmData(ctx context.Context, realm string) (*protocol.RealmData, error) {
   439  	return db.realmData[realm], nil
   440  }