github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/utils/awsrefreshcreds/creds_test.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package awsrefreshcreds 16 17 import ( 18 "os" 19 "path/filepath" 20 "testing" 21 "time" 22 23 "github.com/aws/aws-sdk-go/aws/credentials" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 ) 27 28 type staticProvider struct { 29 v credentials.Value 30 } 31 32 func (p *staticProvider) Retrieve() (credentials.Value, error) { 33 return p.v, nil 34 } 35 36 func (p *staticProvider) IsExpired() bool { 37 return false 38 } 39 40 func TestRefreshingCredentialsProvider(t *testing.T) { 41 var sp staticProvider 42 sp.v.AccessKeyID = "ExampleOne" 43 rp := NewRefreshingCredentialsProvider(&sp, time.Minute) 44 45 n := time.Now() 46 origNow := now 47 t.Cleanup(func() { 48 now = origNow 49 }) 50 now = func() time.Time { return n } 51 52 v, err := rp.Retrieve() 53 assert.NoError(t, err) 54 assert.Equal(t, "ExampleOne", v.AccessKeyID) 55 assert.False(t, rp.IsExpired()) 56 57 sp.v.AccessKeyID = "ExampleTwo" 58 59 now = func() time.Time { return n.Add(30 * time.Second) } 60 61 v, err = rp.Retrieve() 62 assert.NoError(t, err) 63 assert.Equal(t, "ExampleTwo", v.AccessKeyID) 64 assert.False(t, rp.IsExpired()) 65 66 now = func() time.Time { return n.Add(91 * time.Second) } 67 assert.True(t, rp.IsExpired()) 68 v, err = rp.Retrieve() 69 assert.NoError(t, err) 70 assert.Equal(t, "ExampleTwo", v.AccessKeyID) 71 assert.False(t, rp.IsExpired()) 72 } 73 74 func TestRefreshingCredentialsProviderShared(t *testing.T) { 75 d := t.TempDir() 76 77 onecontents := ` 78 [backup] 79 aws_access_key_id = AKIAAAAAAAAAAAAAAAAA 80 aws_secret_access_key = oF8x/JQEGchAAAAAAAAAAAAAAAAAAAAAAAAAAAAA 81 ` 82 83 twocontents := ` 84 [backup] 85 aws_access_key_id = AKIZZZZZZZZZZZZZZZZZ 86 aws_secret_access_key = oF8x/JQEGchZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ 87 ` 88 89 configpath := filepath.Join(d, "config") 90 91 require.NoError(t, os.WriteFile(configpath, []byte(onecontents), 0700)) 92 93 n := time.Now() 94 origNow := now 95 t.Cleanup(func() { 96 now = origNow 97 }) 98 now = func() time.Time { return n } 99 100 creds := credentials.NewCredentials( 101 NewRefreshingCredentialsProvider(&credentials.SharedCredentialsProvider{ 102 Filename: configpath, 103 Profile: "backup", 104 }, time.Minute), 105 ) 106 107 v, err := creds.Get() 108 assert.NoError(t, err) 109 assert.Equal(t, "AKIAAAAAAAAAAAAAAAAA", v.AccessKeyID) 110 111 require.NoError(t, os.WriteFile(configpath, []byte(twocontents), 0700)) 112 113 now = func() time.Time { return n.Add(61 * time.Second) } 114 v, err = creds.Get() 115 assert.NoError(t, err) 116 assert.Equal(t, "AKIZZZZZZZZZZZZZZZZZ", v.AccessKeyID) 117 }