k8s.io/apiserver@v0.31.1/pkg/storage/value/encrypt/envelope/kmsv2/grpc_service_unix_test.go (about)

     1  //go:build !windows
     2  // +build !windows
     3  
     4  /*
     5  Copyright 2022 The Kubernetes Authors.
     6  
     7  Licensed under the Apache License, Version 2.0 (the "License");
     8  you may not use this file except in compliance with the License.
     9  You may obtain a copy of the License at
    10  
    11      http://www.apache.org/licenses/LICENSE-2.0
    12  
    13  Unless required by applicable law or agreed to in writing, software
    14  distributed under the License is distributed on an "AS IS" BASIS,
    15  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    16  See the License for the specific language governing permissions and
    17  limitations under the License.
    18  */
    19  
    20  // Package kmsv2 transforms values for storage at rest using a Envelope v2 provider
    21  package kmsv2
    22  
    23  import (
    24  	"context"
    25  	"fmt"
    26  	"reflect"
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"k8s.io/apimachinery/pkg/util/uuid"
    32  	"k8s.io/apiserver/pkg/storage/value/encrypt/envelope/metrics"
    33  	mock "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/testing/v2"
    34  	"k8s.io/component-base/metrics/testutil"
    35  	kmsservice "k8s.io/kms/pkg/service"
    36  )
    37  
    38  const (
    39  	testProviderName = "providerName"
    40  )
    41  
    42  type testSocket struct {
    43  	path     string
    44  	endpoint string
    45  }
    46  
    47  // newEndpoint constructs a unique name for a Linux Abstract Socket to be used in a test.
    48  // This package uses Linux Domain Sockets to remove the need for clean-up of socket files.
    49  func newEndpoint() *testSocket {
    50  	p := fmt.Sprintf("@%s.sock", uuid.NewUUID())
    51  
    52  	return &testSocket{
    53  		path:     p,
    54  		endpoint: fmt.Sprintf("unix:///%s", p),
    55  	}
    56  }
    57  
    58  // TestKMSPluginLateStart tests the scenario where kms-plugin pod/container starts after kube-apiserver pod/container.
    59  // Since the Dial to kms-plugin is non-blocking we expect the construction of gRPC service to succeed even when
    60  // kms-plugin is not yet up - dialing happens in the background.
    61  func TestKMSPluginLateStart(t *testing.T) {
    62  	t.Parallel()
    63  	callTimeout := 3 * time.Second
    64  	s := newEndpoint()
    65  
    66  	ctx := testContext(t)
    67  
    68  	service, err := NewGRPCService(ctx, s.endpoint, testProviderName, callTimeout)
    69  	if err != nil {
    70  		t.Fatalf("failed to create envelope service, error: %v", err)
    71  	}
    72  	defer destroyService(service)
    73  
    74  	time.Sleep(callTimeout / 2)
    75  	_ = mock.NewBase64Plugin(t, s.path)
    76  
    77  	data := []byte("test data")
    78  	uid := string(uuid.NewUUID())
    79  	_, err = service.Encrypt(ctx, uid, data)
    80  	if err != nil {
    81  		t.Fatalf("failed when execute encrypt, error: %v", err)
    82  	}
    83  }
    84  
    85  func TestTimeouts(t *testing.T) {
    86  	t.Parallel()
    87  	var testCases = []struct {
    88  		desc               string
    89  		callTimeout        time.Duration
    90  		pluginDelay        time.Duration
    91  		kubeAPIServerDelay time.Duration
    92  		wantErr            string
    93  	}{
    94  		{
    95  			desc:        "timeout zero - expect failure when call from kube-apiserver arrives before plugin starts",
    96  			callTimeout: 0 * time.Second,
    97  			pluginDelay: 3 * time.Second,
    98  			wantErr:     "rpc error: code = DeadlineExceeded desc = context deadline exceeded",
    99  		},
   100  		{
   101  			desc:               "timeout zero but kms-plugin already up - still failure - zero timeout is an invalid value",
   102  			callTimeout:        0 * time.Second,
   103  			pluginDelay:        0 * time.Second,
   104  			kubeAPIServerDelay: 2 * time.Second,
   105  			wantErr:            "rpc error: code = DeadlineExceeded desc = context deadline exceeded",
   106  		},
   107  		{
   108  			desc:        "timeout greater than kms-plugin delay - expect success",
   109  			callTimeout: 6 * time.Second,
   110  			pluginDelay: 3 * time.Second,
   111  		},
   112  		{
   113  			desc:        "timeout less than kms-plugin delay - expect failure",
   114  			callTimeout: 3 * time.Second,
   115  			pluginDelay: 6 * time.Second,
   116  			wantErr:     "rpc error: code = DeadlineExceeded desc = context deadline exceeded",
   117  		},
   118  	}
   119  
   120  	for _, tt := range testCases {
   121  		tt := tt
   122  		t.Run(tt.desc, func(t *testing.T) {
   123  			t.Parallel()
   124  			var (
   125  				service         kmsservice.Service
   126  				err             error
   127  				data            = []byte("test data")
   128  				uid             = string(uuid.NewUUID())
   129  				kubeAPIServerWG sync.WaitGroup
   130  				kmsPluginWG     sync.WaitGroup
   131  				testCompletedWG sync.WaitGroup
   132  				socketName      = newEndpoint()
   133  			)
   134  
   135  			testCompletedWG.Add(1)
   136  			defer testCompletedWG.Done()
   137  
   138  			ctx := testContext(t)
   139  
   140  			kubeAPIServerWG.Add(1)
   141  			go func() {
   142  				// Simulating late start of kube-apiserver - plugin is up before kube-apiserver, if requested by the testcase.
   143  				time.Sleep(tt.kubeAPIServerDelay)
   144  
   145  				service, err = NewGRPCService(ctx, socketName.endpoint, testProviderName, tt.callTimeout)
   146  				if err != nil {
   147  					t.Errorf("failed to create envelope service, error: %v", err)
   148  					return
   149  				}
   150  				defer destroyService(service)
   151  				kubeAPIServerWG.Done()
   152  				// Keeping kube-apiserver up to process requests.
   153  				testCompletedWG.Wait()
   154  			}()
   155  
   156  			kmsPluginWG.Add(1)
   157  			go func() {
   158  				// Simulating delayed start of kms-plugin, kube-apiserver is up before the plugin, if requested by the testcase.
   159  				time.Sleep(tt.pluginDelay)
   160  
   161  				_ = mock.NewBase64Plugin(t, socketName.path)
   162  				kmsPluginWG.Done()
   163  				// Keeping plugin up to process requests.
   164  				testCompletedWG.Wait()
   165  			}()
   166  
   167  			kubeAPIServerWG.Wait()
   168  			if t.Failed() {
   169  				return
   170  			}
   171  			_, err = service.Encrypt(ctx, uid, data)
   172  
   173  			if err == nil && tt.wantErr != "" {
   174  				t.Fatalf("got nil, want %s", tt.wantErr)
   175  			}
   176  
   177  			if err != nil && tt.wantErr == "" {
   178  				t.Fatalf("got %q, want nil", err.Error())
   179  			}
   180  
   181  			// Collecting kms-plugin - allowing plugin to clean-up.
   182  			kmsPluginWG.Wait()
   183  		})
   184  	}
   185  }
   186  
   187  // TestIntermittentConnectionLoss tests the scenario where the connection with kms-plugin is intermittently lost.
   188  func TestIntermittentConnectionLoss(t *testing.T) {
   189  	t.Parallel()
   190  	var (
   191  		wg1        sync.WaitGroup
   192  		wg2        sync.WaitGroup
   193  		timeout    = 30 * time.Second
   194  		blackOut   = 1 * time.Second
   195  		data       = []byte("test data")
   196  		uid        = string(uuid.NewUUID())
   197  		endpoint   = newEndpoint()
   198  		encryptErr error
   199  	)
   200  	// Start KMS Plugin
   201  	f := mock.NewBase64Plugin(t, endpoint.path)
   202  
   203  	ctx := testContext(t)
   204  
   205  	//  connect to kms plugin
   206  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, timeout)
   207  	if err != nil {
   208  		t.Fatalf("failed to create envelope service, error: %v", err)
   209  	}
   210  	defer destroyService(service)
   211  
   212  	_, err = service.Encrypt(ctx, uid, data)
   213  	if err != nil {
   214  		t.Fatalf("failed when execute encrypt, error: %v", err)
   215  	}
   216  	t.Log("Connected to KMSPlugin")
   217  
   218  	// Stop KMS Plugin - simulating connection loss
   219  	t.Log("KMS Plugin is stopping")
   220  	f.CleanUp()
   221  	time.Sleep(2 * time.Second)
   222  
   223  	wg1.Add(1)
   224  	wg2.Add(1)
   225  	go func() {
   226  		defer wg2.Done()
   227  		// Call service to encrypt data.
   228  		t.Log("Sending encrypt request")
   229  		wg1.Done()
   230  		_, err := service.Encrypt(ctx, uid, data)
   231  		if err != nil {
   232  			encryptErr = fmt.Errorf("failed when executing encrypt, error: %v", err)
   233  		}
   234  	}()
   235  
   236  	wg1.Wait()
   237  	time.Sleep(blackOut)
   238  	// Start KMS Plugin
   239  	_ = mock.NewBase64Plugin(t, endpoint.path)
   240  	t.Log("Restarted KMS Plugin")
   241  
   242  	wg2.Wait()
   243  
   244  	if encryptErr != nil {
   245  		t.Error(encryptErr)
   246  	}
   247  }
   248  
   249  // Normal encryption and decryption operation.
   250  func TestGRPCService(t *testing.T) {
   251  	t.Parallel()
   252  	// Start a test gRPC server.
   253  	endpoint := newEndpoint()
   254  	_ = mock.NewBase64Plugin(t, endpoint.path)
   255  
   256  	ctx := testContext(t)
   257  
   258  	// Create the gRPC client service.
   259  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, 1*time.Second)
   260  	if err != nil {
   261  		t.Fatalf("failed to create envelope service, error: %v", err)
   262  	}
   263  	defer destroyService(service)
   264  
   265  	// Call service to encrypt data.
   266  	data := []byte("test data")
   267  	uid := string(uuid.NewUUID())
   268  	resp, err := service.Encrypt(ctx, uid, data)
   269  	if err != nil {
   270  		t.Fatalf("failed when execute encrypt, error: %v", err)
   271  	}
   272  
   273  	keyID := "1"
   274  	// Call service to decrypt data.
   275  	result, err := service.Decrypt(ctx, uid, &kmsservice.DecryptRequest{Ciphertext: resp.Ciphertext, KeyID: keyID})
   276  	if err != nil {
   277  		t.Fatalf("failed when execute decrypt, error: %v", err)
   278  	}
   279  
   280  	if !reflect.DeepEqual(data, result) {
   281  		t.Errorf("expect: %v, but: %v", data, result)
   282  	}
   283  }
   284  
   285  // Normal encryption and decryption operation by multiple go-routines.
   286  func TestGRPCServiceConcurrentAccess(t *testing.T) {
   287  	t.Parallel()
   288  	// Start a test gRPC server.
   289  	endpoint := newEndpoint()
   290  	_ = mock.NewBase64Plugin(t, endpoint.path)
   291  
   292  	ctx := testContext(t)
   293  
   294  	// Create the gRPC client service.
   295  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, 15*time.Second)
   296  	if err != nil {
   297  		t.Fatalf("failed to create envelope service, error: %v", err)
   298  	}
   299  	defer destroyService(service)
   300  
   301  	var wg sync.WaitGroup
   302  	n := 100
   303  	wg.Add(n)
   304  	for i := 0; i < n; i++ {
   305  		go func() {
   306  			defer wg.Done()
   307  			// Call service to encrypt data.
   308  			data := []byte("test data")
   309  			uid := string(uuid.NewUUID())
   310  			resp, err := service.Encrypt(ctx, uid, data)
   311  			if err != nil {
   312  				t.Errorf("failed when execute encrypt, error: %v", err)
   313  			}
   314  
   315  			keyID := "1"
   316  			// Call service to decrypt data.
   317  			result, err := service.Decrypt(ctx, uid, &kmsservice.DecryptRequest{Ciphertext: resp.Ciphertext, KeyID: keyID})
   318  			if err != nil {
   319  				t.Errorf("failed when execute decrypt, error: %v", err)
   320  			}
   321  
   322  			if !reflect.DeepEqual(data, result) {
   323  				t.Errorf("expect: %v, but: %v", data, result)
   324  			}
   325  		}()
   326  	}
   327  
   328  	wg.Wait()
   329  }
   330  
   331  func destroyService(service kmsservice.Service) {
   332  	if service != nil {
   333  		s := service.(*gRPCService)
   334  		s.connection.Close()
   335  	}
   336  }
   337  
   338  // Test all those invalid configuration for KMS provider.
   339  func TestInvalidConfiguration(t *testing.T) {
   340  	t.Parallel()
   341  	// Start a test gRPC server.
   342  	_ = mock.NewBase64Plugin(t, newEndpoint().path)
   343  
   344  	ctx := testContext(t)
   345  
   346  	invalidConfigs := []struct {
   347  		name     string
   348  		endpoint string
   349  	}{
   350  		{"emptyConfiguration", ""},
   351  		{"invalidScheme", "tcp://localhost:6060"},
   352  	}
   353  
   354  	for _, testCase := range invalidConfigs {
   355  		t.Run(testCase.name, func(t *testing.T) {
   356  			_, err := NewGRPCService(ctx, testCase.endpoint, testProviderName, 1*time.Second)
   357  			if err == nil {
   358  				t.Fatalf("should fail to create envelope service for %s.", testCase.name)
   359  			}
   360  		})
   361  	}
   362  }
   363  
   364  func TestKMSOperationsMetric(t *testing.T) {
   365  	endpoint := newEndpoint()
   366  	_ = mock.NewBase64Plugin(t, endpoint.path)
   367  
   368  	ctx := testContext(t)
   369  
   370  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, 15*time.Second)
   371  	if err != nil {
   372  		t.Fatalf("failed to create envelope service, error: %v", err)
   373  	}
   374  	defer destroyService(service)
   375  	metrics.RegisterMetrics()
   376  	metrics.KMSOperationsLatencyMetric.Reset() // support running `go test -count X`
   377  
   378  	testCases := []struct {
   379  		name        string
   380  		operation   func()
   381  		labelValues []string
   382  		wantCount   uint64
   383  	}{
   384  		{
   385  			name: "encrypt",
   386  			operation: func() {
   387  				if _, err = service.Encrypt(ctx, "1", []byte("test data")); err != nil {
   388  					t.Fatalf("failed when execute encrypt, error: %v", err)
   389  				}
   390  			},
   391  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Encrypt", "OK"},
   392  			wantCount:   1,
   393  		},
   394  		{
   395  			name: "decrypt",
   396  			operation: func() {
   397  				if _, err = service.Decrypt(ctx, "1", &kmsservice.DecryptRequest{Ciphertext: []byte("testdata"), KeyID: "1"}); err != nil {
   398  					t.Fatalf("failed when execute decrypt, error: %v", err)
   399  				}
   400  			},
   401  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Decrypt", "OK"},
   402  			wantCount:   1,
   403  		},
   404  		{
   405  			name: "status",
   406  			operation: func() {
   407  				if _, err = service.Status(ctx); err != nil {
   408  					t.Fatalf("failed when execute status, error: %v", err)
   409  				}
   410  			},
   411  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Status", "OK"},
   412  			wantCount:   1,
   413  		},
   414  		{
   415  			name: "multiple status calls",
   416  			operation: func() {
   417  				for i := 0; i < 10; i++ {
   418  					if _, err = service.Status(ctx); err != nil {
   419  						t.Fatalf("failed when execute status, error: %v", err)
   420  					}
   421  				}
   422  			},
   423  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Status", "OK"},
   424  			wantCount:   10,
   425  		},
   426  	}
   427  
   428  	for _, tt := range testCases {
   429  		t.Run(tt.name, func(t *testing.T) {
   430  			tt.operation()
   431  			defer metrics.KMSOperationsLatencyMetric.Reset()
   432  			sampleSum, err := testutil.GetHistogramMetricValue(metrics.KMSOperationsLatencyMetric.WithLabelValues(tt.labelValues...))
   433  			if err != nil {
   434  				t.Fatalf("failed to get metric value, error: %v", err)
   435  			}
   436  			// apiserver_envelope_encryption_kms_operations_latency_seconds_sum{grpc_status_code="OK",method_name="/v2alpha1.KeyManagementService/Encrypt",provider_name="providerName"} 0.000881432
   437  			if sampleSum == 0 {
   438  				t.Fatalf("expected metric value to be greater than 0, got %v", sampleSum)
   439  			}
   440  			count, err := testutil.GetHistogramMetricCount(metrics.KMSOperationsLatencyMetric.WithLabelValues(tt.labelValues...))
   441  			if err != nil {
   442  				t.Fatalf("failed to get metric count, error: %v", err)
   443  			}
   444  			// apiserver_envelope_encryption_kms_operations_latency_seconds_count{grpc_status_code="OK",method_name="/v2alpha1.KeyManagementService/Encrypt",provider_name="providerName"} 1
   445  			if count != tt.wantCount {
   446  				t.Fatalf("expected metric count to be %v, got %v", tt.wantCount, count)
   447  			}
   448  		})
   449  	}
   450  }
   451  
   452  func testContext(t *testing.T) context.Context {
   453  	ctx, cancel := context.WithCancel(context.Background())
   454  	t.Cleanup(cancel)
   455  	return ctx
   456  }