k8s.io/apiserver@v0.31.1/pkg/authentication/token/cache/cached_token_authenticator.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package cache
    18  
    19  import (
    20  	"context"
    21  	"crypto/hmac"
    22  	"crypto/rand"
    23  	"crypto/sha256"
    24  	"encoding/binary"
    25  	"errors"
    26  	"hash"
    27  	"io"
    28  	"runtime"
    29  	"sync"
    30  	"time"
    31  	"unsafe"
    32  
    33  	"golang.org/x/sync/singleflight"
    34  
    35  	apierrors "k8s.io/apimachinery/pkg/api/errors"
    36  	auditinternal "k8s.io/apiserver/pkg/apis/audit"
    37  	"k8s.io/apiserver/pkg/audit"
    38  	"k8s.io/apiserver/pkg/authentication/authenticator"
    39  	"k8s.io/apiserver/pkg/warning"
    40  	"k8s.io/klog/v2"
    41  	"k8s.io/utils/clock"
    42  )
    43  
    44  var errAuthnCrash = apierrors.NewInternalError(errors.New("authentication failed unexpectedly"))
    45  
    46  const sharedLookupTimeout = 30 * time.Second
    47  
    48  // cacheRecord holds the three return values of the authenticator.Token AuthenticateToken method
    49  type cacheRecord struct {
    50  	resp *authenticator.Response
    51  	ok   bool
    52  	err  error
    53  
    54  	// this cache assumes token authn has no side-effects or temporal dependence.
    55  	// neither of these are true for audit annotations set via AddAuditAnnotation.
    56  	//
    57  	// for audit annotations, the assumption is that for some period of time (cache TTL),
    58  	// all requests with the same API audiences and the same bearer token result in the
    59  	// same annotations.  This may not be true if the authenticator sets an annotation
    60  	// based on the current time, but that may be okay since cache TTLs are generally
    61  	// small (seconds).
    62  	annotations map[string]string
    63  	warnings    []*cacheWarning
    64  }
    65  
    66  type cacheWarning struct {
    67  	agent string
    68  	text  string
    69  }
    70  
    71  type cachedTokenAuthenticator struct {
    72  	authenticator authenticator.Token
    73  
    74  	cacheErrs  bool
    75  	successTTL time.Duration
    76  	failureTTL time.Duration
    77  
    78  	cache cache
    79  	group singleflight.Group
    80  
    81  	// hashPool is a per authenticator pool of hash.Hash (to avoid allocations from building the Hash)
    82  	// HMAC with SHA-256 and a random key is used to prevent precomputation and length extension attacks
    83  	// It also mitigates hash map DOS attacks via collisions (the inputs are supplied by untrusted users)
    84  	hashPool *sync.Pool
    85  }
    86  
    87  type cache interface {
    88  	// given a key, return the record, and whether or not it existed
    89  	get(key string) (value *cacheRecord, exists bool)
    90  	// caches the record for the key
    91  	set(key string, value *cacheRecord, ttl time.Duration)
    92  	// removes the record for the key
    93  	remove(key string)
    94  }
    95  
    96  // New returns a token authenticator that caches the results of the specified authenticator. A ttl of 0 bypasses the cache.
    97  func New(authenticator authenticator.Token, cacheErrs bool, successTTL, failureTTL time.Duration) authenticator.Token {
    98  	return newWithClock(authenticator, cacheErrs, successTTL, failureTTL, clock.RealClock{})
    99  }
   100  
   101  func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL, failureTTL time.Duration, clock clock.Clock) authenticator.Token {
   102  	randomCacheKey := make([]byte, 32)
   103  	if _, err := rand.Read(randomCacheKey); err != nil {
   104  		panic(err) // rand should never fail
   105  	}
   106  
   107  	return &cachedTokenAuthenticator{
   108  		authenticator: authenticator,
   109  		cacheErrs:     cacheErrs,
   110  		successTTL:    successTTL,
   111  		failureTTL:    failureTTL,
   112  		// Cache performance degrades noticeably when the number of
   113  		// tokens in operation exceeds the size of the cache. It is
   114  		// cheap to make the cache big in the second dimension below,
   115  		// the memory is only consumed when that many tokens are being
   116  		// used. Currently we advertise support 5k nodes and 10k
   117  		// namespaces; a 32k entry cache is therefore a 2x safety
   118  		// margin.
   119  		cache: newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(clock) }),
   120  
   121  		hashPool: &sync.Pool{
   122  			New: func() interface{} {
   123  				return hmac.New(sha256.New, randomCacheKey)
   124  			},
   125  		},
   126  	}
   127  }
   128  
   129  // AuthenticateToken implements authenticator.Token
   130  func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) {
   131  	record := a.doAuthenticateToken(ctx, token)
   132  	if !record.ok || record.err != nil {
   133  		return nil, false, record.err
   134  	}
   135  	for key, value := range record.annotations {
   136  		audit.AddAuditAnnotation(ctx, key, value)
   137  	}
   138  	for _, w := range record.warnings {
   139  		warning.AddWarning(ctx, w.agent, w.text)
   140  	}
   141  	return record.resp, true, nil
   142  }
   143  
   144  func (a *cachedTokenAuthenticator) doAuthenticateToken(ctx context.Context, token string) *cacheRecord {
   145  	doneAuthenticating := stats.authenticating(ctx)
   146  
   147  	auds, audsOk := authenticator.AudiencesFrom(ctx)
   148  
   149  	key := keyFunc(a.hashPool, auds, token)
   150  	if record, ok := a.cache.get(key); ok {
   151  		// Record cache hit
   152  		doneAuthenticating(true)
   153  		return record
   154  	}
   155  
   156  	// Record cache miss
   157  	doneBlocking := stats.blocking(ctx)
   158  	defer doneBlocking()
   159  	defer doneAuthenticating(false)
   160  
   161  	c := a.group.DoChan(key, func() (val interface{}, _ error) {
   162  		// always use one place to read and write the output of AuthenticateToken
   163  		record := &cacheRecord{}
   164  
   165  		doneFetching := stats.fetching(ctx)
   166  		// We're leaving the request handling stack so we need to handle crashes
   167  		// ourselves. Log a stack trace and return a 500 if something panics.
   168  		defer func() {
   169  			if r := recover(); r != nil {
   170  				// make sure to always return a record
   171  				record.err = errAuthnCrash
   172  				val = record
   173  
   174  				// Same as stdlib http server code. Manually allocate stack
   175  				// trace buffer size to prevent excessively large logs
   176  				const size = 64 << 10
   177  				buf := make([]byte, size)
   178  				buf = buf[:runtime.Stack(buf, false)]
   179  				klog.Errorf("%v\n%s", r, buf)
   180  			}
   181  			doneFetching(record.err == nil)
   182  		}()
   183  
   184  		// Check again for a cached record. We may have raced with a fetch.
   185  		if record, ok := a.cache.get(key); ok {
   186  			return record, nil
   187  		}
   188  
   189  		// Detach the context because the lookup may be shared by multiple callers,
   190  		// however propagate the audience.
   191  		ctx, cancel := context.WithTimeout(context.Background(), sharedLookupTimeout)
   192  		defer cancel()
   193  
   194  		if audsOk {
   195  			ctx = authenticator.WithAudiences(ctx, auds)
   196  		}
   197  		recorder := &recorder{}
   198  		ctx = warning.WithWarningRecorder(ctx, recorder)
   199  
   200  		ctx = audit.WithAuditContext(ctx)
   201  		ac := audit.AuditContextFrom(ctx)
   202  		// since this is shared work between multiple requests, we have no way of knowing if any
   203  		// particular request supports audit annotations.  thus we always attempt to record them.
   204  		ac.Event.Level = auditinternal.LevelMetadata
   205  
   206  		record.resp, record.ok, record.err = a.authenticator.AuthenticateToken(ctx, token)
   207  		record.annotations = ac.Event.Annotations
   208  		record.warnings = recorder.extractWarnings()
   209  
   210  		if !a.cacheErrs && record.err != nil {
   211  			return record, nil
   212  		}
   213  
   214  		switch {
   215  		case record.ok && a.successTTL > 0:
   216  			a.cache.set(key, record, a.successTTL)
   217  		case !record.ok && a.failureTTL > 0:
   218  			a.cache.set(key, record, a.failureTTL)
   219  		}
   220  
   221  		return record, nil
   222  	})
   223  
   224  	select {
   225  	case result := <-c:
   226  		// we always set Val and never set Err
   227  		return result.Val.(*cacheRecord)
   228  	case <-ctx.Done():
   229  		// fake a record on context cancel
   230  		return &cacheRecord{err: ctx.Err()}
   231  	}
   232  }
   233  
   234  // keyFunc generates a string key by hashing the inputs.
   235  // This lowers the memory requirement of the cache and keeps tokens out of memory.
   236  func keyFunc(hashPool *sync.Pool, auds []string, token string) string {
   237  	h := hashPool.Get().(hash.Hash)
   238  
   239  	h.Reset()
   240  
   241  	// try to force stack allocation
   242  	var a [4]byte
   243  	b := a[:]
   244  
   245  	writeLengthPrefixedString(h, b, token)
   246  	// encode the length of audiences to avoid ambiguities
   247  	writeLength(h, b, len(auds))
   248  	for _, aud := range auds {
   249  		writeLengthPrefixedString(h, b, aud)
   250  	}
   251  
   252  	key := toString(h.Sum(nil)) // skip base64 encoding to save an allocation
   253  
   254  	hashPool.Put(h)
   255  
   256  	return key
   257  }
   258  
   259  // writeLengthPrefixedString writes s with a length prefix to prevent ambiguities, i.e. "xy" + "z" == "x" + "yz"
   260  // the length of b is assumed to be 4 (b is mutated by this function to store the length of s)
   261  func writeLengthPrefixedString(w io.Writer, b []byte, s string) {
   262  	writeLength(w, b, len(s))
   263  	if _, err := w.Write(toBytes(s)); err != nil {
   264  		panic(err) // Write() on hash never fails
   265  	}
   266  }
   267  
   268  // writeLength encodes length into b and then writes it via the given writer
   269  // the length of b is assumed to be 4
   270  func writeLength(w io.Writer, b []byte, length int) {
   271  	binary.BigEndian.PutUint32(b, uint32(length))
   272  	if _, err := w.Write(b); err != nil {
   273  		panic(err) // Write() on hash never fails
   274  	}
   275  }
   276  
   277  // toBytes performs unholy acts to avoid allocations
   278  func toBytes(s string) []byte {
   279  	// unsafe.StringData is unspecified for the empty string, so we provide a strict interpretation
   280  	if len(s) == 0 {
   281  		return nil
   282  	}
   283  	// Copied from go 1.20.1 os.File.WriteString
   284  	// https://github.com/golang/go/blob/202a1a57064127c3f19d96df57b9f9586145e21c/src/os/file.go#L246
   285  	return unsafe.Slice(unsafe.StringData(s), len(s))
   286  }
   287  
   288  // toString performs unholy acts to avoid allocations
   289  func toString(b []byte) string {
   290  	// unsafe.SliceData relies on cap whereas we want to rely on len
   291  	if len(b) == 0 {
   292  		return ""
   293  	}
   294  	// Copied from go 1.20.1 strings.Builder.String
   295  	// https://github.com/golang/go/blob/202a1a57064127c3f19d96df57b9f9586145e21c/src/strings/builder.go#L48
   296  	return unsafe.String(unsafe.SliceData(b), len(b))
   297  }
   298  
   299  // simple recorder that only appends warning
   300  type recorder struct {
   301  	mu       sync.Mutex
   302  	warnings []*cacheWarning
   303  }
   304  
   305  // AddWarning adds a warning to recorder.
   306  func (r *recorder) AddWarning(agent, text string) {
   307  	r.mu.Lock()
   308  	defer r.mu.Unlock()
   309  	r.warnings = append(r.warnings, &cacheWarning{agent: agent, text: text})
   310  }
   311  
   312  func (r *recorder) extractWarnings() []*cacheWarning {
   313  	r.mu.Lock()
   314  	defer r.mu.Unlock()
   315  	warnings := r.warnings
   316  	r.warnings = nil
   317  	return warnings
   318  }