github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/storageccl/engineccl/pebble_key_manager_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Licensed as a CockroachDB Enterprise file under the Cockroach Community
     4  // License (the "License"); you may not use this file except in compliance with
     5  // the License. You may obtain a copy of the License at
     6  //
     7  //     https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
     8  
     9  package engineccl
    10  
    11  import (
    12  	"bytes"
    13  	"context"
    14  	"fmt"
    15  	"io"
    16  	"strconv"
    17  	"strings"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/ccl/storageccl/engineccl/enginepbccl"
    22  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    23  	"github.com/cockroachdb/cockroach/pkg/util/protoutil"
    24  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    25  	"github.com/cockroachdb/datadriven"
    26  	"github.com/cockroachdb/pebble/vfs"
    27  	"github.com/gogo/protobuf/proto"
    28  	"github.com/stretchr/testify/require"
    29  )
    30  
    31  func writeToFile(t *testing.T, fs vfs.FS, filename string, b []byte) {
    32  	f, err := fs.Create(filename)
    33  	require.NoError(t, err)
    34  	breader := bytes.NewReader(b)
    35  	_, err = io.Copy(f, breader)
    36  	require.NoError(t, err)
    37  	err = f.Close()
    38  	require.NoError(t, err)
    39  }
    40  
    41  const (
    42  	keyFile128 = "111111111111111111111111111111111234567890123456"
    43  	keyFile192 = "22222222222222222222222222222222123456789012345678901234"
    44  	keyFile256 = "3333333333333333333333333333333312345678901234567890123456789012"
    45  	key128     = "1234567890123456"
    46  	key192     = "123456789012345678901234"
    47  	key256     = "12345678901234567890123456789012"
    48  
    49  	// Hex of the binary value of the first KeyIDLength of key files
    50  	keyID128 = "3131313131313131313131313131313131313131313131313131313131313131"
    51  	keyID192 = "3232323232323232323232323232323232323232323232323232323232323232"
    52  	keyID256 = "3333333333333333333333333333333333333333333333333333333333333333"
    53  )
    54  
    55  func TestStoreKeyManagerLoadErrors(t *testing.T) {
    56  	defer leaktest.AfterTest(t)()
    57  
    58  	memFS := vfs.NewMem()
    59  
    60  	type KeyFiles struct {
    61  		filename string
    62  		contents string
    63  	}
    64  	keys := []KeyFiles{
    65  		{"empty.key", ""},
    66  		{"noid_8.key", "12345678"},
    67  		{"noid_16.key", "1234567890123456"},
    68  		{"noid_24.key", "123456789012345678901234"},
    69  		{"noid_32.key", "12345678901234567890123456789012"},
    70  		{"16.key", keyFile128},
    71  		{"24.key", keyFile192},
    72  		{"32.key", keyFile256},
    73  	}
    74  	for _, k := range keys {
    75  		writeToFile(t, memFS, k.filename, []byte(k.contents))
    76  	}
    77  
    78  	type Result int
    79  	const (
    80  		Ok Result = iota
    81  		Err
    82  	)
    83  	type TestCase struct {
    84  		activeFile string
    85  		oldFile    string
    86  		result     Result
    87  	}
    88  	testCases := []TestCase{
    89  		{"", "", Err},
    90  		{"missing_new.key", "missing_old.key", Err},
    91  		{"plain", "missing_old.key", Err},
    92  		{"plain", "plain", Ok},
    93  		{"empty.key", "plain", Err},
    94  		{"noid_8.key", "plain", Err},
    95  		{"noid_16.key", "plain", Err},
    96  		{"noid_24.key", "plain", Err},
    97  		{"noid_32.key", "plain", Err},
    98  		{"16.key", "plain", Ok},
    99  		{"24.key", "plain", Ok},
   100  		{"32.key", "plain", Ok},
   101  		{"16.key", "noid_8.key", Err},
   102  		{"16.key", "32.key", Ok},
   103  	}
   104  
   105  	for _, tc := range testCases {
   106  		t.Run("", func(t *testing.T) {
   107  			skm := &StoreKeyManager{fs: memFS, activeKeyFilename: tc.activeFile, oldKeyFilename: tc.oldFile}
   108  			actual := Ok
   109  			if err := skm.Load(context.Background()); err != nil {
   110  				actual = Err
   111  			}
   112  			require.Equal(t, tc.result, actual)
   113  		})
   114  	}
   115  }
   116  
   117  func TestStoreKeyManager(t *testing.T) {
   118  	defer leaktest.AfterTest(t)()
   119  
   120  	memFS := vfs.NewMem()
   121  
   122  	type KeyFiles struct {
   123  		filename string
   124  		contents string
   125  	}
   126  	keys := []KeyFiles{
   127  		{"16.key", keyFile128},
   128  		{"24.key", keyFile192},
   129  		{"32.key", keyFile256},
   130  	}
   131  
   132  	kmTimeNow = func() time.Time { return timeutil.Unix(5, 0) }
   133  
   134  	keyPlain := &enginepbccl.SecretKey{}
   135  	require.NoError(t, proto.UnmarshalText(
   136  		"info {encryption_type: Plaintext, key_id: \"plain\" creation_time: 5 source: \"plain\"}",
   137  		keyPlain))
   138  	key16 := &enginepbccl.SecretKey{}
   139  	require.NoError(t, proto.UnmarshalText(fmt.Sprintf(
   140  		"info {encryption_type: AES128_CTR, key_id: \"%s\" creation_time: 5 source: \"16.key\"} key: \"%s\"",
   141  		keyID128, key128), key16))
   142  	key24 := &enginepbccl.SecretKey{}
   143  	require.NoError(t, proto.UnmarshalText(fmt.Sprintf(
   144  		"info {encryption_type: AES192_CTR, key_id: \"%s\" creation_time: 5 source: \"24.key\"} key: \"%s\"",
   145  		keyID192, key192), key24))
   146  	key32 := &enginepbccl.SecretKey{}
   147  	require.NoError(t, proto.UnmarshalText(fmt.Sprintf(
   148  		"info {encryption_type: AES256_CTR, key_id: \"%s\" creation_time: 5 source: \"32.key\"} key: \"%s\"",
   149  		keyID256, key256), key32))
   150  
   151  	for _, k := range keys {
   152  		writeToFile(t, memFS, k.filename, []byte(k.contents))
   153  	}
   154  
   155  	{
   156  		skm := &StoreKeyManager{fs: memFS, activeKeyFilename: "plain", oldKeyFilename: "plain"}
   157  		require.NoError(t, skm.Load(context.Background()))
   158  		key, err := skm.ActiveKey(context.Background())
   159  		require.NoError(t, err)
   160  		require.Equal(t, keyPlain.String(), key.String())
   161  		key, err = skm.GetKey("plain")
   162  		require.NoError(t, err)
   163  		require.Equal(t, keyPlain.String(), key.String())
   164  		_, err = skm.GetKey("x")
   165  		require.Error(t, err)
   166  	}
   167  	{
   168  		skm := &StoreKeyManager{fs: memFS, activeKeyFilename: "16.key", oldKeyFilename: "24.key"}
   169  		require.NoError(t, skm.Load(context.Background()))
   170  		key, err := skm.ActiveKey(context.Background())
   171  		require.NoError(t, err)
   172  		require.Equal(t, key16.String(), key.String())
   173  		key, err = skm.GetKey(keyID128)
   174  		require.NoError(t, err)
   175  		require.Equal(t, key16.String(), key.String())
   176  		key, err = skm.GetKey(keyID192)
   177  		require.NoError(t, err)
   178  		require.Equal(t, key24.String(), key.String())
   179  		_, err = skm.GetKey("plain")
   180  		require.Error(t, err)
   181  	}
   182  	{
   183  		skm := &StoreKeyManager{fs: memFS, activeKeyFilename: "32.key", oldKeyFilename: "plain"}
   184  		require.NoError(t, skm.Load(context.Background()))
   185  		key, err := skm.ActiveKey(context.Background())
   186  		require.NoError(t, err)
   187  		require.Equal(t, key32.String(), key.String())
   188  		key, err = skm.GetKey(keyID256)
   189  		require.NoError(t, err)
   190  		require.Equal(t, key32.String(), key.String())
   191  		key, err = skm.GetKey("plain")
   192  		require.NoError(t, err)
   193  		require.Equal(t, keyPlain.String(), key.String())
   194  	}
   195  }
   196  
   197  func setActiveStoreKeyInProto(dkr *enginepbccl.DataKeysRegistry, id string) {
   198  	dkr.StoreKeys[id] = &enginepbccl.KeyInfo{
   199  		EncryptionType: enginepbccl.EncryptionType_AES128_CTR,
   200  		KeyId:          id,
   201  	}
   202  	dkr.ActiveStoreKeyId = id
   203  }
   204  
   205  func setActiveDataKeyInProto(dkr *enginepbccl.DataKeysRegistry, id string) {
   206  	dkr.DataKeys[id] = &enginepbccl.SecretKey{
   207  		Info: &enginepbccl.KeyInfo{
   208  			EncryptionType: enginepbccl.EncryptionType_AES192_CTR, KeyId: id},
   209  		Key: []byte("some key"),
   210  	}
   211  	dkr.ActiveDataKeyId = id
   212  }
   213  
   214  func setActiveStoreKey(dkm *DataKeyManager, id string, kind enginepbccl.EncryptionType) string {
   215  	err := dkm.SetActiveStoreKeyInfo(context.Background(), &enginepbccl.KeyInfo{
   216  		EncryptionType: kind,
   217  		KeyId:          id,
   218  	})
   219  	if err != nil {
   220  		return err.Error()
   221  	}
   222  	return ""
   223  }
   224  
   225  func CompareKeys(last, curr *enginepbccl.SecretKey) string {
   226  	if (last == nil && curr == nil) || (last != nil && curr == nil) || (last == nil && curr != nil) ||
   227  		(last.Info.KeyId == curr.Info.KeyId) {
   228  		return "same\n"
   229  	}
   230  	return "different\n"
   231  }
   232  
   233  func TestDataKeyManager(t *testing.T) {
   234  	defer leaktest.AfterTest(t)()
   235  
   236  	memFS := vfs.NewMem()
   237  
   238  	var dkm *DataKeyManager
   239  	var keyMap map[string]*enginepbccl.SecretKey
   240  	var lastActiveDataKey *enginepbccl.SecretKey
   241  
   242  	var unixTime int64
   243  	kmTimeNow = func() time.Time {
   244  		return timeutil.Unix(unixTime, 0)
   245  	}
   246  
   247  	datadriven.RunTest(t, "testdata/data_key_manager",
   248  		func(t *testing.T, d *datadriven.TestData) string {
   249  			switch d.Cmd {
   250  			case "init":
   251  				data := strings.Split(d.Input, "\n")
   252  				if len(data) < 2 {
   253  					return "insufficient arguments to init"
   254  				}
   255  				data[0] = strings.TrimSpace(data[0])
   256  				data[1] = strings.TrimSpace(data[1])
   257  				period, err := strconv.Atoi(data[1])
   258  				if err != nil {
   259  					return err.Error()
   260  				}
   261  				keyMap = make(map[string]*enginepbccl.SecretKey)
   262  				lastActiveDataKey = nil
   263  				require.NoError(t, memFS.MkdirAll(data[0], 0755))
   264  				dkm = &DataKeyManager{fs: memFS, dbDir: data[0], rotationPeriod: int64(period)}
   265  				dkr := &enginepbccl.DataKeysRegistry{
   266  					StoreKeys: make(map[string]*enginepbccl.KeyInfo),
   267  					DataKeys:  make(map[string]*enginepbccl.SecretKey),
   268  				}
   269  				for i := 2; i < len(data); i++ {
   270  					keyInfo := strings.Split(data[i], " ")
   271  					if len(keyInfo) != 2 {
   272  						return "insufficient parameters: " + data[i]
   273  					}
   274  					keyInfo[0] = strings.TrimSpace(keyInfo[0])
   275  					keyInfo[1] = strings.TrimSpace(keyInfo[1])
   276  					switch keyInfo[0] {
   277  					case "active-store-key":
   278  						setActiveStoreKeyInProto(dkr, keyInfo[1])
   279  					case "active-data-key":
   280  						setActiveDataKeyInProto(dkr, keyInfo[1])
   281  					default:
   282  						return fmt.Sprintf("unknown command: %s", keyInfo[1])
   283  					}
   284  				}
   285  				if len(data) > 2 {
   286  					b, err := protoutil.Marshal(dkr)
   287  					if err != nil {
   288  						return err.Error()
   289  					}
   290  					writeToFile(t, memFS, memFS.PathJoin(data[0], keyRegistryFilename), b)
   291  				}
   292  				return ""
   293  			case "load":
   294  				if err := dkm.Load(context.Background()); err != nil {
   295  					return err.Error()
   296  				}
   297  				return ""
   298  			case "set-active-store-key":
   299  				var id string
   300  				d.ScanArgs(t, "id", &id)
   301  				return setActiveStoreKey(dkm, id, enginepbccl.EncryptionType_AES128_CTR)
   302  			case "set-active-store-key-plain":
   303  				var id string
   304  				d.ScanArgs(t, "id", &id)
   305  				return setActiveStoreKey(dkm, d.CmdArgs[0].Vals[0], enginepbccl.EncryptionType_Plaintext)
   306  			case "check-exposed":
   307  				var val bool
   308  				d.ScanArgs(t, "val", &val)
   309  				for _, key := range dkm.mu.keyRegistry.DataKeys {
   310  					if key.Info.WasExposed != val {
   311  						return fmt.Sprintf(
   312  							"WasExposed: actual: %t, expected: %t\n", key.Info.WasExposed, val)
   313  					}
   314  				}
   315  				return ""
   316  			case "get-active-data-key":
   317  				key, err := dkm.ActiveKey(context.Background())
   318  				if err != nil {
   319  					return err.Error()
   320  				}
   321  				lastActiveDataKey = key
   322  				if key == nil {
   323  					return "none\n"
   324  				}
   325  				keyInfo := &enginepbccl.KeyInfo{}
   326  				proto.Merge(keyInfo, key.Info)
   327  				keyInfo.KeyId = ""
   328  				return strings.TrimSpace(keyInfo.String()) + "\n"
   329  			case "get-active-store-key":
   330  				id := dkm.mu.keyRegistry.ActiveStoreKeyId
   331  				if id == "" {
   332  					return "none\n"
   333  				}
   334  				return id + "\n"
   335  			case "get-store-key":
   336  				var id string
   337  				d.ScanArgs(t, "id", &id)
   338  				if dkm.mu.keyRegistry.StoreKeys != nil && dkm.mu.keyRegistry.StoreKeys[id] != nil {
   339  					return strings.TrimSpace(dkm.mu.keyRegistry.StoreKeys[id].String()) + "\n"
   340  				}
   341  				return "none\n"
   342  			case "record-active-data-key":
   343  				key, err := dkm.ActiveKey(context.Background())
   344  				if err != nil {
   345  					return err.Error()
   346  				}
   347  				if key != nil {
   348  					keyMap[key.Info.KeyId] = key
   349  				}
   350  				return ""
   351  			case "compare-active-data-key":
   352  				key, err := dkm.ActiveKey(context.Background())
   353  				if err != nil {
   354  					return err.Error()
   355  				}
   356  				rv := CompareKeys(lastActiveDataKey, key)
   357  				lastActiveDataKey = key
   358  				return rv
   359  			case "check-all-recorded-data-keys":
   360  				actual := fmt.Sprint(dkm.mu.keyRegistry.DataKeys)
   361  				expected := fmt.Sprint(keyMap)
   362  				require.Equal(t, expected, actual)
   363  				return ""
   364  			case "wait":
   365  				data := strings.Split(d.Input, "\n")
   366  				if len(data) != 1 {
   367  					return "incorrect arguments to wait"
   368  				}
   369  				interval, err := strconv.Atoi(strings.TrimSpace(data[0]))
   370  				if err != nil {
   371  					return err.Error()
   372  				}
   373  				unixTime += int64(interval)
   374  				return ""
   375  			default:
   376  				return fmt.Sprintf("unknown command: %s\n", d.Cmd)
   377  			}
   378  		})
   379  }