github.com/dannyzhou2015/migrate/v4@v4.15.2/source/aws_s3/s3_test.go (about) 1 package awss3 2 3 import ( 4 "errors" 5 "io/ioutil" 6 "strings" 7 "testing" 8 9 "github.com/aws/aws-sdk-go/aws" 10 "github.com/aws/aws-sdk-go/service/s3" 11 st "github.com/dannyzhou2015/migrate/v4/source/testing" 12 "github.com/stretchr/testify/assert" 13 ) 14 15 func Test(t *testing.T) { 16 s3Client := fakeS3{ 17 bucket: "some-bucket", 18 objects: map[string]string{ 19 "staging/migrations/1_foobar.up.sql": "1 up", 20 "staging/migrations/1_foobar.down.sql": "1 down", 21 "prod/migrations/1_foobar.up.sql": "1 up", 22 "prod/migrations/1_foobar.down.sql": "1 down", 23 "prod/migrations/3_foobar.up.sql": "3 up", 24 "prod/migrations/4_foobar.up.sql": "4 up", 25 "prod/migrations/4_foobar.down.sql": "4 down", 26 "prod/migrations/5_foobar.down.sql": "5 down", 27 "prod/migrations/7_foobar.up.sql": "7 up", 28 "prod/migrations/7_foobar.down.sql": "7 down", 29 "prod/migrations/not-a-migration.txt": "", 30 "prod/migrations/0-random-stuff/whatever.txt": "", 31 }, 32 } 33 driver, err := WithInstance(&s3Client, &Config{ 34 Bucket: "some-bucket", 35 Prefix: "prod/migrations/", 36 }) 37 if err != nil { 38 t.Fatal(err) 39 } 40 st.Test(t, driver) 41 } 42 43 func TestParseURI(t *testing.T) { 44 tests := []struct { 45 name string 46 uri string 47 config *Config 48 }{ 49 { 50 "with prefix, no trailing slash", 51 "s3://migration-bucket/production", 52 &Config{ 53 Bucket: "migration-bucket", 54 Prefix: "production/", 55 }, 56 }, 57 { 58 "without prefix, no trailing slash", 59 "s3://migration-bucket", 60 &Config{ 61 Bucket: "migration-bucket", 62 }, 63 }, 64 { 65 "with prefix, trailing slash", 66 "s3://migration-bucket/production/", 67 &Config{ 68 Bucket: "migration-bucket", 69 Prefix: "production/", 70 }, 71 }, 72 { 73 "without prefix, trailing slash", 74 "s3://migration-bucket/", 75 &Config{ 76 Bucket: "migration-bucket", 77 }, 78 }, 79 } 80 for _, test := range tests { 81 t.Run(test.name, func(t *testing.T) { 82 actual, err := parseURI(test.uri) 83 if err != nil { 84 t.Fatal(err) 85 } 86 assert.Equal(t, test.config, actual) 87 }) 88 } 89 } 90 91 type fakeS3 struct { 92 s3.S3 93 bucket string 94 objects map[string]string 95 } 96 97 func (s *fakeS3) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) { 98 bucket := aws.StringValue(input.Bucket) 99 if bucket != s.bucket { 100 return nil, errors.New("bucket not found") 101 } 102 prefix := aws.StringValue(input.Prefix) 103 delimiter := aws.StringValue(input.Delimiter) 104 var output s3.ListObjectsOutput 105 for name := range s.objects { 106 if strings.HasPrefix(name, prefix) { 107 if delimiter == "" || !strings.Contains(strings.Replace(name, prefix, "", 1), delimiter) { 108 output.Contents = append(output.Contents, &s3.Object{ 109 Key: aws.String(name), 110 }) 111 } 112 } 113 } 114 return &output, nil 115 } 116 117 func (s *fakeS3) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) { 118 bucket := aws.StringValue(input.Bucket) 119 if bucket != s.bucket { 120 return nil, errors.New("bucket not found") 121 } 122 if data, ok := s.objects[aws.StringValue(input.Key)]; ok { 123 body := ioutil.NopCloser(strings.NewReader(data)) 124 return &s3.GetObjectOutput{Body: body}, nil 125 } 126 return nil, errors.New("object not found") 127 }