github.com/aavshr/aws-sdk-go@v1.41.3/aws/ec2metadata/service_test.go (about)

     1  //go:build go1.7
     2  // +build go1.7
     3  
     4  package ec2metadata_test
     5  
     6  import (
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"os"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/aavshr/aws-sdk-go/aws"
    16  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    17  	"github.com/aavshr/aws-sdk-go/aws/ec2metadata"
    18  	"github.com/aavshr/aws-sdk-go/aws/request"
    19  	"github.com/aavshr/aws-sdk-go/awstesting/unit"
    20  	"github.com/aavshr/aws-sdk-go/internal/sdktesting"
    21  )
    22  
    23  func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
    24  	svc := ec2metadata.New(unit.Session)
    25  
    26  	if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a {
    27  		t.Errorf("expect %v, not to equal %v", e, a)
    28  	}
    29  
    30  	if e, a := 1*time.Second, svc.Config.HTTPClient.Timeout; e != a {
    31  		t.Errorf("expect %v to be %v", e, a)
    32  	}
    33  }
    34  
    35  func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
    36  	http.DefaultClient.Transport = &http.Transport{}
    37  	defer func() {
    38  		http.DefaultClient.Transport = nil
    39  	}()
    40  
    41  	svc := ec2metadata.New(unit.Session)
    42  
    43  	if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
    44  		t.Errorf("expect %v, got %v", e, a)
    45  	}
    46  
    47  	tr := svc.Config.HTTPClient.Transport.(*http.Transport)
    48  	if tr == nil {
    49  		t.Fatalf("expect transport not to be nil")
    50  	}
    51  	if tr.Dial != nil {
    52  		t.Errorf("expect dial to be nil, was not")
    53  	}
    54  }
    55  
    56  func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
    57  	svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))
    58  
    59  	if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
    60  		t.Errorf("expect %v, got %v", e, a)
    61  	}
    62  }
    63  
    64  func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
    65  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    66  		w.Write([]byte("us-east-1a"))
    67  	}))
    68  	defer server.Close()
    69  
    70  	cfg := aws.NewConfig().WithEndpoint(server.URL)
    71  	runEC2MetadataClients(t, cfg, 50)
    72  }
    73  
    74  func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
    75  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    76  		w.Write([]byte("us-east-1a"))
    77  	}))
    78  	defer server.Close()
    79  
    80  	cfg := aws.NewConfig().WithEndpoint(server.URL).WithHTTPClient(&http.Client{
    81  		Transport: &http.Transport{
    82  			DisableKeepAlives: true,
    83  		},
    84  	})
    85  
    86  	runEC2MetadataClients(t, cfg, 50)
    87  }
    88  
    89  func TestClientDisableIMDS(t *testing.T) {
    90  	restoreEnvFn := sdktesting.StashEnv()
    91  	defer restoreEnvFn()
    92  
    93  	os.Setenv("AWS_EC2_METADATA_DISABLED", "true")
    94  
    95  	svc := ec2metadata.New(unit.Session)
    96  	resp, err := svc.GetUserData()
    97  	if err == nil {
    98  		t.Fatalf("expect error, got none")
    99  	}
   100  	if len(resp) != 0 {
   101  		t.Errorf("expect no response, got %v", resp)
   102  	}
   103  
   104  	aerr := err.(awserr.Error)
   105  	if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
   106  		t.Errorf("expect %v error code, got %v", e, a)
   107  	}
   108  	if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
   109  		t.Errorf("expect %v in error message, got %v", e, a)
   110  	}
   111  }
   112  
   113  func TestClientStripPath(t *testing.T) {
   114  	cases := map[string]struct {
   115  		Endpoint string
   116  		Expect   string
   117  	}{
   118  		"no change": {
   119  			Endpoint: "http://example.aws",
   120  			Expect:   "http://example.aws",
   121  		},
   122  		"strip path": {
   123  			Endpoint: "http://example.aws/foo",
   124  			Expect:   "http://example.aws",
   125  		},
   126  	}
   127  
   128  	for name, c := range cases {
   129  		t.Run(name, func(t *testing.T) {
   130  			restoreEnvFn := sdktesting.StashEnv()
   131  			defer restoreEnvFn()
   132  
   133  			svc := ec2metadata.New(unit.Session, &aws.Config{
   134  				Endpoint: aws.String(c.Endpoint),
   135  			})
   136  
   137  			if e, a := c.Expect, svc.ClientInfo.Endpoint; e != a {
   138  				t.Errorf("expect %v endpoint, got %v", e, a)
   139  			}
   140  		})
   141  	}
   142  }
   143  
   144  func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
   145  	var wg sync.WaitGroup
   146  	wg.Add(atOnce)
   147  	svc := ec2metadata.New(unit.Session, cfg)
   148  	for i := 0; i < atOnce; i++ {
   149  		go func() {
   150  			defer wg.Done()
   151  			_, err := svc.GetUserData()
   152  			if err != nil {
   153  				t.Errorf("expect no error, got %v", err)
   154  			}
   155  		}()
   156  	}
   157  	wg.Wait()
   158  }