golang.org/x/build@v0.0.0-20240506185731-218518f32b70/internal/cloud/aws_interceptor_test.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cloud
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"sync/atomic"
    11  	"testing"
    12  
    13  	"github.com/aws/aws-sdk-go/aws"
    14  	"github.com/aws/aws-sdk-go/aws/request"
    15  	"github.com/aws/aws-sdk-go/service/ec2"
    16  )
    17  
    18  var _ rateLimiter = (*fakeRateLimiter)(nil)
    19  
    20  var rateExceededErr = errors.New("rate limit exceeded")
    21  
    22  type fakeRateLimiter struct {
    23  	waitCalledCount int64
    24  	waitCallLimit   int64
    25  }
    26  
    27  func newFakeRateLimiter(limit int64) *fakeRateLimiter {
    28  	return &fakeRateLimiter{waitCallLimit: limit}
    29  }
    30  
    31  func (frl *fakeRateLimiter) Wait(ctx context.Context) (err error) {
    32  	return frl.WaitN(ctx, 1)
    33  }
    34  
    35  func (frl *fakeRateLimiter) WaitN(ctx context.Context, n int) (err error) {
    36  	count := atomic.AddInt64(&frl.waitCalledCount, int64(n))
    37  	if count > frl.waitCallLimit {
    38  		return rateExceededErr
    39  	}
    40  	return nil
    41  }
    42  
    43  func (frl *fakeRateLimiter) called() bool {
    44  	if atomic.LoadInt64(&frl.waitCalledCount) > 0 {
    45  		return true
    46  	}
    47  	return false
    48  }
    49  
    50  type noopEC2Client struct {
    51  	t *testing.T
    52  }
    53  
    54  func (f *noopEC2Client) DescribeInstancesPagesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, fn func(*ec2.DescribeInstancesOutput, bool) bool, opt ...request.Option) error {
    55  	if ctx == nil || input == nil || fn == nil || len(opt) != 1 {
    56  		f.t.Fatal("DescribeInstancesPagesWithContext params not passed down")
    57  	}
    58  	return nil
    59  }
    60  
    61  func (f *noopEC2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.Option) (*ec2.DescribeInstancesOutput, error) {
    62  	if ctx == nil || input == nil || len(opt) != 1 {
    63  		f.t.Fatal("DescribeInstancesWithContext params not passed down")
    64  	}
    65  	return nil, nil
    66  }
    67  
    68  func (f *noopEC2Client) RunInstancesWithContext(ctx context.Context, input *ec2.RunInstancesInput, opts ...request.Option) (*ec2.Reservation, error) {
    69  	if ctx == nil || input == nil || len(opts) != 1 {
    70  		f.t.Fatal("RunInstancesWithContext params not passed down")
    71  	}
    72  	if ctx.Err() != nil {
    73  		f.t.Fatalf("context.Err() = %s; want no error", ctx.Err())
    74  	}
    75  	return nil, nil
    76  }
    77  
    78  func (f *noopEC2Client) TerminateInstancesWithContext(ctx context.Context, input *ec2.TerminateInstancesInput, opts ...request.Option) (*ec2.TerminateInstancesOutput, error) {
    79  	if ctx == nil || input == nil || len(opts) != 1 {
    80  		f.t.Fatal("TerminateInstancesWithContext params not passed down")
    81  	}
    82  	if ctx.Err() != nil {
    83  		f.t.Fatalf("context.Err() = %s; want no error", ctx.Err())
    84  	}
    85  	return nil, nil
    86  }
    87  
    88  func (f *noopEC2Client) WaitUntilInstanceRunningWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.WaiterOption) error {
    89  	if ctx == nil || input == nil || len(opt) != 1 {
    90  		f.t.Fatal("WaitUntilInstanceRunningWithContext params not passed down")
    91  	}
    92  	return nil
    93  }
    94  
    95  func (f *noopEC2Client) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opt ...request.Option) error {
    96  	if ctx == nil || input == nil || fn == nil || len(opt) != 1 {
    97  		f.t.Fatal("DescribeInstancesPagesWithContext params not passed down")
    98  	}
    99  	return nil
   100  }
   101  
   102  func TestEC2RateLimitInterceptorDescribeInstancesPagesWithContext(t *testing.T) {
   103  	rate := newFakeRateLimiter(1)
   104  	i := &EC2RateLimitInterceptor{
   105  		next:            &noopEC2Client{t: t},
   106  		nonMutatingRate: rate,
   107  	}
   108  	fn := func() error {
   109  		return i.DescribeInstancesPagesWithContext(context.Background(), &ec2.DescribeInstancesInput{}, func(*ec2.DescribeInstancesOutput, bool) bool { return true }, request.WithAppendUserAgent("test-agent"))
   110  	}
   111  	if err := fn(); err != nil {
   112  		t.Fatalf("DescribeInstancesPagesWithContext(...) = nil, %s; want no error", err)
   113  	}
   114  	if !rate.called() {
   115  		t.Error("rateLimiter.Wait() was never called")
   116  	}
   117  	if err := fn(); err != rateExceededErr {
   118  		t.Errorf("DescribeInstancesPagesWithContext(...) = %s; want %s", err, rateExceededErr)
   119  	}
   120  }
   121  
   122  func TestEC2RateLimitInterceptorDescribeInstancesWithContext(t *testing.T) {
   123  	rate := newFakeRateLimiter(1)
   124  	i := &EC2RateLimitInterceptor{
   125  		next:            &noopEC2Client{t: t},
   126  		nonMutatingRate: rate,
   127  	}
   128  	fn := func() error {
   129  		_, err := i.DescribeInstancesWithContext(context.Background(), &ec2.DescribeInstancesInput{}, request.WithAppendUserAgent("test-agent"))
   130  		return err
   131  	}
   132  	if err := fn(); err != nil {
   133  		t.Fatalf("DescribeInstancesWithContext(...) = nil, %s; want no error", err)
   134  	}
   135  	if !rate.called() {
   136  		t.Errorf("rateLimiter.Wait() was never called")
   137  	}
   138  	if err := fn(); err != rateExceededErr {
   139  		t.Errorf("DescribeInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
   140  	}
   141  }
   142  
   143  func TestEC2RateLimitInterceptorRunInstancesWithContext(t *testing.T) {
   144  	rate := newFakeRateLimiter(1)
   145  	resource := newFakeRateLimiter(1)
   146  	i := &EC2RateLimitInterceptor{
   147  		next:                 &noopEC2Client{t: t},
   148  		runInstancesRate:     rate,
   149  		runInstancesResource: resource,
   150  	}
   151  	fn := func() error {
   152  		_, err := i.RunInstancesWithContext(context.Background(), &ec2.RunInstancesInput{
   153  			MaxCount: aws.Int64(1),
   154  		}, request.WithAppendUserAgent("test-agent"))
   155  		return err
   156  	}
   157  	if err := fn(); err != nil {
   158  		t.Fatalf("RunInstancesWithContext(...) = nil, %s; want no error", err)
   159  	}
   160  	if !rate.called() || !resource.called() {
   161  		t.Errorf("rateLimiter.Wait() was never called; rate=%t, resource=%t", rate.called(), resource.called())
   162  	}
   163  	if err := fn(); err != rateExceededErr {
   164  		t.Errorf("RunInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
   165  	}
   166  }
   167  
   168  func TestEC2RateLimitInterceptorTerminateInstancesWithContext(t *testing.T) {
   169  	rate := newFakeRateLimiter(1)
   170  	resource := newFakeRateLimiter(1)
   171  	i := &EC2RateLimitInterceptor{
   172  		next:                      &noopEC2Client{t: t},
   173  		mutatingRate:              rate,
   174  		terminateInstanceResource: resource,
   175  	}
   176  	fn := func() error {
   177  		_, err := i.TerminateInstancesWithContext(context.Background(), &ec2.TerminateInstancesInput{
   178  			InstanceIds: []*string{aws.String("foo")},
   179  		}, request.WithAppendUserAgent("test-agent"))
   180  		return err
   181  	}
   182  	if err := fn(); err != nil {
   183  		t.Fatalf("TerminateInstancesWithContext(...) = nil, %s; want no error", err)
   184  	}
   185  	if !rate.called() || !resource.called() {
   186  		t.Errorf("rateLimiter.Wait() was never called; rate=%t, resource=%t", rate.called(), resource.called())
   187  	}
   188  	if err := fn(); err != rateExceededErr {
   189  		t.Errorf("TerminateInstancesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
   190  	}
   191  }
   192  
   193  func TestEC2RateLimitInterceptorWaitUntilInstanceRunningWithContext(t *testing.T) {
   194  	rate := newFakeRateLimiter(1)
   195  	i := &EC2RateLimitInterceptor{
   196  		next:            &noopEC2Client{t: t},
   197  		nonMutatingRate: rate,
   198  	}
   199  	fn := func() error {
   200  		return i.WaitUntilInstanceRunningWithContext(context.Background(), &ec2.DescribeInstancesInput{}, request.WithWaiterMaxAttempts(1))
   201  	}
   202  	if err := fn(); err != nil {
   203  		t.Fatalf("WaitUntilInstanceRunningWithContext(...) = nil, %s; want no error", err)
   204  	}
   205  	if !rate.called() {
   206  		t.Errorf("rateLimiter.Wait() was never called")
   207  	}
   208  	if err := fn(); err != rateExceededErr {
   209  		t.Errorf("WaitUntilInstanceRunningWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
   210  	}
   211  }
   212  
   213  func TestEC2RateLimitInterceptorDescribeInstanceTypesPagesWithContext(t *testing.T) {
   214  	rate := newFakeRateLimiter(1)
   215  	i := &EC2RateLimitInterceptor{
   216  		next:            &noopEC2Client{t: t},
   217  		nonMutatingRate: rate,
   218  	}
   219  	fn := func() error {
   220  		return i.DescribeInstanceTypesPagesWithContext(context.Background(), &ec2.DescribeInstanceTypesInput{}, func(*ec2.DescribeInstanceTypesOutput, bool) bool { return true }, request.WithAppendUserAgent("test-agent"))
   221  	}
   222  	if err := fn(); err != nil {
   223  		t.Fatalf("DescribeInstanceTypesPagesWithContext(...) = nil, %s; want no error", err)
   224  	}
   225  	if !rate.called() {
   226  		t.Errorf("rateLimiter.Wait() was never called")
   227  	}
   228  	if err := fn(); err != rateExceededErr {
   229  		t.Errorf("DescribeInstanceTypesPagesWithContext(...) = nil, %s; want nil, %s", err, rateExceededErr)
   230  	}
   231  }