github.com/opentofu/opentofu@v1.7.1/internal/encryption/keyprovider/openbao/compliance_test.go (about)

     1  package openbao
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"net/url"
     9  	"os"
    10  	"testing"
    11  
    12  	openbao "github.com/openbao/openbao/api"
    13  	"github.com/opentofu/opentofu/internal/encryption/keyprovider/compliancetest"
    14  )
    15  
    16  func getBaoKeyName() string {
    17  	// Acceptance tests are disabled, running with mock.
    18  	if os.Getenv("TF_ACC") == "" {
    19  		return ""
    20  	}
    21  	return os.Getenv("TF_ACC_BAO_KEY_NAME")
    22  }
    23  
    24  const defaultTestKeyName = "test-key"
    25  
    26  func TestKeyProvider(t *testing.T) {
    27  	testKeyName := getBaoKeyName()
    28  
    29  	if testKeyName == "" {
    30  		testKeyName = defaultTestKeyName
    31  
    32  		mock := prepareClientMockForKeyProviderTest(t, testKeyName)
    33  
    34  		injectMock(mock)
    35  
    36  		t.Cleanup(func() {
    37  			injectDefaultClient()
    38  		})
    39  	}
    40  
    41  	compliancetest.ComplianceTest(
    42  		t,
    43  		compliancetest.TestConfiguration[*descriptor, *Config, *keyMeta, *keyProvider]{
    44  			Descriptor: New().(*descriptor),
    45  			HCLParseTestCases: map[string]compliancetest.HCLParseTestCase[*Config, *keyProvider]{
    46  				"success": {
    47  					HCL: fmt.Sprintf(`key_provider "openbao" "foo" {
    48  							key_name = "%s"
    49  						}`, testKeyName),
    50  					ValidHCL:   true,
    51  					ValidBuild: true,
    52  				},
    53  				"success-full-creds": {
    54  					HCL: fmt.Sprintf(`key_provider "openbao" "foo" {
    55  							token = "s.dummytoken"
    56  							address = "http://127.0.0.1:8201"
    57  							key_name = "%s"
    58  						}`, testKeyName),
    59  					ValidHCL:   true,
    60  					ValidBuild: true,
    61  				},
    62  				"empty": {
    63  					HCL:        `key_provider "openbao" "foo" {}`,
    64  					ValidHCL:   false,
    65  					ValidBuild: false,
    66  				},
    67  				"empty-key-name": {
    68  					HCL: `key_provider "openbao" "foo" {
    69  							key_name = ""
    70  						}`,
    71  					ValidHCL:   true,
    72  					ValidBuild: false,
    73  				},
    74  				"invalid-key-length": {
    75  					HCL: fmt.Sprintf(`key_provider "openbao" "foo" {
    76  							key_name = "%s"
    77  							key_length = 17
    78  						}`, testKeyName),
    79  					ValidHCL:   true,
    80  					ValidBuild: false,
    81  				},
    82  				"no-key-name": {
    83  					HCL: `key_provider "openbao" "foo" {
    84  							key_length = 16
    85  						}`,
    86  					ValidHCL:   false,
    87  					ValidBuild: false,
    88  				},
    89  				"unknown-property": {
    90  					HCL: fmt.Sprintf(`key_provider "openbao" "foo" {
    91  							key_name = "%s"
    92  							key_length = 16
    93  							unknown_property = "foo"
    94  						}`, testKeyName),
    95  					ValidHCL:   false,
    96  					ValidBuild: false,
    97  				},
    98  				"transit-path": {
    99  					HCL: fmt.Sprintf(`key_provider "openbao" "foo" {
   100  							key_name = "%s"
   101  							key_length = 16
   102  							transit_engine_path = "foo"
   103  						}`, testKeyName),
   104  					ValidHCL:   true,
   105  					ValidBuild: true,
   106  				},
   107  			},
   108  			ConfigStructTestCases: map[string]compliancetest.ConfigStructTestCase[*Config, *keyProvider]{
   109  				"success": {
   110  					Config: &Config{
   111  						KeyName:           testKeyName,
   112  						KeyLength:         16,
   113  						TransitEnginePath: "/pki",
   114  					},
   115  					ValidBuild: true,
   116  					Validate: func(p *keyProvider) error {
   117  						if p.keyName != testKeyName {
   118  							return fmt.Errorf("key names don't match: %v and %v", p.keyName, testKeyName)
   119  						}
   120  						if p.keyLength != 16 {
   121  							return fmt.Errorf("invalid key length: %v", p.keyLength)
   122  						}
   123  						if p.svc.transitPath != "/pki" {
   124  							return fmt.Errorf("invalid transit path: %v", p.svc.transitPath)
   125  						}
   126  						return nil
   127  					},
   128  				},
   129  				"success-default-values": {
   130  					Config: &Config{
   131  						KeyName: testKeyName,
   132  					},
   133  					ValidBuild: true,
   134  					Validate: func(p *keyProvider) error {
   135  						if p.keyName != testKeyName {
   136  							return fmt.Errorf("key names don't match: %v and %v", p.keyName, testKeyName)
   137  						}
   138  						if p.keyLength != 32 {
   139  							return fmt.Errorf("invalid default key length: %v", p.keyLength)
   140  						}
   141  						if p.svc.transitPath != "/transit" {
   142  							return fmt.Errorf("invalid default transit path: %v; expected: '/transit'", p.svc.transitPath)
   143  						}
   144  						return nil
   145  					},
   146  				},
   147  				"empty": {
   148  					Config:     &Config{},
   149  					ValidBuild: false,
   150  					Validate:   nil,
   151  				},
   152  			},
   153  			MetadataStructTestCases: map[string]compliancetest.MetadataStructTestCase[*Config, *keyMeta]{
   154  				"empty": {
   155  					ValidConfig: &Config{
   156  						KeyName: testKeyName,
   157  					},
   158  					Meta:      &keyMeta{},
   159  					IsPresent: false,
   160  					IsValid:   false,
   161  				},
   162  			},
   163  			ProvideTestCase: compliancetest.ProvideTestCase[*Config, *keyMeta]{
   164  				ValidConfig: &Config{
   165  					KeyName: testKeyName,
   166  				},
   167  				ValidateKeys: func(dec []byte, enc []byte) error {
   168  					if len(dec) == 0 {
   169  						return fmt.Errorf("decryption key is empty")
   170  					}
   171  					if len(enc) == 0 {
   172  						return fmt.Errorf("encryption key is empty")
   173  					}
   174  					return nil
   175  				},
   176  				ValidateMetadata: func(meta *keyMeta) error {
   177  					if len(meta.Ciphertext) == 0 {
   178  						return fmt.Errorf("ciphertext is empty")
   179  					}
   180  					return nil
   181  				},
   182  			},
   183  		},
   184  	)
   185  }
   186  
   187  // Mocking is a bit complicated due to how openbao/api package is structured,
   188  // but in order to test cover as much as we can, it has to have some logic in here.
   189  
   190  func prepareClientMockForKeyProviderTest(t *testing.T, testKeyName string) mockClientFunc {
   191  	escapedTestKeyName := url.PathEscape(testKeyName)
   192  
   193  	// Mock uses default transit engine path: "/transit".
   194  	generateDataKeyPath := fmt.Sprintf("/transit/datakey/plaintext/%s", escapedTestKeyName)
   195  	decryptPath := fmt.Sprintf("/transit/decrypt/%s", escapedTestKeyName)
   196  
   197  	return func(ctx context.Context, path string, data map[string]interface{}) (*openbao.Secret, error) {
   198  		switch path {
   199  		case generateDataKeyPath:
   200  			bits, ok := data["bits"].(int)
   201  			if !ok {
   202  				t.Fatalf("Invalid bits in data suplied to mock: not a number")
   203  			}
   204  
   205  			plaintext := make([]byte, int(bits)/8)
   206  			if _, err := rand.Read(plaintext); err != nil {
   207  				panic(fmt.Errorf("generating random data key in mock: %w", err))
   208  			}
   209  
   210  			s := &openbao.Secret{
   211  				Data: map[string]interface{}{
   212  					"plaintext":  base64.StdEncoding.EncodeToString(plaintext),
   213  					"ciphertext": string(append([]byte(testKeyName), plaintext...)),
   214  				},
   215  			}
   216  
   217  			return s, nil
   218  
   219  		case decryptPath:
   220  			ciphertext, ok := data["ciphertext"].(string)
   221  			if !ok {
   222  				t.Fatalf("Invalid ciphertext in data suplied to mock: not an string")
   223  			}
   224  
   225  			plaintext := []byte(ciphertext[len(testKeyName):])
   226  
   227  			s := &openbao.Secret{
   228  				Data: map[string]interface{}{
   229  					"plaintext": base64.StdEncoding.EncodeToString(plaintext),
   230  				},
   231  			}
   232  
   233  			return s, nil
   234  
   235  		default:
   236  			t.Fatalf("Invalid path suplied to mock: %s", path)
   237  		}
   238  
   239  		// unreachable code
   240  		return nil, nil
   241  	}
   242  }