github.com/aavshr/aws-sdk-go@v1.41.3/aws/ec2metadata/service_test.go (about) 1 //go:build go1.7 2 // +build go1.7 3 4 package ec2metadata_test 5 6 import ( 7 "net/http" 8 "net/http/httptest" 9 "os" 10 "strings" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/aavshr/aws-sdk-go/aws" 16 "github.com/aavshr/aws-sdk-go/aws/awserr" 17 "github.com/aavshr/aws-sdk-go/aws/ec2metadata" 18 "github.com/aavshr/aws-sdk-go/aws/request" 19 "github.com/aavshr/aws-sdk-go/awstesting/unit" 20 "github.com/aavshr/aws-sdk-go/internal/sdktesting" 21 ) 22 23 func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) { 24 svc := ec2metadata.New(unit.Session) 25 26 if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a { 27 t.Errorf("expect %v, not to equal %v", e, a) 28 } 29 30 if e, a := 1*time.Second, svc.Config.HTTPClient.Timeout; e != a { 31 t.Errorf("expect %v to be %v", e, a) 32 } 33 } 34 35 func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) { 36 http.DefaultClient.Transport = &http.Transport{} 37 defer func() { 38 http.DefaultClient.Transport = nil 39 }() 40 41 svc := ec2metadata.New(unit.Session) 42 43 if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a { 44 t.Errorf("expect %v, got %v", e, a) 45 } 46 47 tr := svc.Config.HTTPClient.Transport.(*http.Transport) 48 if tr == nil { 49 t.Fatalf("expect transport not to be nil") 50 } 51 if tr.Dial != nil { 52 t.Errorf("expect dial to be nil, was not") 53 } 54 } 55 56 func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) { 57 svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true)) 58 59 if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a { 60 t.Errorf("expect %v, got %v", e, a) 61 } 62 } 63 64 func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) { 65 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 66 w.Write([]byte("us-east-1a")) 67 })) 68 defer server.Close() 69 70 cfg := aws.NewConfig().WithEndpoint(server.URL) 71 runEC2MetadataClients(t, cfg, 50) 72 } 73 74 func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) { 75 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 76 w.Write([]byte("us-east-1a")) 77 })) 78 defer server.Close() 79 80 cfg := aws.NewConfig().WithEndpoint(server.URL).WithHTTPClient(&http.Client{ 81 Transport: &http.Transport{ 82 DisableKeepAlives: true, 83 }, 84 }) 85 86 runEC2MetadataClients(t, cfg, 50) 87 } 88 89 func TestClientDisableIMDS(t *testing.T) { 90 restoreEnvFn := sdktesting.StashEnv() 91 defer restoreEnvFn() 92 93 os.Setenv("AWS_EC2_METADATA_DISABLED", "true") 94 95 svc := ec2metadata.New(unit.Session) 96 resp, err := svc.GetUserData() 97 if err == nil { 98 t.Fatalf("expect error, got none") 99 } 100 if len(resp) != 0 { 101 t.Errorf("expect no response, got %v", resp) 102 } 103 104 aerr := err.(awserr.Error) 105 if e, a := request.CanceledErrorCode, aerr.Code(); e != a { 106 t.Errorf("expect %v error code, got %v", e, a) 107 } 108 if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) { 109 t.Errorf("expect %v in error message, got %v", e, a) 110 } 111 } 112 113 func TestClientStripPath(t *testing.T) { 114 cases := map[string]struct { 115 Endpoint string 116 Expect string 117 }{ 118 "no change": { 119 Endpoint: "http://example.aws", 120 Expect: "http://example.aws", 121 }, 122 "strip path": { 123 Endpoint: "http://example.aws/foo", 124 Expect: "http://example.aws", 125 }, 126 } 127 128 for name, c := range cases { 129 t.Run(name, func(t *testing.T) { 130 restoreEnvFn := sdktesting.StashEnv() 131 defer restoreEnvFn() 132 133 svc := ec2metadata.New(unit.Session, &aws.Config{ 134 Endpoint: aws.String(c.Endpoint), 135 }) 136 137 if e, a := c.Expect, svc.ClientInfo.Endpoint; e != a { 138 t.Errorf("expect %v endpoint, got %v", e, a) 139 } 140 }) 141 } 142 } 143 144 func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) { 145 var wg sync.WaitGroup 146 wg.Add(atOnce) 147 svc := ec2metadata.New(unit.Session, cfg) 148 for i := 0; i < atOnce; i++ { 149 go func() { 150 defer wg.Done() 151 _, err := svc.GetUserData() 152 if err != nil { 153 t.Errorf("expect no error, got %v", err) 154 } 155 }() 156 } 157 wg.Wait() 158 }