github.com/muhammadn/cortex@v1.9.1-0.20220510110439-46bb7000d03d/pkg/storage/bucket/sse_bucket_client_test.go (about)

     1  package bucket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/go-kit/log"
    12  	"github.com/grafana/dskit/flagext"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  	"github.com/thanos-io/thanos/pkg/objstore"
    16  
    17  	"github.com/cortexproject/cortex/pkg/storage/bucket/s3"
    18  )
    19  
    20  func TestSSEBucketClient_Upload_ShouldInjectCustomSSEConfig(t *testing.T) {
    21  	tests := map[string]struct {
    22  		withExpectedErrs bool
    23  	}{
    24  		"default client": {
    25  			withExpectedErrs: false,
    26  		},
    27  		"client with expected errors": {
    28  			withExpectedErrs: true,
    29  		},
    30  	}
    31  
    32  	for testName, testData := range tests {
    33  		t.Run(testName, func(t *testing.T) {
    34  			const (
    35  				kmsKeyID             = "ABC"
    36  				kmsEncryptionContext = "{\"department\":\"10103.0\"}"
    37  			)
    38  
    39  			var req *http.Request
    40  
    41  			// Start a fake HTTP server which simulate S3.
    42  			srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    43  				// Keep track of the received request.
    44  				req = r
    45  
    46  				w.WriteHeader(http.StatusOK)
    47  			}))
    48  			defer srv.Close()
    49  
    50  			s3Cfg := s3.Config{
    51  				Endpoint:        srv.Listener.Addr().String(),
    52  				Region:          "test",
    53  				BucketName:      "test-bucket",
    54  				SecretAccessKey: flagext.Secret{Value: "test"},
    55  				AccessKeyID:     "test",
    56  				Insecure:        true,
    57  			}
    58  
    59  			s3Client, err := s3.NewBucketClient(s3Cfg, "test", log.NewNopLogger())
    60  			require.NoError(t, err)
    61  
    62  			// Configure the config provider with NO KMS key ID.
    63  			cfgProvider := &mockTenantConfigProvider{}
    64  
    65  			var sseBkt objstore.Bucket
    66  			if testData.withExpectedErrs {
    67  				sseBkt = NewSSEBucketClient("user-1", s3Client, cfgProvider).WithExpectedErrs(s3Client.IsObjNotFoundErr)
    68  			} else {
    69  				sseBkt = NewSSEBucketClient("user-1", s3Client, cfgProvider)
    70  			}
    71  
    72  			err = sseBkt.Upload(context.Background(), "test", strings.NewReader("test"))
    73  			require.NoError(t, err)
    74  
    75  			// Ensure NO KMS header has been injected.
    76  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption"))
    77  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption-aws-kms-key-id"))
    78  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption-context"))
    79  
    80  			// Configure the config provider with a KMS key ID and without encryption context.
    81  			cfgProvider.s3SseType = s3.SSEKMS
    82  			cfgProvider.s3KmsKeyID = kmsKeyID
    83  
    84  			err = sseBkt.Upload(context.Background(), "test", strings.NewReader("test"))
    85  			require.NoError(t, err)
    86  
    87  			// Ensure the KMS header has been injected.
    88  			assert.Equal(t, "aws:kms", req.Header.Get("x-amz-server-side-encryption"))
    89  			assert.Equal(t, kmsKeyID, req.Header.Get("x-amz-server-side-encryption-aws-kms-key-id"))
    90  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption-context"))
    91  
    92  			// Configure the config provider with a KMS key ID and encryption context.
    93  			cfgProvider.s3SseType = s3.SSEKMS
    94  			cfgProvider.s3KmsKeyID = kmsKeyID
    95  			cfgProvider.s3KmsEncryptionContext = kmsEncryptionContext
    96  
    97  			err = sseBkt.Upload(context.Background(), "test", strings.NewReader("test"))
    98  			require.NoError(t, err)
    99  
   100  			// Ensure the KMS header has been injected.
   101  			assert.Equal(t, "aws:kms", req.Header.Get("x-amz-server-side-encryption"))
   102  			assert.Equal(t, kmsKeyID, req.Header.Get("x-amz-server-side-encryption-aws-kms-key-id"))
   103  			assert.Equal(t, base64.StdEncoding.EncodeToString([]byte(kmsEncryptionContext)), req.Header.Get("x-amz-server-side-encryption-context"))
   104  		})
   105  	}
   106  }
   107  
   108  type mockTenantConfigProvider struct {
   109  	s3SseType              string
   110  	s3KmsKeyID             string
   111  	s3KmsEncryptionContext string
   112  }
   113  
   114  func (m *mockTenantConfigProvider) S3SSEType(_ string) string {
   115  	return m.s3SseType
   116  }
   117  
   118  func (m *mockTenantConfigProvider) S3SSEKMSKeyID(_ string) string {
   119  	return m.s3KmsKeyID
   120  }
   121  
   122  func (m *mockTenantConfigProvider) S3SSEKMSEncryptionContext(_ string) string {
   123  	return m.s3KmsEncryptionContext
   124  }