go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/auth/authtest/db.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package authtest
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math/rand"
    21  	"net"
    22  	"sync"
    23  
    24  	"go.chromium.org/luci/auth/identity"
    25  	"go.chromium.org/luci/common/data/stringset"
    26  	"go.chromium.org/luci/common/errors"
    27  
    28  	"go.chromium.org/luci/server/auth"
    29  	"go.chromium.org/luci/server/auth/authdb"
    30  	"go.chromium.org/luci/server/auth/realms"
    31  	"go.chromium.org/luci/server/auth/service/protocol"
    32  	"go.chromium.org/luci/server/auth/signing"
    33  )
    34  
    35  // FakeDB implements authdb.DB by mocking membership and permission checks.
    36  //
    37  // Initialize it with a bunch of mocks like:
    38  //
    39  // db := authtest.NewFakeDB(
    40  //
    41  //	authtest.MockMembership("user:a@example.com", "group"),
    42  //	authtest.MockPermission("user:a@example.com", "proj:realm", perm),
    43  //	...
    44  //
    45  // )
    46  //
    47  // The list of mocks can also be extended later via db.AddMocks(...).
    48  type FakeDB struct {
    49  	m         sync.RWMutex
    50  	err       error                              // if not nil, return this error
    51  	perID     map[identity.Identity]*mockedForID // id => groups and perms it has
    52  	ips       map[string]stringset.Set           // IP => allowlists it belongs to
    53  	realmData map[string]*protocol.RealmData     // realm name => data
    54  	groups    stringset.Set                      // groups mentioned by the mocks
    55  }
    56  
    57  var _ authdb.DB = (*FakeDB)(nil)
    58  
    59  // Condition evaluates attributes passed to HasPermission and decides if the
    60  // permission should apply.
    61  //
    62  // Used for mocking conditional bindings.
    63  type Condition func(realms.Attrs) bool
    64  
    65  // RestrictAttribute produces a Condition that check the given attribute has any
    66  // of the given values.
    67  //
    68  // Its logic matches AttributeRestriction condition in the RealmsDB.
    69  func RestrictAttribute(attr string, vals ...string) Condition {
    70  	set := stringset.NewFromSlice(vals...)
    71  	return func(attrs realms.Attrs) bool {
    72  		val, ok := attrs[attr]
    73  		return ok && set.Has(val)
    74  	}
    75  }
    76  
    77  // mockedForID is mocked groups and permissions of some identity.
    78  type mockedForID struct {
    79  	groups stringset.Set // a set of group names
    80  	perms  []mockedPerm
    81  }
    82  
    83  // mockedPerm is a single permission of a single identity.
    84  type mockedPerm struct {
    85  	realm string
    86  	perm  realms.Permission
    87  	cond  Condition
    88  }
    89  
    90  // MockedDatum is a return value of various Mock* constructors.
    91  type MockedDatum struct {
    92  	// apply mutates the db to apply the mock, called under the write lock.
    93  	apply func(db *FakeDB)
    94  }
    95  
    96  // MockMembership modifies db to make IsMember(id, group) == true.
    97  func MockMembership(id identity.Identity, group string) MockedDatum {
    98  	return MockedDatum{
    99  		apply: func(db *FakeDB) {
   100  			db.addGroup(group)
   101  			db.mockedForID(id).groups.Add(group)
   102  		},
   103  	}
   104  }
   105  
   106  // MockGroup adds a group (potentially empty) to the fake DB.
   107  func MockGroup(group string, ids []identity.Identity) MockedDatum {
   108  	return MockedDatum{
   109  		apply: func(db *FakeDB) {
   110  			db.addGroup(group)
   111  			for _, id := range ids {
   112  				db.mockedForID(id).groups.Add(group)
   113  			}
   114  		},
   115  	}
   116  }
   117  
   118  // MockPermission modifies db to make HasPermission(id, realm, perm, …) == true.
   119  //
   120  // Panics if `realm` is not a valid globally scoped realm, i.e. it doesn't look
   121  // like "<project>:<realm>".
   122  //
   123  // Optional `conds` allow mocking conditional bindings by defining a condition
   124  // on realms.Attrs that must evaluate to true to allow this permission. Multiple
   125  // `conds` callbacks are AND'ed together to get the final verdict.
   126  func MockPermission(id identity.Identity, realm string, perm realms.Permission, conds ...Condition) MockedDatum {
   127  	if err := realms.ValidateRealmName(realm, realms.GlobalScope); err != nil {
   128  		panic(err)
   129  	}
   130  	return MockedDatum{
   131  		apply: func(db *FakeDB) {
   132  			perID := db.mockedForID(id)
   133  			perID.perms = append(perID.perms, mockedPerm{
   134  				realm: realm,
   135  				perm:  perm,
   136  				cond: func(attrs realms.Attrs) bool {
   137  					for _, cond := range conds {
   138  						if !cond(attrs) {
   139  							return false
   140  						}
   141  					}
   142  					return true
   143  				},
   144  			})
   145  		},
   146  	}
   147  }
   148  
   149  // MockRealmData modifies what db's GetRealmData returns.
   150  //
   151  // Panics if `realm` is not a valid globally scoped realm, i.e. it doesn't look
   152  // like "<project>:<realm>".
   153  func MockRealmData(realm string, data *protocol.RealmData) MockedDatum {
   154  	if err := realms.ValidateRealmName(realm, realms.GlobalScope); err != nil {
   155  		panic(err)
   156  	}
   157  	return MockedDatum{
   158  		apply: func(db *FakeDB) {
   159  			if db.realmData == nil {
   160  				db.realmData = make(map[string]*protocol.RealmData, 1)
   161  			}
   162  			db.realmData[realm] = data
   163  		},
   164  	}
   165  }
   166  
   167  // MockIPAllowlist modifies db to make IsAllowedIP(ip, allowlist) == true.
   168  //
   169  // Panics if `ip` is not a valid IP address.
   170  func MockIPAllowlist(ip, allowlist string) MockedDatum {
   171  	if net.ParseIP(ip) == nil {
   172  		panic(fmt.Sprintf("%q is not a valid IP address", ip))
   173  	}
   174  	return MockedDatum{
   175  		apply: func(db *FakeDB) {
   176  			l, ok := db.ips[ip]
   177  			if !ok {
   178  				l = stringset.New(1)
   179  				if db.ips == nil {
   180  					db.ips = make(map[string]stringset.Set, 1)
   181  				}
   182  				db.ips[ip] = l
   183  			}
   184  			l.Add(allowlist)
   185  		},
   186  	}
   187  }
   188  
   189  // MockError modifies db to make its methods return this error.
   190  //
   191  // `err` may be nil, in which case the previously mocked error is removed.
   192  func MockError(err error) MockedDatum {
   193  	return MockedDatum{
   194  		apply: func(db *FakeDB) { db.err = err },
   195  	}
   196  }
   197  
   198  // NewFakeDB creates a FakeDB populated with the given mocks.
   199  //
   200  // Construct mocks using MockMembership, MockPermission, MockIPAllowlist and
   201  // MockError functions.
   202  func NewFakeDB(mocks ...MockedDatum) *FakeDB {
   203  	db := &FakeDB{}
   204  	db.AddMocks(mocks...)
   205  	return db
   206  }
   207  
   208  // AddMocks applies a bunch of mocks to the state in the db.
   209  func (db *FakeDB) AddMocks(mocks ...MockedDatum) {
   210  	db.m.Lock()
   211  	defer db.m.Unlock()
   212  	for _, m := range mocks {
   213  		m.apply(db)
   214  	}
   215  }
   216  
   217  // Use installs the fake db into the context.
   218  //
   219  // Note that if you use auth.WithState(ctx, &authtest.FakeState{...}), you don't
   220  // need this method. Modify FakeDB in the FakeState instead. See its doc for
   221  // some examples.
   222  func (db *FakeDB) Use(ctx context.Context) context.Context {
   223  	return auth.ModifyConfig(ctx, func(cfg auth.Config) auth.Config {
   224  		cfg.DBProvider = func(context.Context) (authdb.DB, error) {
   225  			return db, nil
   226  		}
   227  		return cfg
   228  	})
   229  }
   230  
   231  // IsMember is part of authdb.DB interface.
   232  func (db *FakeDB) IsMember(ctx context.Context, id identity.Identity, groups []string) (bool, error) {
   233  	hits, err := db.CheckMembership(ctx, id, groups)
   234  	if err != nil {
   235  		return false, err
   236  	}
   237  	return len(hits) > 0, nil
   238  }
   239  
   240  // CheckMembership is part of authdb.DB interface.
   241  func (db *FakeDB) CheckMembership(ctx context.Context, id identity.Identity, groups []string) (out []string, err error) {
   242  	db.m.RLock()
   243  	defer db.m.RUnlock()
   244  
   245  	if db.err != nil {
   246  		return nil, db.err
   247  	}
   248  
   249  	if mocked := db.perID[id]; mocked != nil {
   250  		for _, group := range groups {
   251  			if mocked.groups.Has(group) {
   252  				out = append(out, group)
   253  			}
   254  		}
   255  	}
   256  
   257  	return
   258  }
   259  
   260  // HasPermission is part of authdb.DB interface.
   261  func (db *FakeDB) HasPermission(ctx context.Context, id identity.Identity, perm realms.Permission, realm string, attrs realms.Attrs) (bool, error) {
   262  	// This flips a flag forbidding registration of new permissions. Presumably
   263  	// this should help catching "dynamic" permission registration in tests,
   264  	// before it panics in production.
   265  	realms.ForbidPermissionChanges()
   266  
   267  	db.m.RLock()
   268  	defer db.m.RUnlock()
   269  
   270  	if db.err != nil {
   271  		return false, db.err
   272  	}
   273  
   274  	if mocked := db.perID[id]; mocked != nil {
   275  		for _, mockedPerm := range mocked.perms {
   276  			if mockedPerm.realm == realm && mockedPerm.perm == perm && mockedPerm.cond(attrs) {
   277  				return true, nil
   278  			}
   279  		}
   280  	}
   281  
   282  	return false, nil
   283  }
   284  
   285  // QueryRealms is part of authdb.DB interface.
   286  func (db *FakeDB) QueryRealms(ctx context.Context, id identity.Identity, perm realms.Permission, project string, attrs realms.Attrs) ([]string, error) {
   287  	// This implicitly flips a flag forbidding registration of new permissions.
   288  	// Presumably this should help catching "dynamic" permission registration
   289  	// in tests, before it panics in production. We also need the result to check
   290  	// UsedInQueryRealms flag.
   291  	flags := realms.RegisteredPermissions()
   292  
   293  	db.m.RLock()
   294  	defer db.m.RUnlock()
   295  
   296  	if db.err != nil {
   297  		return nil, db.err
   298  	}
   299  
   300  	if project != "" {
   301  		if err := realms.ValidateProjectName(project); err != nil {
   302  			return nil, err
   303  		}
   304  	}
   305  
   306  	if flags[perm]&realms.UsedInQueryRealms == 0 {
   307  		return nil, errors.Reason("permission %s cannot be used in QueryRealms: it was not flagged with UsedInQueryRealms flag", perm).Err()
   308  	}
   309  
   310  	var out []string
   311  	if mocked := db.perID[id]; mocked != nil {
   312  		for _, mockedPerm := range mocked.perms {
   313  			if mockedPerm.perm == perm && mockedPerm.cond(attrs) {
   314  				if realmProj, _ := realms.Split(mockedPerm.realm); project == "" || project == realmProj {
   315  					out = append(out, mockedPerm.realm)
   316  				}
   317  			}
   318  		}
   319  	}
   320  
   321  	// The result in production in inherently unordered, since ordering it takes
   322  	// time and in many applications the order doesn't really matter, so always
   323  	// sorting it is wasteful. Simulate this behavior in tests too. If callers of
   324  	// QueryRealms want the result ordered, they should sort it themselves.
   325  	rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
   326  
   327  	return out, nil
   328  }
   329  
   330  // FilterKnownGroups is part of authdb.DB interface.
   331  func (db *FakeDB) FilterKnownGroups(ctx context.Context, groups []string) ([]string, error) {
   332  	db.m.RLock()
   333  	defer db.m.RUnlock()
   334  
   335  	if db.err != nil {
   336  		return nil, db.err
   337  	}
   338  
   339  	var filtered []string
   340  	for _, gr := range groups {
   341  		if db.groups.Has(gr) {
   342  			filtered = append(filtered, gr)
   343  		}
   344  	}
   345  	return filtered, nil
   346  }
   347  
   348  // IsAllowedOAuthClientID is part of authdb.DB interface.
   349  func (db *FakeDB) IsAllowedOAuthClientID(ctx context.Context, email, clientID string) (bool, error) {
   350  	return true, nil
   351  }
   352  
   353  // IsInternalService is part of authdb.DB interface.
   354  func (db *FakeDB) IsInternalService(ctx context.Context, hostname string) (bool, error) {
   355  	return false, nil
   356  }
   357  
   358  // GetCertificates is part of authdb.DB interface.
   359  func (db *FakeDB) GetCertificates(ctx context.Context, id identity.Identity) (*signing.PublicCertificates, error) {
   360  	return nil, fmt.Errorf("GetCertificates is not implemented by FakeDB")
   361  }
   362  
   363  // GetAllowlistForIdentity is part of authdb.DB interface.
   364  func (db *FakeDB) GetAllowlistForIdentity(ctx context.Context, ident identity.Identity) (string, error) {
   365  	return "", nil
   366  }
   367  
   368  // IsAllowedIP is part of authdb.DB interface.
   369  func (db *FakeDB) IsAllowedIP(ctx context.Context, ip net.IP, allowlist string) (bool, error) {
   370  	db.m.RLock()
   371  	defer db.m.RUnlock()
   372  	if db.err != nil {
   373  		return false, db.err
   374  	}
   375  	return db.ips[ip.String()].Has(allowlist), nil
   376  }
   377  
   378  // GetAuthServiceURL is part of authdb.DB interface.
   379  func (db *FakeDB) GetAuthServiceURL(ctx context.Context) (string, error) {
   380  	return "", fmt.Errorf("GetAuthServiceURL is not implemented by FakeDB")
   381  }
   382  
   383  // GetTokenServiceURL is part of authdb.DB interface.
   384  func (db *FakeDB) GetTokenServiceURL(ctx context.Context) (string, error) {
   385  	return "", fmt.Errorf("GetTokenServiceURL is not implemented by FakeDB")
   386  }
   387  
   388  // GetRealmData is part of authdb.DB interface.
   389  func (db *FakeDB) GetRealmData(ctx context.Context, realm string) (*protocol.RealmData, error) {
   390  	db.m.RLock()
   391  	defer db.m.RUnlock()
   392  	return db.realmData[realm], nil
   393  }
   394  
   395  // addGroup adds a group to the list of known groups.
   396  func (db *FakeDB) addGroup(group string) {
   397  	if db.groups == nil {
   398  		db.groups = stringset.New(1)
   399  	}
   400  	db.groups.Add(group)
   401  }
   402  
   403  // mockedForID returns db.perID[id], initializing it if necessary.
   404  //
   405  // Called under the write lock.
   406  func (db *FakeDB) mockedForID(id identity.Identity) *mockedForID {
   407  	m, ok := db.perID[id]
   408  	if !ok {
   409  		m = &mockedForID{groups: stringset.New(0)}
   410  		if db.perID == nil {
   411  			db.perID = make(map[identity.Identity]*mockedForID, 1)
   412  		}
   413  		db.perID[id] = m
   414  	}
   415  	return m
   416  }