github.com/aavshr/aws-sdk-go@v1.41.3/service/neptune/customizations_test.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package neptune
     5  
     6  import (
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net/url"
    10  	"regexp"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/aavshr/aws-sdk-go/aws"
    15  	"github.com/aavshr/aws-sdk-go/aws/request"
    16  	"github.com/aavshr/aws-sdk-go/awstesting"
    17  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    18  )
    19  
    20  func TestCopyDBClusterSnapshotRequestNoPanic(t *testing.T) {
    21  	svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
    22  
    23  	f := func() {
    24  		// Doesn't panic on nil input
    25  		req, _ := svc.CopyDBClusterSnapshotRequest(nil)
    26  		req.Sign()
    27  	}
    28  	if paniced, p := awstesting.DidPanic(f); paniced {
    29  		t.Errorf("expect no panic, got %v", p)
    30  	}
    31  }
    32  
    33  func TestPresignCrossRegionRequest(t *testing.T) {
    34  	const targetRegion = "us-west-2"
    35  
    36  	svc := New(unit.Session, &aws.Config{Region: aws.String(targetRegion)})
    37  
    38  	const regexPattern = `^https://rds.us-west-1\.amazonaws\.com/\?Action=%s.+?DestinationRegion=%s.+`
    39  
    40  	cases := map[string]struct {
    41  		Req    *request.Request
    42  		Assert func(*testing.T, string)
    43  	}{
    44  		opCopyDBClusterSnapshot: {
    45  			Req: func() *request.Request {
    46  				req, _ := svc.CopyDBClusterSnapshotRequest(
    47  					&CopyDBClusterSnapshotInput{
    48  						SourceRegion:                      aws.String("us-west-1"),
    49  						SourceDBClusterSnapshotIdentifier: aws.String("foo"),
    50  						TargetDBClusterSnapshotIdentifier: aws.String("bar"),
    51  					})
    52  				return req
    53  			}(),
    54  			Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
    55  				opCopyDBClusterSnapshot, targetRegion)),
    56  		},
    57  		opCreateDBCluster: {
    58  			Req: func() *request.Request {
    59  				req, _ := svc.CreateDBClusterRequest(
    60  					&CreateDBClusterInput{
    61  						SourceRegion:        aws.String("us-west-1"),
    62  						DBClusterIdentifier: aws.String("foo"),
    63  						Engine:              aws.String("bar"),
    64  					})
    65  				return req
    66  			}(),
    67  			Assert: assertAsRegexMatch(fmt.Sprintf(regexPattern,
    68  				opCreateDBCluster, targetRegion)),
    69  		},
    70  		opCopyDBClusterSnapshot + " same region": {
    71  			Req: func() *request.Request {
    72  				req, _ := svc.CopyDBClusterSnapshotRequest(
    73  					&CopyDBClusterSnapshotInput{
    74  						SourceRegion:                      aws.String("us-west-2"),
    75  						SourceDBClusterSnapshotIdentifier: aws.String("foo"),
    76  						TargetDBClusterSnapshotIdentifier: aws.String("bar"),
    77  					})
    78  				return req
    79  			}(),
    80  			Assert: assertAsEmpty(),
    81  		},
    82  		opCreateDBCluster + " same region": {
    83  			Req: func() *request.Request {
    84  				req, _ := svc.CreateDBClusterRequest(
    85  					&CreateDBClusterInput{
    86  						SourceRegion:        aws.String("us-west-2"),
    87  						DBClusterIdentifier: aws.String("foo"),
    88  						Engine:              aws.String("bar"),
    89  					})
    90  				return req
    91  			}(),
    92  			Assert: assertAsEmpty(),
    93  		},
    94  		opCopyDBClusterSnapshot + " presignURL set": {
    95  			Req: func() *request.Request {
    96  				req, _ := svc.CopyDBClusterSnapshotRequest(
    97  					&CopyDBClusterSnapshotInput{
    98  						SourceRegion:                      aws.String("us-west-1"),
    99  						SourceDBClusterSnapshotIdentifier: aws.String("foo"),
   100  						TargetDBClusterSnapshotIdentifier: aws.String("bar"),
   101  						PreSignedUrl:                      aws.String("mockPresignedURL"),
   102  					})
   103  				return req
   104  			}(),
   105  			Assert: assertAsEqual("mockPresignedURL"),
   106  		},
   107  		opCreateDBCluster + " presignURL set": {
   108  			Req: func() *request.Request {
   109  				req, _ := svc.CreateDBClusterRequest(
   110  					&CreateDBClusterInput{
   111  						SourceRegion:        aws.String("us-west-1"),
   112  						DBClusterIdentifier: aws.String("foo"),
   113  						Engine:              aws.String("bar"),
   114  						PreSignedUrl:        aws.String("mockPresignedURL"),
   115  					})
   116  				return req
   117  			}(),
   118  			Assert: assertAsEqual("mockPresignedURL"),
   119  		},
   120  	}
   121  
   122  	for name, c := range cases {
   123  		t.Run(name, func(t *testing.T) {
   124  			if err := c.Req.Sign(); err != nil {
   125  				t.Fatalf("expect no error, got %v", err)
   126  			}
   127  			b, _ := ioutil.ReadAll(c.Req.HTTPRequest.Body)
   128  			q, _ := url.ParseQuery(string(b))
   129  
   130  			u, _ := url.QueryUnescape(q.Get("PreSignedUrl"))
   131  
   132  			c.Assert(t, u)
   133  		})
   134  	}
   135  }
   136  
   137  func TestPresignWithSourceNotSet(t *testing.T) {
   138  	reqs := map[string]*request.Request{}
   139  	svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
   140  
   141  	reqs[opCopyDBClusterSnapshot], _ = svc.CopyDBClusterSnapshotRequest(&CopyDBClusterSnapshotInput{
   142  		SourceDBClusterSnapshotIdentifier: aws.String("foo"),
   143  		TargetDBClusterSnapshotIdentifier: aws.String("bar"),
   144  	})
   145  
   146  	for _, req := range reqs {
   147  		_, err := req.Presign(5 * time.Minute)
   148  		if err != nil {
   149  			t.Fatal(err)
   150  		}
   151  	}
   152  }
   153  
   154  func assertAsRegexMatch(exp string) func(*testing.T, string) {
   155  	return func(t *testing.T, v string) {
   156  		t.Helper()
   157  
   158  		if re, a := regexp.MustCompile(exp), v; !re.MatchString(a) {
   159  			t.Errorf("expect %s to match %s", re, a)
   160  		}
   161  	}
   162  }
   163  
   164  func assertAsEmpty() func(*testing.T, string) {
   165  	return func(t *testing.T, v string) {
   166  		t.Helper()
   167  
   168  		if len(v) != 0 {
   169  			t.Errorf("expect empty, got %v", v)
   170  		}
   171  	}
   172  }
   173  
   174  func assertAsEqual(expect string) func(*testing.T, string) {
   175  	return func(t *testing.T, v string) {
   176  		t.Helper()
   177  
   178  		if e, a := expect, v; e != a {
   179  			t.Errorf("expect %v, got %v", e, a)
   180  		}
   181  	}
   182  }