github.com/grafana/pyroscope@v1.18.0/pkg/objstore/providers/s3/config_test.go (about)

     1  // SPDX-License-Identifier: AGPL-3.0-only
     2  // Provenance-includes-location: https://github.com/cortexproject/cortex/blob/master/pkg/storage/bucket/s3/config_test.go
     3  // Provenance-includes-license: Apache-2.0
     4  // Provenance-includes-copyright: The Cortex Authors.
     5  
     6  package s3
     7  
     8  import (
     9  	"context"
    10  	"encoding/base64"
    11  	"io"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"os"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/go-kit/log"
    19  
    20  	"github.com/grafana/dskit/flagext"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  func TestSSEConfig_Validate(t *testing.T) {
    26  	tests := map[string]struct {
    27  		setup    func() *SSEConfig
    28  		expected error
    29  	}{
    30  		"should pass with default config": {
    31  			setup: func() *SSEConfig {
    32  				cfg := &SSEConfig{}
    33  				flagext.DefaultValues(cfg)
    34  
    35  				return cfg
    36  			},
    37  		},
    38  		"should fail on invalid SSE type": {
    39  			setup: func() *SSEConfig {
    40  				return &SSEConfig{
    41  					Type: "unknown",
    42  				}
    43  			},
    44  			expected: errUnsupportedSSEType,
    45  		},
    46  		"should fail on invalid SSE KMS encryption context": {
    47  			setup: func() *SSEConfig {
    48  				return &SSEConfig{
    49  					Type:                 SSEKMS,
    50  					KMSEncryptionContext: "!{}!",
    51  				}
    52  			},
    53  			expected: errInvalidSSEContext,
    54  		},
    55  		"should pass on valid SSE KMS encryption context": {
    56  			setup: func() *SSEConfig {
    57  				return &SSEConfig{
    58  					Type:                 SSEKMS,
    59  					KMSEncryptionContext: `{"department": "10103.0"}`,
    60  				}
    61  			},
    62  		},
    63  	}
    64  
    65  	for testName, testData := range tests {
    66  		t.Run(testName, func(t *testing.T) {
    67  			assert.Equal(t, testData.expected, testData.setup().Validate())
    68  		})
    69  	}
    70  }
    71  
    72  func TestSSEConfig_BuildMinioConfig(t *testing.T) {
    73  	tests := map[string]struct {
    74  		cfg             *SSEConfig
    75  		expectedType    string
    76  		expectedKeyID   string
    77  		expectedContext string
    78  	}{
    79  		"SSE KMS without encryption context": {
    80  			cfg: &SSEConfig{
    81  				Type:     SSEKMS,
    82  				KMSKeyID: "test-key",
    83  			},
    84  			expectedType:    "aws:kms",
    85  			expectedKeyID:   "test-key",
    86  			expectedContext: "",
    87  		},
    88  		"SSE KMS with encryption context": {
    89  			cfg: &SSEConfig{
    90  				Type:                 SSEKMS,
    91  				KMSKeyID:             "test-key",
    92  				KMSEncryptionContext: "{\"department\":\"10103.0\"}",
    93  			},
    94  			expectedType:    "aws:kms",
    95  			expectedKeyID:   "test-key",
    96  			expectedContext: "{\"department\":\"10103.0\"}",
    97  		},
    98  	}
    99  
   100  	for testName, testData := range tests {
   101  		t.Run(testName, func(t *testing.T) {
   102  			sse, err := testData.cfg.BuildMinioConfig()
   103  			require.NoError(t, err)
   104  
   105  			headers := http.Header{}
   106  			sse.Marshal(headers)
   107  
   108  			assert.Equal(t, testData.expectedType, headers.Get("x-amz-server-side-encryption"))
   109  			assert.Equal(t, testData.expectedKeyID, headers.Get("x-amz-server-side-encryption-aws-kms-key-id"))
   110  			assert.Equal(t, base64.StdEncoding.EncodeToString([]byte(testData.expectedContext)), headers.Get("x-amz-server-side-encryption-context"))
   111  		})
   112  	}
   113  }
   114  
   115  func TestParseKMSEncryptionContext(t *testing.T) {
   116  	actual, err := parseKMSEncryptionContext("")
   117  	assert.NoError(t, err)
   118  	assert.Equal(t, map[string]string(nil), actual)
   119  
   120  	expected := map[string]string{
   121  		"department": "10103.0",
   122  	}
   123  	actual, err = parseKMSEncryptionContext(`{"department": "10103.0"}`)
   124  	assert.NoError(t, err)
   125  	assert.Equal(t, expected, actual)
   126  }
   127  
   128  func TestConfig_Validate(t *testing.T) {
   129  	tests := map[string]struct {
   130  		setup    func() *Config
   131  		expected error
   132  	}{
   133  		"should pass with default config": {
   134  			setup: func() *Config {
   135  				cfg := &Config{}
   136  				flagext.DefaultValues(cfg)
   137  
   138  				return cfg
   139  			},
   140  		},
   141  		"should fail on invalid bucket lookup style": {
   142  			setup: func() *Config {
   143  				cfg := &Config{}
   144  				flagext.DefaultValues(cfg)
   145  				cfg.BucketLookupType = "invalid"
   146  				return cfg
   147  			},
   148  			expected: errUnsupportedBucketLookupType,
   149  		},
   150  		"should fail if force-path-style conflicts with bucket-lookup-type": {
   151  			setup: func() *Config {
   152  				cfg := &Config{}
   153  				flagext.DefaultValues(cfg)
   154  				cfg.ForcePathStyle = true
   155  				cfg.BucketLookupType = VirtualHostedStyleLookup
   156  				return cfg
   157  			},
   158  			expected: errBucketLookupConfigConflict,
   159  		},
   160  	}
   161  
   162  	for testName, testData := range tests {
   163  		t.Run(testName, func(t *testing.T) {
   164  			assert.Equal(t, testData.expected, testData.setup().Validate())
   165  		})
   166  	}
   167  }
   168  
   169  type testRoundTripper struct {
   170  	roundTrip func(r *http.Request) (*http.Response, error)
   171  }
   172  
   173  func (t *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
   174  	return t.roundTrip(r)
   175  }
   176  
   177  func handleSTSRequest(t *testing.T, r *http.Request, w http.ResponseWriter) {
   178  	body, err := io.ReadAll(r.Body)
   179  	require.NoError(t, err)
   180  
   181  	require.Contains(t, string(body), "RoleArn=arn%3Ahello-world")
   182  	require.Contains(t, string(body), "WebIdentityToken=my-web-token")
   183  	require.Contains(t, string(body), "Action=AssumeRoleWithWebIdentity")
   184  
   185  	w.WriteHeader(200)
   186  	_, err = w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?>
   187  				<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
   188  				  <AssumeRoleWithWebIdentityResult>
   189  				    <Credentials>
   190  				      <AccessKeyId>test-key</AccessKeyId>
   191  				      <SecretAccessKey>test-secret</SecretAccessKey>
   192  				      <SessionToken>test-token</SessionToken>
   193  				      <Expiration>` + time.Now().Add(time.Hour).Format(time.RFC3339) + `</Expiration>
   194  				    </Credentials>
   195  				  </AssumeRoleWithWebIdentityResult>
   196  				  <ResponseMetadata>
   197  				    <RequestId>test-request-id</RequestId>
   198  				  </ResponseMetadata>
   199  				</AssumeRoleWithWebIdentityResponse>`))
   200  	require.NoError(t, err)
   201  
   202  }
   203  
   204  func overrideEnv(t testing.TB, kv ...string) {
   205  	old := make([]string, len(kv))
   206  	for i := 0; i < len(kv); i += 2 {
   207  		k := kv[i]
   208  		v := kv[i+1]
   209  		old[i] = k
   210  		old[i+1] = os.Getenv(k)
   211  		os.Setenv(k, v)
   212  	}
   213  	t.Cleanup(func() {
   214  		for i := 0; i < len(old); i += 2 {
   215  			os.Setenv(old[i], old[i+1])
   216  		}
   217  	})
   218  }
   219  
   220  func TestAWSSTSWebIdentity(t *testing.T) {
   221  	logger := log.NewNopLogger()
   222  	tmpDir := t.TempDir()
   223  
   224  	// override env variables, will be cleaned up by t.Cleanup
   225  	overrideEnv(t,
   226  		"AWS_WEB_IDENTITY_TOKEN_FILE", tmpDir+"/token",
   227  		"AWS_ROLE_ARN", "arn:hello-world",
   228  		"AWS_DEFAULT_REGION", "eu-central-1",
   229  		"AWS_CONFIG_FILE", "/dev/null", // dont accidentally use real config
   230  		"AWS_ACCESS_KEY_ID", "", // dont use real credentials
   231  		"AWS_SECRET_ACCESS_KEY", "", // dont use real credentials
   232  	)
   233  
   234  	rt := &testRoundTripper{
   235  		roundTrip: func(r *http.Request) (*http.Response, error) {
   236  			w := httptest.NewRecorder()
   237  			if r.Body != nil {
   238  				defer r.Body.Close()
   239  			}
   240  			switch r.URL.String() {
   241  			case "https://sts.amazonaws.com":
   242  				handleSTSRequest(t, r, w)
   243  			case "https://eu-central-1.amazonaws.com/pyroscope-test-bucket/test":
   244  				assert.Equal(t, "GET", r.Method)
   245  				assert.Contains(t, r.Header.Get("Authorization"), "AWS4-HMAC-SHA256 Credential=test-key")
   246  				w.Header().Set("Last-Modified", time.Now().Format("Mon, 2 Jan 2006 15:04:05 GMT"))
   247  				w.WriteHeader(200)
   248  				_, err := w.Write([]byte("test"))
   249  				require.NoError(t, err)
   250  			default:
   251  				w.WriteHeader(404)
   252  				_, err := w.Write([]byte("unexpected"))
   253  				require.NoError(t, err)
   254  				t.Errorf("unexpected request: %s", r.URL.Host)
   255  				t.FailNow()
   256  			}
   257  			return w.Result(), nil
   258  		},
   259  	}
   260  	oldDefaultTransport := http.DefaultTransport
   261  	oldDefaultClient := http.DefaultClient
   262  	http.DefaultTransport = rt
   263  	http.DefaultClient = &http.Client{
   264  		Transport: rt,
   265  	}
   266  	// restore default transport and client
   267  	t.Cleanup(func() {
   268  		http.DefaultTransport = oldDefaultTransport
   269  		http.DefaultClient = oldDefaultClient
   270  	})
   271  
   272  	// mock a web token
   273  	err := os.WriteFile(tmpDir+"/token", []byte("my-web-token"), 0644)
   274  	require.NoError(t, err)
   275  
   276  	cfg := Config{
   277  		SignatureVersion: SignatureVersionV4,
   278  		BucketName:       "pyroscope-test-bucket",
   279  		Region:           "eu-central-1",
   280  		Endpoint:         "eu-central-1.amazonaws.com",
   281  		BucketLookupType: AutoLookup,
   282  	}
   283  
   284  	cfg.HTTP.Transport = rt
   285  	r, err := NewBucketClient(cfg, "test", logger)
   286  	require.NoError(t, err)
   287  
   288  	_, err = r.Get(context.Background(), "test")
   289  	require.NoError(t, err)
   290  
   291  }