github.com/dhui/migrate@v3.4.0+incompatible/source/aws_s3/s3_test.go (about)

     1  package awss3
     2  
     3  import (
     4  	"errors"
     5  	"io/ioutil"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/aws/aws-sdk-go/aws"
    10  	"github.com/aws/aws-sdk-go/service/s3"
    11  	"github.com/golang-migrate/migrate/source"
    12  	st "github.com/golang-migrate/migrate/source/testing"
    13  )
    14  
    15  func Test(t *testing.T) {
    16  	s3Client := fakeS3{
    17  		bucket: "some-bucket",
    18  		objects: map[string]string{
    19  			"staging/migrations/1_foobar.up.sql":          "1 up",
    20  			"staging/migrations/1_foobar.down.sql":        "1 down",
    21  			"prod/migrations/1_foobar.up.sql":             "1 up",
    22  			"prod/migrations/1_foobar.down.sql":           "1 down",
    23  			"prod/migrations/3_foobar.up.sql":             "3 up",
    24  			"prod/migrations/4_foobar.up.sql":             "4 up",
    25  			"prod/migrations/4_foobar.down.sql":           "4 down",
    26  			"prod/migrations/5_foobar.down.sql":           "5 down",
    27  			"prod/migrations/7_foobar.up.sql":             "7 up",
    28  			"prod/migrations/7_foobar.down.sql":           "7 down",
    29  			"prod/migrations/not-a-migration.txt":         "",
    30  			"prod/migrations/0-random-stuff/whatever.txt": "",
    31  		},
    32  	}
    33  	driver := s3Driver{
    34  		bucket:     "some-bucket",
    35  		prefix:     "prod/migrations/",
    36  		migrations: source.NewMigrations(),
    37  		s3client:   &s3Client,
    38  	}
    39  	err := driver.loadMigrations()
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	st.Test(t, &driver)
    44  }
    45  
    46  type fakeS3 struct {
    47  	s3.S3
    48  	bucket  string
    49  	objects map[string]string
    50  }
    51  
    52  func (s *fakeS3) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
    53  	bucket := aws.StringValue(input.Bucket)
    54  	if bucket != s.bucket {
    55  		return nil, errors.New("bucket not found")
    56  	}
    57  	prefix := aws.StringValue(input.Prefix)
    58  	delimiter := aws.StringValue(input.Delimiter)
    59  	var output s3.ListObjectsOutput
    60  	for name := range s.objects {
    61  		if strings.HasPrefix(name, prefix) {
    62  			if delimiter == "" || !strings.Contains(strings.Replace(name, prefix, "", 1), delimiter) {
    63  				output.Contents = append(output.Contents, &s3.Object{
    64  					Key: aws.String(name),
    65  				})
    66  			}
    67  		}
    68  	}
    69  	return &output, nil
    70  }
    71  
    72  func (s *fakeS3) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) {
    73  	bucket := aws.StringValue(input.Bucket)
    74  	if bucket != s.bucket {
    75  		return nil, errors.New("bucket not found")
    76  	}
    77  	if data, ok := s.objects[aws.StringValue(input.Key)]; ok {
    78  		body := ioutil.NopCloser(strings.NewReader(data))
    79  		return &s3.GetObjectOutput{Body: body}, nil
    80  	}
    81  	return nil, errors.New("object not found")
    82  }