k8s.io/apiserver@v0.31.1/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package kmsv2 transforms values for storage at rest using a Envelope v2 provider
    18  package kmsv2
    19  
    20  import (
    21  	"crypto/rand"
    22  	"crypto/sha256"
    23  	"fmt"
    24  	"sync"
    25  	"sync/atomic"
    26  	"testing"
    27  	"time"
    28  
    29  	"k8s.io/apimachinery/pkg/util/sets"
    30  	"k8s.io/apiserver/pkg/storage/value"
    31  	testingclock "k8s.io/utils/clock/testing"
    32  )
    33  
    34  func TestSimpleCacheSetError(t *testing.T) {
    35  	fakeClock := testingclock.NewFakeClock(time.Now())
    36  	cache := newSimpleCache(fakeClock, time.Second, "providerName")
    37  
    38  	tests := []struct {
    39  		name        string
    40  		key         []byte
    41  		transformer value.Transformer
    42  	}{
    43  		{
    44  			name:        "empty key",
    45  			key:         []byte{},
    46  			transformer: &envelopeTransformer{},
    47  		},
    48  		{
    49  			name:        "nil transformer",
    50  			key:         []byte("key"),
    51  			transformer: nil,
    52  		},
    53  	}
    54  
    55  	for _, test := range tests {
    56  		t.Run(test.name, func(t *testing.T) {
    57  			defer func() {
    58  				if r := recover(); r == nil {
    59  					t.Errorf("The code did not panic")
    60  				}
    61  			}()
    62  			cache.set(test.key, test.transformer)
    63  		})
    64  	}
    65  }
    66  
    67  func TestKeyFunc(t *testing.T) {
    68  	fakeClock := testingclock.NewFakeClock(time.Now())
    69  	cache := newSimpleCache(fakeClock, time.Second, "providerName")
    70  
    71  	t.Run("AllocsPerRun test", func(t *testing.T) {
    72  		key, err := generateKey(encryptedDEKSourceMaxSize) // simulate worst case EDEK
    73  		if err != nil {
    74  			t.Fatal(err)
    75  		}
    76  
    77  		f := func() {
    78  			out := cache.keyFunc(key)
    79  			if len(out) != sha256.Size {
    80  				t.Errorf("Expected %d bytes, got %d", sha256.Size, len(out))
    81  			}
    82  		}
    83  
    84  		// prime the key func
    85  		var wg sync.WaitGroup
    86  		for i := 0; i < 100; i++ {
    87  			wg.Add(1)
    88  			go func() {
    89  				f()
    90  				wg.Done()
    91  			}()
    92  		}
    93  		wg.Wait()
    94  
    95  		allocs := testing.AllocsPerRun(100, f)
    96  		if allocs > 1 {
    97  			t.Errorf("Expected 1 allocations, got %v", allocs)
    98  		}
    99  	})
   100  }
   101  
   102  func TestSimpleCache(t *testing.T) {
   103  	fakeClock := testingclock.NewFakeClock(time.Now())
   104  	cache := newSimpleCache(fakeClock, 5*time.Second, "providerName")
   105  	transformer := &envelopeTransformer{}
   106  
   107  	wg := sync.WaitGroup{}
   108  	for i := 0; i < 10; i++ {
   109  		k := fmt.Sprintf("key-%d", i)
   110  		wg.Add(1)
   111  		go func(key string) {
   112  			defer wg.Done()
   113  			cache.set([]byte(key), transformer)
   114  		}(k)
   115  	}
   116  	wg.Wait()
   117  
   118  	if cache.cache.Len() != 10 {
   119  		t.Fatalf("Expected 10 items in the cache, got %v", cache.cache.Len())
   120  	}
   121  
   122  	for i := 0; i < 10; i++ {
   123  		k := fmt.Sprintf("key-%d", i)
   124  		if cache.get([]byte(k)) != transformer {
   125  			t.Fatalf("Expected to get the transformer for key %v", k)
   126  		}
   127  	}
   128  
   129  	// Wait for the cache to expire
   130  	fakeClock.Step(6 * time.Second)
   131  
   132  	// expired reads still work until GC runs on write
   133  	for i := 0; i < 10; i++ {
   134  		k := fmt.Sprintf("key-%d", i)
   135  		if cache.get([]byte(k)) != transformer {
   136  			t.Fatalf("Expected to get the transformer for key %v", k)
   137  		}
   138  	}
   139  
   140  	// run GC by performing a write
   141  	cache.set([]byte("some-other-unrelated-key"), transformer)
   142  
   143  	for i := 0; i < 10; i++ {
   144  		k := fmt.Sprintf("key-%d", i)
   145  		if cache.get([]byte(k)) != nil {
   146  			t.Fatalf("Expected to get nil for key %v", k)
   147  		}
   148  	}
   149  }
   150  
   151  func generateKey(length int) (key []byte, err error) {
   152  	key = make([]byte, length)
   153  	if _, err = rand.Read(key); err != nil {
   154  		return nil, err
   155  	}
   156  	return key, nil
   157  }
   158  
   159  func TestMetrics(t *testing.T) {
   160  	fakeClock := testingclock.NewFakeClock(time.Now())
   161  	cache := newSimpleCache(fakeClock, 5*time.Second, "panda")
   162  	var record sync.Map
   163  	var cacheSize atomic.Uint64
   164  	cache.recordCacheSize = func(providerName string, size int) {
   165  		if providerName != "panda" {
   166  			t.Errorf(`expected "panda" as provider name, got %q`, providerName)
   167  		}
   168  		if _, loaded := record.LoadOrStore(size, nil); loaded {
   169  			t.Errorf("detected duplicated cache size metric for %d", size)
   170  		}
   171  		newSize := uint64(size)
   172  		oldSize := cacheSize.Swap(newSize)
   173  		if oldSize > newSize {
   174  			t.Errorf("cache size decreased from %d to %d", oldSize, newSize)
   175  		}
   176  	}
   177  	transformer := &envelopeTransformer{}
   178  
   179  	want := sets.NewInt()
   180  	startCh := make(chan struct{})
   181  	wg := sync.WaitGroup{}
   182  	for i := 0; i < 100; i++ {
   183  		want.Insert(i + 1)
   184  		k := fmt.Sprintf("key-%d", i)
   185  		wg.Add(1)
   186  		go func(key string) {
   187  			defer wg.Done()
   188  			<-startCh
   189  			cache.set([]byte(key), transformer)
   190  		}(k)
   191  	}
   192  	close(startCh)
   193  	wg.Wait()
   194  
   195  	got := sets.NewInt()
   196  	record.Range(func(key, value any) bool {
   197  		got.Insert(key.(int))
   198  		if value != nil {
   199  			t.Errorf("expected value to be nil but got %v", value)
   200  		}
   201  		return true
   202  	})
   203  	if !want.Equal(got) {
   204  		t.Errorf("cache size entries missing values: %v", want.SymmetricDifference(got).List())
   205  	}
   206  }