go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/internal/disk_cache.go (about)

     1  // Copyright 2017 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package internal
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"encoding/json"
    21  	"fmt"
    22  	"io"
    23  	"io/ioutil"
    24  	"os"
    25  	"path/filepath"
    26  	"sort"
    27  	"time"
    28  
    29  	"golang.org/x/oauth2"
    30  
    31  	"go.chromium.org/luci/common/clock"
    32  	"go.chromium.org/luci/common/data/stringset"
    33  	"go.chromium.org/luci/common/logging"
    34  	"go.chromium.org/luci/common/retry"
    35  	"go.chromium.org/luci/common/retry/transient"
    36  )
    37  
    38  const (
    39  	// GCAccessTokenMaxAge defines when to remove unused access tokens from the
    40  	// disk cache.
    41  	//
    42  	// We define "an access token" as an instance of oauth2.Token with
    43  	// RefreshToken set to "".
    44  	//
    45  	// If an access token expired older than GCAccessTokenMaxAge ago, it will be
    46  	// evicted from the cache (it is essentially garbage now anyway).
    47  	GCAccessTokenMaxAge = 2 * time.Hour
    48  
    49  	// GCRefreshTokenMaxAge defines when to remove unused refresh tokens from the
    50  	// disk cache.
    51  	//
    52  	// We define "a refresh token" as an instance of oauth2.Token with
    53  	// RefreshToken not set to "".
    54  	//
    55  	// Refresh tokens don't expire, but they get neglected and forgotten by users,
    56  	// staying on their disks forever. We remove such tokens if they haven't been
    57  	// used for more than two weeks.
    58  	//
    59  	// It essentially logs out the user on inactivity. We don't actively revoke
    60  	// evicted tokens though, since it's possible the user has copied the token
    61  	// and uses it elsewhere (it happens). Such token can always be revoked from
    62  	// Google Accounts page manually, if required.
    63  	GCRefreshTokenMaxAge = 14 * 24 * time.Hour
    64  )
    65  
    66  // DiskTokenCache implements TokenCache on top of a file.
    67  //
    68  // It uses single file to store all tokens. If multiple processes try to write
    69  // to it at the same time, only one process wins (so some updates may be lost).
    70  //
    71  // TODO(vadimsh): Either use file locking or split the cache into multiple
    72  // files to avoid concurrency issues.
    73  //
    74  // TODO(vadimsh): Once this implementation settles and is deployed everywhere,
    75  // add a cleanup step that removes <cache_dir>/*.tok left from the previous
    76  // version of this code.
    77  type DiskTokenCache struct {
    78  	Context    context.Context // for logging and timing
    79  	SecretsDir string
    80  }
    81  
    82  type cacheFile struct {
    83  	Cache      []*cacheFileEntry `json:"cache"`
    84  	LastUpdate time.Time         `json:"last_update"`
    85  }
    86  
    87  // cacheFileEntry holds one set of cached tokens.
    88  //
    89  // Implements custom JSON marshaling logic to round-trip unknown fields. This
    90  // is useful when new fields are added by newer code, but the token cache is
    91  // still used by older code. Extra fields are better left untouched by the older
    92  // code.
    93  type cacheFileEntry struct {
    94  	key        CacheKey
    95  	token      oauth2.Token
    96  	idToken    string
    97  	email      string
    98  	lastUpdate time.Time
    99  
   100  	extra map[string]*json.RawMessage
   101  }
   102  
   103  type keyPtr struct {
   104  	key string
   105  	ptr any
   106  }
   107  
   108  func (e *cacheFileEntry) structure() []keyPtr {
   109  	return []keyPtr{
   110  		{"key", &e.key},
   111  		{"token", &e.token},
   112  		{"id_token", &e.idToken},
   113  		{"email", &e.email},
   114  		{"last_update", &e.lastUpdate},
   115  	}
   116  }
   117  
   118  func (e *cacheFileEntry) UnmarshalJSON(data []byte) error {
   119  	*e = cacheFileEntry{extra: make(map[string]*json.RawMessage)}
   120  	if err := json.Unmarshal(data, &e.extra); err != nil {
   121  		return err
   122  	}
   123  	for _, kv := range e.structure() {
   124  		if raw := e.extra[kv.key]; raw != nil {
   125  			delete(e.extra, kv.key)
   126  			if err := json.Unmarshal([]byte(*raw), kv.ptr); err != nil {
   127  				return fmt.Errorf("when JSON decoding %q - %s", kv.key, err)
   128  			}
   129  		}
   130  	}
   131  	return nil
   132  }
   133  
   134  func (e *cacheFileEntry) MarshalJSON() ([]byte, error) {
   135  	// Note: this way of marshaling preserves the order of keys per structure().
   136  	// All unrecognized extra keys are placed at the end, sorted.
   137  
   138  	fields := e.structure()
   139  	if len(e.extra) != 0 {
   140  		l := len(fields)
   141  		for k, v := range e.extra {
   142  			fields = append(fields, keyPtr{k, v})
   143  		}
   144  		extra := fields[l:]
   145  		sort.Slice(extra, func(i, j int) bool { return extra[i].key < extra[j].key })
   146  	}
   147  
   148  	out := bytes.Buffer{}
   149  	out.WriteString("{")
   150  
   151  	first := true
   152  	for _, kv := range fields {
   153  		if !first {
   154  			out.WriteString(",")
   155  		}
   156  		first = false
   157  		fmt.Fprintf(&out, "%q:", kv.key)
   158  		if err := json.NewEncoder(&out).Encode(kv.ptr); err != nil {
   159  			return nil, fmt.Errorf("when JSON encoding %q - %s", kv.key, err)
   160  		}
   161  	}
   162  
   163  	out.WriteString("}")
   164  	return out.Bytes(), nil
   165  }
   166  
   167  func (e *cacheFileEntry) isOld(now time.Time) bool {
   168  	delay := GCAccessTokenMaxAge
   169  	if e.token.RefreshToken != "" {
   170  		delay = GCRefreshTokenMaxAge
   171  	}
   172  	exp := e.token.Expiry
   173  	if exp.IsZero() {
   174  		exp = e.lastUpdate
   175  	}
   176  	return now.Sub(exp.Round(0)) >= delay
   177  }
   178  
   179  func (c *DiskTokenCache) legacyPath() string {
   180  	return filepath.Join(c.SecretsDir, "creds.json")
   181  }
   182  
   183  func (c *DiskTokenCache) tokensPath() string {
   184  	return filepath.Join(c.SecretsDir, "tokens.json")
   185  }
   186  
   187  // readCacheFile loads the file with cached tokens.
   188  func (c *DiskTokenCache) readCacheFile(path string) (*cacheFile, error) {
   189  	// Minimize the time the file is locked on Windows by reading it all at once
   190  	// and decoding later.
   191  	//
   192  	// We also need to open it with FILE_SHARE_DELETE sharing mode to allow
   193  	// writeCacheFile() below to replace open files (even though it tries to wait
   194  	// for the file to be closed). For some reason, omitting FILE_SHARE_DELETE
   195  	// flag causes random sharing violation errors when opening the file for
   196  	// reading.
   197  	f, err := openSharedDelete(path)
   198  	switch {
   199  	case os.IsNotExist(err):
   200  		return &cacheFile{}, nil
   201  	case err != nil:
   202  		return nil, err
   203  	}
   204  	blob, err := io.ReadAll(f)
   205  	f.Close()
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	cache := &cacheFile{}
   211  	if err := json.Unmarshal(blob, cache); err != nil {
   212  		// If the cache file got broken somehow, it makes sense to treat it as
   213  		// empty (so it can later be overwritten), since it's unlikely it's going
   214  		// to "fix itself".
   215  		logging.Warningf(c.Context, "The token cache %s is broken: %s", path, err)
   216  		return &cacheFile{}, nil
   217  	}
   218  
   219  	return cache, nil
   220  }
   221  
   222  // writeCacheFile overwrites the file with cached tokens.
   223  //
   224  // Returns a transient error if the file is locked by some other process and
   225  // can't be updated (this happens on Windows).
   226  func (c *DiskTokenCache) writeCacheFile(path string, cache *cacheFile) error {
   227  	// Nothing left? Remove the file completely.
   228  	if len(cache.Cache) == 0 {
   229  		if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
   230  			return err
   231  		}
   232  		return nil
   233  	}
   234  
   235  	blob, err := json.MarshalIndent(cache, "", "  ")
   236  	if err != nil {
   237  		return err
   238  	}
   239  
   240  	// Write to temp file first.
   241  	if err := os.MkdirAll(c.SecretsDir, 0700); err != nil {
   242  		logging.WithError(err).Warningf(c.Context, "Failed to mkdir token cache dir")
   243  		// carry on, TempFile will fail too.
   244  	}
   245  	tmp, err := ioutil.TempFile(c.SecretsDir, "tokens.json.*")
   246  	if err != nil {
   247  		return err
   248  	}
   249  
   250  	cleanup := func() {
   251  		if err := os.Remove(tmp.Name()); err != nil {
   252  			logging.WithError(err).Warningf(c.Context, "Failed to remove temp creds cache file: %s", tmp.Name())
   253  		}
   254  	}
   255  
   256  	_, writeErr := tmp.Write(blob)
   257  	closeErr := tmp.Close()
   258  	switch {
   259  	case writeErr != nil:
   260  		err = writeErr
   261  	case closeErr != nil:
   262  		err = closeErr
   263  	}
   264  	if err != nil {
   265  		cleanup()
   266  		return err
   267  	}
   268  
   269  	// Note that TempFile creates the file in 0600 mode already, so we don't need
   270  	// to chmod it.
   271  	//
   272  	// On Windows Rename may fail with sharing violation error if some other
   273  	// process has opened the file. We treat it as transient error, to trigger
   274  	// a retry in updateCacheFile.
   275  	if err = os.Rename(tmp.Name(), path); err != nil {
   276  		cleanup()
   277  		return transient.Tag.Apply(err)
   278  	}
   279  	return nil
   280  }
   281  
   282  // updateCache reads token cache files, calls the callback and writes the files
   283  // back if the callback returns 'true'.
   284  //
   285  // It retries a bunch of times when encountering sharing violation errors on
   286  // Windows.
   287  //
   288  // Mutates tokens.json and creds.json. tokens.json is the primary token cache
   289  // file and creds.json is an old one used by the older versions of this library.
   290  // It will eventually be phased out.
   291  //
   292  // TODO(vadimsh): Change this to use file locking - updateCacheFile is a global
   293  // critical section.
   294  func (c *DiskTokenCache) updateCache(cb func(*cacheFile, time.Time) bool) error {
   295  	retryParams := func() retry.Iterator {
   296  		return &retry.ExponentialBackoff{
   297  			Limited: retry.Limited{
   298  				Delay:    10 * time.Millisecond,
   299  				Retries:  200,
   300  				MaxTotal: 4 * time.Second,
   301  			},
   302  			Multiplier: 1.5,
   303  		}
   304  	}
   305  	return retry.Retry(c.Context, transient.Only(retryParams), func() error {
   306  		return c.updateCacheFiles(cb)
   307  	}, func(err error, _ time.Duration) {
   308  		logging.Warningf(c.Context, "Retrying the failed token cache update: %s", err)
   309  	})
   310  }
   311  
   312  // readCache reads tokens.json and creds.json and merges them.
   313  func (c *DiskTokenCache) readCache() (*cacheFile, time.Time, error) {
   314  	legacyCache, err := c.readCacheFile(c.legacyPath())
   315  	if err != nil {
   316  		return nil, time.Time{}, err
   317  	}
   318  	newCache, err := c.readCacheFile(c.tokensPath())
   319  	if err != nil {
   320  		return nil, time.Time{}, err
   321  	}
   322  
   323  	// Merge tokens from legacyCache into newCache, but don't override anything.
   324  	seen := stringset.New(len(newCache.Cache))
   325  	for _, entry := range newCache.Cache {
   326  		seen.Add(entry.key.ToMapKey())
   327  	}
   328  	for _, entry := range legacyCache.Cache {
   329  		if !seen.Has(entry.key.ToMapKey()) {
   330  			newCache.Cache = append(newCache.Cache, entry)
   331  		}
   332  	}
   333  
   334  	// If legacyCache didn't exist at all, pretend it was touched in distant past.
   335  	// This avoid weird looking "0001-01-01" dates. Seventies were better.
   336  	if legacyCache.LastUpdate.IsZero() {
   337  		legacyCache.LastUpdate = time.Date(1970, time.January, 01, 0, 0, 0, 0, time.UTC)
   338  	}
   339  
   340  	return newCache, legacyCache.LastUpdate, nil
   341  }
   342  
   343  // updateCacheFiles does one attempt at updating the cache files.
   344  func (c *DiskTokenCache) updateCacheFiles(cb func(*cacheFile, time.Time) bool) error {
   345  	// Read and merge tokens.json and creds.json.
   346  	cache, legacyLastUpdate, err := c.readCache()
   347  	if err != nil {
   348  		return err
   349  	}
   350  
   351  	// Apply the mutation.
   352  	now := clock.Now(c.Context).UTC()
   353  	if !cb(cache, now) {
   354  		return nil
   355  	}
   356  
   357  	// Tidy up the cache before saving it.
   358  	c.discardOldEntries(cache, now)
   359  
   360  	// HACK: Update creds.json, but do not touch its "last_update" time. That way
   361  	// refresh tokens created by newer `cipd auth-login ...` would still work with
   362  	// older binaries that look at creds.json, but there's still a way to know
   363  	// when creds.json is not actually used (its `last_update` time would be
   364  	// ancient). This will eventually be used to decide if it is safe to delete
   365  	// creds.json.
   366  	cache.LastUpdate = legacyLastUpdate
   367  	if err := c.writeCacheFile(c.legacyPath(), cache); err != nil {
   368  		return err
   369  	}
   370  
   371  	// Update tokens.json as usual, updating its `last_update` field.
   372  	cache.LastUpdate = now
   373  	return c.writeCacheFile(c.tokensPath(), cache)
   374  }
   375  
   376  // discardOldEntries filters out old entries.
   377  func (c *DiskTokenCache) discardOldEntries(cache *cacheFile, now time.Time) {
   378  	filtered := cache.Cache[:0]
   379  	for _, entry := range cache.Cache {
   380  		if !entry.isOld(now) {
   381  			filtered = append(filtered, entry)
   382  		} else {
   383  			logging.Debugf(c.Context, "Cleaning up old token cache entry: %s", entry.key.Key)
   384  		}
   385  	}
   386  	cache.Cache = filtered
   387  }
   388  
   389  // GetToken reads the token from cache.
   390  func (c *DiskTokenCache) GetToken(key *CacheKey) (*Token, error) {
   391  	cache, _, err := c.readCache()
   392  	if err != nil {
   393  		return nil, err
   394  	}
   395  	for _, entry := range cache.Cache {
   396  		if EqualCacheKeys(&entry.key, key) {
   397  			return &Token{
   398  				Token:   entry.token,
   399  				IDToken: entry.idToken,
   400  				Email:   entry.email,
   401  			}, nil
   402  		}
   403  	}
   404  	return nil, nil
   405  }
   406  
   407  // PutToken writes the token to cache.
   408  func (c *DiskTokenCache) PutToken(key *CacheKey, tok *Token) error {
   409  	token := tok.Token
   410  	if !token.Expiry.IsZero() {
   411  		token.Expiry = token.Expiry.UTC()
   412  	}
   413  	return c.updateCache(func(cache *cacheFile, now time.Time) bool {
   414  		for _, entry := range cache.Cache {
   415  			if EqualCacheKeys(&entry.key, key) {
   416  				entry.token = token
   417  				entry.idToken = tok.IDToken
   418  				entry.email = tok.Email
   419  				entry.lastUpdate = now
   420  				return true
   421  			}
   422  		}
   423  		cache.Cache = append(cache.Cache, &cacheFileEntry{
   424  			key:        *key,
   425  			token:      token,
   426  			idToken:    tok.IDToken,
   427  			email:      tok.Email,
   428  			lastUpdate: now,
   429  		})
   430  		return true
   431  	})
   432  }
   433  
   434  // DeleteToken removes the token from cache.
   435  func (c *DiskTokenCache) DeleteToken(key *CacheKey) error {
   436  	return c.updateCache(func(cache *cacheFile, now time.Time) bool {
   437  		for i, entry := range cache.Cache {
   438  			if EqualCacheKeys(&entry.key, key) {
   439  				cache.Cache = append(cache.Cache[:i], cache.Cache[i+1:]...)
   440  				return true
   441  			}
   442  		}
   443  		return false // not there, this is fine, skip writing the file
   444  	})
   445  }