github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/cust_integ_shared_test.go (about)

     1  //go:build integration
     2  // +build integration
     3  
     4  package s3_test
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/tls"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"net/http"
    15  	"os"
    16  	"reflect"
    17  	"strings"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/aavshr/aws-sdk-go/aws"
    22  	"github.com/aavshr/aws-sdk-go/aws/arn"
    23  	"github.com/aavshr/aws-sdk-go/aws/endpoints"
    24  	"github.com/aavshr/aws-sdk-go/aws/request"
    25  	"github.com/aavshr/aws-sdk-go/awstesting/integration"
    26  	"github.com/aavshr/aws-sdk-go/awstesting/integration/s3integ"
    27  	"github.com/aavshr/aws-sdk-go/service/s3"
    28  	"github.com/aavshr/aws-sdk-go/service/s3control"
    29  	"github.com/aavshr/aws-sdk-go/service/sts"
    30  )
    31  
    32  const integBucketPrefix = "aws-sdk-go-integration"
    33  
    34  var integMetadata = struct {
    35  	AccountID string
    36  	Region    string
    37  	Buckets   struct {
    38  		Source struct {
    39  			Name string
    40  			ARN  string
    41  		}
    42  		Target struct {
    43  			Name string
    44  			ARN  string
    45  		}
    46  	}
    47  
    48  	AccessPoints struct {
    49  		Source struct {
    50  			Name string
    51  			ARN  string
    52  		}
    53  		Target struct {
    54  			Name string
    55  			ARN  string
    56  		}
    57  	}
    58  }{}
    59  
    60  var s3Svc *s3.S3
    61  var s3ControlSvc *s3control.S3Control
    62  var stsSvc *sts.STS
    63  var httpClient *http.Client
    64  
    65  // TODO: (Westeros) Remove Custom Resolver Usage Before Launch
    66  type customS3Resolver struct {
    67  	endpoint string
    68  	withTLS  bool
    69  	region   string
    70  }
    71  
    72  func (r customS3Resolver) EndpointFor(service, _ string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
    73  	switch strings.ToLower(service) {
    74  	case "s3-control":
    75  	case "s3":
    76  	default:
    77  		return endpoints.ResolvedEndpoint{}, fmt.Errorf("unsupported in custom resolver")
    78  	}
    79  
    80  	return endpoints.ResolvedEndpoint{
    81  		PartitionID:   "aws",
    82  		SigningRegion: r.region,
    83  		SigningName:   "s3",
    84  		SigningMethod: "s3v4",
    85  		URL:           endpoints.AddScheme(r.endpoint, r.withTLS),
    86  	}, nil
    87  }
    88  
    89  func TestMain(m *testing.M) {
    90  	var result int
    91  	defer func() {
    92  		if r := recover(); r != nil {
    93  			fmt.Fprintln(os.Stderr, "S3 integration tests paniced,", r)
    94  			result = 1
    95  		}
    96  		os.Exit(result)
    97  	}()
    98  
    99  	var verifyTLS bool
   100  	var s3Endpoint, s3ControlEndpoint string
   101  	var s3EnableTLS, s3ControlEnableTLS bool
   102  
   103  	flag.StringVar(&s3Endpoint, "s3-endpoint", "", "integration endpoint for S3")
   104  	flag.BoolVar(&s3EnableTLS, "s3-tls", true, "enable TLS for S3 endpoint")
   105  
   106  	flag.StringVar(&s3ControlEndpoint, "s3-control-endpoint", "", "integration endpoint for S3")
   107  	flag.BoolVar(&s3ControlEnableTLS, "s3-control-tls", true, "enable TLS for S3 control endpoint")
   108  
   109  	flag.StringVar(&integMetadata.AccountID, "account", "", "integration account id")
   110  	flag.BoolVar(&verifyTLS, "verify-tls", true, "verify server TLS certificate")
   111  	flag.Parse()
   112  
   113  	httpClient = &http.Client{
   114  		Transport: &http.Transport{
   115  			TLSClientConfig: &tls.Config{InsecureSkipVerify: verifyTLS},
   116  		}}
   117  
   118  	sess := integration.SessionWithDefaultRegion("us-west-2").Copy(&aws.Config{
   119  		HTTPClient: httpClient,
   120  	})
   121  
   122  	var s3EndpointResolver endpoints.Resolver
   123  	if len(s3Endpoint) != 0 {
   124  		s3EndpointResolver = customS3Resolver{
   125  			endpoint: s3Endpoint,
   126  			withTLS:  s3EnableTLS,
   127  			region:   aws.StringValue(sess.Config.Region),
   128  		}
   129  	}
   130  	s3Svc = s3.New(sess, &aws.Config{
   131  		DisableSSL:       aws.Bool(!s3EnableTLS),
   132  		EndpointResolver: s3EndpointResolver,
   133  	})
   134  
   135  	var s3ControlEndpointResolver endpoints.Resolver
   136  	if len(s3Endpoint) != 0 {
   137  		s3ControlEndpointResolver = customS3Resolver{
   138  			endpoint: s3ControlEndpoint,
   139  			withTLS:  s3ControlEnableTLS,
   140  			region:   aws.StringValue(sess.Config.Region),
   141  		}
   142  	}
   143  	s3ControlSvc = s3control.New(sess, &aws.Config{
   144  		DisableSSL:       aws.Bool(!s3ControlEnableTLS),
   145  		EndpointResolver: s3ControlEndpointResolver,
   146  	})
   147  	stsSvc = sts.New(sess)
   148  
   149  	var err error
   150  	integMetadata.AccountID, err = getAccountID()
   151  	if err != nil {
   152  		fmt.Fprintf(os.Stderr, "failed to get integration aws account id: %v\n", err)
   153  		result = 1
   154  		return
   155  	}
   156  
   157  	bucketCleanup, err := setupBuckets()
   158  	defer bucketCleanup()
   159  	if err != nil {
   160  		fmt.Fprintf(os.Stderr, "failed to setup integration test buckets: %v\n", err)
   161  		result = 1
   162  		return
   163  	}
   164  
   165  	accessPointsCleanup, err := setupAccessPoints()
   166  	defer accessPointsCleanup()
   167  	if err != nil {
   168  		fmt.Fprintf(os.Stderr, "failed to setup integration test access points: %v\n", err)
   169  		result = 1
   170  		return
   171  	}
   172  
   173  	result = m.Run()
   174  }
   175  
   176  func getAccountID() (string, error) {
   177  	if len(integMetadata.AccountID) != 0 {
   178  		return integMetadata.AccountID, nil
   179  	}
   180  
   181  	output, err := stsSvc.GetCallerIdentity(nil)
   182  	if err != nil {
   183  		return "", fmt.Errorf("faield to get sts caller identity")
   184  	}
   185  
   186  	return *output.Account, nil
   187  }
   188  
   189  func setupBuckets() (func(), error) {
   190  	var cleanups []func()
   191  
   192  	cleanup := func() {
   193  		for i := range cleanups {
   194  			cleanups[i]()
   195  		}
   196  	}
   197  
   198  	bucketCreates := []struct {
   199  		name *string
   200  		arn  *string
   201  	}{
   202  		{name: &integMetadata.Buckets.Source.Name, arn: &integMetadata.Buckets.Source.ARN},
   203  		{name: &integMetadata.Buckets.Target.Name, arn: &integMetadata.Buckets.Target.ARN},
   204  	}
   205  
   206  	for _, bucket := range bucketCreates {
   207  		*bucket.name = s3integ.GenerateBucketName()
   208  
   209  		if err := s3integ.SetupBucket(s3Svc, *bucket.name); err != nil {
   210  			return cleanup, err
   211  		}
   212  
   213  		// Compute ARN
   214  		bARN := arn.ARN{
   215  			Partition: "aws",
   216  			Service:   "s3",
   217  			Region:    s3Svc.SigningRegion,
   218  			AccountID: integMetadata.AccountID,
   219  			Resource:  fmt.Sprintf("bucket_name:%s", *bucket.name),
   220  		}.String()
   221  
   222  		*bucket.arn = bARN
   223  
   224  		bucketName := *bucket.name
   225  		cleanups = append(cleanups, func() {
   226  			if err := s3integ.CleanupBucket(s3Svc, bucketName); err != nil {
   227  				fmt.Fprintln(os.Stderr, err)
   228  			}
   229  		})
   230  	}
   231  
   232  	return cleanup, nil
   233  }
   234  
   235  func setupAccessPoints() (func(), error) {
   236  	var cleanups []func()
   237  
   238  	cleanup := func() {
   239  		for i := range cleanups {
   240  			cleanups[i]()
   241  		}
   242  	}
   243  
   244  	creates := []struct {
   245  		bucket string
   246  		name   *string
   247  		arn    *string
   248  	}{
   249  		{bucket: integMetadata.Buckets.Source.Name, name: &integMetadata.AccessPoints.Source.Name, arn: &integMetadata.AccessPoints.Source.ARN},
   250  		{bucket: integMetadata.Buckets.Target.Name, name: &integMetadata.AccessPoints.Target.Name, arn: &integMetadata.AccessPoints.Target.ARN},
   251  	}
   252  
   253  	for _, ap := range creates {
   254  		*ap.name = integration.UniqueID()
   255  
   256  		err := s3integ.SetupAccessPoint(s3ControlSvc, integMetadata.AccountID, ap.bucket, *ap.name)
   257  		if err != nil {
   258  			return cleanup, err
   259  		}
   260  
   261  		// Compute ARN
   262  		apARN := arn.ARN{
   263  			Partition: "aws",
   264  			Service:   "s3",
   265  			Region:    s3ControlSvc.SigningRegion,
   266  			AccountID: integMetadata.AccountID,
   267  			Resource:  fmt.Sprintf("accesspoint/%s", *ap.name),
   268  		}.String()
   269  
   270  		*ap.arn = apARN
   271  
   272  		apName := *ap.name
   273  		cleanups = append(cleanups, func() {
   274  			err := s3integ.CleanupAccessPoint(s3ControlSvc, integMetadata.AccountID, apName)
   275  			if err != nil {
   276  				fmt.Fprintln(os.Stderr, err)
   277  			}
   278  		})
   279  	}
   280  
   281  	return cleanup, nil
   282  }
   283  
   284  func putTestFile(t *testing.T, filename, key string, opts ...request.Option) {
   285  	f, err := os.Open(filename)
   286  	if err != nil {
   287  		t.Fatalf("failed to open testfile, %v", err)
   288  	}
   289  	defer f.Close()
   290  
   291  	putTestContent(t, f, key, opts...)
   292  }
   293  
   294  func putTestContent(t *testing.T, reader io.ReadSeeker, key string, opts ...request.Option) {
   295  	t.Logf("uploading test file %s/%s", integMetadata.Buckets.Source.Name, key)
   296  	_, err := s3Svc.PutObjectWithContext(context.Background(),
   297  		&s3.PutObjectInput{
   298  			Bucket: &integMetadata.Buckets.Source.Name,
   299  			Key:    aws.String(key),
   300  			Body:   reader,
   301  		}, opts...)
   302  	if err != nil {
   303  		t.Errorf("expect no error, got %v", err)
   304  	}
   305  }
   306  
   307  func testWriteToObject(t *testing.T, bucket string, opts ...request.Option) {
   308  	key := integration.UniqueID()
   309  
   310  	_, err := s3Svc.PutObjectWithContext(context.Background(),
   311  		&s3.PutObjectInput{
   312  			Bucket: &bucket,
   313  			Key:    &key,
   314  			Body:   bytes.NewReader([]byte("hello world")),
   315  		}, opts...)
   316  	if err != nil {
   317  		t.Fatalf("expect no error, got %v", err)
   318  	}
   319  
   320  	resp, err := s3Svc.GetObjectWithContext(context.Background(),
   321  		&s3.GetObjectInput{
   322  			Bucket: &bucket,
   323  			Key:    &key,
   324  		}, opts...)
   325  	if err != nil {
   326  		t.Fatalf("expect no error, got %v", err)
   327  	}
   328  
   329  	b, _ := ioutil.ReadAll(resp.Body)
   330  	if e, a := []byte("hello world"), b; !bytes.Equal(e, a) {
   331  		t.Errorf("expect %v, got %v", e, a)
   332  	}
   333  }
   334  
   335  func testPresignedGetPut(t *testing.T, bucket string, opts ...request.Option) {
   336  	key := integration.UniqueID()
   337  
   338  	putreq, _ := s3Svc.PutObjectRequest(&s3.PutObjectInput{
   339  		Bucket: &bucket,
   340  		Key:    &key,
   341  	})
   342  	putreq.ApplyOptions(opts...)
   343  	var err error
   344  
   345  	// Presign a PUT request
   346  	var puturl string
   347  	puturl, err = putreq.Presign(5 * time.Minute)
   348  	if err != nil {
   349  		t.Fatalf("expect no error, got %v", err)
   350  	}
   351  
   352  	// PUT to the presigned URL with a body
   353  	var puthttpreq *http.Request
   354  	buf := bytes.NewReader([]byte("hello world"))
   355  	puthttpreq, err = http.NewRequest("PUT", puturl, buf)
   356  	if err != nil {
   357  		t.Fatalf("expect no error, got %v", err)
   358  	}
   359  
   360  	var putresp *http.Response
   361  	putresp, err = httpClient.Do(puthttpreq)
   362  	if err != nil {
   363  		t.Errorf("expect put with presign url no error, got %v", err)
   364  	}
   365  	if e, a := 200, putresp.StatusCode; e != a {
   366  		t.Fatalf("expect %v, got %v", e, a)
   367  	}
   368  
   369  	// Presign a GET on the same URL
   370  	getreq, _ := s3Svc.GetObjectRequest(&s3.GetObjectInput{
   371  		Bucket: &bucket,
   372  		Key:    &key,
   373  	})
   374  	getreq.ApplyOptions(opts...)
   375  
   376  	var geturl string
   377  	geturl, err = getreq.Presign(300 * time.Second)
   378  	if err != nil {
   379  		t.Fatalf("expect no error, got %v", err)
   380  	}
   381  
   382  	// Get the body
   383  	var getresp *http.Response
   384  	getresp, err = httpClient.Get(geturl)
   385  	if err != nil {
   386  		t.Fatalf("expect no error, got %v", err)
   387  	}
   388  
   389  	var b []byte
   390  	defer getresp.Body.Close()
   391  	b, err = ioutil.ReadAll(getresp.Body)
   392  	if e, a := "hello world", string(b); e != a {
   393  		t.Fatalf("expect %v, got %v", e, a)
   394  	}
   395  }
   396  
   397  func testCopyObject(t *testing.T, sourceBucket string, targetBucket string, opts ...request.Option) {
   398  	key := integration.UniqueID()
   399  
   400  	_, err := s3Svc.PutObjectWithContext(context.Background(),
   401  		&s3.PutObjectInput{
   402  			Bucket: &sourceBucket,
   403  			Key:    &key,
   404  			Body:   bytes.NewReader([]byte("hello world")),
   405  		}, opts...)
   406  	if err != nil {
   407  		t.Fatalf("expect no error, got %v", err)
   408  	}
   409  
   410  	_, err = s3Svc.CopyObjectWithContext(context.Background(),
   411  		&s3.CopyObjectInput{
   412  			Bucket:     &targetBucket,
   413  			CopySource: aws.String("/" + sourceBucket + "/" + key),
   414  			Key:        &key,
   415  		}, opts...)
   416  	if err != nil {
   417  		t.Fatalf("expect no error, got %v", err)
   418  	}
   419  
   420  	resp, err := s3Svc.GetObjectWithContext(context.Background(),
   421  		&s3.GetObjectInput{
   422  			Bucket: &targetBucket,
   423  			Key:    &key,
   424  		}, opts...)
   425  	if err != nil {
   426  		t.Fatalf("expect no error, got %v", err)
   427  	}
   428  
   429  	b, _ := ioutil.ReadAll(resp.Body)
   430  	if e, a := []byte("hello world"), b; !reflect.DeepEqual(e, a) {
   431  		t.Errorf("expect %v, got %v", e, a)
   432  	}
   433  }