k8s.io/apiserver@v0.29.3/pkg/storage/value/encrypt/envelope/kmsv2/grpc_service_unix_test.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package kmsv2 transforms values for storage at rest using a Envelope v2 provider
    18  package kmsv2
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"reflect"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"k8s.io/apimachinery/pkg/util/uuid"
    29  	"k8s.io/apiserver/pkg/storage/value/encrypt/envelope/metrics"
    30  	mock "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/testing/v2"
    31  	"k8s.io/component-base/metrics/testutil"
    32  	kmsservice "k8s.io/kms/pkg/service"
    33  )
    34  
    35  const (
    36  	testProviderName = "providerName"
    37  )
    38  
    39  type testSocket struct {
    40  	path     string
    41  	endpoint string
    42  }
    43  
    44  // newEndpoint constructs a unique name for a Linux Abstract Socket to be used in a test.
    45  // This package uses Linux Domain Sockets to remove the need for clean-up of socket files.
    46  func newEndpoint() *testSocket {
    47  	p := fmt.Sprintf("@%s.sock", uuid.NewUUID())
    48  
    49  	return &testSocket{
    50  		path:     p,
    51  		endpoint: fmt.Sprintf("unix:///%s", p),
    52  	}
    53  }
    54  
    55  // TestKMSPluginLateStart tests the scenario where kms-plugin pod/container starts after kube-apiserver pod/container.
    56  // Since the Dial to kms-plugin is non-blocking we expect the construction of gRPC service to succeed even when
    57  // kms-plugin is not yet up - dialing happens in the background.
    58  func TestKMSPluginLateStart(t *testing.T) {
    59  	t.Parallel()
    60  	callTimeout := 3 * time.Second
    61  	s := newEndpoint()
    62  
    63  	ctx := testContext(t)
    64  
    65  	service, err := NewGRPCService(ctx, s.endpoint, testProviderName, callTimeout)
    66  	if err != nil {
    67  		t.Fatalf("failed to create envelope service, error: %v", err)
    68  	}
    69  	defer destroyService(service)
    70  
    71  	time.Sleep(callTimeout / 2)
    72  	_ = mock.NewBase64Plugin(t, s.path)
    73  
    74  	data := []byte("test data")
    75  	uid := string(uuid.NewUUID())
    76  	_, err = service.Encrypt(ctx, uid, data)
    77  	if err != nil {
    78  		t.Fatalf("failed when execute encrypt, error: %v", err)
    79  	}
    80  }
    81  
    82  func TestTimeouts(t *testing.T) {
    83  	t.Parallel()
    84  	var testCases = []struct {
    85  		desc               string
    86  		callTimeout        time.Duration
    87  		pluginDelay        time.Duration
    88  		kubeAPIServerDelay time.Duration
    89  		wantErr            string
    90  	}{
    91  		{
    92  			desc:        "timeout zero - expect failure when call from kube-apiserver arrives before plugin starts",
    93  			callTimeout: 0 * time.Second,
    94  			pluginDelay: 3 * time.Second,
    95  			wantErr:     "rpc error: code = DeadlineExceeded desc = context deadline exceeded",
    96  		},
    97  		{
    98  			desc:               "timeout zero but kms-plugin already up - still failure - zero timeout is an invalid value",
    99  			callTimeout:        0 * time.Second,
   100  			pluginDelay:        0 * time.Second,
   101  			kubeAPIServerDelay: 2 * time.Second,
   102  			wantErr:            "rpc error: code = DeadlineExceeded desc = context deadline exceeded",
   103  		},
   104  		{
   105  			desc:        "timeout greater than kms-plugin delay - expect success",
   106  			callTimeout: 6 * time.Second,
   107  			pluginDelay: 3 * time.Second,
   108  		},
   109  		{
   110  			desc:        "timeout less than kms-plugin delay - expect failure",
   111  			callTimeout: 3 * time.Second,
   112  			pluginDelay: 6 * time.Second,
   113  			wantErr:     "rpc error: code = DeadlineExceeded desc = context deadline exceeded",
   114  		},
   115  	}
   116  
   117  	for _, tt := range testCases {
   118  		tt := tt
   119  		t.Run(tt.desc, func(t *testing.T) {
   120  			t.Parallel()
   121  			var (
   122  				service         kmsservice.Service
   123  				err             error
   124  				data            = []byte("test data")
   125  				uid             = string(uuid.NewUUID())
   126  				kubeAPIServerWG sync.WaitGroup
   127  				kmsPluginWG     sync.WaitGroup
   128  				testCompletedWG sync.WaitGroup
   129  				socketName      = newEndpoint()
   130  			)
   131  
   132  			testCompletedWG.Add(1)
   133  			defer testCompletedWG.Done()
   134  
   135  			ctx := testContext(t)
   136  
   137  			kubeAPIServerWG.Add(1)
   138  			go func() {
   139  				// Simulating late start of kube-apiserver - plugin is up before kube-apiserver, if requested by the testcase.
   140  				time.Sleep(tt.kubeAPIServerDelay)
   141  
   142  				service, err = NewGRPCService(ctx, socketName.endpoint, testProviderName, tt.callTimeout)
   143  				if err != nil {
   144  					t.Errorf("failed to create envelope service, error: %v", err)
   145  					return
   146  				}
   147  				defer destroyService(service)
   148  				kubeAPIServerWG.Done()
   149  				// Keeping kube-apiserver up to process requests.
   150  				testCompletedWG.Wait()
   151  			}()
   152  
   153  			kmsPluginWG.Add(1)
   154  			go func() {
   155  				// Simulating delayed start of kms-plugin, kube-apiserver is up before the plugin, if requested by the testcase.
   156  				time.Sleep(tt.pluginDelay)
   157  
   158  				_ = mock.NewBase64Plugin(t, socketName.path)
   159  				kmsPluginWG.Done()
   160  				// Keeping plugin up to process requests.
   161  				testCompletedWG.Wait()
   162  			}()
   163  
   164  			kubeAPIServerWG.Wait()
   165  			if t.Failed() {
   166  				return
   167  			}
   168  			_, err = service.Encrypt(ctx, uid, data)
   169  
   170  			if err == nil && tt.wantErr != "" {
   171  				t.Fatalf("got nil, want %s", tt.wantErr)
   172  			}
   173  
   174  			if err != nil && tt.wantErr == "" {
   175  				t.Fatalf("got %q, want nil", err.Error())
   176  			}
   177  
   178  			// Collecting kms-plugin - allowing plugin to clean-up.
   179  			kmsPluginWG.Wait()
   180  		})
   181  	}
   182  }
   183  
   184  // TestIntermittentConnectionLoss tests the scenario where the connection with kms-plugin is intermittently lost.
   185  func TestIntermittentConnectionLoss(t *testing.T) {
   186  	t.Parallel()
   187  	var (
   188  		wg1        sync.WaitGroup
   189  		wg2        sync.WaitGroup
   190  		timeout    = 30 * time.Second
   191  		blackOut   = 1 * time.Second
   192  		data       = []byte("test data")
   193  		uid        = string(uuid.NewUUID())
   194  		endpoint   = newEndpoint()
   195  		encryptErr error
   196  	)
   197  	// Start KMS Plugin
   198  	f := mock.NewBase64Plugin(t, endpoint.path)
   199  
   200  	ctx := testContext(t)
   201  
   202  	//  connect to kms plugin
   203  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, timeout)
   204  	if err != nil {
   205  		t.Fatalf("failed to create envelope service, error: %v", err)
   206  	}
   207  	defer destroyService(service)
   208  
   209  	_, err = service.Encrypt(ctx, uid, data)
   210  	if err != nil {
   211  		t.Fatalf("failed when execute encrypt, error: %v", err)
   212  	}
   213  	t.Log("Connected to KMSPlugin")
   214  
   215  	// Stop KMS Plugin - simulating connection loss
   216  	t.Log("KMS Plugin is stopping")
   217  	f.CleanUp()
   218  	time.Sleep(2 * time.Second)
   219  
   220  	wg1.Add(1)
   221  	wg2.Add(1)
   222  	go func() {
   223  		defer wg2.Done()
   224  		// Call service to encrypt data.
   225  		t.Log("Sending encrypt request")
   226  		wg1.Done()
   227  		_, err := service.Encrypt(ctx, uid, data)
   228  		if err != nil {
   229  			encryptErr = fmt.Errorf("failed when executing encrypt, error: %v", err)
   230  		}
   231  	}()
   232  
   233  	wg1.Wait()
   234  	time.Sleep(blackOut)
   235  	// Start KMS Plugin
   236  	_ = mock.NewBase64Plugin(t, endpoint.path)
   237  	t.Log("Restarted KMS Plugin")
   238  
   239  	wg2.Wait()
   240  
   241  	if encryptErr != nil {
   242  		t.Error(encryptErr)
   243  	}
   244  }
   245  
   246  // Normal encryption and decryption operation.
   247  func TestGRPCService(t *testing.T) {
   248  	t.Parallel()
   249  	// Start a test gRPC server.
   250  	endpoint := newEndpoint()
   251  	_ = mock.NewBase64Plugin(t, endpoint.path)
   252  
   253  	ctx := testContext(t)
   254  
   255  	// Create the gRPC client service.
   256  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, 1*time.Second)
   257  	if err != nil {
   258  		t.Fatalf("failed to create envelope service, error: %v", err)
   259  	}
   260  	defer destroyService(service)
   261  
   262  	// Call service to encrypt data.
   263  	data := []byte("test data")
   264  	uid := string(uuid.NewUUID())
   265  	resp, err := service.Encrypt(ctx, uid, data)
   266  	if err != nil {
   267  		t.Fatalf("failed when execute encrypt, error: %v", err)
   268  	}
   269  
   270  	keyID := "1"
   271  	// Call service to decrypt data.
   272  	result, err := service.Decrypt(ctx, uid, &kmsservice.DecryptRequest{Ciphertext: resp.Ciphertext, KeyID: keyID})
   273  	if err != nil {
   274  		t.Fatalf("failed when execute decrypt, error: %v", err)
   275  	}
   276  
   277  	if !reflect.DeepEqual(data, result) {
   278  		t.Errorf("expect: %v, but: %v", data, result)
   279  	}
   280  }
   281  
   282  // Normal encryption and decryption operation by multiple go-routines.
   283  func TestGRPCServiceConcurrentAccess(t *testing.T) {
   284  	t.Parallel()
   285  	// Start a test gRPC server.
   286  	endpoint := newEndpoint()
   287  	_ = mock.NewBase64Plugin(t, endpoint.path)
   288  
   289  	ctx := testContext(t)
   290  
   291  	// Create the gRPC client service.
   292  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, 15*time.Second)
   293  	if err != nil {
   294  		t.Fatalf("failed to create envelope service, error: %v", err)
   295  	}
   296  	defer destroyService(service)
   297  
   298  	var wg sync.WaitGroup
   299  	n := 100
   300  	wg.Add(n)
   301  	for i := 0; i < n; i++ {
   302  		go func() {
   303  			defer wg.Done()
   304  			// Call service to encrypt data.
   305  			data := []byte("test data")
   306  			uid := string(uuid.NewUUID())
   307  			resp, err := service.Encrypt(ctx, uid, data)
   308  			if err != nil {
   309  				t.Errorf("failed when execute encrypt, error: %v", err)
   310  			}
   311  
   312  			keyID := "1"
   313  			// Call service to decrypt data.
   314  			result, err := service.Decrypt(ctx, uid, &kmsservice.DecryptRequest{Ciphertext: resp.Ciphertext, KeyID: keyID})
   315  			if err != nil {
   316  				t.Errorf("failed when execute decrypt, error: %v", err)
   317  			}
   318  
   319  			if !reflect.DeepEqual(data, result) {
   320  				t.Errorf("expect: %v, but: %v", data, result)
   321  			}
   322  		}()
   323  	}
   324  
   325  	wg.Wait()
   326  }
   327  
   328  func destroyService(service kmsservice.Service) {
   329  	if service != nil {
   330  		s := service.(*gRPCService)
   331  		s.connection.Close()
   332  	}
   333  }
   334  
   335  // Test all those invalid configuration for KMS provider.
   336  func TestInvalidConfiguration(t *testing.T) {
   337  	t.Parallel()
   338  	// Start a test gRPC server.
   339  	_ = mock.NewBase64Plugin(t, newEndpoint().path)
   340  
   341  	ctx := testContext(t)
   342  
   343  	invalidConfigs := []struct {
   344  		name     string
   345  		endpoint string
   346  	}{
   347  		{"emptyConfiguration", ""},
   348  		{"invalidScheme", "tcp://localhost:6060"},
   349  	}
   350  
   351  	for _, testCase := range invalidConfigs {
   352  		t.Run(testCase.name, func(t *testing.T) {
   353  			_, err := NewGRPCService(ctx, testCase.endpoint, testProviderName, 1*time.Second)
   354  			if err == nil {
   355  				t.Fatalf("should fail to create envelope service for %s.", testCase.name)
   356  			}
   357  		})
   358  	}
   359  }
   360  
   361  func TestKMSOperationsMetric(t *testing.T) {
   362  	endpoint := newEndpoint()
   363  	_ = mock.NewBase64Plugin(t, endpoint.path)
   364  
   365  	ctx := testContext(t)
   366  
   367  	service, err := NewGRPCService(ctx, endpoint.endpoint, testProviderName, 15*time.Second)
   368  	if err != nil {
   369  		t.Fatalf("failed to create envelope service, error: %v", err)
   370  	}
   371  	defer destroyService(service)
   372  	metrics.RegisterMetrics()
   373  	metrics.KMSOperationsLatencyMetric.Reset() // support running `go test -count X`
   374  
   375  	testCases := []struct {
   376  		name        string
   377  		operation   func()
   378  		labelValues []string
   379  		wantCount   uint64
   380  	}{
   381  		{
   382  			name: "encrypt",
   383  			operation: func() {
   384  				if _, err = service.Encrypt(ctx, "1", []byte("test data")); err != nil {
   385  					t.Fatalf("failed when execute encrypt, error: %v", err)
   386  				}
   387  			},
   388  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Encrypt", "OK"},
   389  			wantCount:   1,
   390  		},
   391  		{
   392  			name: "decrypt",
   393  			operation: func() {
   394  				if _, err = service.Decrypt(ctx, "1", &kmsservice.DecryptRequest{Ciphertext: []byte("testdata"), KeyID: "1"}); err != nil {
   395  					t.Fatalf("failed when execute decrypt, error: %v", err)
   396  				}
   397  			},
   398  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Decrypt", "OK"},
   399  			wantCount:   1,
   400  		},
   401  		{
   402  			name: "status",
   403  			operation: func() {
   404  				if _, err = service.Status(ctx); err != nil {
   405  					t.Fatalf("failed when execute status, error: %v", err)
   406  				}
   407  			},
   408  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Status", "OK"},
   409  			wantCount:   1,
   410  		},
   411  		{
   412  			name: "multiple status calls",
   413  			operation: func() {
   414  				for i := 0; i < 10; i++ {
   415  					if _, err = service.Status(ctx); err != nil {
   416  						t.Fatalf("failed when execute status, error: %v", err)
   417  					}
   418  				}
   419  			},
   420  			labelValues: []string{testProviderName, "/v2.KeyManagementService/Status", "OK"},
   421  			wantCount:   10,
   422  		},
   423  	}
   424  
   425  	for _, tt := range testCases {
   426  		t.Run(tt.name, func(t *testing.T) {
   427  			tt.operation()
   428  			defer metrics.KMSOperationsLatencyMetric.Reset()
   429  			sampleSum, err := testutil.GetHistogramMetricValue(metrics.KMSOperationsLatencyMetric.WithLabelValues(tt.labelValues...))
   430  			if err != nil {
   431  				t.Fatalf("failed to get metric value, error: %v", err)
   432  			}
   433  			// apiserver_envelope_encryption_kms_operations_latency_seconds_sum{grpc_status_code="OK",method_name="/v2alpha1.KeyManagementService/Encrypt",provider_name="providerName"} 0.000881432
   434  			if sampleSum == 0 {
   435  				t.Fatalf("expected metric value to be greater than 0, got %v", sampleSum)
   436  			}
   437  			count, err := testutil.GetHistogramMetricCount(metrics.KMSOperationsLatencyMetric.WithLabelValues(tt.labelValues...))
   438  			if err != nil {
   439  				t.Fatalf("failed to get metric count, error: %v", err)
   440  			}
   441  			// apiserver_envelope_encryption_kms_operations_latency_seconds_count{grpc_status_code="OK",method_name="/v2alpha1.KeyManagementService/Encrypt",provider_name="providerName"} 1
   442  			if count != tt.wantCount {
   443  				t.Fatalf("expected metric count to be %v, got %v", tt.wantCount, count)
   444  			}
   445  		})
   446  	}
   447  }
   448  
   449  func testContext(t *testing.T) context.Context {
   450  	ctx, cancel := context.WithCancel(context.Background())
   451  	t.Cleanup(cancel)
   452  	return ctx
   453  }