k8s.io/apiserver@v0.29.3/pkg/storage/value/encrypt/envelope/kmsv2/envelope_test.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package kmsv2 transforms values for storage at rest using a Envelope v2 provider
    18  package kmsv2
    19  
    20  import (
    21  	"bytes"
    22  	"context"
    23  	"encoding/base64"
    24  	"flag"
    25  	"fmt"
    26  	"reflect"
    27  	"regexp"
    28  	"strconv"
    29  	"strings"
    30  	"sync"
    31  	"sync/atomic"
    32  	"testing"
    33  	"time"
    34  
    35  	"github.com/gogo/protobuf/proto"
    36  	"go.opentelemetry.io/otel/sdk/trace"
    37  	"go.opentelemetry.io/otel/sdk/trace/tracetest"
    38  
    39  	utilrand "k8s.io/apimachinery/pkg/util/rand"
    40  	"k8s.io/apimachinery/pkg/util/uuid"
    41  	genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
    42  	"k8s.io/apiserver/pkg/storage/value"
    43  	kmstypes "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/v2"
    44  	"k8s.io/apiserver/pkg/storage/value/encrypt/envelope/metrics"
    45  	"k8s.io/component-base/metrics/legacyregistry"
    46  	"k8s.io/component-base/metrics/testutil"
    47  	"k8s.io/klog/v2"
    48  	kmsservice "k8s.io/kms/pkg/service"
    49  	"k8s.io/utils/clock"
    50  	testingclock "k8s.io/utils/clock/testing"
    51  )
    52  
    53  const (
    54  	testText            = "abcdefghijklmnopqrstuvwxyz"
    55  	testContextText     = "0123456789"
    56  	testKeyHash         = "sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b"
    57  	testKeyVersion      = "1"
    58  	testAPIServerID     = "testAPIServerID"
    59  	testAPIServerIDHash = "sha256:14f9d63e669337ac6bfda2e2162915ee6a6067743eddd4e5c374b572f951ff37"
    60  )
    61  
    62  // testEnvelopeService is a mock Envelope service which can be used to simulate remote Envelope services
    63  // for testing of Envelope based encryption providers.
    64  type testEnvelopeService struct {
    65  	annotations  map[string][]byte
    66  	disabled     bool
    67  	keyVersion   string
    68  	ciphertext   []byte
    69  	decryptCalls int32
    70  }
    71  
    72  func (t *testEnvelopeService) Decrypt(ctx context.Context, uid string, req *kmsservice.DecryptRequest) ([]byte, error) {
    73  	atomic.AddInt32(&t.decryptCalls, 1)
    74  	if t.disabled {
    75  		return nil, fmt.Errorf("Envelope service was disabled")
    76  	}
    77  	if len(uid) == 0 {
    78  		return nil, fmt.Errorf("uid is required")
    79  	}
    80  	if len(req.KeyID) == 0 {
    81  		return nil, fmt.Errorf("keyID is required")
    82  	}
    83  	return base64.StdEncoding.DecodeString(string(req.Ciphertext))
    84  }
    85  
    86  func (t *testEnvelopeService) Encrypt(ctx context.Context, uid string, data []byte) (*kmsservice.EncryptResponse, error) {
    87  	if t.disabled {
    88  		return nil, fmt.Errorf("Envelope service was disabled")
    89  	}
    90  	if len(uid) == 0 {
    91  		return nil, fmt.Errorf("uid is required")
    92  	}
    93  	annotations := make(map[string][]byte)
    94  	if t.annotations != nil {
    95  		for k, v := range t.annotations {
    96  			annotations[k] = v
    97  		}
    98  	} else {
    99  		annotations["local-kek.kms.kubernetes.io"] = []byte("encrypted-local-kek")
   100  	}
   101  
   102  	ciphertext := t.ciphertext
   103  	if ciphertext == nil {
   104  		ciphertext = []byte(base64.StdEncoding.EncodeToString(data))
   105  	}
   106  
   107  	return &kmsservice.EncryptResponse{Ciphertext: ciphertext, KeyID: t.keyVersion, Annotations: annotations}, nil
   108  }
   109  
   110  func (t *testEnvelopeService) Status(ctx context.Context) (*kmsservice.StatusResponse, error) {
   111  	if t.disabled {
   112  		return nil, fmt.Errorf("Envelope service was disabled")
   113  	}
   114  	return &kmsservice.StatusResponse{KeyID: t.keyVersion}, nil
   115  }
   116  
   117  func (t *testEnvelopeService) SetDisabledStatus(status bool) {
   118  	t.disabled = status
   119  }
   120  
   121  func (t *testEnvelopeService) SetAnnotations(annotations map[string][]byte) {
   122  	t.annotations = annotations
   123  }
   124  
   125  func (t *testEnvelopeService) SetCiphertext(ciphertext []byte) {
   126  	t.ciphertext = ciphertext
   127  }
   128  
   129  func (t *testEnvelopeService) Rotate() {
   130  	i, _ := strconv.Atoi(t.keyVersion)
   131  	t.keyVersion = strconv.FormatInt(int64(i+1), 10)
   132  }
   133  
   134  func newTestEnvelopeService() *testEnvelopeService {
   135  	return &testEnvelopeService{
   136  		keyVersion: testKeyVersion,
   137  	}
   138  }
   139  
   140  // Throw error if Envelope transformer tries to contact Envelope without hitting cache.
   141  func TestEnvelopeCaching(t *testing.T) {
   142  	testCases := []struct {
   143  		desc                     string
   144  		cacheTTL                 time.Duration
   145  		simulateKMSPluginFailure bool
   146  		expectedError            string
   147  		expectedDecryptCalls     int
   148  	}{
   149  		{
   150  			desc:                     "entry in cache should withstand plugin failure",
   151  			cacheTTL:                 5 * time.Minute,
   152  			simulateKMSPluginFailure: true,
   153  			expectedDecryptCalls:     0, // should not hit KMS plugin
   154  		},
   155  		{
   156  			desc:                     "cache entry expired should not withstand plugin failure",
   157  			cacheTTL:                 1 * time.Millisecond,
   158  			simulateKMSPluginFailure: true,
   159  			expectedError:            "failed to decrypt DEK, error: Envelope service was disabled",
   160  			expectedDecryptCalls:     10, // should hit KMS plugin for each read after cache entry expired and fail
   161  		},
   162  		{
   163  			desc:                     "cache entry expired should work after cache refresh",
   164  			cacheTTL:                 1 * time.Millisecond,
   165  			simulateKMSPluginFailure: false,
   166  			expectedDecryptCalls:     1, // should hit KMS plugin just for the 1st read after cache entry expired
   167  		},
   168  	}
   169  
   170  	for _, tt := range testCases {
   171  		t.Run(tt.desc, func(t *testing.T) {
   172  			ctx := testContext(t)
   173  
   174  			envelopeService := newTestEnvelopeService()
   175  			fakeClock := testingclock.NewFakeClock(time.Now())
   176  
   177  			useSeed := randomBool()
   178  
   179  			state, err := testStateFunc(ctx, envelopeService, fakeClock, useSeed)()
   180  			if err != nil {
   181  				t.Fatal(err)
   182  			}
   183  
   184  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
   185  				func() (State, error) { return state, nil }, testAPIServerID,
   186  				tt.cacheTTL, fakeClock)
   187  
   188  			dataCtx := value.DefaultContext(testContextText)
   189  			originalText := []byte(testText)
   190  
   191  			transformedData, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   192  			if err != nil {
   193  				t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err)
   194  			}
   195  			untransformedData, _, err := transformer.TransformFromStorage(ctx, transformedData, dataCtx)
   196  			if err != nil {
   197  				t.Fatalf("could not decrypt Envelope transformer's encrypted data even once: %v", err)
   198  			}
   199  			if !bytes.Equal(untransformedData, originalText) {
   200  				t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData)
   201  			}
   202  
   203  			// advance the clock to allow cache entries to expire depending on TTL
   204  			fakeClock.Step(2 * time.Minute)
   205  			// force GC to run by performing a write
   206  			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
   207  
   208  			state, err = testStateFunc(ctx, envelopeService, fakeClock, useSeed)()
   209  			if err != nil {
   210  				t.Fatal(err)
   211  			}
   212  			envelopeService.SetDisabledStatus(tt.simulateKMSPluginFailure)
   213  
   214  			for i := 0; i < 10; i++ {
   215  				// Subsequent reads for the same data should work fine due to caching.
   216  				untransformedData, _, err = transformer.TransformFromStorage(ctx, transformedData, dataCtx)
   217  				if tt.expectedError != "" {
   218  					if err == nil {
   219  						t.Fatalf("expected error: %v, got nil", tt.expectedError)
   220  					}
   221  					if err.Error() != tt.expectedError {
   222  						t.Fatalf("expected error: %v, got: %v", tt.expectedError, err)
   223  					}
   224  				} else {
   225  					if err != nil {
   226  						t.Fatalf("unexpected error: %v", err)
   227  					}
   228  					if !bytes.Equal(untransformedData, originalText) {
   229  						t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData)
   230  					}
   231  				}
   232  			}
   233  			if int(envelopeService.decryptCalls) != tt.expectedDecryptCalls {
   234  				t.Fatalf("expected %d decrypt calls, got %d", tt.expectedDecryptCalls, envelopeService.decryptCalls)
   235  			}
   236  		})
   237  	}
   238  }
   239  
   240  func testStateFunc(ctx context.Context, envelopeService kmsservice.Service, clock clock.Clock, useSeed bool) func() (State, error) {
   241  	return func() (State, error) {
   242  		transformer, encObject, cacheKey, errGen := GenerateTransformer(ctx, string(uuid.NewUUID()), envelopeService, useSeed)
   243  		if errGen != nil {
   244  			return State{}, errGen
   245  		}
   246  		return State{
   247  			Transformer:         transformer,
   248  			EncryptedObject:     *encObject,
   249  			UID:                 "panda",
   250  			ExpirationTimestamp: clock.Now().Add(time.Hour),
   251  			CacheKey:            cacheKey,
   252  		}, nil
   253  	}
   254  }
   255  
   256  // TestEnvelopeTransformerStaleness validates that staleness checks on read honor the data returned from the StateFunc.
   257  func TestEnvelopeTransformerStaleness(t *testing.T) {
   258  	t.Parallel()
   259  	testCases := []struct {
   260  		desc          string
   261  		expectedStale bool
   262  		testErr       error
   263  		testKeyID     string
   264  		useSeedWrite  bool
   265  		useSeedRead   bool
   266  	}{
   267  		{
   268  			desc:          "stateFunc returns err",
   269  			expectedStale: false,
   270  			testErr:       fmt.Errorf("failed to perform status section of the healthz check for KMS Provider"),
   271  			testKeyID:     "",
   272  		},
   273  		{
   274  			desc:          "stateFunc returns same keyID, not using seed",
   275  			expectedStale: false,
   276  			testErr:       nil,
   277  			testKeyID:     testKeyVersion,
   278  		},
   279  		{
   280  			desc:          "stateFunc returns same keyID, using seed",
   281  			expectedStale: false,
   282  			testErr:       nil,
   283  			testKeyID:     testKeyVersion,
   284  			useSeedWrite:  true,
   285  			useSeedRead:   true,
   286  		},
   287  		{
   288  			desc:          "stateFunc returns same keyID, migrating away from seed",
   289  			expectedStale: true,
   290  			testErr:       nil,
   291  			testKeyID:     testKeyVersion,
   292  			useSeedWrite:  true,
   293  			useSeedRead:   false,
   294  		},
   295  		{
   296  			desc:          "stateFunc returns same keyID, migrating to seed",
   297  			expectedStale: true,
   298  			testErr:       nil,
   299  			testKeyID:     testKeyVersion,
   300  			useSeedWrite:  false,
   301  			useSeedRead:   true,
   302  		},
   303  		{
   304  			desc:          "stateFunc returns different keyID",
   305  			expectedStale: true,
   306  			testErr:       nil,
   307  			testKeyID:     "2",
   308  		},
   309  	}
   310  
   311  	for _, tt := range testCases {
   312  		tt := tt
   313  		t.Run(tt.desc, func(t *testing.T) {
   314  			t.Parallel()
   315  
   316  			ctx := testContext(t)
   317  
   318  			envelopeService := newTestEnvelopeService()
   319  			state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, tt.useSeedWrite)()
   320  			if err != nil {
   321  				t.Fatal(err)
   322  			}
   323  			var stateErr error
   324  
   325  			transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   326  				func() (State, error) { return state, stateErr }, testAPIServerID,
   327  			)
   328  
   329  			dataCtx := value.DefaultContext(testContextText)
   330  			originalText := []byte(testText)
   331  
   332  			transformedData, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   333  			if err != nil {
   334  				t.Fatalf("envelopeTransformer: error while transforming data (%v) to storage: %s", originalText, err)
   335  			}
   336  
   337  			// inject test data before performing a read
   338  			state.EncryptedObject.KeyID = tt.testKeyID
   339  			if tt.useSeedRead {
   340  				state.EncryptedObject.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_HKDF_SHA256_XNONCE_AES_GCM_SEED
   341  			} else {
   342  				state.EncryptedObject.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_AES_GCM_KEY
   343  			}
   344  			stateErr = tt.testErr
   345  
   346  			_, stale, err := transformer.TransformFromStorage(ctx, transformedData, dataCtx)
   347  			if tt.testErr != nil {
   348  				if err == nil {
   349  					t.Fatalf("envelopeTransformer: expected error: %v, got nil", tt.testErr)
   350  				}
   351  				if err.Error() != tt.testErr.Error() {
   352  					t.Fatalf("envelopeTransformer: expected error: %v, got: %v", tt.testErr, err)
   353  				}
   354  			} else {
   355  				if err != nil {
   356  					t.Fatalf("envelopeTransformer: unexpected error: %v", err)
   357  				}
   358  				if stale != tt.expectedStale {
   359  					t.Fatalf("envelopeTransformer TransformFromStorage determined keyID staleness incorrectly, expected: %v, got %v", tt.expectedStale, stale)
   360  				}
   361  			}
   362  		})
   363  	}
   364  }
   365  
   366  func TestEnvelopeTransformerStateFunc(t *testing.T) {
   367  	t.Parallel()
   368  
   369  	ctx := testContext(t)
   370  
   371  	useSeed := randomBool()
   372  
   373  	envelopeService := newTestEnvelopeService()
   374  	state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, useSeed)()
   375  	if err != nil {
   376  		t.Fatal(err)
   377  	}
   378  
   379  	// start with a broken state
   380  	stateErr := fmt.Errorf("some state error")
   381  
   382  	transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   383  		func() (State, error) { return state, stateErr }, testAPIServerID,
   384  	)
   385  
   386  	dataCtx := value.DefaultContext(testContextText)
   387  	originalText := []byte(testText)
   388  
   389  	t.Run("nothing works when the state is broken", func(t *testing.T) {
   390  		_, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   391  		if err != stateErr {
   392  			t.Fatalf("expected state error, got: %v", err)
   393  		}
   394  		o := &kmstypes.EncryptedObject{
   395  			EncryptedData:      []byte{1},
   396  			KeyID:              "2",
   397  			EncryptedDEKSource: []byte{3},
   398  			Annotations:        nil,
   399  		}
   400  		if useSeed {
   401  			o.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_HKDF_SHA256_XNONCE_AES_GCM_SEED
   402  		} else {
   403  			o.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_AES_GCM_KEY
   404  		}
   405  		data, err := proto.Marshal(o)
   406  		if err != nil {
   407  			t.Fatal(err)
   408  		}
   409  		_, _, err = transformer.TransformFromStorage(ctx, data, dataCtx)
   410  		if err != stateErr {
   411  			t.Fatalf("expected state error, got: %v", err)
   412  		}
   413  	})
   414  
   415  	// fix the state
   416  	stateErr = nil
   417  
   418  	var encryptedData []byte
   419  
   420  	t.Run("everything works when the state is fixed", func(t *testing.T) {
   421  		encryptedData, err = transformer.TransformToStorage(ctx, originalText, dataCtx)
   422  		if err != nil {
   423  			t.Fatal(err)
   424  		}
   425  		_, _, err = transformer.TransformFromStorage(ctx, encryptedData, dataCtx)
   426  		if err != nil {
   427  			t.Fatal(err)
   428  		}
   429  	})
   430  
   431  	// break the plugin
   432  	envelopeService.SetDisabledStatus(true)
   433  
   434  	t.Run("everything works even when the plugin is down but the state is valid", func(t *testing.T) {
   435  		data, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   436  		if err != nil {
   437  			t.Fatal(err)
   438  		}
   439  		_, _, err = transformer.TransformFromStorage(ctx, data, dataCtx)
   440  		if err != nil {
   441  			t.Fatal(err)
   442  		}
   443  	})
   444  
   445  	// make the state invalid
   446  	state.ExpirationTimestamp = time.Now().Add(-time.Hour)
   447  
   448  	t.Run("writes fail when the plugin is down and the state is invalid", func(t *testing.T) {
   449  		_, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   450  		if !strings.Contains(errString(err), `encryptedDEKSource with keyID hash "sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b" expired at`) {
   451  			t.Fatalf("expected expiration error, got: %v", err)
   452  		}
   453  	})
   454  
   455  	t.Run("reads succeed when the plugin is down and the state is invalid", func(t *testing.T) {
   456  		_, _, err = transformer.TransformFromStorage(ctx, encryptedData, dataCtx)
   457  		if err != nil {
   458  			t.Fatal(err)
   459  		}
   460  	})
   461  
   462  	t.Run("reads for a different DEK fail when the plugin is down and the state is invalid", func(t *testing.T) {
   463  		obj := &kmstypes.EncryptedObject{}
   464  		if err := proto.Unmarshal(encryptedData, obj); err != nil {
   465  			t.Fatal(err)
   466  		}
   467  
   468  		obj.EncryptedDEKSource = append(obj.EncryptedDEKSource, 1) // skip StateFunc transformer
   469  
   470  		data, err := proto.Marshal(obj)
   471  		if err != nil {
   472  			t.Fatal(err)
   473  		}
   474  
   475  		_, _, err = transformer.TransformFromStorage(ctx, data, dataCtx)
   476  		if errString(err) != "failed to decrypt DEK, error: Envelope service was disabled" {
   477  			t.Fatal(err)
   478  		}
   479  	})
   480  }
   481  
   482  func TestTransformToStorageError(t *testing.T) {
   483  	t.Parallel()
   484  	testCases := []struct {
   485  		name        string
   486  		annotations map[string][]byte
   487  	}{
   488  		{
   489  			name: "invalid annotation key",
   490  			annotations: map[string][]byte{
   491  				"http://foo.example.com": []byte("bar"),
   492  			},
   493  		},
   494  		{
   495  			name: "annotation value size too large",
   496  			annotations: map[string][]byte{
   497  				"simple": []byte(strings.Repeat("a", 32*1024)),
   498  			},
   499  		},
   500  		{
   501  			name: "annotations size too large",
   502  			annotations: map[string][]byte{
   503  				"simple":  []byte(strings.Repeat("a", 31*1024)),
   504  				"simple2": []byte(strings.Repeat("a", 1024)),
   505  			},
   506  		},
   507  	}
   508  
   509  	for _, tt := range testCases {
   510  		tt := tt
   511  		t.Run(tt.name, func(t *testing.T) {
   512  			t.Parallel()
   513  
   514  			ctx := testContext(t)
   515  
   516  			envelopeService := newTestEnvelopeService()
   517  			envelopeService.SetAnnotations(tt.annotations)
   518  			transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   519  				testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool()),
   520  				testAPIServerID,
   521  			)
   522  			dataCtx := value.DefaultContext(testContextText)
   523  
   524  			_, err := transformer.TransformToStorage(ctx, []byte(testText), dataCtx)
   525  			if err == nil {
   526  				t.Fatalf("expected error, got nil")
   527  			}
   528  			if !strings.Contains(err.Error(), "failed to validate annotations") {
   529  				t.Fatalf("expected error to contain 'failed to validate annotations', got %v", err)
   530  			}
   531  		})
   532  	}
   533  }
   534  
   535  func TestEncodeDecode(t *testing.T) {
   536  	transformer := &envelopeTransformer{}
   537  
   538  	obj := &kmstypes.EncryptedObject{
   539  		EncryptedData:      []byte{0x01, 0x02, 0x03},
   540  		KeyID:              "1",
   541  		EncryptedDEKSource: []byte{0x04, 0x05, 0x06},
   542  	}
   543  
   544  	data, err := transformer.doEncode(obj)
   545  	if err != nil {
   546  		t.Fatalf("envelopeTransformer: error while encoding data: %s", err)
   547  	}
   548  	got, err := transformer.doDecode(data)
   549  	if err != nil {
   550  		t.Fatalf("envelopeTransformer: error while decoding data: %s", err)
   551  	}
   552  	// reset internal field modified by marshaling obj
   553  	obj.XXX_sizecache = 0
   554  	if !reflect.DeepEqual(got, obj) {
   555  		t.Fatalf("envelopeTransformer: decoded data does not match original data. Got: %v, want %v", got, obj)
   556  	}
   557  }
   558  
   559  func TestValidateEncryptedObject(t *testing.T) {
   560  	t.Parallel()
   561  	testCases := []struct {
   562  		desc          string
   563  		originalData  *kmstypes.EncryptedObject
   564  		expectedError error
   565  	}{
   566  		{
   567  			desc:          "encrypted object is nil",
   568  			originalData:  nil,
   569  			expectedError: fmt.Errorf("encrypted object is nil"),
   570  		},
   571  		{
   572  			desc: "encrypted data is nil",
   573  			originalData: &kmstypes.EncryptedObject{
   574  				KeyID:              "1",
   575  				EncryptedDEKSource: []byte{0x01, 0x02, 0x03},
   576  			},
   577  			expectedError: fmt.Errorf("encrypted data is empty"),
   578  		},
   579  		{
   580  			desc: "encrypted data is []byte{}",
   581  			originalData: &kmstypes.EncryptedObject{
   582  				EncryptedDEKSource: []byte{0x01, 0x02, 0x03},
   583  				EncryptedData:      []byte{},
   584  			},
   585  			expectedError: fmt.Errorf("encrypted data is empty"),
   586  		},
   587  		{
   588  			desc: "invalid dek source type",
   589  			originalData: &kmstypes.EncryptedObject{
   590  				EncryptedDEKSource:     []byte{0x01, 0x02, 0x03},
   591  				EncryptedData:          []byte{0},
   592  				EncryptedDEKSourceType: 55,
   593  			},
   594  			expectedError: fmt.Errorf("unknown encryptedDEKSourceType: 55"),
   595  		},
   596  		{
   597  			desc: "empty dek source",
   598  			originalData: &kmstypes.EncryptedObject{
   599  				EncryptedData:          []byte{0},
   600  				EncryptedDEKSourceType: 1,
   601  				KeyID:                  "1",
   602  			},
   603  			expectedError: fmt.Errorf("failed to validate encrypted DEK source: encrypted DEK source is empty"),
   604  		},
   605  		{
   606  			desc: "empty key ID",
   607  			originalData: &kmstypes.EncryptedObject{
   608  				EncryptedDEKSource:     []byte{0x01, 0x02, 0x03},
   609  				EncryptedData:          []byte{0},
   610  				EncryptedDEKSourceType: 1,
   611  			},
   612  			expectedError: fmt.Errorf("failed to validate key id: keyID is empty"),
   613  		},
   614  		{
   615  			desc: "invalid annotations",
   616  			originalData: &kmstypes.EncryptedObject{
   617  				EncryptedDEKSource:     []byte{0x01, 0x02, 0x03},
   618  				EncryptedData:          []byte{0},
   619  				EncryptedDEKSourceType: 1,
   620  				KeyID:                  "1",
   621  				Annotations:            map[string][]byte{"@": nil},
   622  			},
   623  			expectedError: fmt.Errorf(`failed to validate annotations: annotations: Invalid value: "@": a lowercase RFC 1123 subdomain must consist of lower case alphanumeric characters, '-' or '.', and must start and end with an alphanumeric character (e.g. 'example.com', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*')`),
   624  		},
   625  	}
   626  
   627  	for _, tt := range testCases {
   628  		t.Run(tt.desc, func(t *testing.T) {
   629  			err := ValidateEncryptedObject(tt.originalData)
   630  			if err == nil {
   631  				t.Fatalf("envelopeTransformer: expected error while decoding data, got nil")
   632  			}
   633  
   634  			if err.Error() != tt.expectedError.Error() {
   635  				t.Fatalf("doDecode() error: expected %v, got %v", tt.expectedError, err)
   636  			}
   637  		})
   638  	}
   639  }
   640  
   641  func TestValidateAnnotations(t *testing.T) {
   642  	t.Parallel()
   643  	successCases := []map[string][]byte{
   644  		{"a.com": []byte("bar")},
   645  		{"k8s.io": []byte("bar")},
   646  		{"dev.k8s.io": []byte("bar")},
   647  		{"dev.k8s.io.": []byte("bar")},
   648  		{"foo.example.com": []byte("bar")},
   649  		{"this.is.a.really.long.fqdn": []byte("bar")},
   650  		{"bbc.co.uk": []byte("bar")},
   651  		{"10.0.0.1": []byte("bar")}, // DNS labels can start with numbers and there is no requirement for letters.
   652  		{"hyphens-are-good.k8s.io": []byte("bar")},
   653  		{strings.Repeat("a", 63) + ".k8s.io": []byte("bar")},
   654  		{strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + "." + strings.Repeat("c", 63) + "." + strings.Repeat("d", 54) + ".k8s.io": []byte("bar")},
   655  	}
   656  	t.Run("success", func(t *testing.T) {
   657  		for i := range successCases {
   658  			i := i
   659  			t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   660  				t.Parallel()
   661  				if err := validateAnnotations(successCases[i]); err != nil {
   662  					t.Errorf("case[%d] expected success, got %#v", i, err)
   663  				}
   664  			})
   665  		}
   666  	})
   667  
   668  	atleastTwoSegmentsErrorMsg := "should be a domain with at least two segments separated by dots"
   669  	moreThan63CharsErrorMsg := "must be no more than 63 characters"
   670  	moreThan253CharsErrorMsg := "must be no more than 253 characters"
   671  	dns1123SubdomainErrorMsg := "a lowercase RFC 1123 subdomain must consist of lower case alphanumeric characters, '-' or '.', and must start and end with an alphanumeric character"
   672  
   673  	annotationsNameErrorCases := []struct {
   674  		annotations map[string][]byte
   675  		expect      string
   676  	}{
   677  		{map[string][]byte{".": []byte("bar")}, dns1123SubdomainErrorMsg},
   678  		{map[string][]byte{"...": []byte("bar")}, dns1123SubdomainErrorMsg},
   679  		{map[string][]byte{".io": []byte("bar")}, dns1123SubdomainErrorMsg},
   680  		{map[string][]byte{"com": []byte("bar")}, atleastTwoSegmentsErrorMsg},
   681  		{map[string][]byte{".com": []byte("bar")}, dns1123SubdomainErrorMsg},
   682  		{map[string][]byte{"Dev.k8s.io": []byte("bar")}, dns1123SubdomainErrorMsg},
   683  		{map[string][]byte{".foo.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   684  		{map[string][]byte{"*.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   685  		{map[string][]byte{"*.bar.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   686  		{map[string][]byte{"*.foo.bar.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   687  		{map[string][]byte{"underscores_are_bad.k8s.io": []byte("bar")}, dns1123SubdomainErrorMsg},
   688  		{map[string][]byte{"foo@bar.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   689  		{map[string][]byte{"http://foo.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   690  		{map[string][]byte{strings.Repeat("a", 64) + ".k8s.io": []byte("bar")}, moreThan63CharsErrorMsg},
   691  		{map[string][]byte{strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + "." + strings.Repeat("c", 63) + "." + strings.Repeat("d", 55) + ".k8s.io": []byte("bar")}, moreThan253CharsErrorMsg},
   692  	}
   693  
   694  	t.Run("name error", func(t *testing.T) {
   695  		for i := range annotationsNameErrorCases {
   696  			i := i
   697  			t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   698  				t.Parallel()
   699  				err := validateAnnotations(annotationsNameErrorCases[i].annotations)
   700  				if err == nil {
   701  					t.Errorf("case[%d]: expected failure", i)
   702  				} else {
   703  					if !strings.Contains(err.Error(), annotationsNameErrorCases[i].expect) {
   704  						t.Errorf("case[%d]: error details do not include %q: %q", i, annotationsNameErrorCases[i].expect, err)
   705  					}
   706  				}
   707  			})
   708  		}
   709  	})
   710  
   711  	maxSizeErrMsg := "which exceeds the max size of"
   712  	annotationsSizeErrorCases := []struct {
   713  		annotations map[string][]byte
   714  		expect      string
   715  	}{
   716  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 33*1024))}, maxSizeErrMsg},
   717  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 32*1024))}, maxSizeErrMsg},
   718  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 64*1024))}, maxSizeErrMsg},
   719  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 31*1024)), "simple2": []byte(strings.Repeat("a", 1024))}, maxSizeErrMsg},
   720  		{map[string][]byte{strings.Repeat("a", 253): []byte(strings.Repeat("a", 32*1024))}, maxSizeErrMsg},
   721  	}
   722  	t.Run("size error", func(t *testing.T) {
   723  		for i := range annotationsSizeErrorCases {
   724  			i := i
   725  			t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   726  				t.Parallel()
   727  				err := validateAnnotations(annotationsSizeErrorCases[i].annotations)
   728  				if err == nil {
   729  					t.Errorf("case[%d]: expected failure", i)
   730  				} else {
   731  					if !strings.Contains(err.Error(), annotationsSizeErrorCases[i].expect) {
   732  						t.Errorf("case[%d]: error details do not include %q: %q", i, annotationsSizeErrorCases[i].expect, err)
   733  					}
   734  				}
   735  			})
   736  		}
   737  	})
   738  }
   739  
   740  func TestValidateKeyID(t *testing.T) {
   741  	t.Parallel()
   742  	testCases := []struct {
   743  		name              string
   744  		keyID             string
   745  		expectedError     string
   746  		expectedErrorCode string
   747  	}{
   748  		{
   749  			name:              "valid key ID",
   750  			keyID:             "1234",
   751  			expectedError:     "",
   752  			expectedErrorCode: "ok",
   753  		},
   754  		{
   755  			name:              "empty key ID",
   756  			keyID:             "",
   757  			expectedError:     "keyID is empty",
   758  			expectedErrorCode: "empty",
   759  		},
   760  		{
   761  			name:              "keyID size is greater than 1 kB",
   762  			keyID:             strings.Repeat("a", 1024+1),
   763  			expectedError:     "which exceeds the max size of",
   764  			expectedErrorCode: "too_long",
   765  		},
   766  	}
   767  
   768  	for _, tt := range testCases {
   769  		tt := tt
   770  		t.Run(tt.name, func(t *testing.T) {
   771  			t.Parallel()
   772  			errCode, err := ValidateKeyID(tt.keyID)
   773  			if tt.expectedError != "" {
   774  				if err == nil {
   775  					t.Fatalf("expected error %q, got nil", tt.expectedError)
   776  				}
   777  				if !strings.Contains(err.Error(), tt.expectedError) {
   778  					t.Fatalf("expected error %q, got %q", tt.expectedError, err)
   779  				}
   780  			} else {
   781  				if err != nil {
   782  					t.Fatalf("expected no error, got %q", err)
   783  				}
   784  			}
   785  			if tt.expectedErrorCode != string(errCode) {
   786  				t.Fatalf("expected %s errCode, got %s", tt.expectedErrorCode, string(errCode))
   787  			}
   788  		})
   789  	}
   790  }
   791  
   792  func TestValidateEncryptedDEKSource(t *testing.T) {
   793  	t.Parallel()
   794  	testCases := []struct {
   795  		name               string
   796  		encryptedDEKSource []byte
   797  		expectedError      string
   798  	}{
   799  		{
   800  			name:               "encrypted DEK source is nil",
   801  			encryptedDEKSource: nil,
   802  			expectedError:      "encrypted DEK source is empty",
   803  		},
   804  		{
   805  			name:               "encrypted DEK source is empty",
   806  			encryptedDEKSource: []byte{},
   807  			expectedError:      "encrypted DEK source is empty",
   808  		},
   809  		{
   810  			name:               "encrypted DEK source size is greater than 1 kB",
   811  			encryptedDEKSource: bytes.Repeat([]byte("a"), 1024+1),
   812  			expectedError:      "which exceeds the max size of",
   813  		},
   814  		{
   815  			name:               "valid encrypted DEK source",
   816  			encryptedDEKSource: []byte{0x01, 0x02, 0x03},
   817  			expectedError:      "",
   818  		},
   819  	}
   820  
   821  	for _, tt := range testCases {
   822  		tt := tt
   823  		t.Run(tt.name, func(t *testing.T) {
   824  			t.Parallel()
   825  			err := validateEncryptedDEKSource(tt.encryptedDEKSource)
   826  			if tt.expectedError != "" {
   827  				if err == nil {
   828  					t.Fatalf("expected error %q, got nil", tt.expectedError)
   829  				}
   830  				if !strings.Contains(err.Error(), tt.expectedError) {
   831  					t.Fatalf("expected error %q, got %q", tt.expectedError, err)
   832  				}
   833  			} else {
   834  				if err != nil {
   835  					t.Fatalf("expected no error, got %q", err)
   836  				}
   837  			}
   838  		})
   839  	}
   840  }
   841  
   842  func TestEnvelopeMetrics(t *testing.T) {
   843  	envelopeService := newTestEnvelopeService()
   844  	transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   845  		testStateFunc(testContext(t), envelopeService, clock.RealClock{}, randomBool()),
   846  		testAPIServerID,
   847  	)
   848  
   849  	dataCtx := value.DefaultContext(testContextText)
   850  
   851  	kmsv2Transformer := value.PrefixTransformer{Prefix: []byte("k8s:enc:kms:v2:"), Transformer: transformer}
   852  
   853  	testCases := []struct {
   854  		desc                  string
   855  		keyVersionFromEncrypt string
   856  		prefix                value.Transformer
   857  		metrics               []string
   858  		want                  string
   859  	}{
   860  		{
   861  			desc:                  "keyIDHash total",
   862  			keyVersionFromEncrypt: testKeyVersion,
   863  			prefix:                value.NewPrefixTransformers(nil, kmsv2Transformer),
   864  			metrics: []string{
   865  				"apiserver_envelope_encryption_key_id_hash_total",
   866  			},
   867  			want: fmt.Sprintf(`
   868  				# HELP apiserver_envelope_encryption_key_id_hash_total [ALPHA] Number of times a keyID is used split by transformation type, provider, and apiserver identity.
   869  				# TYPE apiserver_envelope_encryption_key_id_hash_total counter
   870  				apiserver_envelope_encryption_key_id_hash_total{apiserver_id_hash="%s",key_id_hash="%s",provider_name="%s",transformation_type="%s"} 1
   871  				apiserver_envelope_encryption_key_id_hash_total{apiserver_id_hash="%s",key_id_hash="%s",provider_name="%s",transformation_type="%s"} 1
   872  				`, testAPIServerIDHash, testKeyHash, testProviderName, metrics.FromStorageLabel, testAPIServerIDHash, testKeyHash, testProviderName, metrics.ToStorageLabel),
   873  		},
   874  	}
   875  
   876  	metrics.KeyIDHashTotal.Reset()
   877  	metrics.InvalidKeyIDFromStatusTotal.Reset()
   878  
   879  	for _, tt := range testCases {
   880  		t.Run(tt.desc, func(t *testing.T) {
   881  			defer metrics.KeyIDHashTotal.Reset()
   882  			defer metrics.InvalidKeyIDFromStatusTotal.Reset()
   883  			ctx := testContext(t)
   884  			envelopeService.keyVersion = tt.keyVersionFromEncrypt
   885  			transformedData, err := tt.prefix.TransformToStorage(ctx, []byte(testText), dataCtx)
   886  			if err != nil {
   887  				t.Fatal(err)
   888  			}
   889  			if _, _, err := tt.prefix.TransformFromStorage(ctx, transformedData, dataCtx); err != nil {
   890  				t.Fatal(err)
   891  			}
   892  
   893  			if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(tt.want), tt.metrics...); err != nil {
   894  				t.Fatal(err)
   895  			}
   896  		})
   897  	}
   898  }
   899  
   900  // TestEnvelopeMetricsCache validates the correctness of the apiserver_envelope_encryption_dek_source_cache_size metric
   901  // and asserts that all of the associated logic is go routine safe.
   902  // 1. Multiple transformers are created, which should result in unique cache size for each provider
   903  // 2. A transformer with known number of states was created to encrypt, then on restart, another transformer
   904  // was created, which should result in expected number of cache keys for all the decryption calls for each
   905  // state used previously for encryption.
   906  func TestEnvelopeMetricsCache(t *testing.T) {
   907  	envelopeService := newTestEnvelopeService()
   908  	envelopeService.keyVersion = testKeyVersion
   909  	state, err := testStateFunc(testContext(t), envelopeService, clock.RealClock{}, randomBool())()
   910  	if err != nil {
   911  		t.Fatal(err)
   912  	}
   913  	ctx := testContext(t)
   914  	dataCtx := value.DefaultContext(testContextText)
   915  	provider1 := "one"
   916  	provider2 := "two"
   917  	numOfStates := 10
   918  
   919  	testCases := []struct {
   920  		desc    string
   921  		metrics []string
   922  		want    string
   923  	}{
   924  		{
   925  			desc: "dek source cache size",
   926  			metrics: []string{
   927  				"apiserver_envelope_encryption_dek_source_cache_size",
   928  			},
   929  			want: fmt.Sprintf(`
   930  				# HELP apiserver_envelope_encryption_dek_source_cache_size [ALPHA] Number of records in data encryption key (DEK) source cache. On a restart, this value is an approximation of the number of decrypt RPC calls the server will make to the KMS plugin.
   931  				# TYPE apiserver_envelope_encryption_dek_source_cache_size gauge
   932          		apiserver_envelope_encryption_dek_source_cache_size{provider_name="%s"} %d
   933          		apiserver_envelope_encryption_dek_source_cache_size{provider_name="%s"} 1
   934  				`, provider1, numOfStates, provider2),
   935  		},
   936  	}
   937  	transformer1 := NewEnvelopeTransformer(envelopeService, provider1, func() (State, error) {
   938  		// return different states to ensure we get expected number of cache keys after restart on decryption
   939  		return testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool())()
   940  	}, testAPIServerID)
   941  	transformer2 := NewEnvelopeTransformer(envelopeService, provider2, func() (State, error) { return state, nil }, testAPIServerID)
   942  	// used for restart
   943  	transformer3 := NewEnvelopeTransformer(envelopeService, provider1, func() (State, error) { return state, nil }, testAPIServerID)
   944  	var transformedDatas [][]byte
   945  	for j := 0; j < numOfStates; j++ {
   946  		transformedData, err := transformer1.TransformToStorage(ctx, []byte(testText), dataCtx)
   947  		if err != nil {
   948  			t.Fatal(err)
   949  		}
   950  		transformedDatas = append(transformedDatas, transformedData)
   951  	}
   952  
   953  	for _, tt := range testCases {
   954  		t.Run(tt.desc, func(t *testing.T) {
   955  			metrics.DekSourceCacheSize.Reset()
   956  			var wg sync.WaitGroup
   957  			wg.Add(2 * numOfStates)
   958  			for i := 0; i < numOfStates; i++ {
   959  				i := i
   960  				go func() {
   961  					defer wg.Done()
   962  					// mimick a restart, the server will make decrypt RPC calls to the KMS plugin
   963  					// check cache metrics for the decrypt / read flow, which should repopulate the cache
   964  					if _, _, err := transformer3.TransformFromStorage(ctx, transformedDatas[i], dataCtx); err != nil {
   965  						panic(err)
   966  					}
   967  				}()
   968  				go func() {
   969  					defer wg.Done()
   970  					// check cache metrics for the encrypt / write flow
   971  					_, err := transformer2.TransformToStorage(ctx, []byte(testText), dataCtx)
   972  					if err != nil {
   973  						panic(err)
   974  					}
   975  				}()
   976  			}
   977  			wg.Wait()
   978  			if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(tt.want), tt.metrics...); err != nil {
   979  				t.Fatal(err)
   980  			}
   981  		})
   982  	}
   983  }
   984  
   985  var flagOnce sync.Once // support running `go test -count X`
   986  
   987  func TestEnvelopeLogging(t *testing.T) {
   988  	flagOnce.Do(func() {
   989  		klog.InitFlags(nil)
   990  	})
   991  	flag.Set("v", "6")
   992  	flag.Parse()
   993  
   994  	testCases := []struct {
   995  		desc     string
   996  		ctx      context.Context
   997  		wantLogs []string
   998  	}{
   999  		{
  1000  			desc: "no request info in context",
  1001  			ctx:  testContext(t),
  1002  			wantLogs: []string{
  1003  				`"encrypting content using envelope service" uid="UID"`,
  1004  				`"encrypting content using DEK" uid="UID" key="0123456789" group="" version="" resource="" subresource="" verb="" namespace="" name=""`,
  1005  				`"decrypting content using envelope service" uid="UID" key="0123456789" group="" version="" resource="" subresource="" verb="" namespace="" name=""`,
  1006  			},
  1007  		},
  1008  		{
  1009  			desc: "request info in context",
  1010  			ctx: genericapirequest.WithRequestInfo(testContext(t), &genericapirequest.RequestInfo{
  1011  				APIGroup:    "awesome.bears.com",
  1012  				APIVersion:  "v1",
  1013  				Resource:    "pandas",
  1014  				Subresource: "status",
  1015  				Namespace:   "kube-system",
  1016  				Name:        "panda",
  1017  				Verb:        "update",
  1018  			}),
  1019  			wantLogs: []string{
  1020  				`"encrypting content using envelope service" uid="UID"`,
  1021  				`"encrypting content using DEK" uid="UID" key="0123456789" group="awesome.bears.com" version="v1" resource="pandas" subresource="status" verb="update" namespace="kube-system" name="panda"`,
  1022  				`"decrypting content using envelope service" uid="UID" key="0123456789" group="awesome.bears.com" version="v1" resource="pandas" subresource="status" verb="update" namespace="kube-system" name="panda"`,
  1023  			},
  1024  		},
  1025  	}
  1026  
  1027  	for _, tc := range testCases {
  1028  		tc := tc
  1029  		t.Run(tc.desc, func(t *testing.T) {
  1030  			var buf bytes.Buffer
  1031  			klog.SetOutput(&buf)
  1032  			klog.LogToStderr(false)
  1033  			defer klog.LogToStderr(true)
  1034  
  1035  			envelopeService := newTestEnvelopeService()
  1036  			fakeClock := testingclock.NewFakeClock(time.Now())
  1037  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1038  				testStateFunc(tc.ctx, envelopeService, clock.RealClock{}, randomBool()), testAPIServerID, 1*time.Second, fakeClock)
  1039  
  1040  			dataCtx := value.DefaultContext([]byte(testContextText))
  1041  			originalText := []byte(testText)
  1042  
  1043  			transformedData, err := transformer.TransformToStorage(tc.ctx, originalText, dataCtx)
  1044  			if err != nil {
  1045  				t.Fatalf("envelopeTransformer: error while transforming data to storage: %v", err)
  1046  			}
  1047  
  1048  			// advance the clock to trigger cache to expire, so we make a decrypt call that will log
  1049  			fakeClock.Step(2 * time.Second)
  1050  			// force GC to run by performing a write
  1051  			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
  1052  
  1053  			_, _, err = transformer.TransformFromStorage(tc.ctx, transformedData, dataCtx)
  1054  			if err != nil {
  1055  				t.Fatalf("could not decrypt Envelope transformer's encrypted data even once: %v", err)
  1056  			}
  1057  
  1058  			klog.Flush()
  1059  			klog.SetOutput(&bytes.Buffer{}) // prevent further writes into buf
  1060  			capturedOutput := buf.String()
  1061  
  1062  			// replace the uid with a constant to make the test output stable and assertable
  1063  			capturedOutput = regexp.MustCompile(`uid="[^"]+"`).ReplaceAllString(capturedOutput, `uid="UID"`)
  1064  
  1065  			for _, wantLog := range tc.wantLogs {
  1066  				if !strings.Contains(capturedOutput, wantLog) {
  1067  					t.Errorf("expected log %q, got %q", wantLog, capturedOutput)
  1068  				}
  1069  			}
  1070  		})
  1071  	}
  1072  }
  1073  
  1074  func TestCacheNotCorrupted(t *testing.T) {
  1075  	ctx := testContext(t)
  1076  
  1077  	envelopeService := newTestEnvelopeService()
  1078  	envelopeService.SetAnnotations(map[string][]byte{
  1079  		"encrypted-dek.kms.kubernetes.io": []byte("encrypted-dek-0"),
  1080  	})
  1081  
  1082  	fakeClock := testingclock.NewFakeClock(time.Now())
  1083  
  1084  	state, err := testStateFunc(ctx, envelopeService, fakeClock, randomBool())()
  1085  	if err != nil {
  1086  		t.Fatal(err)
  1087  	}
  1088  
  1089  	transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1090  		func() (State, error) { return state, nil }, testAPIServerID,
  1091  		1*time.Second, fakeClock)
  1092  
  1093  	dataCtx := value.DefaultContext(testContextText)
  1094  	originalText := []byte(testText)
  1095  
  1096  	transformedData1, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
  1097  	if err != nil {
  1098  		t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err)
  1099  	}
  1100  
  1101  	// this is to mimic a plugin that sets a static response for ciphertext
  1102  	// but uses the annotation field to send the actual encrypted DEK source.
  1103  	envelopeService.SetCiphertext(state.EncryptedObject.EncryptedDEKSource)
  1104  	// for this plugin, it indicates a change in the remote key ID as the returned
  1105  	// encrypted DEK source is different.
  1106  	envelopeService.SetAnnotations(map[string][]byte{
  1107  		"encrypted-dek.kms.kubernetes.io": []byte("encrypted-dek-1"),
  1108  	})
  1109  
  1110  	state, err = testStateFunc(ctx, envelopeService, fakeClock, randomBool())()
  1111  	if err != nil {
  1112  		t.Fatal(err)
  1113  	}
  1114  
  1115  	transformer = newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1116  		func() (State, error) { return state, nil }, testAPIServerID,
  1117  		1*time.Second, fakeClock)
  1118  
  1119  	transformedData2, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
  1120  	if err != nil {
  1121  		t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err)
  1122  	}
  1123  
  1124  	if _, _, err := transformer.TransformFromStorage(ctx, transformedData1, dataCtx); err != nil {
  1125  		t.Fatal(err)
  1126  	}
  1127  	if _, _, err := transformer.TransformFromStorage(ctx, transformedData2, dataCtx); err != nil {
  1128  		t.Fatal(err)
  1129  	}
  1130  }
  1131  
  1132  func TestGenerateCacheKey(t *testing.T) {
  1133  	encryptedDEKSource1 := []byte{1, 2, 3}
  1134  	keyID1 := "id1"
  1135  	annotations1 := map[string][]byte{"a": {4, 5}, "b": {6, 7}}
  1136  	encryptedDEKSourceType1 := kmstypes.EncryptedDEKSourceType_AES_GCM_KEY
  1137  
  1138  	encryptedDEKSource2 := []byte{4, 5, 6}
  1139  	keyID2 := "id2"
  1140  	annotations2 := map[string][]byte{"x": {9, 10}, "y": {11, 12}}
  1141  	encryptedDEKSourceType2 := kmstypes.EncryptedDEKSourceType_HKDF_SHA256_XNONCE_AES_GCM_SEED
  1142  
  1143  	// generate all possible combinations of the above
  1144  	testCases := []struct {
  1145  		encryptedDEKSourceType kmstypes.EncryptedDEKSourceType
  1146  		encryptedDEKSource     []byte
  1147  		keyID                  string
  1148  		annotations            map[string][]byte
  1149  	}{
  1150  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID1, annotations1},
  1151  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID1, annotations2},
  1152  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID2, annotations1},
  1153  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID2, annotations2},
  1154  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID1, annotations1},
  1155  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID1, annotations2},
  1156  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID2, annotations1},
  1157  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID2, annotations2},
  1158  
  1159  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID1, annotations1},
  1160  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID1, annotations2},
  1161  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID2, annotations1},
  1162  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID2, annotations2},
  1163  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID1, annotations1},
  1164  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID1, annotations2},
  1165  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID2, annotations1},
  1166  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID2, annotations2},
  1167  	}
  1168  
  1169  	for _, tc := range testCases {
  1170  		tc := tc
  1171  		for _, tc2 := range testCases {
  1172  			tc2 := tc2
  1173  			t.Run(fmt.Sprintf("%+v-%+v", tc, tc2), func(t *testing.T) {
  1174  				key1, err1 := generateCacheKey(tc.encryptedDEKSourceType, tc.encryptedDEKSource, tc.keyID, tc.annotations)
  1175  				key2, err2 := generateCacheKey(tc2.encryptedDEKSourceType, tc2.encryptedDEKSource, tc2.keyID, tc2.annotations)
  1176  				if err1 != nil || err2 != nil {
  1177  					t.Errorf("generateCacheKey() want err=nil, got err1=%q, err2=%q", errString(err1), errString(err2))
  1178  				}
  1179  				if bytes.Equal(key1, key2) != reflect.DeepEqual(tc, tc2) {
  1180  					t.Errorf("expected %v, got %v", reflect.DeepEqual(tc, tc2), bytes.Equal(key1, key2))
  1181  				}
  1182  			})
  1183  		}
  1184  	}
  1185  }
  1186  
  1187  func TestGenerateTransformer(t *testing.T) {
  1188  	t.Parallel()
  1189  	testCases := []struct {
  1190  		name            string
  1191  		envelopeService func() kmsservice.Service
  1192  		expectedErr     string
  1193  	}{
  1194  		{
  1195  			name: "encrypt call fails",
  1196  			envelopeService: func() kmsservice.Service {
  1197  				envelopeService := newTestEnvelopeService()
  1198  				envelopeService.SetDisabledStatus(true)
  1199  				return envelopeService
  1200  			},
  1201  			expectedErr: "Envelope service was disabled",
  1202  		},
  1203  		{
  1204  			name: "invalid key ID",
  1205  			envelopeService: func() kmsservice.Service {
  1206  				envelopeService := newTestEnvelopeService()
  1207  				envelopeService.keyVersion = ""
  1208  				return envelopeService
  1209  			},
  1210  			expectedErr: "failed to validate key id: keyID is empty",
  1211  		},
  1212  		{
  1213  			name: "invalid encrypted DEK",
  1214  			envelopeService: func() kmsservice.Service {
  1215  				envelopeService := newTestEnvelopeService()
  1216  				envelopeService.SetCiphertext([]byte{})
  1217  				return envelopeService
  1218  			},
  1219  			expectedErr: "failed to validate encrypted DEK source: encrypted DEK source is empty",
  1220  		},
  1221  		{
  1222  			name: "invalid annotations",
  1223  			envelopeService: func() kmsservice.Service {
  1224  				envelopeService := newTestEnvelopeService()
  1225  				envelopeService.SetAnnotations(map[string][]byte{"invalid": {}})
  1226  				return envelopeService
  1227  			},
  1228  			expectedErr: "failed to validate annotations: annotations: Invalid value: \"invalid\": should be a domain with at least two segments separated by dots",
  1229  		},
  1230  		{
  1231  			name: "success",
  1232  			envelopeService: func() kmsservice.Service {
  1233  				return newTestEnvelopeService()
  1234  			},
  1235  			expectedErr: "",
  1236  		},
  1237  	}
  1238  
  1239  	for _, tc := range testCases {
  1240  		tc := tc
  1241  		t.Run(tc.name, func(t *testing.T) {
  1242  			t.Parallel()
  1243  
  1244  			transformer, encObject, cacheKey, err := GenerateTransformer(testContext(t), "panda", tc.envelopeService(), randomBool())
  1245  			if tc.expectedErr == "" {
  1246  				if err != nil {
  1247  					t.Errorf("expected no error, got %q", errString(err))
  1248  				}
  1249  				if transformer == nil {
  1250  					t.Error("expected transformer, got nil")
  1251  				}
  1252  				if encObject == nil {
  1253  					t.Error("expected encrypt response, got nil")
  1254  				}
  1255  				if cacheKey == nil {
  1256  					t.Error("expected cache key, got nil")
  1257  				}
  1258  			} else {
  1259  				if err == nil || !strings.Contains(err.Error(), tc.expectedErr) {
  1260  					t.Errorf("expected error %q, got %q", tc.expectedErr, errString(err))
  1261  				}
  1262  			}
  1263  		})
  1264  	}
  1265  }
  1266  
  1267  func TestEnvelopeTracing_TransformToStorage(t *testing.T) {
  1268  	testCases := []struct {
  1269  		desc     string
  1270  		expected []string
  1271  	}{
  1272  		{
  1273  			desc: "encrypt",
  1274  			expected: []string{
  1275  				"About to encrypt data using DEK",
  1276  				"Data encryption succeeded",
  1277  				"About to encode encrypted object",
  1278  				"Encoded encrypted object",
  1279  			},
  1280  		},
  1281  	}
  1282  
  1283  	for _, tc := range testCases {
  1284  		t.Run(tc.desc, func(t *testing.T) {
  1285  			fakeRecorder := tracetest.NewSpanRecorder()
  1286  			otelTracer := trace.NewTracerProvider(trace.WithSpanProcessor(fakeRecorder)).Tracer("test")
  1287  
  1288  			ctx := testContext(t)
  1289  			ctx, span := otelTracer.Start(ctx, "parent")
  1290  			defer span.End()
  1291  
  1292  			envelopeService := newTestEnvelopeService()
  1293  			fakeClock := testingclock.NewFakeClock(time.Now())
  1294  			state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool())()
  1295  			if err != nil {
  1296  				t.Fatal(err)
  1297  			}
  1298  
  1299  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1300  				func() (State, error) { return state, nil }, testAPIServerID, 1*time.Second, fakeClock)
  1301  
  1302  			dataCtx := value.DefaultContext([]byte(testContextText))
  1303  			originalText := []byte(testText)
  1304  
  1305  			if _, err := transformer.TransformToStorage(ctx, originalText, dataCtx); err != nil {
  1306  				t.Fatalf("envelopeTransformer: error while transforming data to storage: %v", err)
  1307  			}
  1308  
  1309  			output := fakeRecorder.Ended()
  1310  			if len(output) != 1 {
  1311  				t.Fatalf("expected 1 span, got %d", len(output))
  1312  			}
  1313  			out := output[0]
  1314  			validateTraceSpan(t, out, "TransformToStorage with envelopeTransformer", testProviderName, testAPIServerID, tc.expected)
  1315  		})
  1316  	}
  1317  }
  1318  
  1319  func TestEnvelopeTracing_TransformFromStorage(t *testing.T) {
  1320  	testCases := []struct {
  1321  		desc                     string
  1322  		cacheTTL                 time.Duration
  1323  		simulateKMSPluginFailure bool
  1324  		expected                 []string
  1325  	}{
  1326  		{
  1327  			desc:     "decrypt",
  1328  			cacheTTL: 5 * time.Second,
  1329  			expected: []string{
  1330  				"About to decode encrypted object",
  1331  				"Decoded encrypted object",
  1332  				"About to decrypt data using DEK",
  1333  				"Data decryption succeeded",
  1334  			},
  1335  		},
  1336  		{
  1337  			desc:     "decrypt with cache miss",
  1338  			cacheTTL: 1 * time.Second,
  1339  			expected: []string{
  1340  				"About to decode encrypted object",
  1341  				"Decoded encrypted object",
  1342  				"About to decrypt DEK using remote service",
  1343  				"DEK decryption succeeded",
  1344  				"About to decrypt data using DEK",
  1345  				"Data decryption succeeded",
  1346  			},
  1347  		},
  1348  		{
  1349  			desc:                     "decrypt with cache miss, simulate KMS plugin failure",
  1350  			cacheTTL:                 1 * time.Second,
  1351  			simulateKMSPluginFailure: true,
  1352  			expected: []string{
  1353  				"About to decode encrypted object",
  1354  				"Decoded encrypted object",
  1355  				"About to decrypt DEK using remote service",
  1356  				"DEK decryption failed",
  1357  				"exception",
  1358  			},
  1359  		},
  1360  	}
  1361  
  1362  	for _, tc := range testCases {
  1363  		t.Run(tc.desc, func(t *testing.T) {
  1364  			fakeRecorder := tracetest.NewSpanRecorder()
  1365  			otelTracer := trace.NewTracerProvider(trace.WithSpanProcessor(fakeRecorder)).Tracer("test")
  1366  
  1367  			ctx := testContext(t)
  1368  
  1369  			envelopeService := newTestEnvelopeService()
  1370  			fakeClock := testingclock.NewFakeClock(time.Now())
  1371  			state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool())()
  1372  			if err != nil {
  1373  				t.Fatal(err)
  1374  			}
  1375  
  1376  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1377  				func() (State, error) { return state, nil }, testAPIServerID, tc.cacheTTL, fakeClock)
  1378  
  1379  			dataCtx := value.DefaultContext([]byte(testContextText))
  1380  			originalText := []byte(testText)
  1381  
  1382  			transformedData, _ := transformer.TransformToStorage(ctx, originalText, dataCtx)
  1383  
  1384  			// advance the clock to allow cache entries to expire depending on TTL
  1385  			fakeClock.Step(2 * time.Second)
  1386  			// force GC to run by performing a write
  1387  			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
  1388  
  1389  			envelopeService.SetDisabledStatus(tc.simulateKMSPluginFailure)
  1390  
  1391  			// start recording only for the decrypt call
  1392  			ctx, span := otelTracer.Start(ctx, "parent")
  1393  			defer span.End()
  1394  
  1395  			_, _, _ = transformer.TransformFromStorage(ctx, transformedData, dataCtx)
  1396  
  1397  			output := fakeRecorder.Ended()
  1398  			validateTraceSpan(t, output[0], "TransformFromStorage with envelopeTransformer", testProviderName, testAPIServerID, tc.expected)
  1399  		})
  1400  	}
  1401  }
  1402  
  1403  func validateTraceSpan(t *testing.T, span trace.ReadOnlySpan, spanName, providerName, apiserverID string, expected []string) {
  1404  	t.Helper()
  1405  
  1406  	if span.Name() != spanName {
  1407  		t.Fatalf("expected span name %q, got %q", spanName, span.Name())
  1408  	}
  1409  	attrs := span.Attributes()
  1410  	if len(attrs) != 1 {
  1411  		t.Fatalf("expected 1 attributes, got %d", len(attrs))
  1412  	}
  1413  	if attrs[0].Key != "transformer.provider.name" && attrs[0].Value.AsString() != providerName {
  1414  		t.Errorf("expected providerName %q, got %q", providerName, attrs[0].Value.AsString())
  1415  	}
  1416  	if len(span.Events()) != len(expected) {
  1417  		t.Fatalf("expected %d events, got %d", len(expected), len(span.Events()))
  1418  	}
  1419  	for i, event := range span.Events() {
  1420  		if event.Name != expected[i] {
  1421  			t.Errorf("expected event %q, got %q", expected[i], event.Name)
  1422  		}
  1423  	}
  1424  }
  1425  
  1426  func errString(err error) string {
  1427  	if err == nil {
  1428  		return ""
  1429  	}
  1430  
  1431  	return err.Error()
  1432  }
  1433  
  1434  func randomBool() bool { return utilrand.Int()%2 == 1 }