github.com/grafana/pyroscope@v1.18.0/pkg/objstore/sse_bucket_client_test.go (about)

     1  // SPDX-License-Identifier: AGPL-3.0-only
     2  // Provenance-includes-location: https://github.com/cortexproject/cortex/blob/master/pkg/storage/bucket/sse_bucket_client_test.go
     3  // Provenance-includes-license: Apache-2.0
     4  // Provenance-includes-copyright: The Cortex Authors.
     5  
     6  package objstore
     7  
     8  import (
     9  	"context"
    10  	"encoding/base64"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/go-kit/log"
    17  	"github.com/grafana/dskit/flagext"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/thanos-io/objstore"
    21  
    22  	"github.com/grafana/pyroscope/pkg/objstore/providers/s3"
    23  )
    24  
    25  func TestSSEBucketClient_Upload_ShouldInjectCustomSSEConfig(t *testing.T) {
    26  	tests := map[string]struct {
    27  		withExpectedErrs bool
    28  	}{
    29  		"default client": {
    30  			withExpectedErrs: false,
    31  		},
    32  		"client with expected errors": {
    33  			withExpectedErrs: true,
    34  		},
    35  	}
    36  
    37  	for testName, testData := range tests {
    38  		t.Run(testName, func(t *testing.T) {
    39  			const (
    40  				kmsKeyID             = "ABC"
    41  				kmsEncryptionContext = "{\"department\":\"10103.0\"}"
    42  			)
    43  
    44  			var req *http.Request
    45  
    46  			// Start a fake HTTP server which simulate S3.
    47  			srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    48  				// Keep track of the received request.
    49  				req = r
    50  
    51  				w.WriteHeader(http.StatusOK)
    52  			}))
    53  			defer srv.Close()
    54  
    55  			s3Cfg := s3.Config{
    56  				Endpoint:        srv.Listener.Addr().String(),
    57  				Region:          "test",
    58  				BucketName:      "test-bucket",
    59  				SecretAccessKey: flagext.SecretWithValue("test"),
    60  				AccessKeyID:     "test",
    61  				Insecure:        true,
    62  			}
    63  
    64  			s3Client, err := s3.NewBucketClient(s3Cfg, "test", log.NewNopLogger())
    65  			require.NoError(t, err)
    66  
    67  			// Configure the config provider with NO KMS key ID.
    68  			cfgProvider := &mockTenantConfigProvider{}
    69  
    70  			var sseBkt objstore.Bucket
    71  			if testData.withExpectedErrs {
    72  				sseBkt = NewSSEBucketClient("user-1", NewBucket(s3Client), cfgProvider).WithExpectedErrs(s3Client.IsObjNotFoundErr)
    73  			} else {
    74  				sseBkt = NewSSEBucketClient("user-1", NewBucket(s3Client), cfgProvider)
    75  			}
    76  
    77  			err = sseBkt.Upload(context.Background(), "test", strings.NewReader("test"))
    78  			require.NoError(t, err)
    79  
    80  			// Ensure NO KMS header has been injected.
    81  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption"))
    82  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption-aws-kms-key-id"))
    83  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption-context"))
    84  
    85  			// Configure the config provider with a KMS key ID and without encryption context.
    86  			cfgProvider.s3SseType = s3.SSEKMS
    87  			cfgProvider.s3KmsKeyID = kmsKeyID
    88  
    89  			err = sseBkt.Upload(context.Background(), "test", strings.NewReader("test"))
    90  			require.NoError(t, err)
    91  
    92  			// Ensure the KMS header has been injected.
    93  			assert.Equal(t, "aws:kms", req.Header.Get("x-amz-server-side-encryption"))
    94  			assert.Equal(t, kmsKeyID, req.Header.Get("x-amz-server-side-encryption-aws-kms-key-id"))
    95  			assert.Equal(t, "", req.Header.Get("x-amz-server-side-encryption-context"))
    96  
    97  			// Configure the config provider with a KMS key ID and encryption context.
    98  			cfgProvider.s3SseType = s3.SSEKMS
    99  			cfgProvider.s3KmsKeyID = kmsKeyID
   100  			cfgProvider.s3KmsEncryptionContext = kmsEncryptionContext
   101  
   102  			err = sseBkt.Upload(context.Background(), "test", strings.NewReader("test"))
   103  			require.NoError(t, err)
   104  
   105  			// Ensure the KMS header has been injected.
   106  			assert.Equal(t, "aws:kms", req.Header.Get("x-amz-server-side-encryption"))
   107  			assert.Equal(t, kmsKeyID, req.Header.Get("x-amz-server-side-encryption-aws-kms-key-id"))
   108  			assert.Equal(t, base64.StdEncoding.EncodeToString([]byte(kmsEncryptionContext)), req.Header.Get("x-amz-server-side-encryption-context"))
   109  		})
   110  	}
   111  }
   112  
   113  type mockTenantConfigProvider struct {
   114  	s3SseType              string
   115  	s3KmsKeyID             string
   116  	s3KmsEncryptionContext string
   117  }
   118  
   119  func (m *mockTenantConfigProvider) S3SSEType(_ string) string {
   120  	return m.s3SseType
   121  }
   122  
   123  func (m *mockTenantConfigProvider) S3SSEKMSKeyID(_ string) string {
   124  	return m.s3KmsKeyID
   125  }
   126  
   127  func (m *mockTenantConfigProvider) S3SSEKMSEncryptionContext(_ string) string {
   128  	return m.s3KmsEncryptionContext
   129  }