github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/state/bakerystorage/rootkeys_test.go (about)

     1  // Copyright 2014-2022 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package bakerystorage
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/go-macaroon-bakery/macaroon-bakery/v3/bakery"
    12  	"github.com/go-macaroon-bakery/macaroon-bakery/v3/bakery/dbrootkeystore"
    13  	"github.com/juju/mgo/v3"
    14  	mgotesting "github.com/juju/mgo/v3/testing"
    15  	gc "gopkg.in/check.v1"
    16  
    17  	"github.com/juju/juju/testing"
    18  )
    19  
    20  type RootKeySuite struct {
    21  	testing.BaseSuite
    22  	mgotesting.MgoSuite
    23  }
    24  
    25  var _ = gc.Suite(&RootKeySuite{})
    26  
    27  func (s *RootKeySuite) SetUpSuite(c *gc.C) {
    28  	s.MgoSuite.SetUpSuite(c)
    29  	s.BaseSuite.SetUpSuite(c)
    30  }
    31  
    32  func (s *RootKeySuite) TearDownSuite(c *gc.C) {
    33  	s.BaseSuite.TearDownSuite(c)
    34  	s.MgoSuite.TearDownSuite(c)
    35  }
    36  
    37  func (s *RootKeySuite) SetUpTest(c *gc.C) {
    38  	s.MgoSuite.SetUpTest(c)
    39  	s.BaseSuite.SetUpTest(c)
    40  }
    41  
    42  func (s *RootKeySuite) TearDownTest(c *gc.C) {
    43  	s.BaseSuite.TearDownTest(c)
    44  	s.MgoSuite.TearDownTest(c)
    45  }
    46  
    47  var epoch = time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC)
    48  
    49  var isValidWithPolicyTests = []struct {
    50  	about  string
    51  	policy Policy
    52  	now    time.Time
    53  	key    dbrootkeystore.RootKey
    54  	expect bool
    55  }{{
    56  	about: "success",
    57  	policy: Policy{
    58  		GenerateInterval: 2 * time.Minute,
    59  		ExpiryDuration:   3 * time.Minute,
    60  	},
    61  	now: epoch.Add(20 * time.Minute),
    62  	key: dbrootkeystore.RootKey{
    63  		Created: epoch.Add(19 * time.Minute),
    64  		Expires: epoch.Add(24 * time.Minute),
    65  		Id:      []byte("id"),
    66  		RootKey: []byte("key"),
    67  	},
    68  	expect: true,
    69  }, {
    70  	about: "empty root key",
    71  	policy: Policy{
    72  		GenerateInterval: 2 * time.Minute,
    73  		ExpiryDuration:   3 * time.Minute,
    74  	},
    75  	now:    epoch.Add(20 * time.Minute),
    76  	key:    dbrootkeystore.RootKey{},
    77  	expect: false,
    78  }, {
    79  	about: "created too early",
    80  	policy: Policy{
    81  		GenerateInterval: 2 * time.Minute,
    82  		ExpiryDuration:   3 * time.Minute,
    83  	},
    84  	now: epoch.Add(20 * time.Minute),
    85  	key: dbrootkeystore.RootKey{
    86  		Created: epoch.Add(18*time.Minute - time.Millisecond),
    87  		Expires: epoch.Add(24 * time.Minute),
    88  		Id:      []byte("id"),
    89  		RootKey: []byte("key"),
    90  	},
    91  	expect: false,
    92  }, {
    93  	about: "expires too early",
    94  	policy: Policy{
    95  		GenerateInterval: 2 * time.Minute,
    96  		ExpiryDuration:   3 * time.Minute,
    97  	},
    98  	now: epoch.Add(20 * time.Minute),
    99  	key: dbrootkeystore.RootKey{
   100  		Created: epoch.Add(19 * time.Minute),
   101  		Expires: epoch.Add(21 * time.Minute),
   102  		Id:      []byte("id"),
   103  		RootKey: []byte("key"),
   104  	},
   105  	expect: false,
   106  }, {
   107  	about: "expires too late",
   108  	policy: Policy{
   109  		GenerateInterval: 2 * time.Minute,
   110  		ExpiryDuration:   3 * time.Minute,
   111  	},
   112  	now: epoch.Add(20 * time.Minute),
   113  	key: dbrootkeystore.RootKey{
   114  		Created: epoch.Add(19 * time.Minute),
   115  		Expires: epoch.Add(25*time.Minute + time.Millisecond),
   116  		Id:      []byte("id"),
   117  		RootKey: []byte("key"),
   118  	},
   119  	expect: false,
   120  }}
   121  
   122  func (s *RootKeySuite) TestIsValidWithPolicy(c *gc.C) {
   123  	var now time.Time
   124  	s.PatchValue(&clock, clockVal(&now))
   125  	for i, test := range isValidWithPolicyTests {
   126  		c.Logf("test %d: %v", i, test.about)
   127  		c.Assert(test.key.IsValidWithPolicy(dbrootkeystore.Policy(test.policy), test.now), gc.Equals, test.expect)
   128  	}
   129  }
   130  
   131  func (s *RootKeySuite) TestRootKeyUsesKeysValidWithPolicy(c *gc.C) {
   132  	// We re-use the TestIsValidWithPolicy tests so that we
   133  	// know that the mongo logic uses the same behaviour.
   134  	var now time.Time
   135  	s.PatchValue(&clock, clockVal(&now))
   136  	for _, test := range isValidWithPolicyTests {
   137  		if test.key.RootKey == nil {
   138  			// We don't store empty root keys in the database.
   139  			c.Log("skipping test with empty root key")
   140  			continue
   141  		}
   142  		coll := s.testColl(c)
   143  		// Prime the collection with the root key document.
   144  		err := coll.Insert(test.key)
   145  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   146  
   147  		store := NewRootKeys(10).NewStore(coll, test.policy)
   148  		now = test.now
   149  		key, id, err := store.RootKey(context.Background())
   150  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   151  		if test.expect {
   152  			c.Assert(string(id), gc.Equals, "id", gc.Commentf(test.about))
   153  			c.Assert(string(key), gc.Equals, "key", gc.Commentf(test.about))
   154  		} else {
   155  			// If it didn't match then RootKey will have
   156  			// generated a new key.
   157  			c.Assert(key, gc.HasLen, 24, gc.Commentf(test.about))
   158  			c.Assert(id, gc.HasLen, 32, gc.Commentf(test.about))
   159  		}
   160  		err = coll.DropCollection()
   161  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   162  	}
   163  }
   164  
   165  func (s *RootKeySuite) TestRootKey(c *gc.C) {
   166  	now := epoch
   167  	s.PatchValue(&clock, clockVal(&now))
   168  	coll := s.testColl(c)
   169  
   170  	store := NewRootKeys(10).NewStore(coll, Policy{
   171  		GenerateInterval: 2 * time.Minute,
   172  		ExpiryDuration:   5 * time.Minute,
   173  	})
   174  	key, id, err := store.RootKey(context.Background())
   175  	c.Assert(err, gc.IsNil)
   176  	c.Assert(key, gc.HasLen, 24)
   177  	c.Assert(id, gc.HasLen, 32)
   178  
   179  	// If we get a key within the generate interval, we should
   180  	// get the same one.
   181  	now = epoch.Add(time.Minute)
   182  	key1, id1, err := store.RootKey(context.Background())
   183  	c.Assert(err, gc.IsNil)
   184  	c.Assert(key1, gc.DeepEquals, key)
   185  	c.Assert(id1, gc.DeepEquals, id)
   186  
   187  	// A different store instance should get the same root key.
   188  	store1 := NewRootKeys(10).NewStore(coll, Policy{
   189  		GenerateInterval: 2 * time.Minute,
   190  		ExpiryDuration:   5 * time.Minute,
   191  	})
   192  	key1, id1, err = store1.RootKey(context.Background())
   193  	c.Assert(err, gc.IsNil)
   194  	c.Assert(key1, gc.DeepEquals, key)
   195  	c.Assert(id1, gc.DeepEquals, id)
   196  
   197  	// After the generation interval has passed, we should generate a new key.
   198  	now = epoch.Add(2*time.Minute + time.Second)
   199  	key1, id1, err = store.RootKey(context.Background())
   200  	c.Assert(err, gc.IsNil)
   201  	c.Assert(key, gc.HasLen, 24)
   202  	c.Assert(id, gc.HasLen, 32)
   203  	c.Assert(key1, gc.Not(gc.DeepEquals), key)
   204  	c.Assert(id1, gc.Not(gc.DeepEquals), id)
   205  
   206  	// The other store should pick it up too.
   207  	key2, id2, err := store1.RootKey(context.Background())
   208  	c.Assert(err, gc.IsNil)
   209  	c.Assert(key2, gc.DeepEquals, key1)
   210  	c.Assert(id2, gc.DeepEquals, id1)
   211  }
   212  
   213  func (s *RootKeySuite) TestRootKeyDefaultGenerateInterval(c *gc.C) {
   214  	now := epoch
   215  	s.PatchValue(&clock, clockVal(&now))
   216  	coll := s.testColl(c)
   217  	store := NewRootKeys(10).NewStore(coll, Policy{
   218  		ExpiryDuration: 5 * time.Minute,
   219  	})
   220  	key, id, err := store.RootKey(context.Background())
   221  	c.Assert(err, gc.IsNil)
   222  
   223  	now = epoch.Add(5 * time.Minute)
   224  	key1, id1, err := store.RootKey(context.Background())
   225  	c.Assert(err, gc.IsNil)
   226  	c.Assert(key1, gc.DeepEquals, key)
   227  	c.Assert(id1, gc.DeepEquals, id)
   228  
   229  	now = epoch.Add(5*time.Minute + time.Millisecond)
   230  	key1, id1, err = store.RootKey(context.Background())
   231  	c.Assert(err, gc.IsNil)
   232  	c.Assert(string(key1), gc.Not(gc.Equals), string(key))
   233  	c.Assert(string(id1), gc.Not(gc.Equals), string(id))
   234  }
   235  
   236  var preferredRootKeyTests = []struct {
   237  	about    string
   238  	now      time.Time
   239  	keys     []dbrootkeystore.RootKey
   240  	policy   Policy
   241  	expectId []byte
   242  }{{
   243  	about: "latest creation time is preferred",
   244  	now:   epoch.Add(5 * time.Minute),
   245  	keys: []dbrootkeystore.RootKey{{
   246  		Created: epoch.Add(4 * time.Minute),
   247  		Expires: epoch.Add(15 * time.Minute),
   248  		Id:      []byte("id0"),
   249  		RootKey: []byte("key0"),
   250  	}, {
   251  		Created: epoch.Add(5*time.Minute + 30*time.Second),
   252  		Expires: epoch.Add(16 * time.Minute),
   253  		Id:      []byte("id1"),
   254  		RootKey: []byte("key1"),
   255  	}, {
   256  		Created: epoch.Add(5 * time.Minute),
   257  		Expires: epoch.Add(16 * time.Minute),
   258  		Id:      []byte("id2"),
   259  		RootKey: []byte("key2"),
   260  	}},
   261  	policy: Policy{
   262  		GenerateInterval: 5 * time.Minute,
   263  		ExpiryDuration:   7 * time.Minute,
   264  	},
   265  	expectId: []byte("id1"),
   266  }, {
   267  	about: "ineligible keys are exluded",
   268  	now:   epoch.Add(5 * time.Minute),
   269  	keys: []dbrootkeystore.RootKey{{
   270  		Created: epoch.Add(4 * time.Minute),
   271  		Expires: epoch.Add(15 * time.Minute),
   272  		Id:      []byte("id0"),
   273  		RootKey: []byte("key0"),
   274  	}, {
   275  		Created: epoch.Add(5 * time.Minute),
   276  		Expires: epoch.Add(16*time.Minute + 30*time.Second),
   277  		Id:      []byte("id1"),
   278  		RootKey: []byte("key1"),
   279  	}, {
   280  		Created: epoch.Add(6 * time.Minute),
   281  		Expires: epoch.Add(time.Hour),
   282  		Id:      []byte("id2"),
   283  		RootKey: []byte("key2"),
   284  	}},
   285  	policy: Policy{
   286  		GenerateInterval: 5 * time.Minute,
   287  		ExpiryDuration:   7 * time.Minute,
   288  	},
   289  	expectId: []byte("id1"),
   290  }}
   291  
   292  func (s *RootKeySuite) TestPreferredRootKeyFromDatabase(c *gc.C) {
   293  	var now time.Time
   294  	s.PatchValue(&clock, clockVal(&now))
   295  	for _, test := range preferredRootKeyTests {
   296  		coll := s.testColl(c)
   297  		for _, key := range test.keys {
   298  			err := coll.Insert(key)
   299  			c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   300  		}
   301  		store := NewRootKeys(10).NewStore(coll, test.policy)
   302  		now = test.now
   303  		_, id, err := store.RootKey(context.Background())
   304  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   305  		c.Assert(id, gc.DeepEquals, test.expectId, gc.Commentf(test.about))
   306  		err = coll.DropCollection()
   307  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   308  	}
   309  }
   310  
   311  func (s *RootKeySuite) TestPreferredRootKeyFromCache(c *gc.C) {
   312  	var now time.Time
   313  	s.PatchValue(&clock, clockVal(&now))
   314  	for _, test := range preferredRootKeyTests {
   315  		coll := s.testColl(c)
   316  		for _, key := range test.keys {
   317  			err := coll.Insert(key)
   318  			c.Assert(err, gc.IsNil)
   319  		}
   320  		store := NewRootKeys(10).NewStore(coll, test.policy)
   321  		// Ensure that all the keys are in cache by getting all of them.
   322  		for _, key := range test.keys {
   323  			got, err := store.Get(context.Background(), key.Id)
   324  			c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   325  			c.Assert(got, gc.DeepEquals, key.RootKey, gc.Commentf(test.about))
   326  		}
   327  		// Remove all the keys from the collection so that
   328  		// we know we must be acquiring them from the cache.
   329  		_, err := coll.RemoveAll(nil)
   330  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   331  
   332  		// Test that RootKey returns the expected key.
   333  		now = test.now
   334  		_, id, err := store.RootKey(context.Background())
   335  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   336  		c.Assert(id, gc.DeepEquals, test.expectId, gc.Commentf(test.about))
   337  		err = coll.DropCollection()
   338  		c.Assert(err, gc.IsNil, gc.Commentf(test.about))
   339  	}
   340  }
   341  
   342  func (s *RootKeySuite) TestGet(c *gc.C) {
   343  	now := epoch
   344  	s.PatchValue(&clock, clockVal(&now))
   345  
   346  	coll := s.testColl(c)
   347  	store := NewRootKeys(5).NewStore(coll, Policy{
   348  		GenerateInterval: 1 * time.Minute,
   349  		ExpiryDuration:   30 * time.Minute,
   350  	})
   351  	type idKey struct {
   352  		id  string
   353  		key []byte
   354  	}
   355  	var keys []idKey
   356  	keyIds := make(map[string]bool)
   357  	for i := 0; i < 20; i++ {
   358  		key, id, err := store.RootKey(context.Background())
   359  		c.Assert(err, gc.IsNil)
   360  		c.Assert(keyIds[string(id)], gc.Equals, false)
   361  		keys = append(keys, idKey{string(id), key})
   362  		now = now.Add(time.Minute + time.Second)
   363  	}
   364  	for i, k := range keys {
   365  		key, err := store.Get(context.Background(), []byte(k.id))
   366  		c.Assert(err, gc.IsNil, gc.Commentf("key %d (%s)", i, k.id))
   367  		c.Assert(key, gc.DeepEquals, k.key, gc.Commentf("key %d (%s)", i, k.id))
   368  	}
   369  	// Check that the keys are cached.
   370  	//
   371  	// Since the cache size is 5, the most recent 5 items will be in
   372  	// the primary cache; the 5 items before that will be in the old
   373  	// cache and nothing else will be cached.
   374  	//
   375  	// The first time we fetch an item from the old cache, a new
   376  	// primary cache will be allocated, all existing items in the
   377  	// old cache except that item will be evicted, and all items in
   378  	// the current primary cache moved to the old cache.
   379  	//
   380  	// The upshot of that is that all but the first 6 calls to Get
   381  	// should result in a database fetch.
   382  
   383  	var fetched []string
   384  	s.PatchValue(&mgoCollectionFindId, func(coll *mgo.Collection, id interface{}) *mgo.Query {
   385  		fetched = append(fetched, string(id.([]byte)))
   386  		return coll.FindId(id)
   387  	})
   388  	c.Logf("testing cache")
   389  
   390  	for i := len(keys) - 1; i >= 0; i-- {
   391  		k := keys[i]
   392  		key, err := store.Get(context.Background(), []byte(k.id))
   393  		c.Assert(err, gc.IsNil)
   394  		c.Assert(err, gc.IsNil, gc.Commentf("key %d (%s)", i, k.id))
   395  		c.Assert(key, gc.DeepEquals, k.key, gc.Commentf("key %d (%s)", i, k.id))
   396  	}
   397  	c.Assert(len(fetched), gc.Equals, len(keys)-6)
   398  	for i, id := range fetched {
   399  		c.Assert(id, gc.Equals, keys[len(keys)-6-i-1].id)
   400  	}
   401  }
   402  
   403  func (s *RootKeySuite) TestGetCachesMisses(c *gc.C) {
   404  	coll := s.testColl(c)
   405  	store := NewRootKeys(5).NewStore(coll, Policy{
   406  		GenerateInterval: 1 * time.Minute,
   407  		ExpiryDuration:   30 * time.Minute,
   408  	})
   409  	var fetched []string
   410  	s.PatchValue(&mgoCollectionFindId, func(coll *mgo.Collection, id interface{}) *mgo.Query {
   411  		fetched = append(fetched, fmt.Sprintf("%#v", id))
   412  		return coll.FindId(id)
   413  	})
   414  	key, err := store.Get(context.Background(), []byte("foo"))
   415  	c.Assert(err, gc.Equals, bakery.ErrNotFound)
   416  	c.Assert(key, gc.IsNil)
   417  	// This should check twice first using a []byte second using a string
   418  	c.Assert(fetched, gc.DeepEquals, []string{fmt.Sprintf("%#v", []byte("foo")), fmt.Sprintf("%#v", "foo")})
   419  	fetched = nil
   420  
   421  	key, err = store.Get(context.Background(), []byte("foo"))
   422  	c.Assert(err, gc.Equals, bakery.ErrNotFound)
   423  	c.Assert(key, gc.IsNil)
   424  	c.Assert(fetched, gc.IsNil)
   425  }
   426  
   427  func (s *RootKeySuite) TestGetExpiredItemFromCache(c *gc.C) {
   428  	now := epoch
   429  	s.PatchValue(&clock, clockVal(&now))
   430  	coll := s.testColl(c)
   431  	store := NewRootKeys(10).NewStore(coll, Policy{
   432  		ExpiryDuration: 5 * time.Minute,
   433  	})
   434  	_, id, err := store.RootKey(context.Background())
   435  	c.Assert(err, gc.IsNil)
   436  
   437  	s.PatchValue(&mgoCollectionFindId, func(*mgo.Collection, interface{}) *mgo.Query {
   438  		c.Errorf("FindId unexpectedly called")
   439  		return nil
   440  	})
   441  
   442  	now = epoch.Add(15 * time.Minute)
   443  
   444  	_, err = store.Get(context.Background(), id)
   445  	c.Assert(err, gc.Equals, bakery.ErrNotFound)
   446  }
   447  
   448  func (s *RootKeySuite) TestEnsureIndex(c *gc.C) {
   449  	keys := NewRootKeys(5)
   450  	coll := s.testColl(c)
   451  	err := keys.EnsureIndex(coll)
   452  	c.Assert(err, gc.IsNil)
   453  
   454  	// This code can take up to 60s to run; there's no way
   455  	// to force it to run more quickly, but it provides reassurance
   456  	// that the code actually works.
   457  	// Reenable the rest of this test if concerned about index behaviour.
   458  
   459  	c.Skip("test runs too slowly")
   460  
   461  	_, id1, err := keys.NewStore(coll, Policy{
   462  		ExpiryDuration: 100 * time.Millisecond,
   463  	}).RootKey(context.Background())
   464  
   465  	c.Assert(err, gc.IsNil)
   466  
   467  	_, id2, err := keys.NewStore(coll, Policy{
   468  		ExpiryDuration: time.Hour,
   469  	}).RootKey(context.Background())
   470  
   471  	c.Assert(err, gc.IsNil)
   472  	c.Assert(id2, gc.Not(gc.Equals), id1)
   473  
   474  	// Sanity check that the keys are in the collection.
   475  	n, err := coll.Find(nil).Count()
   476  	c.Assert(err, gc.IsNil)
   477  	c.Assert(n, gc.Equals, 2)
   478  	for i := 0; i < 100; i++ {
   479  		n, err := coll.Find(nil).Count()
   480  		c.Assert(err, gc.IsNil)
   481  		switch n {
   482  		case 1:
   483  			return
   484  		case 2:
   485  			time.Sleep(time.Second)
   486  		default:
   487  			c.Fatalf("unexpected key count %v", n)
   488  		}
   489  	}
   490  	c.Fatalf("key was never removed from database")
   491  }
   492  
   493  type legacyRootKeyDoc struct {
   494  	Id      string `bson:"_id"`
   495  	Created time.Time
   496  	Expires time.Time
   497  	RootKey []byte
   498  }
   499  
   500  func (s *RootKeySuite) TestLegacy(c *gc.C) {
   501  	coll := s.testColl(c)
   502  	err := coll.Insert(&legacyRootKeyDoc{
   503  		Id:      "foo",
   504  		RootKey: []byte("a key"),
   505  		Created: time.Now(),
   506  		Expires: time.Now().Add(10 * time.Minute),
   507  	})
   508  	c.Assert(err, gc.IsNil)
   509  	store := NewRootKeys(10).NewStore(coll, Policy{
   510  		ExpiryDuration: 5 * time.Minute,
   511  	})
   512  	rk, err := store.Get(context.Background(), []byte("foo"))
   513  	c.Assert(err, gc.IsNil)
   514  	c.Assert(string(rk), gc.Equals, "a key")
   515  }
   516  
   517  func (s *RootKeySuite) TestUsesSessionFromContext(c *gc.C) {
   518  	coll := s.testColl(c)
   519  
   520  	s1 := coll.Database.Session.Copy()
   521  	s2 := coll.Database.Session.Copy()
   522  	s.AddCleanup(func(c *gc.C) {
   523  		s2.Close()
   524  	})
   525  
   526  	coll = coll.With(s1)
   527  	store := NewRootKeys(10).NewStore(coll, Policy{
   528  		ExpiryDuration: 5 * time.Minute,
   529  	})
   530  	s1.Close()
   531  
   532  	ctx := ContextWithMgoSession(context.Background(), s2)
   533  	_, _, err := store.RootKey(ctx)
   534  	c.Assert(err, gc.Equals, nil)
   535  }
   536  
   537  func (s *RootKeySuite) TestDoneContext(c *gc.C) {
   538  	store := NewRootKeys(10).NewStore(s.testColl(c), Policy{
   539  		ExpiryDuration: 5 * time.Minute,
   540  	})
   541  
   542  	ctx, cancel := context.WithCancel(context.Background())
   543  	cancel()
   544  	_, _, err := store.RootKey(ctx)
   545  	c.Assert(err, gc.ErrorMatches, `cannot query existing keys: context canceled`)
   546  }
   547  
   548  func (s *RootKeySuite) testColl(c *gc.C) *mgo.Collection {
   549  	return s.Session.DB("test").C("rootkeyitems")
   550  }
   551  
   552  func clockVal(t *time.Time) dbrootkeystore.Clock {
   553  	return clockFunc(func() time.Time {
   554  		return *t
   555  	})
   556  }
   557  
   558  type clockFunc func() time.Time
   559  
   560  func (f clockFunc) Now() time.Time {
   561  	return f()
   562  }