github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/provider/ec2/instance_test.go (about)

     1  // Copyright 2023 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ec2
     5  
     6  import (
     7  	stdcontext "context"
     8  
     9  	"github.com/aws/aws-sdk-go-v2/aws"
    10  	"github.com/aws/aws-sdk-go-v2/service/ec2"
    11  	"github.com/aws/aws-sdk-go-v2/service/ec2/types"
    12  	jc "github.com/juju/testing/checkers"
    13  	gc "gopkg.in/check.v1"
    14  
    15  	"github.com/juju/juju/environs/context"
    16  )
    17  
    18  type fetchInstanceClientFunc func(stdcontext.Context, *ec2.DescribeInstanceTypesInput, ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error)
    19  
    20  type instanceSuite struct{}
    21  
    22  var _ = gc.Suite(&instanceSuite{})
    23  
    24  func (f fetchInstanceClientFunc) DescribeInstanceTypes(
    25  	c stdcontext.Context,
    26  	i *ec2.DescribeInstanceTypesInput,
    27  	o ...func(*ec2.Options),
    28  ) (*ec2.DescribeInstanceTypesOutput, error) {
    29  	return f(c, i, o...)
    30  }
    31  
    32  func (s *instanceSuite) TestFetchInstanceTypeInfoPagnation(c *gc.C) {
    33  	callCount := 0
    34  	client := func(
    35  		_ stdcontext.Context,
    36  		i *ec2.DescribeInstanceTypesInput,
    37  		o ...func(*ec2.Options),
    38  	) (*ec2.DescribeInstanceTypesOutput, error) {
    39  		if callCount != 0 {
    40  			c.Assert(*i.NextToken, gc.Equals, "next")
    41  		}
    42  		c.Assert(*i.MaxResults, gc.Equals, int32(100))
    43  
    44  		callCount++
    45  		nextToken := aws.String("next")
    46  		// Let 6 calls happen
    47  		if callCount == 6 {
    48  			nextToken = nil
    49  		}
    50  
    51  		return &ec2.DescribeInstanceTypesOutput{
    52  			InstanceTypes: make([]types.InstanceTypeInfo, 100),
    53  			NextToken:     nextToken,
    54  		}, nil
    55  	}
    56  
    57  	res, err := FetchInstanceTypeInfo(
    58  		context.NewCloudCallContext(stdcontext.Background()),
    59  		fetchInstanceClientFunc(client),
    60  	)
    61  	c.Assert(err, jc.ErrorIsNil)
    62  	c.Assert(len(res), gc.Equals, 600)
    63  }