github.com/grafana/pyroscope@v1.18.0/pkg/objstore/providers/s3/config_test.go (about) 1 // SPDX-License-Identifier: AGPL-3.0-only 2 // Provenance-includes-location: https://github.com/cortexproject/cortex/blob/master/pkg/storage/bucket/s3/config_test.go 3 // Provenance-includes-license: Apache-2.0 4 // Provenance-includes-copyright: The Cortex Authors. 5 6 package s3 7 8 import ( 9 "context" 10 "encoding/base64" 11 "io" 12 "net/http" 13 "net/http/httptest" 14 "os" 15 "testing" 16 "time" 17 18 "github.com/go-kit/log" 19 20 "github.com/grafana/dskit/flagext" 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/require" 23 ) 24 25 func TestSSEConfig_Validate(t *testing.T) { 26 tests := map[string]struct { 27 setup func() *SSEConfig 28 expected error 29 }{ 30 "should pass with default config": { 31 setup: func() *SSEConfig { 32 cfg := &SSEConfig{} 33 flagext.DefaultValues(cfg) 34 35 return cfg 36 }, 37 }, 38 "should fail on invalid SSE type": { 39 setup: func() *SSEConfig { 40 return &SSEConfig{ 41 Type: "unknown", 42 } 43 }, 44 expected: errUnsupportedSSEType, 45 }, 46 "should fail on invalid SSE KMS encryption context": { 47 setup: func() *SSEConfig { 48 return &SSEConfig{ 49 Type: SSEKMS, 50 KMSEncryptionContext: "!{}!", 51 } 52 }, 53 expected: errInvalidSSEContext, 54 }, 55 "should pass on valid SSE KMS encryption context": { 56 setup: func() *SSEConfig { 57 return &SSEConfig{ 58 Type: SSEKMS, 59 KMSEncryptionContext: `{"department": "10103.0"}`, 60 } 61 }, 62 }, 63 } 64 65 for testName, testData := range tests { 66 t.Run(testName, func(t *testing.T) { 67 assert.Equal(t, testData.expected, testData.setup().Validate()) 68 }) 69 } 70 } 71 72 func TestSSEConfig_BuildMinioConfig(t *testing.T) { 73 tests := map[string]struct { 74 cfg *SSEConfig 75 expectedType string 76 expectedKeyID string 77 expectedContext string 78 }{ 79 "SSE KMS without encryption context": { 80 cfg: &SSEConfig{ 81 Type: SSEKMS, 82 KMSKeyID: "test-key", 83 }, 84 expectedType: "aws:kms", 85 expectedKeyID: "test-key", 86 expectedContext: "", 87 }, 88 "SSE KMS with encryption context": { 89 cfg: &SSEConfig{ 90 Type: SSEKMS, 91 KMSKeyID: "test-key", 92 KMSEncryptionContext: "{\"department\":\"10103.0\"}", 93 }, 94 expectedType: "aws:kms", 95 expectedKeyID: "test-key", 96 expectedContext: "{\"department\":\"10103.0\"}", 97 }, 98 } 99 100 for testName, testData := range tests { 101 t.Run(testName, func(t *testing.T) { 102 sse, err := testData.cfg.BuildMinioConfig() 103 require.NoError(t, err) 104 105 headers := http.Header{} 106 sse.Marshal(headers) 107 108 assert.Equal(t, testData.expectedType, headers.Get("x-amz-server-side-encryption")) 109 assert.Equal(t, testData.expectedKeyID, headers.Get("x-amz-server-side-encryption-aws-kms-key-id")) 110 assert.Equal(t, base64.StdEncoding.EncodeToString([]byte(testData.expectedContext)), headers.Get("x-amz-server-side-encryption-context")) 111 }) 112 } 113 } 114 115 func TestParseKMSEncryptionContext(t *testing.T) { 116 actual, err := parseKMSEncryptionContext("") 117 assert.NoError(t, err) 118 assert.Equal(t, map[string]string(nil), actual) 119 120 expected := map[string]string{ 121 "department": "10103.0", 122 } 123 actual, err = parseKMSEncryptionContext(`{"department": "10103.0"}`) 124 assert.NoError(t, err) 125 assert.Equal(t, expected, actual) 126 } 127 128 func TestConfig_Validate(t *testing.T) { 129 tests := map[string]struct { 130 setup func() *Config 131 expected error 132 }{ 133 "should pass with default config": { 134 setup: func() *Config { 135 cfg := &Config{} 136 flagext.DefaultValues(cfg) 137 138 return cfg 139 }, 140 }, 141 "should fail on invalid bucket lookup style": { 142 setup: func() *Config { 143 cfg := &Config{} 144 flagext.DefaultValues(cfg) 145 cfg.BucketLookupType = "invalid" 146 return cfg 147 }, 148 expected: errUnsupportedBucketLookupType, 149 }, 150 "should fail if force-path-style conflicts with bucket-lookup-type": { 151 setup: func() *Config { 152 cfg := &Config{} 153 flagext.DefaultValues(cfg) 154 cfg.ForcePathStyle = true 155 cfg.BucketLookupType = VirtualHostedStyleLookup 156 return cfg 157 }, 158 expected: errBucketLookupConfigConflict, 159 }, 160 } 161 162 for testName, testData := range tests { 163 t.Run(testName, func(t *testing.T) { 164 assert.Equal(t, testData.expected, testData.setup().Validate()) 165 }) 166 } 167 } 168 169 type testRoundTripper struct { 170 roundTrip func(r *http.Request) (*http.Response, error) 171 } 172 173 func (t *testRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { 174 return t.roundTrip(r) 175 } 176 177 func handleSTSRequest(t *testing.T, r *http.Request, w http.ResponseWriter) { 178 body, err := io.ReadAll(r.Body) 179 require.NoError(t, err) 180 181 require.Contains(t, string(body), "RoleArn=arn%3Ahello-world") 182 require.Contains(t, string(body), "WebIdentityToken=my-web-token") 183 require.Contains(t, string(body), "Action=AssumeRoleWithWebIdentity") 184 185 w.WriteHeader(200) 186 _, err = w.Write([]byte(`<?xml version="1.0" encoding="UTF-8"?> 187 <AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> 188 <AssumeRoleWithWebIdentityResult> 189 <Credentials> 190 <AccessKeyId>test-key</AccessKeyId> 191 <SecretAccessKey>test-secret</SecretAccessKey> 192 <SessionToken>test-token</SessionToken> 193 <Expiration>` + time.Now().Add(time.Hour).Format(time.RFC3339) + `</Expiration> 194 </Credentials> 195 </AssumeRoleWithWebIdentityResult> 196 <ResponseMetadata> 197 <RequestId>test-request-id</RequestId> 198 </ResponseMetadata> 199 </AssumeRoleWithWebIdentityResponse>`)) 200 require.NoError(t, err) 201 202 } 203 204 func overrideEnv(t testing.TB, kv ...string) { 205 old := make([]string, len(kv)) 206 for i := 0; i < len(kv); i += 2 { 207 k := kv[i] 208 v := kv[i+1] 209 old[i] = k 210 old[i+1] = os.Getenv(k) 211 os.Setenv(k, v) 212 } 213 t.Cleanup(func() { 214 for i := 0; i < len(old); i += 2 { 215 os.Setenv(old[i], old[i+1]) 216 } 217 }) 218 } 219 220 func TestAWSSTSWebIdentity(t *testing.T) { 221 logger := log.NewNopLogger() 222 tmpDir := t.TempDir() 223 224 // override env variables, will be cleaned up by t.Cleanup 225 overrideEnv(t, 226 "AWS_WEB_IDENTITY_TOKEN_FILE", tmpDir+"/token", 227 "AWS_ROLE_ARN", "arn:hello-world", 228 "AWS_DEFAULT_REGION", "eu-central-1", 229 "AWS_CONFIG_FILE", "/dev/null", // dont accidentally use real config 230 "AWS_ACCESS_KEY_ID", "", // dont use real credentials 231 "AWS_SECRET_ACCESS_KEY", "", // dont use real credentials 232 ) 233 234 rt := &testRoundTripper{ 235 roundTrip: func(r *http.Request) (*http.Response, error) { 236 w := httptest.NewRecorder() 237 if r.Body != nil { 238 defer r.Body.Close() 239 } 240 switch r.URL.String() { 241 case "https://sts.amazonaws.com": 242 handleSTSRequest(t, r, w) 243 case "https://eu-central-1.amazonaws.com/pyroscope-test-bucket/test": 244 assert.Equal(t, "GET", r.Method) 245 assert.Contains(t, r.Header.Get("Authorization"), "AWS4-HMAC-SHA256 Credential=test-key") 246 w.Header().Set("Last-Modified", time.Now().Format("Mon, 2 Jan 2006 15:04:05 GMT")) 247 w.WriteHeader(200) 248 _, err := w.Write([]byte("test")) 249 require.NoError(t, err) 250 default: 251 w.WriteHeader(404) 252 _, err := w.Write([]byte("unexpected")) 253 require.NoError(t, err) 254 t.Errorf("unexpected request: %s", r.URL.Host) 255 t.FailNow() 256 } 257 return w.Result(), nil 258 }, 259 } 260 oldDefaultTransport := http.DefaultTransport 261 oldDefaultClient := http.DefaultClient 262 http.DefaultTransport = rt 263 http.DefaultClient = &http.Client{ 264 Transport: rt, 265 } 266 // restore default transport and client 267 t.Cleanup(func() { 268 http.DefaultTransport = oldDefaultTransport 269 http.DefaultClient = oldDefaultClient 270 }) 271 272 // mock a web token 273 err := os.WriteFile(tmpDir+"/token", []byte("my-web-token"), 0644) 274 require.NoError(t, err) 275 276 cfg := Config{ 277 SignatureVersion: SignatureVersionV4, 278 BucketName: "pyroscope-test-bucket", 279 Region: "eu-central-1", 280 Endpoint: "eu-central-1.amazonaws.com", 281 BucketLookupType: AutoLookup, 282 } 283 284 cfg.HTTP.Transport = rt 285 r, err := NewBucketClient(cfg, "test", logger) 286 require.NoError(t, err) 287 288 _, err = r.Get(context.Background(), "test") 289 require.NoError(t, err) 290 291 }