k8s.io/apiserver@v0.31.1/pkg/storage/value/encrypt/envelope/kmsv2/envelope_test.go (about)

     1  //go:build !windows
     2  // +build !windows
     3  
     4  /*
     5  Copyright 2022 The Kubernetes Authors.
     6  
     7  Licensed under the Apache License, Version 2.0 (the "License");
     8  you may not use this file except in compliance with the License.
     9  You may obtain a copy of the License at
    10  
    11      http://www.apache.org/licenses/LICENSE-2.0
    12  
    13  Unless required by applicable law or agreed to in writing, software
    14  distributed under the License is distributed on an "AS IS" BASIS,
    15  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    16  See the License for the specific language governing permissions and
    17  limitations under the License.
    18  */
    19  
    20  // Package kmsv2 transforms values for storage at rest using a Envelope v2 provider
    21  package kmsv2
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"encoding/base64"
    27  	"flag"
    28  	"fmt"
    29  	"reflect"
    30  	"regexp"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"sync/atomic"
    35  	"testing"
    36  	"time"
    37  
    38  	"github.com/gogo/protobuf/proto"
    39  	"go.opentelemetry.io/otel/sdk/trace"
    40  	"go.opentelemetry.io/otel/sdk/trace/tracetest"
    41  
    42  	utilrand "k8s.io/apimachinery/pkg/util/rand"
    43  	"k8s.io/apimachinery/pkg/util/uuid"
    44  	genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
    45  	"k8s.io/apiserver/pkg/storage/value"
    46  	kmstypes "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/v2"
    47  	"k8s.io/apiserver/pkg/storage/value/encrypt/envelope/metrics"
    48  	"k8s.io/component-base/metrics/legacyregistry"
    49  	"k8s.io/component-base/metrics/testutil"
    50  	"k8s.io/klog/v2"
    51  	kmsservice "k8s.io/kms/pkg/service"
    52  	"k8s.io/utils/clock"
    53  	testingclock "k8s.io/utils/clock/testing"
    54  )
    55  
    56  const (
    57  	testText            = "abcdefghijklmnopqrstuvwxyz"
    58  	testContextText     = "0123456789"
    59  	testKeyHash         = "sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b"
    60  	testKeyVersion      = "1"
    61  	testAPIServerID     = "testAPIServerID"
    62  	testAPIServerIDHash = "sha256:14f9d63e669337ac6bfda2e2162915ee6a6067743eddd4e5c374b572f951ff37"
    63  )
    64  
    65  // testEnvelopeService is a mock Envelope service which can be used to simulate remote Envelope services
    66  // for testing of Envelope based encryption providers.
    67  type testEnvelopeService struct {
    68  	annotations  map[string][]byte
    69  	disabled     bool
    70  	keyVersion   string
    71  	ciphertext   []byte
    72  	decryptCalls int32
    73  }
    74  
    75  func (t *testEnvelopeService) Decrypt(ctx context.Context, uid string, req *kmsservice.DecryptRequest) ([]byte, error) {
    76  	atomic.AddInt32(&t.decryptCalls, 1)
    77  	if t.disabled {
    78  		return nil, fmt.Errorf("Envelope service was disabled")
    79  	}
    80  	if len(uid) == 0 {
    81  		return nil, fmt.Errorf("uid is required")
    82  	}
    83  	if len(req.KeyID) == 0 {
    84  		return nil, fmt.Errorf("keyID is required")
    85  	}
    86  	return base64.StdEncoding.DecodeString(string(req.Ciphertext))
    87  }
    88  
    89  func (t *testEnvelopeService) Encrypt(ctx context.Context, uid string, data []byte) (*kmsservice.EncryptResponse, error) {
    90  	if t.disabled {
    91  		return nil, fmt.Errorf("Envelope service was disabled")
    92  	}
    93  	if len(uid) == 0 {
    94  		return nil, fmt.Errorf("uid is required")
    95  	}
    96  	annotations := make(map[string][]byte)
    97  	if t.annotations != nil {
    98  		for k, v := range t.annotations {
    99  			annotations[k] = v
   100  		}
   101  	} else {
   102  		annotations["local-kek.kms.kubernetes.io"] = []byte("encrypted-local-kek")
   103  	}
   104  
   105  	ciphertext := t.ciphertext
   106  	if ciphertext == nil {
   107  		ciphertext = []byte(base64.StdEncoding.EncodeToString(data))
   108  	}
   109  
   110  	return &kmsservice.EncryptResponse{Ciphertext: ciphertext, KeyID: t.keyVersion, Annotations: annotations}, nil
   111  }
   112  
   113  func (t *testEnvelopeService) Status(ctx context.Context) (*kmsservice.StatusResponse, error) {
   114  	if t.disabled {
   115  		return nil, fmt.Errorf("Envelope service was disabled")
   116  	}
   117  	return &kmsservice.StatusResponse{KeyID: t.keyVersion}, nil
   118  }
   119  
   120  func (t *testEnvelopeService) SetDisabledStatus(status bool) {
   121  	t.disabled = status
   122  }
   123  
   124  func (t *testEnvelopeService) SetAnnotations(annotations map[string][]byte) {
   125  	t.annotations = annotations
   126  }
   127  
   128  func (t *testEnvelopeService) SetCiphertext(ciphertext []byte) {
   129  	t.ciphertext = ciphertext
   130  }
   131  
   132  func (t *testEnvelopeService) Rotate() {
   133  	i, _ := strconv.Atoi(t.keyVersion)
   134  	t.keyVersion = strconv.FormatInt(int64(i+1), 10)
   135  }
   136  
   137  func newTestEnvelopeService() *testEnvelopeService {
   138  	return &testEnvelopeService{
   139  		keyVersion: testKeyVersion,
   140  	}
   141  }
   142  
   143  // Throw error if Envelope transformer tries to contact Envelope without hitting cache.
   144  func TestEnvelopeCaching(t *testing.T) {
   145  	testCases := []struct {
   146  		desc                     string
   147  		cacheTTL                 time.Duration
   148  		simulateKMSPluginFailure bool
   149  		expectedError            string
   150  		expectedDecryptCalls     int
   151  	}{
   152  		{
   153  			desc:                     "entry in cache should withstand plugin failure",
   154  			cacheTTL:                 5 * time.Minute,
   155  			simulateKMSPluginFailure: true,
   156  			expectedDecryptCalls:     0, // should not hit KMS plugin
   157  		},
   158  		{
   159  			desc:                     "cache entry expired should not withstand plugin failure",
   160  			cacheTTL:                 1 * time.Millisecond,
   161  			simulateKMSPluginFailure: true,
   162  			expectedError:            "failed to decrypt DEK, error: Envelope service was disabled",
   163  			expectedDecryptCalls:     10, // should hit KMS plugin for each read after cache entry expired and fail
   164  		},
   165  		{
   166  			desc:                     "cache entry expired should work after cache refresh",
   167  			cacheTTL:                 1 * time.Millisecond,
   168  			simulateKMSPluginFailure: false,
   169  			expectedDecryptCalls:     1, // should hit KMS plugin just for the 1st read after cache entry expired
   170  		},
   171  	}
   172  
   173  	for _, tt := range testCases {
   174  		t.Run(tt.desc, func(t *testing.T) {
   175  			ctx := testContext(t)
   176  
   177  			envelopeService := newTestEnvelopeService()
   178  			fakeClock := testingclock.NewFakeClock(time.Now())
   179  
   180  			useSeed := randomBool()
   181  
   182  			state, err := testStateFunc(ctx, envelopeService, fakeClock, useSeed)()
   183  			if err != nil {
   184  				t.Fatal(err)
   185  			}
   186  
   187  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
   188  				func() (State, error) { return state, nil }, testAPIServerID,
   189  				tt.cacheTTL, fakeClock)
   190  
   191  			dataCtx := value.DefaultContext(testContextText)
   192  			originalText := []byte(testText)
   193  
   194  			transformedData, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   195  			if err != nil {
   196  				t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err)
   197  			}
   198  			untransformedData, _, err := transformer.TransformFromStorage(ctx, transformedData, dataCtx)
   199  			if err != nil {
   200  				t.Fatalf("could not decrypt Envelope transformer's encrypted data even once: %v", err)
   201  			}
   202  			if !bytes.Equal(untransformedData, originalText) {
   203  				t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData)
   204  			}
   205  
   206  			// advance the clock to allow cache entries to expire depending on TTL
   207  			fakeClock.Step(2 * time.Minute)
   208  			// force GC to run by performing a write
   209  			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
   210  
   211  			state, err = testStateFunc(ctx, envelopeService, fakeClock, useSeed)()
   212  			if err != nil {
   213  				t.Fatal(err)
   214  			}
   215  			envelopeService.SetDisabledStatus(tt.simulateKMSPluginFailure)
   216  
   217  			for i := 0; i < 10; i++ {
   218  				// Subsequent reads for the same data should work fine due to caching.
   219  				untransformedData, _, err = transformer.TransformFromStorage(ctx, transformedData, dataCtx)
   220  				if tt.expectedError != "" {
   221  					if err == nil {
   222  						t.Fatalf("expected error: %v, got nil", tt.expectedError)
   223  					}
   224  					if err.Error() != tt.expectedError {
   225  						t.Fatalf("expected error: %v, got: %v", tt.expectedError, err)
   226  					}
   227  				} else {
   228  					if err != nil {
   229  						t.Fatalf("unexpected error: %v", err)
   230  					}
   231  					if !bytes.Equal(untransformedData, originalText) {
   232  						t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData)
   233  					}
   234  				}
   235  			}
   236  			if int(envelopeService.decryptCalls) != tt.expectedDecryptCalls {
   237  				t.Fatalf("expected %d decrypt calls, got %d", tt.expectedDecryptCalls, envelopeService.decryptCalls)
   238  			}
   239  		})
   240  	}
   241  }
   242  
   243  func testStateFunc(ctx context.Context, envelopeService kmsservice.Service, clock clock.Clock, useSeed bool) func() (State, error) {
   244  	return func() (State, error) {
   245  		transformer, encObject, cacheKey, errGen := GenerateTransformer(ctx, string(uuid.NewUUID()), envelopeService, useSeed)
   246  		if errGen != nil {
   247  			return State{}, errGen
   248  		}
   249  		return State{
   250  			Transformer:         transformer,
   251  			EncryptedObject:     *encObject,
   252  			UID:                 "panda",
   253  			ExpirationTimestamp: clock.Now().Add(time.Hour),
   254  			CacheKey:            cacheKey,
   255  		}, nil
   256  	}
   257  }
   258  
   259  // TestEnvelopeTransformerStaleness validates that staleness checks on read honor the data returned from the StateFunc.
   260  func TestEnvelopeTransformerStaleness(t *testing.T) {
   261  	t.Parallel()
   262  	testCases := []struct {
   263  		desc          string
   264  		expectedStale bool
   265  		testErr       error
   266  		testKeyID     string
   267  		useSeedWrite  bool
   268  		useSeedRead   bool
   269  	}{
   270  		{
   271  			desc:          "stateFunc returns err",
   272  			expectedStale: false,
   273  			testErr:       fmt.Errorf("failed to perform status section of the healthz check for KMS Provider"),
   274  			testKeyID:     "",
   275  		},
   276  		{
   277  			desc:          "stateFunc returns same keyID, not using seed",
   278  			expectedStale: false,
   279  			testErr:       nil,
   280  			testKeyID:     testKeyVersion,
   281  		},
   282  		{
   283  			desc:          "stateFunc returns same keyID, using seed",
   284  			expectedStale: false,
   285  			testErr:       nil,
   286  			testKeyID:     testKeyVersion,
   287  			useSeedWrite:  true,
   288  			useSeedRead:   true,
   289  		},
   290  		{
   291  			desc:          "stateFunc returns same keyID, migrating away from seed",
   292  			expectedStale: true,
   293  			testErr:       nil,
   294  			testKeyID:     testKeyVersion,
   295  			useSeedWrite:  true,
   296  			useSeedRead:   false,
   297  		},
   298  		{
   299  			desc:          "stateFunc returns same keyID, migrating to seed",
   300  			expectedStale: true,
   301  			testErr:       nil,
   302  			testKeyID:     testKeyVersion,
   303  			useSeedWrite:  false,
   304  			useSeedRead:   true,
   305  		},
   306  		{
   307  			desc:          "stateFunc returns different keyID",
   308  			expectedStale: true,
   309  			testErr:       nil,
   310  			testKeyID:     "2",
   311  		},
   312  	}
   313  
   314  	for _, tt := range testCases {
   315  		tt := tt
   316  		t.Run(tt.desc, func(t *testing.T) {
   317  			t.Parallel()
   318  
   319  			ctx := testContext(t)
   320  
   321  			envelopeService := newTestEnvelopeService()
   322  			state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, tt.useSeedWrite)()
   323  			if err != nil {
   324  				t.Fatal(err)
   325  			}
   326  			var stateErr error
   327  
   328  			transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   329  				func() (State, error) { return state, stateErr }, testAPIServerID,
   330  			)
   331  
   332  			dataCtx := value.DefaultContext(testContextText)
   333  			originalText := []byte(testText)
   334  
   335  			transformedData, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   336  			if err != nil {
   337  				t.Fatalf("envelopeTransformer: error while transforming data (%v) to storage: %s", originalText, err)
   338  			}
   339  
   340  			// inject test data before performing a read
   341  			state.EncryptedObject.KeyID = tt.testKeyID
   342  			if tt.useSeedRead {
   343  				state.EncryptedObject.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_HKDF_SHA256_XNONCE_AES_GCM_SEED
   344  			} else {
   345  				state.EncryptedObject.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_AES_GCM_KEY
   346  			}
   347  			stateErr = tt.testErr
   348  
   349  			_, stale, err := transformer.TransformFromStorage(ctx, transformedData, dataCtx)
   350  			if tt.testErr != nil {
   351  				if err == nil {
   352  					t.Fatalf("envelopeTransformer: expected error: %v, got nil", tt.testErr)
   353  				}
   354  				if err.Error() != tt.testErr.Error() {
   355  					t.Fatalf("envelopeTransformer: expected error: %v, got: %v", tt.testErr, err)
   356  				}
   357  			} else {
   358  				if err != nil {
   359  					t.Fatalf("envelopeTransformer: unexpected error: %v", err)
   360  				}
   361  				if stale != tt.expectedStale {
   362  					t.Fatalf("envelopeTransformer TransformFromStorage determined keyID staleness incorrectly, expected: %v, got %v", tt.expectedStale, stale)
   363  				}
   364  			}
   365  		})
   366  	}
   367  }
   368  
   369  func TestEnvelopeTransformerStateFunc(t *testing.T) {
   370  	t.Parallel()
   371  
   372  	ctx := testContext(t)
   373  
   374  	useSeed := randomBool()
   375  
   376  	envelopeService := newTestEnvelopeService()
   377  	state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, useSeed)()
   378  	if err != nil {
   379  		t.Fatal(err)
   380  	}
   381  
   382  	// start with a broken state
   383  	stateErr := fmt.Errorf("some state error")
   384  
   385  	transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   386  		func() (State, error) { return state, stateErr }, testAPIServerID,
   387  	)
   388  
   389  	dataCtx := value.DefaultContext(testContextText)
   390  	originalText := []byte(testText)
   391  
   392  	t.Run("nothing works when the state is broken", func(t *testing.T) {
   393  		_, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   394  		if err != stateErr {
   395  			t.Fatalf("expected state error, got: %v", err)
   396  		}
   397  		o := &kmstypes.EncryptedObject{
   398  			EncryptedData:      []byte{1},
   399  			KeyID:              "2",
   400  			EncryptedDEKSource: []byte{3},
   401  			Annotations:        nil,
   402  		}
   403  		if useSeed {
   404  			o.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_HKDF_SHA256_XNONCE_AES_GCM_SEED
   405  		} else {
   406  			o.EncryptedDEKSourceType = kmstypes.EncryptedDEKSourceType_AES_GCM_KEY
   407  		}
   408  		data, err := proto.Marshal(o)
   409  		if err != nil {
   410  			t.Fatal(err)
   411  		}
   412  		_, _, err = transformer.TransformFromStorage(ctx, data, dataCtx)
   413  		if err != stateErr {
   414  			t.Fatalf("expected state error, got: %v", err)
   415  		}
   416  	})
   417  
   418  	// fix the state
   419  	stateErr = nil
   420  
   421  	var encryptedData []byte
   422  
   423  	t.Run("everything works when the state is fixed", func(t *testing.T) {
   424  		encryptedData, err = transformer.TransformToStorage(ctx, originalText, dataCtx)
   425  		if err != nil {
   426  			t.Fatal(err)
   427  		}
   428  		_, _, err = transformer.TransformFromStorage(ctx, encryptedData, dataCtx)
   429  		if err != nil {
   430  			t.Fatal(err)
   431  		}
   432  	})
   433  
   434  	// break the plugin
   435  	envelopeService.SetDisabledStatus(true)
   436  
   437  	t.Run("everything works even when the plugin is down but the state is valid", func(t *testing.T) {
   438  		data, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   439  		if err != nil {
   440  			t.Fatal(err)
   441  		}
   442  		_, _, err = transformer.TransformFromStorage(ctx, data, dataCtx)
   443  		if err != nil {
   444  			t.Fatal(err)
   445  		}
   446  	})
   447  
   448  	// make the state invalid
   449  	state.ExpirationTimestamp = time.Now().Add(-time.Hour)
   450  
   451  	t.Run("writes fail when the plugin is down and the state is invalid", func(t *testing.T) {
   452  		_, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
   453  		if !strings.Contains(errString(err), `encryptedDEKSource with keyID hash "sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b" expired at`) {
   454  			t.Fatalf("expected expiration error, got: %v", err)
   455  		}
   456  	})
   457  
   458  	t.Run("reads succeed when the plugin is down and the state is invalid", func(t *testing.T) {
   459  		_, _, err = transformer.TransformFromStorage(ctx, encryptedData, dataCtx)
   460  		if err != nil {
   461  			t.Fatal(err)
   462  		}
   463  	})
   464  
   465  	t.Run("reads for a different DEK fail when the plugin is down and the state is invalid", func(t *testing.T) {
   466  		obj := &kmstypes.EncryptedObject{}
   467  		if err := proto.Unmarshal(encryptedData, obj); err != nil {
   468  			t.Fatal(err)
   469  		}
   470  
   471  		obj.EncryptedDEKSource = append(obj.EncryptedDEKSource, 1) // skip StateFunc transformer
   472  
   473  		data, err := proto.Marshal(obj)
   474  		if err != nil {
   475  			t.Fatal(err)
   476  		}
   477  
   478  		_, _, err = transformer.TransformFromStorage(ctx, data, dataCtx)
   479  		if errString(err) != "failed to decrypt DEK, error: Envelope service was disabled" {
   480  			t.Fatal(err)
   481  		}
   482  	})
   483  }
   484  
   485  func TestTransformToStorageError(t *testing.T) {
   486  	t.Parallel()
   487  	testCases := []struct {
   488  		name        string
   489  		annotations map[string][]byte
   490  	}{
   491  		{
   492  			name: "invalid annotation key",
   493  			annotations: map[string][]byte{
   494  				"http://foo.example.com": []byte("bar"),
   495  			},
   496  		},
   497  		{
   498  			name: "annotation value size too large",
   499  			annotations: map[string][]byte{
   500  				"simple": []byte(strings.Repeat("a", 32*1024)),
   501  			},
   502  		},
   503  		{
   504  			name: "annotations size too large",
   505  			annotations: map[string][]byte{
   506  				"simple":  []byte(strings.Repeat("a", 31*1024)),
   507  				"simple2": []byte(strings.Repeat("a", 1024)),
   508  			},
   509  		},
   510  	}
   511  
   512  	for _, tt := range testCases {
   513  		tt := tt
   514  		t.Run(tt.name, func(t *testing.T) {
   515  			t.Parallel()
   516  
   517  			ctx := testContext(t)
   518  
   519  			envelopeService := newTestEnvelopeService()
   520  			envelopeService.SetAnnotations(tt.annotations)
   521  			transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   522  				testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool()),
   523  				testAPIServerID,
   524  			)
   525  			dataCtx := value.DefaultContext(testContextText)
   526  
   527  			_, err := transformer.TransformToStorage(ctx, []byte(testText), dataCtx)
   528  			if err == nil {
   529  				t.Fatalf("expected error, got nil")
   530  			}
   531  			if !strings.Contains(err.Error(), "failed to validate annotations") {
   532  				t.Fatalf("expected error to contain 'failed to validate annotations', got %v", err)
   533  			}
   534  		})
   535  	}
   536  }
   537  
   538  func TestEncodeDecode(t *testing.T) {
   539  	transformer := &envelopeTransformer{}
   540  
   541  	obj := &kmstypes.EncryptedObject{
   542  		EncryptedData:      []byte{0x01, 0x02, 0x03},
   543  		KeyID:              "1",
   544  		EncryptedDEKSource: []byte{0x04, 0x05, 0x06},
   545  	}
   546  
   547  	data, err := transformer.doEncode(obj)
   548  	if err != nil {
   549  		t.Fatalf("envelopeTransformer: error while encoding data: %s", err)
   550  	}
   551  	got, err := transformer.doDecode(data)
   552  	if err != nil {
   553  		t.Fatalf("envelopeTransformer: error while decoding data: %s", err)
   554  	}
   555  	// reset internal field modified by marshaling obj
   556  	obj.XXX_sizecache = 0
   557  	if !reflect.DeepEqual(got, obj) {
   558  		t.Fatalf("envelopeTransformer: decoded data does not match original data. Got: %v, want %v", got, obj)
   559  	}
   560  }
   561  
   562  func TestValidateEncryptedObject(t *testing.T) {
   563  	t.Parallel()
   564  	testCases := []struct {
   565  		desc          string
   566  		originalData  *kmstypes.EncryptedObject
   567  		expectedError error
   568  	}{
   569  		{
   570  			desc:          "encrypted object is nil",
   571  			originalData:  nil,
   572  			expectedError: fmt.Errorf("encrypted object is nil"),
   573  		},
   574  		{
   575  			desc: "encrypted data is nil",
   576  			originalData: &kmstypes.EncryptedObject{
   577  				KeyID:              "1",
   578  				EncryptedDEKSource: []byte{0x01, 0x02, 0x03},
   579  			},
   580  			expectedError: fmt.Errorf("encrypted data is empty"),
   581  		},
   582  		{
   583  			desc: "encrypted data is []byte{}",
   584  			originalData: &kmstypes.EncryptedObject{
   585  				EncryptedDEKSource: []byte{0x01, 0x02, 0x03},
   586  				EncryptedData:      []byte{},
   587  			},
   588  			expectedError: fmt.Errorf("encrypted data is empty"),
   589  		},
   590  		{
   591  			desc: "invalid dek source type",
   592  			originalData: &kmstypes.EncryptedObject{
   593  				EncryptedDEKSource:     []byte{0x01, 0x02, 0x03},
   594  				EncryptedData:          []byte{0},
   595  				EncryptedDEKSourceType: 55,
   596  			},
   597  			expectedError: fmt.Errorf("unknown encryptedDEKSourceType: 55"),
   598  		},
   599  		{
   600  			desc: "empty dek source",
   601  			originalData: &kmstypes.EncryptedObject{
   602  				EncryptedData:          []byte{0},
   603  				EncryptedDEKSourceType: 1,
   604  				KeyID:                  "1",
   605  			},
   606  			expectedError: fmt.Errorf("failed to validate encrypted DEK source: encrypted DEK source is empty"),
   607  		},
   608  		{
   609  			desc: "empty key ID",
   610  			originalData: &kmstypes.EncryptedObject{
   611  				EncryptedDEKSource:     []byte{0x01, 0x02, 0x03},
   612  				EncryptedData:          []byte{0},
   613  				EncryptedDEKSourceType: 1,
   614  			},
   615  			expectedError: fmt.Errorf("failed to validate key id: keyID is empty"),
   616  		},
   617  		{
   618  			desc: "invalid annotations",
   619  			originalData: &kmstypes.EncryptedObject{
   620  				EncryptedDEKSource:     []byte{0x01, 0x02, 0x03},
   621  				EncryptedData:          []byte{0},
   622  				EncryptedDEKSourceType: 1,
   623  				KeyID:                  "1",
   624  				Annotations:            map[string][]byte{"@": nil},
   625  			},
   626  			expectedError: fmt.Errorf(`failed to validate annotations: annotations: Invalid value: "@": a lowercase RFC 1123 subdomain must consist of lower case alphanumeric characters, '-' or '.', and must start and end with an alphanumeric character (e.g. 'example.com', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*')`),
   627  		},
   628  	}
   629  
   630  	for _, tt := range testCases {
   631  		t.Run(tt.desc, func(t *testing.T) {
   632  			err := ValidateEncryptedObject(tt.originalData)
   633  			if err == nil {
   634  				t.Fatalf("envelopeTransformer: expected error while decoding data, got nil")
   635  			}
   636  
   637  			if err.Error() != tt.expectedError.Error() {
   638  				t.Fatalf("doDecode() error: expected %v, got %v", tt.expectedError, err)
   639  			}
   640  		})
   641  	}
   642  }
   643  
   644  func TestValidateAnnotations(t *testing.T) {
   645  	t.Parallel()
   646  	successCases := []map[string][]byte{
   647  		{"a.com": []byte("bar")},
   648  		{"k8s.io": []byte("bar")},
   649  		{"dev.k8s.io": []byte("bar")},
   650  		{"dev.k8s.io.": []byte("bar")},
   651  		{"foo.example.com": []byte("bar")},
   652  		{"this.is.a.really.long.fqdn": []byte("bar")},
   653  		{"bbc.co.uk": []byte("bar")},
   654  		{"10.0.0.1": []byte("bar")}, // DNS labels can start with numbers and there is no requirement for letters.
   655  		{"hyphens-are-good.k8s.io": []byte("bar")},
   656  		{strings.Repeat("a", 63) + ".k8s.io": []byte("bar")},
   657  		{strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + "." + strings.Repeat("c", 63) + "." + strings.Repeat("d", 54) + ".k8s.io": []byte("bar")},
   658  	}
   659  	t.Run("success", func(t *testing.T) {
   660  		for i := range successCases {
   661  			i := i
   662  			t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   663  				t.Parallel()
   664  				if err := validateAnnotations(successCases[i]); err != nil {
   665  					t.Errorf("case[%d] expected success, got %#v", i, err)
   666  				}
   667  			})
   668  		}
   669  	})
   670  
   671  	atleastTwoSegmentsErrorMsg := "should be a domain with at least two segments separated by dots"
   672  	moreThan63CharsErrorMsg := "must be no more than 63 characters"
   673  	moreThan253CharsErrorMsg := "must be no more than 253 characters"
   674  	dns1123SubdomainErrorMsg := "a lowercase RFC 1123 subdomain must consist of lower case alphanumeric characters, '-' or '.', and must start and end with an alphanumeric character"
   675  
   676  	annotationsNameErrorCases := []struct {
   677  		annotations map[string][]byte
   678  		expect      string
   679  	}{
   680  		{map[string][]byte{".": []byte("bar")}, dns1123SubdomainErrorMsg},
   681  		{map[string][]byte{"...": []byte("bar")}, dns1123SubdomainErrorMsg},
   682  		{map[string][]byte{".io": []byte("bar")}, dns1123SubdomainErrorMsg},
   683  		{map[string][]byte{"com": []byte("bar")}, atleastTwoSegmentsErrorMsg},
   684  		{map[string][]byte{".com": []byte("bar")}, dns1123SubdomainErrorMsg},
   685  		{map[string][]byte{"Dev.k8s.io": []byte("bar")}, dns1123SubdomainErrorMsg},
   686  		{map[string][]byte{".foo.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   687  		{map[string][]byte{"*.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   688  		{map[string][]byte{"*.bar.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   689  		{map[string][]byte{"*.foo.bar.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   690  		{map[string][]byte{"underscores_are_bad.k8s.io": []byte("bar")}, dns1123SubdomainErrorMsg},
   691  		{map[string][]byte{"foo@bar.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   692  		{map[string][]byte{"http://foo.example.com": []byte("bar")}, dns1123SubdomainErrorMsg},
   693  		{map[string][]byte{strings.Repeat("a", 64) + ".k8s.io": []byte("bar")}, moreThan63CharsErrorMsg},
   694  		{map[string][]byte{strings.Repeat("a", 63) + "." + strings.Repeat("b", 63) + "." + strings.Repeat("c", 63) + "." + strings.Repeat("d", 55) + ".k8s.io": []byte("bar")}, moreThan253CharsErrorMsg},
   695  	}
   696  
   697  	t.Run("name error", func(t *testing.T) {
   698  		for i := range annotationsNameErrorCases {
   699  			i := i
   700  			t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   701  				t.Parallel()
   702  				err := validateAnnotations(annotationsNameErrorCases[i].annotations)
   703  				if err == nil {
   704  					t.Errorf("case[%d]: expected failure", i)
   705  				} else {
   706  					if !strings.Contains(err.Error(), annotationsNameErrorCases[i].expect) {
   707  						t.Errorf("case[%d]: error details do not include %q: %q", i, annotationsNameErrorCases[i].expect, err)
   708  					}
   709  				}
   710  			})
   711  		}
   712  	})
   713  
   714  	maxSizeErrMsg := "which exceeds the max size of"
   715  	annotationsSizeErrorCases := []struct {
   716  		annotations map[string][]byte
   717  		expect      string
   718  	}{
   719  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 33*1024))}, maxSizeErrMsg},
   720  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 32*1024))}, maxSizeErrMsg},
   721  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 64*1024))}, maxSizeErrMsg},
   722  		{map[string][]byte{"simple": []byte(strings.Repeat("a", 31*1024)), "simple2": []byte(strings.Repeat("a", 1024))}, maxSizeErrMsg},
   723  		{map[string][]byte{strings.Repeat("a", 253): []byte(strings.Repeat("a", 32*1024))}, maxSizeErrMsg},
   724  	}
   725  	t.Run("size error", func(t *testing.T) {
   726  		for i := range annotationsSizeErrorCases {
   727  			i := i
   728  			t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
   729  				t.Parallel()
   730  				err := validateAnnotations(annotationsSizeErrorCases[i].annotations)
   731  				if err == nil {
   732  					t.Errorf("case[%d]: expected failure", i)
   733  				} else {
   734  					if !strings.Contains(err.Error(), annotationsSizeErrorCases[i].expect) {
   735  						t.Errorf("case[%d]: error details do not include %q: %q", i, annotationsSizeErrorCases[i].expect, err)
   736  					}
   737  				}
   738  			})
   739  		}
   740  	})
   741  }
   742  
   743  func TestValidateKeyID(t *testing.T) {
   744  	t.Parallel()
   745  	testCases := []struct {
   746  		name              string
   747  		keyID             string
   748  		expectedError     string
   749  		expectedErrorCode string
   750  	}{
   751  		{
   752  			name:              "valid key ID",
   753  			keyID:             "1234",
   754  			expectedError:     "",
   755  			expectedErrorCode: "ok",
   756  		},
   757  		{
   758  			name:              "empty key ID",
   759  			keyID:             "",
   760  			expectedError:     "keyID is empty",
   761  			expectedErrorCode: "empty",
   762  		},
   763  		{
   764  			name:              "keyID size is greater than 1 kB",
   765  			keyID:             strings.Repeat("a", 1024+1),
   766  			expectedError:     "which exceeds the max size of",
   767  			expectedErrorCode: "too_long",
   768  		},
   769  	}
   770  
   771  	for _, tt := range testCases {
   772  		tt := tt
   773  		t.Run(tt.name, func(t *testing.T) {
   774  			t.Parallel()
   775  			errCode, err := ValidateKeyID(tt.keyID)
   776  			if tt.expectedError != "" {
   777  				if err == nil {
   778  					t.Fatalf("expected error %q, got nil", tt.expectedError)
   779  				}
   780  				if !strings.Contains(err.Error(), tt.expectedError) {
   781  					t.Fatalf("expected error %q, got %q", tt.expectedError, err)
   782  				}
   783  			} else {
   784  				if err != nil {
   785  					t.Fatalf("expected no error, got %q", err)
   786  				}
   787  			}
   788  			if tt.expectedErrorCode != string(errCode) {
   789  				t.Fatalf("expected %s errCode, got %s", tt.expectedErrorCode, string(errCode))
   790  			}
   791  		})
   792  	}
   793  }
   794  
   795  func TestValidateEncryptedDEKSource(t *testing.T) {
   796  	t.Parallel()
   797  	testCases := []struct {
   798  		name               string
   799  		encryptedDEKSource []byte
   800  		expectedError      string
   801  	}{
   802  		{
   803  			name:               "encrypted DEK source is nil",
   804  			encryptedDEKSource: nil,
   805  			expectedError:      "encrypted DEK source is empty",
   806  		},
   807  		{
   808  			name:               "encrypted DEK source is empty",
   809  			encryptedDEKSource: []byte{},
   810  			expectedError:      "encrypted DEK source is empty",
   811  		},
   812  		{
   813  			name:               "encrypted DEK source size is greater than 1 kB",
   814  			encryptedDEKSource: bytes.Repeat([]byte("a"), 1024+1),
   815  			expectedError:      "which exceeds the max size of",
   816  		},
   817  		{
   818  			name:               "valid encrypted DEK source",
   819  			encryptedDEKSource: []byte{0x01, 0x02, 0x03},
   820  			expectedError:      "",
   821  		},
   822  	}
   823  
   824  	for _, tt := range testCases {
   825  		tt := tt
   826  		t.Run(tt.name, func(t *testing.T) {
   827  			t.Parallel()
   828  			err := validateEncryptedDEKSource(tt.encryptedDEKSource)
   829  			if tt.expectedError != "" {
   830  				if err == nil {
   831  					t.Fatalf("expected error %q, got nil", tt.expectedError)
   832  				}
   833  				if !strings.Contains(err.Error(), tt.expectedError) {
   834  					t.Fatalf("expected error %q, got %q", tt.expectedError, err)
   835  				}
   836  			} else {
   837  				if err != nil {
   838  					t.Fatalf("expected no error, got %q", err)
   839  				}
   840  			}
   841  		})
   842  	}
   843  }
   844  
   845  func TestEnvelopeMetrics(t *testing.T) {
   846  	envelopeService := newTestEnvelopeService()
   847  	transformer := NewEnvelopeTransformer(envelopeService, testProviderName,
   848  		testStateFunc(testContext(t), envelopeService, clock.RealClock{}, randomBool()),
   849  		testAPIServerID,
   850  	)
   851  
   852  	dataCtx := value.DefaultContext(testContextText)
   853  
   854  	kmsv2Transformer := value.PrefixTransformer{Prefix: []byte("k8s:enc:kms:v2:"), Transformer: transformer}
   855  
   856  	testCases := []struct {
   857  		desc                  string
   858  		keyVersionFromEncrypt string
   859  		prefix                value.Transformer
   860  		metrics               []string
   861  		want                  string
   862  	}{
   863  		{
   864  			desc:                  "keyIDHash total",
   865  			keyVersionFromEncrypt: testKeyVersion,
   866  			prefix:                value.NewPrefixTransformers(nil, kmsv2Transformer),
   867  			metrics: []string{
   868  				"apiserver_envelope_encryption_key_id_hash_total",
   869  			},
   870  			want: fmt.Sprintf(`
   871  				# HELP apiserver_envelope_encryption_key_id_hash_total [ALPHA] Number of times a keyID is used split by transformation type, provider, and apiserver identity.
   872  				# TYPE apiserver_envelope_encryption_key_id_hash_total counter
   873  				apiserver_envelope_encryption_key_id_hash_total{apiserver_id_hash="%s",key_id_hash="%s",provider_name="%s",transformation_type="%s"} 1
   874  				apiserver_envelope_encryption_key_id_hash_total{apiserver_id_hash="%s",key_id_hash="%s",provider_name="%s",transformation_type="%s"} 1
   875  				`, testAPIServerIDHash, testKeyHash, testProviderName, metrics.FromStorageLabel, testAPIServerIDHash, testKeyHash, testProviderName, metrics.ToStorageLabel),
   876  		},
   877  	}
   878  
   879  	metrics.KeyIDHashTotal.Reset()
   880  	metrics.InvalidKeyIDFromStatusTotal.Reset()
   881  
   882  	for _, tt := range testCases {
   883  		t.Run(tt.desc, func(t *testing.T) {
   884  			defer metrics.KeyIDHashTotal.Reset()
   885  			defer metrics.InvalidKeyIDFromStatusTotal.Reset()
   886  			ctx := testContext(t)
   887  			envelopeService.keyVersion = tt.keyVersionFromEncrypt
   888  			transformedData, err := tt.prefix.TransformToStorage(ctx, []byte(testText), dataCtx)
   889  			if err != nil {
   890  				t.Fatal(err)
   891  			}
   892  			if _, _, err := tt.prefix.TransformFromStorage(ctx, transformedData, dataCtx); err != nil {
   893  				t.Fatal(err)
   894  			}
   895  
   896  			if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(tt.want), tt.metrics...); err != nil {
   897  				t.Fatal(err)
   898  			}
   899  		})
   900  	}
   901  }
   902  
   903  // TestEnvelopeMetricsCache validates the correctness of the apiserver_envelope_encryption_dek_source_cache_size metric
   904  // and asserts that all of the associated logic is go routine safe.
   905  // 1. Multiple transformers are created, which should result in unique cache size for each provider
   906  // 2. A transformer with known number of states was created to encrypt, then on restart, another transformer
   907  // was created, which should result in expected number of cache keys for all the decryption calls for each
   908  // state used previously for encryption.
   909  func TestEnvelopeMetricsCache(t *testing.T) {
   910  	envelopeService := newTestEnvelopeService()
   911  	envelopeService.keyVersion = testKeyVersion
   912  	state, err := testStateFunc(testContext(t), envelopeService, clock.RealClock{}, randomBool())()
   913  	if err != nil {
   914  		t.Fatal(err)
   915  	}
   916  	ctx := testContext(t)
   917  	dataCtx := value.DefaultContext(testContextText)
   918  	provider1 := "one"
   919  	provider2 := "two"
   920  	numOfStates := 10
   921  
   922  	testCases := []struct {
   923  		desc    string
   924  		metrics []string
   925  		want    string
   926  	}{
   927  		{
   928  			desc: "dek source cache size",
   929  			metrics: []string{
   930  				"apiserver_envelope_encryption_dek_source_cache_size",
   931  			},
   932  			want: fmt.Sprintf(`
   933  				# HELP apiserver_envelope_encryption_dek_source_cache_size [ALPHA] Number of records in data encryption key (DEK) source cache. On a restart, this value is an approximation of the number of decrypt RPC calls the server will make to the KMS plugin.
   934  				# TYPE apiserver_envelope_encryption_dek_source_cache_size gauge
   935          		apiserver_envelope_encryption_dek_source_cache_size{provider_name="%s"} %d
   936          		apiserver_envelope_encryption_dek_source_cache_size{provider_name="%s"} 1
   937  				`, provider1, numOfStates, provider2),
   938  		},
   939  	}
   940  	transformer1 := NewEnvelopeTransformer(envelopeService, provider1, func() (State, error) {
   941  		// return different states to ensure we get expected number of cache keys after restart on decryption
   942  		return testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool())()
   943  	}, testAPIServerID)
   944  	transformer2 := NewEnvelopeTransformer(envelopeService, provider2, func() (State, error) { return state, nil }, testAPIServerID)
   945  	// used for restart
   946  	transformer3 := NewEnvelopeTransformer(envelopeService, provider1, func() (State, error) { return state, nil }, testAPIServerID)
   947  	var transformedDatas [][]byte
   948  	for j := 0; j < numOfStates; j++ {
   949  		transformedData, err := transformer1.TransformToStorage(ctx, []byte(testText), dataCtx)
   950  		if err != nil {
   951  			t.Fatal(err)
   952  		}
   953  		transformedDatas = append(transformedDatas, transformedData)
   954  	}
   955  
   956  	for _, tt := range testCases {
   957  		t.Run(tt.desc, func(t *testing.T) {
   958  			metrics.DekSourceCacheSize.Reset()
   959  			var wg sync.WaitGroup
   960  			wg.Add(2 * numOfStates)
   961  			for i := 0; i < numOfStates; i++ {
   962  				i := i
   963  				go func() {
   964  					defer wg.Done()
   965  					// mimick a restart, the server will make decrypt RPC calls to the KMS plugin
   966  					// check cache metrics for the decrypt / read flow, which should repopulate the cache
   967  					if _, _, err := transformer3.TransformFromStorage(ctx, transformedDatas[i], dataCtx); err != nil {
   968  						panic(err)
   969  					}
   970  				}()
   971  				go func() {
   972  					defer wg.Done()
   973  					// check cache metrics for the encrypt / write flow
   974  					_, err := transformer2.TransformToStorage(ctx, []byte(testText), dataCtx)
   975  					if err != nil {
   976  						panic(err)
   977  					}
   978  				}()
   979  			}
   980  			wg.Wait()
   981  			if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(tt.want), tt.metrics...); err != nil {
   982  				t.Fatal(err)
   983  			}
   984  		})
   985  	}
   986  }
   987  
   988  var flagOnce sync.Once // support running `go test -count X`
   989  
   990  func TestEnvelopeLogging(t *testing.T) {
   991  	flagOnce.Do(func() {
   992  		klog.InitFlags(nil)
   993  	})
   994  	flag.Set("v", "6")
   995  	flag.Parse()
   996  
   997  	testCases := []struct {
   998  		desc     string
   999  		ctx      context.Context
  1000  		wantLogs []string
  1001  	}{
  1002  		{
  1003  			desc: "no request info in context",
  1004  			ctx:  testContext(t),
  1005  			wantLogs: []string{
  1006  				`"encrypting content using envelope service" uid="UID"`,
  1007  				`"encrypting content using DEK" uid="UID" key="0123456789" group="" version="" resource="" subresource="" verb="" namespace="" name=""`,
  1008  				`"decrypting content using envelope service" uid="UID" key="0123456789" group="" version="" resource="" subresource="" verb="" namespace="" name=""`,
  1009  			},
  1010  		},
  1011  		{
  1012  			desc: "request info in context",
  1013  			ctx: genericapirequest.WithRequestInfo(testContext(t), &genericapirequest.RequestInfo{
  1014  				APIGroup:    "awesome.bears.com",
  1015  				APIVersion:  "v1",
  1016  				Resource:    "pandas",
  1017  				Subresource: "status",
  1018  				Namespace:   "kube-system",
  1019  				Name:        "panda",
  1020  				Verb:        "update",
  1021  			}),
  1022  			wantLogs: []string{
  1023  				`"encrypting content using envelope service" uid="UID"`,
  1024  				`"encrypting content using DEK" uid="UID" key="0123456789" group="awesome.bears.com" version="v1" resource="pandas" subresource="status" verb="update" namespace="kube-system" name="panda"`,
  1025  				`"decrypting content using envelope service" uid="UID" key="0123456789" group="awesome.bears.com" version="v1" resource="pandas" subresource="status" verb="update" namespace="kube-system" name="panda"`,
  1026  			},
  1027  		},
  1028  	}
  1029  
  1030  	for _, tc := range testCases {
  1031  		tc := tc
  1032  		t.Run(tc.desc, func(t *testing.T) {
  1033  			var buf bytes.Buffer
  1034  			klog.SetOutput(&buf)
  1035  			klog.LogToStderr(false)
  1036  			defer klog.LogToStderr(true)
  1037  
  1038  			envelopeService := newTestEnvelopeService()
  1039  			fakeClock := testingclock.NewFakeClock(time.Now())
  1040  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1041  				testStateFunc(tc.ctx, envelopeService, clock.RealClock{}, randomBool()), testAPIServerID, 1*time.Second, fakeClock)
  1042  
  1043  			dataCtx := value.DefaultContext([]byte(testContextText))
  1044  			originalText := []byte(testText)
  1045  
  1046  			transformedData, err := transformer.TransformToStorage(tc.ctx, originalText, dataCtx)
  1047  			if err != nil {
  1048  				t.Fatalf("envelopeTransformer: error while transforming data to storage: %v", err)
  1049  			}
  1050  
  1051  			// advance the clock to trigger cache to expire, so we make a decrypt call that will log
  1052  			fakeClock.Step(2 * time.Second)
  1053  			// force GC to run by performing a write
  1054  			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
  1055  
  1056  			_, _, err = transformer.TransformFromStorage(tc.ctx, transformedData, dataCtx)
  1057  			if err != nil {
  1058  				t.Fatalf("could not decrypt Envelope transformer's encrypted data even once: %v", err)
  1059  			}
  1060  
  1061  			klog.Flush()
  1062  			klog.SetOutput(&bytes.Buffer{}) // prevent further writes into buf
  1063  			capturedOutput := buf.String()
  1064  
  1065  			// replace the uid with a constant to make the test output stable and assertable
  1066  			capturedOutput = regexp.MustCompile(`uid="[^"]+"`).ReplaceAllString(capturedOutput, `uid="UID"`)
  1067  
  1068  			for _, wantLog := range tc.wantLogs {
  1069  				if !strings.Contains(capturedOutput, wantLog) {
  1070  					t.Errorf("expected log %q, got %q", wantLog, capturedOutput)
  1071  				}
  1072  			}
  1073  		})
  1074  	}
  1075  }
  1076  
  1077  func TestCacheNotCorrupted(t *testing.T) {
  1078  	ctx := testContext(t)
  1079  
  1080  	envelopeService := newTestEnvelopeService()
  1081  	envelopeService.SetAnnotations(map[string][]byte{
  1082  		"encrypted-dek.kms.kubernetes.io": []byte("encrypted-dek-0"),
  1083  	})
  1084  
  1085  	fakeClock := testingclock.NewFakeClock(time.Now())
  1086  
  1087  	state, err := testStateFunc(ctx, envelopeService, fakeClock, randomBool())()
  1088  	if err != nil {
  1089  		t.Fatal(err)
  1090  	}
  1091  
  1092  	transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1093  		func() (State, error) { return state, nil }, testAPIServerID,
  1094  		1*time.Second, fakeClock)
  1095  
  1096  	dataCtx := value.DefaultContext(testContextText)
  1097  	originalText := []byte(testText)
  1098  
  1099  	transformedData1, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
  1100  	if err != nil {
  1101  		t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err)
  1102  	}
  1103  
  1104  	// this is to mimic a plugin that sets a static response for ciphertext
  1105  	// but uses the annotation field to send the actual encrypted DEK source.
  1106  	envelopeService.SetCiphertext(state.EncryptedObject.EncryptedDEKSource)
  1107  	// for this plugin, it indicates a change in the remote key ID as the returned
  1108  	// encrypted DEK source is different.
  1109  	envelopeService.SetAnnotations(map[string][]byte{
  1110  		"encrypted-dek.kms.kubernetes.io": []byte("encrypted-dek-1"),
  1111  	})
  1112  
  1113  	state, err = testStateFunc(ctx, envelopeService, fakeClock, randomBool())()
  1114  	if err != nil {
  1115  		t.Fatal(err)
  1116  	}
  1117  
  1118  	transformer = newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1119  		func() (State, error) { return state, nil }, testAPIServerID,
  1120  		1*time.Second, fakeClock)
  1121  
  1122  	transformedData2, err := transformer.TransformToStorage(ctx, originalText, dataCtx)
  1123  	if err != nil {
  1124  		t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err)
  1125  	}
  1126  
  1127  	if _, _, err := transformer.TransformFromStorage(ctx, transformedData1, dataCtx); err != nil {
  1128  		t.Fatal(err)
  1129  	}
  1130  	if _, _, err := transformer.TransformFromStorage(ctx, transformedData2, dataCtx); err != nil {
  1131  		t.Fatal(err)
  1132  	}
  1133  }
  1134  
  1135  func TestGenerateCacheKey(t *testing.T) {
  1136  	encryptedDEKSource1 := []byte{1, 2, 3}
  1137  	keyID1 := "id1"
  1138  	annotations1 := map[string][]byte{"a": {4, 5}, "b": {6, 7}}
  1139  	encryptedDEKSourceType1 := kmstypes.EncryptedDEKSourceType_AES_GCM_KEY
  1140  
  1141  	encryptedDEKSource2 := []byte{4, 5, 6}
  1142  	keyID2 := "id2"
  1143  	annotations2 := map[string][]byte{"x": {9, 10}, "y": {11, 12}}
  1144  	encryptedDEKSourceType2 := kmstypes.EncryptedDEKSourceType_HKDF_SHA256_XNONCE_AES_GCM_SEED
  1145  
  1146  	// generate all possible combinations of the above
  1147  	testCases := []struct {
  1148  		encryptedDEKSourceType kmstypes.EncryptedDEKSourceType
  1149  		encryptedDEKSource     []byte
  1150  		keyID                  string
  1151  		annotations            map[string][]byte
  1152  	}{
  1153  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID1, annotations1},
  1154  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID1, annotations2},
  1155  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID2, annotations1},
  1156  		{encryptedDEKSourceType1, encryptedDEKSource1, keyID2, annotations2},
  1157  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID1, annotations1},
  1158  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID1, annotations2},
  1159  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID2, annotations1},
  1160  		{encryptedDEKSourceType1, encryptedDEKSource2, keyID2, annotations2},
  1161  
  1162  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID1, annotations1},
  1163  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID1, annotations2},
  1164  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID2, annotations1},
  1165  		{encryptedDEKSourceType2, encryptedDEKSource1, keyID2, annotations2},
  1166  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID1, annotations1},
  1167  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID1, annotations2},
  1168  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID2, annotations1},
  1169  		{encryptedDEKSourceType2, encryptedDEKSource2, keyID2, annotations2},
  1170  	}
  1171  
  1172  	for _, tc := range testCases {
  1173  		tc := tc
  1174  		for _, tc2 := range testCases {
  1175  			tc2 := tc2
  1176  			t.Run(fmt.Sprintf("%+v-%+v", tc, tc2), func(t *testing.T) {
  1177  				key1, err1 := generateCacheKey(tc.encryptedDEKSourceType, tc.encryptedDEKSource, tc.keyID, tc.annotations)
  1178  				key2, err2 := generateCacheKey(tc2.encryptedDEKSourceType, tc2.encryptedDEKSource, tc2.keyID, tc2.annotations)
  1179  				if err1 != nil || err2 != nil {
  1180  					t.Errorf("generateCacheKey() want err=nil, got err1=%q, err2=%q", errString(err1), errString(err2))
  1181  				}
  1182  				if bytes.Equal(key1, key2) != reflect.DeepEqual(tc, tc2) {
  1183  					t.Errorf("expected %v, got %v", reflect.DeepEqual(tc, tc2), bytes.Equal(key1, key2))
  1184  				}
  1185  			})
  1186  		}
  1187  	}
  1188  }
  1189  
  1190  func TestGenerateTransformer(t *testing.T) {
  1191  	t.Parallel()
  1192  	testCases := []struct {
  1193  		name            string
  1194  		envelopeService func() kmsservice.Service
  1195  		expectedErr     string
  1196  	}{
  1197  		{
  1198  			name: "encrypt call fails",
  1199  			envelopeService: func() kmsservice.Service {
  1200  				envelopeService := newTestEnvelopeService()
  1201  				envelopeService.SetDisabledStatus(true)
  1202  				return envelopeService
  1203  			},
  1204  			expectedErr: "Envelope service was disabled",
  1205  		},
  1206  		{
  1207  			name: "invalid key ID",
  1208  			envelopeService: func() kmsservice.Service {
  1209  				envelopeService := newTestEnvelopeService()
  1210  				envelopeService.keyVersion = ""
  1211  				return envelopeService
  1212  			},
  1213  			expectedErr: "failed to validate key id: keyID is empty",
  1214  		},
  1215  		{
  1216  			name: "invalid encrypted DEK",
  1217  			envelopeService: func() kmsservice.Service {
  1218  				envelopeService := newTestEnvelopeService()
  1219  				envelopeService.SetCiphertext([]byte{})
  1220  				return envelopeService
  1221  			},
  1222  			expectedErr: "failed to validate encrypted DEK source: encrypted DEK source is empty",
  1223  		},
  1224  		{
  1225  			name: "invalid annotations",
  1226  			envelopeService: func() kmsservice.Service {
  1227  				envelopeService := newTestEnvelopeService()
  1228  				envelopeService.SetAnnotations(map[string][]byte{"invalid": {}})
  1229  				return envelopeService
  1230  			},
  1231  			expectedErr: "failed to validate annotations: annotations: Invalid value: \"invalid\": should be a domain with at least two segments separated by dots",
  1232  		},
  1233  		{
  1234  			name: "success",
  1235  			envelopeService: func() kmsservice.Service {
  1236  				return newTestEnvelopeService()
  1237  			},
  1238  			expectedErr: "",
  1239  		},
  1240  	}
  1241  
  1242  	for _, tc := range testCases {
  1243  		tc := tc
  1244  		t.Run(tc.name, func(t *testing.T) {
  1245  			t.Parallel()
  1246  
  1247  			transformer, encObject, cacheKey, err := GenerateTransformer(testContext(t), "panda", tc.envelopeService(), randomBool())
  1248  			if tc.expectedErr == "" {
  1249  				if err != nil {
  1250  					t.Errorf("expected no error, got %q", errString(err))
  1251  				}
  1252  				if transformer == nil {
  1253  					t.Error("expected transformer, got nil")
  1254  				}
  1255  				if encObject == nil {
  1256  					t.Error("expected encrypt response, got nil")
  1257  				}
  1258  				if cacheKey == nil {
  1259  					t.Error("expected cache key, got nil")
  1260  				}
  1261  			} else {
  1262  				if err == nil || !strings.Contains(err.Error(), tc.expectedErr) {
  1263  					t.Errorf("expected error %q, got %q", tc.expectedErr, errString(err))
  1264  				}
  1265  			}
  1266  		})
  1267  	}
  1268  }
  1269  
  1270  func TestEnvelopeTracing_TransformToStorage(t *testing.T) {
  1271  	testCases := []struct {
  1272  		desc     string
  1273  		expected []string
  1274  	}{
  1275  		{
  1276  			desc: "encrypt",
  1277  			expected: []string{
  1278  				"About to encrypt data using DEK",
  1279  				"Data encryption succeeded",
  1280  				"About to encode encrypted object",
  1281  				"Encoded encrypted object",
  1282  			},
  1283  		},
  1284  	}
  1285  
  1286  	for _, tc := range testCases {
  1287  		t.Run(tc.desc, func(t *testing.T) {
  1288  			fakeRecorder := tracetest.NewSpanRecorder()
  1289  			otelTracer := trace.NewTracerProvider(trace.WithSpanProcessor(fakeRecorder)).Tracer("test")
  1290  
  1291  			ctx := testContext(t)
  1292  			ctx, span := otelTracer.Start(ctx, "parent")
  1293  			defer span.End()
  1294  
  1295  			envelopeService := newTestEnvelopeService()
  1296  			fakeClock := testingclock.NewFakeClock(time.Now())
  1297  			state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool())()
  1298  			if err != nil {
  1299  				t.Fatal(err)
  1300  			}
  1301  
  1302  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1303  				func() (State, error) { return state, nil }, testAPIServerID, 1*time.Second, fakeClock)
  1304  
  1305  			dataCtx := value.DefaultContext([]byte(testContextText))
  1306  			originalText := []byte(testText)
  1307  
  1308  			if _, err := transformer.TransformToStorage(ctx, originalText, dataCtx); err != nil {
  1309  				t.Fatalf("envelopeTransformer: error while transforming data to storage: %v", err)
  1310  			}
  1311  
  1312  			output := fakeRecorder.Ended()
  1313  			if len(output) != 1 {
  1314  				t.Fatalf("expected 1 span, got %d", len(output))
  1315  			}
  1316  			out := output[0]
  1317  			validateTraceSpan(t, out, "TransformToStorage with envelopeTransformer", testProviderName, testAPIServerID, tc.expected)
  1318  		})
  1319  	}
  1320  }
  1321  
  1322  func TestEnvelopeTracing_TransformFromStorage(t *testing.T) {
  1323  	testCases := []struct {
  1324  		desc                     string
  1325  		cacheTTL                 time.Duration
  1326  		simulateKMSPluginFailure bool
  1327  		expected                 []string
  1328  	}{
  1329  		{
  1330  			desc:     "decrypt",
  1331  			cacheTTL: 5 * time.Second,
  1332  			expected: []string{
  1333  				"About to decode encrypted object",
  1334  				"Decoded encrypted object",
  1335  				"About to decrypt data using DEK",
  1336  				"Data decryption succeeded",
  1337  			},
  1338  		},
  1339  		{
  1340  			desc:     "decrypt with cache miss",
  1341  			cacheTTL: 1 * time.Second,
  1342  			expected: []string{
  1343  				"About to decode encrypted object",
  1344  				"Decoded encrypted object",
  1345  				"About to decrypt DEK using remote service",
  1346  				"DEK decryption succeeded",
  1347  				"About to decrypt data using DEK",
  1348  				"Data decryption succeeded",
  1349  			},
  1350  		},
  1351  		{
  1352  			desc:                     "decrypt with cache miss, simulate KMS plugin failure",
  1353  			cacheTTL:                 1 * time.Second,
  1354  			simulateKMSPluginFailure: true,
  1355  			expected: []string{
  1356  				"About to decode encrypted object",
  1357  				"Decoded encrypted object",
  1358  				"About to decrypt DEK using remote service",
  1359  				"DEK decryption failed",
  1360  				"exception",
  1361  			},
  1362  		},
  1363  	}
  1364  
  1365  	for _, tc := range testCases {
  1366  		t.Run(tc.desc, func(t *testing.T) {
  1367  			fakeRecorder := tracetest.NewSpanRecorder()
  1368  			otelTracer := trace.NewTracerProvider(trace.WithSpanProcessor(fakeRecorder)).Tracer("test")
  1369  
  1370  			ctx := testContext(t)
  1371  
  1372  			envelopeService := newTestEnvelopeService()
  1373  			fakeClock := testingclock.NewFakeClock(time.Now())
  1374  			state, err := testStateFunc(ctx, envelopeService, clock.RealClock{}, randomBool())()
  1375  			if err != nil {
  1376  				t.Fatal(err)
  1377  			}
  1378  
  1379  			transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName,
  1380  				func() (State, error) { return state, nil }, testAPIServerID, tc.cacheTTL, fakeClock)
  1381  
  1382  			dataCtx := value.DefaultContext([]byte(testContextText))
  1383  			originalText := []byte(testText)
  1384  
  1385  			transformedData, _ := transformer.TransformToStorage(ctx, originalText, dataCtx)
  1386  
  1387  			// advance the clock to allow cache entries to expire depending on TTL
  1388  			fakeClock.Step(2 * time.Second)
  1389  			// force GC to run by performing a write
  1390  			transformer.(*envelopeTransformer).cache.set([]byte("some-other-unrelated-key"), &envelopeTransformer{})
  1391  
  1392  			envelopeService.SetDisabledStatus(tc.simulateKMSPluginFailure)
  1393  
  1394  			// start recording only for the decrypt call
  1395  			ctx, span := otelTracer.Start(ctx, "parent")
  1396  			defer span.End()
  1397  
  1398  			_, _, _ = transformer.TransformFromStorage(ctx, transformedData, dataCtx)
  1399  
  1400  			output := fakeRecorder.Ended()
  1401  			validateTraceSpan(t, output[0], "TransformFromStorage with envelopeTransformer", testProviderName, testAPIServerID, tc.expected)
  1402  		})
  1403  	}
  1404  }
  1405  
  1406  func validateTraceSpan(t *testing.T, span trace.ReadOnlySpan, spanName, providerName, apiserverID string, expected []string) {
  1407  	t.Helper()
  1408  
  1409  	if span.Name() != spanName {
  1410  		t.Fatalf("expected span name %q, got %q", spanName, span.Name())
  1411  	}
  1412  	attrs := span.Attributes()
  1413  	if len(attrs) != 1 {
  1414  		t.Fatalf("expected 1 attributes, got %d", len(attrs))
  1415  	}
  1416  	if attrs[0].Key != "transformer.provider.name" && attrs[0].Value.AsString() != providerName {
  1417  		t.Errorf("expected providerName %q, got %q", providerName, attrs[0].Value.AsString())
  1418  	}
  1419  	if len(span.Events()) != len(expected) {
  1420  		t.Fatalf("expected %d events, got %d", len(expected), len(span.Events()))
  1421  	}
  1422  	for i, event := range span.Events() {
  1423  		if event.Name != expected[i] {
  1424  			t.Errorf("expected event %q, got %q", expected[i], event.Name)
  1425  		}
  1426  	}
  1427  }
  1428  
  1429  func errString(err error) string {
  1430  	if err == nil {
  1431  		return ""
  1432  	}
  1433  
  1434  	return err.Error()
  1435  }
  1436  
  1437  func randomBool() bool { return utilrand.Int()%2 == 1 }