golang.org/x/build@v0.0.0-20240506185731-218518f32b70/internal/cloud/aws_test.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cloud
     6  
     7  import (
     8  	"context"
     9  	"encoding/base64"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/aws/aws-sdk-go/aws"
    18  	"github.com/aws/aws-sdk-go/aws/request"
    19  	"github.com/aws/aws-sdk-go/service/ec2"
    20  	"github.com/aws/aws-sdk-go/service/servicequotas"
    21  	"github.com/google/go-cmp/cmp"
    22  )
    23  
    24  type awsClient interface {
    25  	vmClient
    26  	quotaClient
    27  }
    28  
    29  var _ awsClient = (*fakeEC2Client)(nil)
    30  
    31  type fakeEC2Client struct {
    32  	mu sync.RWMutex
    33  	// instances map of instanceId -> *ec2.Instance
    34  	instances     map[string]*ec2.Instance
    35  	instanceTypes []*ec2.InstanceTypeInfo
    36  	serviceQuota  map[string]float64
    37  }
    38  
    39  func newFakeAWSClient() *fakeEC2Client {
    40  	return &fakeEC2Client{
    41  		instances:     make(map[string]*ec2.Instance),
    42  		instanceTypes: []*ec2.InstanceTypeInfo{},
    43  		serviceQuota:  make(map[string]float64),
    44  	}
    45  }
    46  
    47  // filterFunc represents a function used to filter out instances.
    48  type filterFunc func(*ec2.Instance) bool
    49  
    50  // createFilter returns filtering functions for a subset of `ec2.Filter`.
    51  // The response in the function returned indicates whether the instance
    52  // should be included.
    53  func createFilter(f *ec2.Filter) filterFunc {
    54  	if *f.Name == "instance-state-name" {
    55  		states := aws.StringValueSlice(f.Values)
    56  		return func(i *ec2.Instance) bool {
    57  			for _, s := range states {
    58  				if *i.State.Name == s {
    59  					return true
    60  				}
    61  			}
    62  			return false
    63  		}
    64  	}
    65  	// return noop filter for unsupported filters
    66  	return func(i *ec2.Instance) bool { return true }
    67  }
    68  
    69  // createFilters creates a filtering function for a subset of `ec2.Filter`.
    70  // The response for the returned function indicates whether the instance
    71  // should be included after all of the supplied filters have been evaluated.
    72  func createFilters(fs []*ec2.Filter) filterFunc {
    73  	if len(fs) == 0 {
    74  		// return noop filter for unsupported filters
    75  		return func(i *ec2.Instance) bool { return true }
    76  	}
    77  	filters := make([]filterFunc, 0, len(fs))
    78  	for _, f := range fs {
    79  		filters = append(filters, createFilter(f))
    80  	}
    81  	return func(i *ec2.Instance) bool {
    82  		for _, fn := range filters {
    83  			if !fn(i) {
    84  				return false
    85  			}
    86  		}
    87  		return true
    88  	}
    89  }
    90  
    91  func (f *fakeEC2Client) DescribeInstancesPagesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, fn func(*ec2.DescribeInstancesOutput, bool) bool, opt ...request.Option) error {
    92  	if input == nil || fn == nil {
    93  		return errors.New("invalid input")
    94  	}
    95  	filters := createFilters(input.Filters)
    96  	f.mu.RLock()
    97  	defer f.mu.RUnlock()
    98  	insts := make([]*ec2.Instance, 0, len(f.instances))
    99  	for _, inst := range f.instances {
   100  		if !filters(inst) {
   101  			continue
   102  		}
   103  		insts = append(insts, inst)
   104  	}
   105  	for it, inst := range insts {
   106  		fn(&ec2.DescribeInstancesOutput{
   107  			Reservations: []*ec2.Reservation{
   108  				&ec2.Reservation{
   109  					Instances: []*ec2.Instance{
   110  						inst,
   111  					},
   112  				},
   113  			},
   114  		}, it == len(insts)-1)
   115  	}
   116  	return nil
   117  }
   118  
   119  func (f *fakeEC2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.Option) (*ec2.DescribeInstancesOutput, error) {
   120  	if ctx == nil || input == nil || len(input.InstanceIds) == 0 {
   121  		return nil, request.ErrInvalidParams{}
   122  	}
   123  	filters := createFilters(input.Filters)
   124  	instances := make([]*ec2.Instance, 0, len(input.InstanceIds))
   125  	f.mu.RLock()
   126  	defer f.mu.RUnlock()
   127  	for _, id := range aws.StringValueSlice(input.InstanceIds) {
   128  		inst, ok := f.instances[id]
   129  		if !ok {
   130  			return nil, errors.New("instance not found")
   131  		}
   132  		if !filters(inst) {
   133  			continue
   134  		}
   135  		instances = append(instances, inst)
   136  	}
   137  	return &ec2.DescribeInstancesOutput{
   138  		Reservations: []*ec2.Reservation{
   139  			&ec2.Reservation{
   140  				Instances: instances,
   141  			},
   142  		},
   143  	}, nil
   144  }
   145  
   146  func (f *fakeEC2Client) RunInstancesWithContext(ctx context.Context, input *ec2.RunInstancesInput, opts ...request.Option) (*ec2.Reservation, error) {
   147  	if ctx == nil || input == nil {
   148  		return nil, request.ErrInvalidParams{}
   149  	}
   150  	if input.ImageId == nil || aws.StringValue(input.ImageId) == "" ||
   151  		input.InstanceType == nil || aws.StringValue(input.InstanceType) == "" ||
   152  		input.MinCount == nil || aws.Int64Value(input.MinCount) == 0 ||
   153  		input.Placement == nil || aws.StringValue(input.Placement.AvailabilityZone) == "" {
   154  		return nil, errors.New("invalid instance configuration")
   155  	}
   156  	instCount := int(aws.Int64Value(input.MaxCount))
   157  	instances := make([]*ec2.Instance, 0, instCount)
   158  	f.mu.Lock()
   159  	defer f.mu.Unlock()
   160  	for i := 0; i < instCount; i++ {
   161  		inst := &ec2.Instance{
   162  			CpuOptions: &ec2.CpuOptions{
   163  				CoreCount: aws.Int64(4),
   164  			},
   165  			ImageId:          input.ImageId,
   166  			InstanceType:     input.InstanceType,
   167  			InstanceId:       aws.String(fmt.Sprintf("instance-%s", randHex(10))),
   168  			Placement:        input.Placement,
   169  			PrivateIpAddress: aws.String(randIPv4()),
   170  			PublicIpAddress:  aws.String(randIPv4()),
   171  			State: &ec2.InstanceState{
   172  				Name: aws.String("running"),
   173  			},
   174  			Tags:           []*ec2.Tag{},
   175  			KeyName:        input.KeyName,
   176  			SecurityGroups: []*ec2.GroupIdentifier{},
   177  			LaunchTime:     aws.Time(time.Now()),
   178  		}
   179  		for _, id := range input.SecurityGroups {
   180  			inst.SecurityGroups = append(inst.SecurityGroups, &ec2.GroupIdentifier{
   181  				GroupId: id,
   182  			})
   183  		}
   184  		for _, tagSpec := range input.TagSpecifications {
   185  			for _, tag := range tagSpec.Tags {
   186  				inst.Tags = append(inst.Tags, tag)
   187  			}
   188  		}
   189  		f.instances[*inst.InstanceId] = inst
   190  		instances = append(instances, inst)
   191  	}
   192  	return &ec2.Reservation{
   193  		Instances:     instances,
   194  		ReservationId: aws.String(fmt.Sprintf("reservation-%s", randHex(10))),
   195  	}, nil
   196  }
   197  
   198  func (f *fakeEC2Client) TerminateInstancesWithContext(ctx context.Context, input *ec2.TerminateInstancesInput, opts ...request.Option) (*ec2.TerminateInstancesOutput, error) {
   199  	if ctx == nil || input == nil || len(input.InstanceIds) == 0 {
   200  		return nil, request.ErrInvalidParams{}
   201  	}
   202  	isc := make([]*ec2.InstanceStateChange, 0, len(input.InstanceIds))
   203  	f.mu.Lock()
   204  	defer f.mu.Unlock()
   205  	for _, id := range input.InstanceIds {
   206  		if *id == "" {
   207  			return nil, errors.New("invalid instance id")
   208  		}
   209  		var prevState string
   210  		inst, ok := f.instances[*id]
   211  		if !ok {
   212  			return nil, errors.New("instance not found")
   213  		}
   214  		prevState = *inst.State.Name
   215  		inst.State.Name = aws.String(ec2.InstanceStateNameTerminated)
   216  		isc = append(isc, &ec2.InstanceStateChange{
   217  			CurrentState: &ec2.InstanceState{
   218  				Name: aws.String(prevState),
   219  			},
   220  			InstanceId: id,
   221  			PreviousState: &ec2.InstanceState{
   222  				Code: nil,
   223  				Name: aws.String(ec2.InstanceStateNameTerminated),
   224  			},
   225  		})
   226  	}
   227  	return &ec2.TerminateInstancesOutput{
   228  		TerminatingInstances: isc,
   229  	}, nil
   230  }
   231  
   232  func (f *fakeEC2Client) WaitUntilInstanceRunningWithContext(ctx context.Context, input *ec2.DescribeInstancesInput, opt ...request.WaiterOption) error {
   233  	if ctx == nil || input == nil || len(input.InstanceIds) == 0 {
   234  		return request.ErrInvalidParams{}
   235  	}
   236  	f.mu.Lock()
   237  	defer f.mu.Unlock()
   238  	for _, id := range input.InstanceIds {
   239  		inst, ok := f.instances[*id]
   240  		if !ok {
   241  			return fmt.Errorf("instance %s not found", *id)
   242  		}
   243  		inst.State = &ec2.InstanceState{
   244  			Name: aws.String("running"),
   245  		}
   246  	}
   247  	return nil
   248  }
   249  
   250  func (f *fakeEC2Client) DescribeInstanceTypesPagesWithContext(ctx context.Context, input *ec2.DescribeInstanceTypesInput, fn func(*ec2.DescribeInstanceTypesOutput, bool) bool, opt ...request.Option) error {
   251  	if ctx == nil || input == nil || fn == nil {
   252  		return errors.New("invalid input")
   253  	}
   254  	f.mu.RLock()
   255  	defer f.mu.RUnlock()
   256  	for it, its := range f.instanceTypes {
   257  		fn(&ec2.DescribeInstanceTypesOutput{
   258  			InstanceTypes: []*ec2.InstanceTypeInfo{its},
   259  		}, it == len(f.instanceTypes)-1)
   260  	}
   261  	return nil
   262  }
   263  
   264  func (f *fakeEC2Client) GetServiceQuota(input *servicequotas.GetServiceQuotaInput) (*servicequotas.GetServiceQuotaOutput, error) {
   265  	if input == nil || input.QuotaCode == nil || input.ServiceCode == nil {
   266  		return nil, request.ErrInvalidParams{}
   267  	}
   268  	v, ok := f.serviceQuota[aws.StringValue(input.ServiceCode)+"-"+aws.StringValue(input.QuotaCode)]
   269  	if !ok {
   270  		return nil, errors.New("quota not found")
   271  	}
   272  	return &servicequotas.GetServiceQuotaOutput{
   273  		Quota: &servicequotas.ServiceQuota{
   274  			Value: aws.Float64(v),
   275  		},
   276  	}, nil
   277  }
   278  
   279  type option func(*fakeEC2Client)
   280  
   281  func WithServiceQuota(service, quota string, value float64) option {
   282  	return func(c *fakeEC2Client) {
   283  		c.serviceQuota[service+"-"+quota] = value
   284  	}
   285  }
   286  
   287  func WithInstanceType(name, arch string, numCPU int64) option {
   288  	return func(c *fakeEC2Client) {
   289  		c.instanceTypes = append(c.instanceTypes, &ec2.InstanceTypeInfo{
   290  			InstanceType: aws.String(name),
   291  			ProcessorInfo: &ec2.ProcessorInfo{
   292  				SupportedArchitectures: []*string{aws.String(arch)},
   293  			},
   294  			VCpuInfo: &ec2.VCpuInfo{
   295  				DefaultVCpus: aws.Int64(numCPU),
   296  			},
   297  		})
   298  	}
   299  }
   300  
   301  func fakeClient(opts ...option) *AWSClient {
   302  	fc := newFakeAWSClient()
   303  	for _, opt := range opts {
   304  		opt(fc)
   305  	}
   306  	return &AWSClient{
   307  		ec2Client:   fc,
   308  		quotaClient: fc,
   309  	}
   310  }
   311  
   312  func fakeClientWithInstances(t *testing.T, count int, opts ...option) (*AWSClient, []*Instance) {
   313  	c := fakeClient(opts...)
   314  	ctx := context.Background()
   315  	insts := make([]*Instance, 0, count)
   316  	for i := 0; i < count; i++ {
   317  		inst, err := c.CreateInstance(ctx, randomVMConfig())
   318  		if err != nil {
   319  			t.Fatalf("unable to create instance: %s", err)
   320  		}
   321  		insts = append(insts, inst)
   322  	}
   323  	return c, insts
   324  }
   325  
   326  func randomVMConfig() *EC2VMConfiguration {
   327  	return &EC2VMConfiguration{
   328  		Description:    fmt.Sprintf("description-" + randHex(4)),
   329  		ImageID:        fmt.Sprintf("image-" + randHex(4)),
   330  		Name:           fmt.Sprintf("name-" + randHex(4)),
   331  		SSHKeyID:       fmt.Sprintf("ssh-key-id-" + randHex(4)),
   332  		SecurityGroups: []string{fmt.Sprintf("sg-" + randHex(4))},
   333  		Tags: map[string]string{
   334  			fmt.Sprintf("tag-key-" + randHex(4)): fmt.Sprintf("tag-value-" + randHex(4)),
   335  		},
   336  		Type:     fmt.Sprintf("type-" + randHex(4)),
   337  		UserData: fmt.Sprintf("user-data-" + randHex(4)),
   338  		Zone:     fmt.Sprintf("zone-" + randHex(4)),
   339  	}
   340  }
   341  
   342  func TestRunningInstances(t *testing.T) {
   343  	t.Run("query-all-instances", func(t *testing.T) {
   344  		c, wantInsts := fakeClientWithInstances(t, 4)
   345  		gotInsts, gotErr := c.RunningInstances(context.Background())
   346  		if gotErr != nil {
   347  			t.Fatalf("Instances(ctx) = %+v, %s; want nil, nil", gotInsts, gotErr)
   348  		}
   349  		if len(gotInsts) != len(wantInsts) {
   350  			t.Errorf("got instance count %d: want %d", len(gotInsts), len(wantInsts))
   351  		}
   352  	})
   353  	t.Run("query-with-a-terminated-instance", func(t *testing.T) {
   354  		ctx := context.Background()
   355  		c, wantInsts := fakeClientWithInstances(t, 4)
   356  		gotErr := c.DestroyInstances(ctx, wantInsts[0].ID)
   357  		if gotErr != nil {
   358  			t.Fatalf("unable to destroy instance: %s", gotErr)
   359  		}
   360  		gotInsts, gotErr := c.RunningInstances(ctx)
   361  		if gotErr != nil {
   362  			t.Fatalf("Instances(ctx) = %+v, %s; want nil, nil", gotInsts, gotErr)
   363  		}
   364  		if len(gotInsts) != len(wantInsts)-1 {
   365  			t.Errorf("got instance count %d: want %d", len(gotInsts), len(wantInsts)-1)
   366  		}
   367  	})
   368  }
   369  
   370  func TestInstanceTypesARM(t *testing.T) {
   371  	opts := []option{
   372  		WithInstanceType("zz.large", "x86_64", 10),
   373  		WithInstanceType("aa.xlarge", "arm64", 20),
   374  	}
   375  
   376  	t.Run("query-arm64-instances", func(t *testing.T) {
   377  		c := fakeClient(opts...)
   378  		gotInstTypes, gotErr := c.InstanceTypesARM(context.Background())
   379  		if gotErr != nil {
   380  			t.Fatalf("InstanceTypesArm(ctx) = %+v, %s; want nil, nil", gotInstTypes, gotErr)
   381  		}
   382  		if len(gotInstTypes) != 1 {
   383  			t.Errorf("got instance type count %d: want %d", len(gotInstTypes), 1)
   384  		}
   385  	})
   386  	t.Run("nil-request", func(t *testing.T) {
   387  		c := fakeClient(opts...)
   388  		gotInstTypes, gotErr := c.InstanceTypesARM(nil)
   389  		if gotErr == nil {
   390  			t.Fatalf("InstanceTypesArm(nil) = %+v, %s; want nil, error", gotInstTypes, gotErr)
   391  		}
   392  	})
   393  }
   394  
   395  func TestQuota(t *testing.T) {
   396  	t.Run("on-demand-vcpu", func(t *testing.T) {
   397  		wantQuota := int64(384)
   398  		c := fakeClient(WithServiceQuota(QuotaServiceEC2, QuotaCodeCPUOnDemand, float64(wantQuota)))
   399  		gotQuota, gotErr := c.Quota(context.Background(), QuotaServiceEC2, QuotaCodeCPUOnDemand)
   400  		if gotErr != nil || wantQuota != gotQuota {
   401  			t.Fatalf("Quota(ctx, %s, %s) = %+v, %s; want %d, nil", QuotaServiceEC2, QuotaCodeCPUOnDemand, gotQuota, gotErr, wantQuota)
   402  		}
   403  	})
   404  	t.Run("nil-request", func(t *testing.T) {
   405  		wantQuota := int64(384)
   406  		c := fakeClient(WithServiceQuota(QuotaServiceEC2, QuotaCodeCPUOnDemand, float64(wantQuota)))
   407  		gotQuota, gotErr := c.Quota(context.Background(), "", "")
   408  		if gotErr == nil || gotQuota != 0 {
   409  			t.Fatalf("Quota(ctx, %s, %s) = %+v, %s; want 0, error", QuotaServiceEC2, QuotaCodeCPUOnDemand, gotQuota, gotErr)
   410  		}
   411  	})
   412  }
   413  
   414  func TestInstance(t *testing.T) {
   415  	t.Run("query-instance", func(t *testing.T) {
   416  		c, wantInsts := fakeClientWithInstances(t, 1)
   417  		wantInst := wantInsts[0]
   418  		gotInst, gotErr := c.Instance(context.Background(), wantInst.ID)
   419  		if gotErr != nil || gotInst == nil || gotInst.ID != wantInst.ID {
   420  			t.Errorf("Instance(ctx, %s) = %+v, %s; want no error", wantInst.ID, gotInst, gotErr)
   421  		}
   422  	})
   423  	t.Run("query-terminated-instance", func(t *testing.T) {
   424  		c, wantInsts := fakeClientWithInstances(t, 1)
   425  		wantInst := wantInsts[0]
   426  		ctx := context.Background()
   427  		gotErr := c.DestroyInstances(ctx, wantInst.ID)
   428  		if gotErr != nil {
   429  			t.Fatalf("unable to destroy instance: %s", gotErr)
   430  		}
   431  		gotInst, gotErr := c.Instance(ctx, wantInst.ID)
   432  		if gotErr != nil || gotInst == nil || gotInst.ID != wantInst.ID {
   433  			t.Errorf("Instance(ctx, %s) = %+v, %s; want no error", wantInst.ID, gotInst, gotErr)
   434  		}
   435  	})
   436  }
   437  
   438  func TestCreateInstance(t *testing.T) {
   439  	ud := &EC2UserData{
   440  		BuildletBinaryURL: "b-url",
   441  		BuildletHostType:  "b-host-type",
   442  		BuildletImageURL:  "b-image-url",
   443  		BuildletName:      "b-name",
   444  		Metadata: map[string]string{
   445  			"tag-a": "value-b",
   446  		},
   447  		TLSCert:     "cert-a",
   448  		TLSKey:      "key-a",
   449  		TLSPassword: "pass-a",
   450  	}
   451  	config := &EC2VMConfiguration{
   452  		Description:    "description-a",
   453  		ImageID:        "my-image",
   454  		Name:           "my-instance",
   455  		SSHKeyID:       "my-key",
   456  		SecurityGroups: []string{"test-key"},
   457  		Tags: map[string]string{
   458  			"tag-1": "value-1",
   459  		},
   460  		Type:     "xby.large",
   461  		UserData: ud.EncodedString(),
   462  		Zone:     "us-west-14",
   463  	}
   464  	c := fakeClient()
   465  	gotInst, gotErr := c.CreateInstance(context.Background(), config)
   466  	if gotErr != nil {
   467  		t.Errorf("CreateInstance(ctx, %v) = %+v, %s; want no error", config, gotInst, gotErr)
   468  	}
   469  	if gotInst.Description != config.Description {
   470  		t.Errorf("Instance.Description = %s; want %s", gotInst.Description, config.Description)
   471  	}
   472  	if gotInst.ImageID != config.ImageID {
   473  		t.Errorf("Instance.ImageID = %s; want %s", gotInst.ImageID, config.ImageID)
   474  	}
   475  	if gotInst.Name != config.Name {
   476  		t.Errorf("Instance.Name = %s; want %s", gotInst.Name, config.Name)
   477  	}
   478  	if gotInst.SSHKeyID != config.SSHKeyID {
   479  		t.Errorf("Instance.SSHKeyID = %s; want %s", gotInst.SSHKeyID, config.SSHKeyID)
   480  	}
   481  	if !cmp.Equal(gotInst.SecurityGroups, config.SecurityGroups) {
   482  		t.Errorf("Instance.SecruityGroups = %v; want %v", gotInst.SecurityGroups, config.SecurityGroups)
   483  	}
   484  	if !cmp.Equal(gotInst.Tags, config.Tags) {
   485  		t.Errorf("Instance.Tags = %v want %v", gotInst.Tags, config.Tags)
   486  	}
   487  	if gotInst.Type != config.Type {
   488  		t.Errorf("Instance.Type = %s; want %s", gotInst.Type, config.Type)
   489  	}
   490  	if gotInst.Zone != config.Zone {
   491  		t.Errorf("Instance.Zone = %s; want %s", gotInst.Zone, config.Zone)
   492  	}
   493  }
   494  
   495  func TestCreateInstanceError(t *testing.T) {
   496  	testCases := []struct {
   497  		desc     string
   498  		vmConfig *EC2VMConfiguration
   499  	}{
   500  		{
   501  			desc:     "missing-vmConfig",
   502  			vmConfig: nil,
   503  		},
   504  		{
   505  			desc: "missing-image-type",
   506  			vmConfig: &EC2VMConfiguration{
   507  				Type: "type-a",
   508  				Zone: "eu-15",
   509  			},
   510  		},
   511  		{
   512  			desc: "missing-vm-type",
   513  			vmConfig: &EC2VMConfiguration{
   514  				ImageID: "ami-15",
   515  				Zone:    "eu-15",
   516  			},
   517  		},
   518  		{
   519  			desc: "missing-zone",
   520  			vmConfig: &EC2VMConfiguration{
   521  				ImageID: "ami-15",
   522  				Type:    "abc.large",
   523  			},
   524  		},
   525  	}
   526  	for _, tc := range testCases {
   527  		t.Run(tc.desc, func(t *testing.T) {
   528  			c := fakeClient()
   529  			gotInst, gotErr := c.CreateInstance(context.Background(), tc.vmConfig)
   530  			if gotErr == nil || gotInst != nil {
   531  				t.Errorf("CreateInstance(ctx, %+v) = %+v, %s; want error", tc.vmConfig, gotInst, gotErr)
   532  			}
   533  		})
   534  	}
   535  }
   536  
   537  func TestDestroyInstances(t *testing.T) {
   538  	testCases := []struct {
   539  		desc    string
   540  		ctx     context.Context
   541  		vmCount int
   542  		wantErr bool
   543  	}{
   544  		{"baseline request", context.Background(), 1, false},
   545  		{"nil context", nil, 1, true},
   546  		{"missing vmID", context.Background(), 0, true},
   547  	}
   548  	for _, tc := range testCases {
   549  		t.Run(tc.desc, func(t *testing.T) {
   550  			c, insts := fakeClientWithInstances(t, tc.vmCount)
   551  			instIDs := make([]string, 0, tc.vmCount)
   552  			for _, inst := range insts {
   553  				instIDs = append(instIDs, inst.ID)
   554  			}
   555  			gotErr := c.DestroyInstances(tc.ctx, instIDs...)
   556  			if (gotErr != nil) != tc.wantErr {
   557  				t.Errorf("DestroyVM(%v, %+v) = %v; want error %t", tc.ctx, instIDs, gotErr, tc.wantErr)
   558  			}
   559  		})
   560  	}
   561  }
   562  
   563  func TestWaitUntilInstanceRunning(t *testing.T) {
   564  	c, wantInsts := fakeClientWithInstances(t, 1)
   565  	wantInst := wantInsts[0]
   566  	ctx := context.Background()
   567  	gotErr := c.WaitUntilInstanceRunning(ctx, wantInst.ID)
   568  	if gotErr != nil {
   569  		t.Errorf("WaitUntilVMExists(%v, %v) failed with error %s", ctx, wantInst.ID, gotErr)
   570  	}
   571  }
   572  
   573  func TestWaitUntilInstanceRunningErr(t *testing.T) {
   574  	testCases := []struct {
   575  		desc    string
   576  		ctx     context.Context
   577  		vmCount int
   578  	}{
   579  		{"nil-context", nil, 1},
   580  		{"missing vmID", context.Background(), 0},
   581  	}
   582  	for _, tc := range testCases {
   583  		t.Run(tc.desc, func(t *testing.T) {
   584  			c, wantInsts := fakeClientWithInstances(t, tc.vmCount)
   585  			ctx := context.Background()
   586  			wantID := ""
   587  			if len(wantInsts) > 0 {
   588  				wantID = wantInsts[0].ID
   589  			}
   590  			gotErr := c.WaitUntilInstanceRunning(tc.ctx, wantID)
   591  			if gotErr == nil {
   592  				t.Errorf("WaitUntilVMExists(%v, %v) = %s: want error", ctx, wantID, gotErr)
   593  			}
   594  		})
   595  	}
   596  }
   597  
   598  func TestEC2ToInstance(t *testing.T) {
   599  	wantCreationTime := time.Unix(1, 1)
   600  	wantDescription := "my-desc"
   601  	wantID := "inst-55"
   602  	wantIPExt := "1.1.1.1"
   603  	wantIPInt := "2.2.2.2"
   604  	wantImage := "ami-56"
   605  	wantKey := "my-key"
   606  	wantName := "my-name"
   607  	wantSecurityGroup := "22"
   608  	wantTagKey := "tag1"
   609  	wantTagValue := "taggy1"
   610  	wantType := "type-1"
   611  	wantZone := "us-east-22"
   612  	wantState := "running"
   613  	var wantCPUCount int64 = 66
   614  
   615  	ei := &ec2.Instance{
   616  		CpuOptions: &ec2.CpuOptions{
   617  			CoreCount: aws.Int64(wantCPUCount),
   618  		},
   619  		ImageId:      aws.String(wantImage),
   620  		InstanceId:   aws.String(wantID),
   621  		InstanceType: aws.String(wantType),
   622  		KeyName:      aws.String(wantKey),
   623  		LaunchTime:   aws.Time(wantCreationTime),
   624  		Placement: &ec2.Placement{
   625  			AvailabilityZone: aws.String(wantZone),
   626  		},
   627  		PrivateIpAddress: aws.String(wantIPInt),
   628  		PublicIpAddress:  aws.String(wantIPExt),
   629  		SecurityGroups: []*ec2.GroupIdentifier{
   630  			&ec2.GroupIdentifier{
   631  				GroupId: aws.String(wantSecurityGroup),
   632  			},
   633  		},
   634  		State: &ec2.InstanceState{
   635  			Name: aws.String(wantState),
   636  		},
   637  		Tags: []*ec2.Tag{
   638  			&ec2.Tag{
   639  				Key:   aws.String(tagName),
   640  				Value: aws.String(wantName),
   641  			},
   642  			&ec2.Tag{
   643  				Key:   aws.String(tagDescription),
   644  				Value: aws.String(wantDescription),
   645  			},
   646  			&ec2.Tag{
   647  				Key:   aws.String(wantTagKey),
   648  				Value: aws.String(wantTagValue),
   649  			},
   650  		},
   651  	}
   652  	gotInst := ec2ToInstance(ei)
   653  	if gotInst.CPUCount != wantCPUCount {
   654  		t.Errorf("CPUCount %d; want %d", gotInst.CPUCount, wantCPUCount)
   655  	}
   656  	if gotInst.CreatedAt != wantCreationTime {
   657  		t.Errorf("CreatedAt %s; want %s", gotInst.CreatedAt, wantCreationTime)
   658  	}
   659  	if gotInst.Description != wantDescription {
   660  		t.Errorf("Description %s; want %s", gotInst.Description, wantDescription)
   661  	}
   662  	if gotInst.ID != wantID {
   663  		t.Errorf("ID %s; want %s", gotInst.ID, wantID)
   664  	}
   665  	if gotInst.IPAddressExternal != wantIPExt {
   666  		t.Errorf("IPAddressExternal %s; want %s", gotInst.IPAddressExternal, wantIPExt)
   667  	}
   668  	if gotInst.IPAddressInternal != wantIPInt {
   669  		t.Errorf("IPAddressInternal %s; want %s", gotInst.IPAddressInternal, wantIPInt)
   670  	}
   671  	if gotInst.ImageID != wantImage {
   672  		t.Errorf("Image %s; want %s", gotInst.ImageID, wantImage)
   673  	}
   674  	if gotInst.Name != wantName {
   675  		t.Errorf("Name %s; want %s", gotInst.Name, wantName)
   676  	}
   677  	if gotInst.SSHKeyID != wantKey {
   678  		t.Errorf("SSHKeyID %s; want %s", gotInst.SSHKeyID, wantKey)
   679  	}
   680  	found := false
   681  	for _, sg := range gotInst.SecurityGroups {
   682  		if sg == wantSecurityGroup {
   683  			found = true
   684  			break
   685  		}
   686  	}
   687  	if !found {
   688  		t.Errorf("SecurityGroups not found")
   689  	}
   690  	if gotInst.State != wantState {
   691  		t.Errorf("State %s; want %s", gotInst.State, wantState)
   692  	}
   693  	if gotInst.Type != wantType {
   694  		t.Errorf("Type %s; want %s", gotInst.Type, wantType)
   695  	}
   696  	if gotInst.Zone != wantZone {
   697  		t.Errorf("Zone %s; want %s", gotInst.Zone, wantZone)
   698  	}
   699  	gotValue, ok := gotInst.Tags[wantTagKey]
   700  	if !ok || gotValue != wantTagValue {
   701  		t.Errorf("Tags[%s] = %s, %t; want %s, %t", wantTagKey, gotValue, ok, wantTagValue, true)
   702  	}
   703  }
   704  
   705  func TestVMConfig(t *testing.T) {
   706  	wantDescription := "desc"
   707  	wantImage := "ami-56"
   708  	wantName := "my-instance"
   709  	wantKey := "my-key"
   710  	wantSecurityGroups := []string{"22"}
   711  	wantTags := map[string]string{
   712  		"tag1": "taggy1",
   713  		"tag2": "taggy2",
   714  	}
   715  	wantType := "type-1"
   716  	wantUserData := "user-data-x"
   717  	wantZone := "us-east-22"
   718  
   719  	rii := vmConfig(&EC2VMConfiguration{
   720  		Description:    wantDescription,
   721  		ImageID:        wantImage,
   722  		Name:           wantName,
   723  		SSHKeyID:       wantKey,
   724  		SecurityGroups: wantSecurityGroups,
   725  		Tags:           wantTags,
   726  		Type:           wantType,
   727  		UserData:       wantUserData,
   728  		Zone:           wantZone,
   729  	})
   730  
   731  	if *rii.ImageId != wantImage {
   732  		t.Errorf("image id %s; want %s", *rii.ImageId, wantImage)
   733  	}
   734  	if *rii.InstanceType != wantType {
   735  		t.Errorf("image id %s; want %s", *rii.ImageId, wantImage)
   736  	}
   737  	if *rii.MinCount != 1 {
   738  		t.Errorf("MinCount %d; want %d", *rii.MinCount, 1)
   739  	}
   740  	if *rii.MaxCount != 1 {
   741  		t.Errorf("MaxCount %d; want %d", *rii.MaxCount, 1)
   742  	}
   743  	if *rii.Placement.AvailabilityZone != wantZone {
   744  		t.Errorf("AvailabilityZone %s; want %s", *rii.Placement.AvailabilityZone, wantZone)
   745  	}
   746  	if !cmp.Equal(*rii.KeyName, wantKey) {
   747  		t.Errorf("SSHKeyID %+v; want %+v", *rii.KeyName, wantKey)
   748  	}
   749  	if *rii.InstanceInitiatedShutdownBehavior != ec2.ShutdownBehaviorTerminate {
   750  		t.Errorf("Shutdown Behavior %s; want %s", *rii.InstanceInitiatedShutdownBehavior, ec2.ShutdownBehaviorTerminate)
   751  	}
   752  	if *rii.UserData != wantUserData {
   753  		t.Errorf("UserData %s; want %s", *rii.UserData, wantUserData)
   754  	}
   755  	contains := func(tagSpec []*ec2.TagSpecification, key, value string) bool {
   756  		for _, ts := range tagSpec {
   757  			for _, t := range ts.Tags {
   758  				if *t.Key == key && *t.Value == value {
   759  					return true
   760  				}
   761  			}
   762  		}
   763  		return false
   764  	}
   765  	if !contains(rii.TagSpecifications, tagName, wantName) {
   766  		t.Errorf("want Tag Key: %s, Value: %s", tagName, wantName)
   767  	}
   768  	if !contains(rii.TagSpecifications, tagDescription, wantDescription) {
   769  		t.Errorf("want Tag Key: %s, Value: %s", tagDescription, wantDescription)
   770  	}
   771  	for k, v := range wantTags {
   772  		if !contains(rii.TagSpecifications, k, v) {
   773  			t.Errorf("want Tag Key: %s, Value: %s", k, v)
   774  		}
   775  	}
   776  	if !cmp.Equal(aws.StringValueSlice(rii.SecurityGroups), wantSecurityGroups) {
   777  		t.Errorf("SecurityGroups %v; want %v", aws.StringValueSlice(rii.SecurityGroups), wantSecurityGroups)
   778  	}
   779  }
   780  
   781  func TestEncodedString(t *testing.T) {
   782  	ud := EC2UserData{
   783  		BuildletBinaryURL: "binary_url_b",
   784  		BuildletHostType:  "host_type_a",
   785  		BuildletImageURL:  "image_url_c",
   786  		BuildletName:      "name_d",
   787  		Metadata: map[string]string{
   788  			"key": "value",
   789  		},
   790  		TLSCert:     "x",
   791  		TLSKey:      "y",
   792  		TLSPassword: "z",
   793  	}
   794  	jsonUserData, err := json.Marshal(ud)
   795  	if err != nil {
   796  		t.Fatalf("unable to marshal user data to json: %s", err)
   797  	}
   798  	wantUD := base64.StdEncoding.EncodeToString([]byte(jsonUserData))
   799  	if ud.EncodedString() != wantUD {
   800  		t.Errorf("EncodedString() = %s; want %s", ud.EncodedString(), wantUD)
   801  	}
   802  }