github.com/aavshr/aws-sdk-go@v1.41.3/aws/session/custom_ca_bundle_test.go (about) 1 package session 2 3 import ( 4 "bytes" 5 "fmt" 6 "net" 7 "net/http" 8 "os" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/aavshr/aws-sdk-go/aws" 14 "github.com/aavshr/aws-sdk-go/aws/awserr" 15 "github.com/aavshr/aws-sdk-go/aws/credentials" 16 "github.com/aavshr/aws-sdk-go/awstesting" 17 ) 18 19 var TLSBundleCertFile string 20 var TLSBundleKeyFile string 21 var TLSBundleCAFile string 22 23 func TestMain(m *testing.M) { 24 var err error 25 26 TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile, err = awstesting.CreateTLSBundleFiles() 27 if err != nil { 28 panic(err) 29 } 30 31 fmt.Println("TestMain", TLSBundleCertFile, TLSBundleKeyFile) 32 33 code := m.Run() 34 35 err = awstesting.CleanupTLSBundleFiles(TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile) 36 if err != nil { 37 panic(err) 38 } 39 40 os.Exit(code) 41 } 42 43 func TestNewSession_WithCustomCABundle_Env(t *testing.T) { 44 restoreEnvFn := initSessionTestEnv() 45 defer restoreEnvFn() 46 47 endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil) 48 if err != nil { 49 t.Fatalf("expect no error, got %v", err) 50 } 51 52 os.Setenv("AWS_CA_BUNDLE", TLSBundleCAFile) 53 54 s, err := NewSession(&aws.Config{ 55 HTTPClient: &http.Client{}, 56 Endpoint: aws.String(endpoint), 57 Region: aws.String("mock-region"), 58 Credentials: credentials.AnonymousCredentials, 59 }) 60 if err != nil { 61 t.Fatalf("expect no error, got %v", err) 62 } 63 if s == nil { 64 t.Fatalf("expect session to be created, got none") 65 } 66 67 req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil) 68 resp, err := s.Config.HTTPClient.Do(req) 69 if err != nil { 70 t.Fatalf("expect no error, got %v", err) 71 } 72 if e, a := http.StatusOK, resp.StatusCode; e != a { 73 t.Errorf("expect %d status code, got %d", e, a) 74 } 75 } 76 77 func TestNewSession_WithCustomCABundle_EnvNotExists(t *testing.T) { 78 restoreEnvFn := initSessionTestEnv() 79 defer restoreEnvFn() 80 81 os.Setenv("AWS_CA_BUNDLE", "file-not-exists") 82 83 s, err := NewSession() 84 if err == nil { 85 t.Fatalf("expect error, got none") 86 } 87 if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a { 88 t.Errorf("expect %s error code, got %s", e, a) 89 } 90 if s != nil { 91 t.Errorf("expect nil session, got %v", s) 92 } 93 } 94 95 func TestNewSession_WithCustomCABundle_Option(t *testing.T) { 96 restoreEnvFn := initSessionTestEnv() 97 defer restoreEnvFn() 98 99 endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil) 100 if err != nil { 101 t.Fatalf("expect no error, got %v", err) 102 } 103 104 s, err := NewSessionWithOptions(Options{ 105 Config: aws.Config{ 106 HTTPClient: &http.Client{}, 107 Endpoint: aws.String(endpoint), 108 Region: aws.String("mock-region"), 109 Credentials: credentials.AnonymousCredentials, 110 }, 111 CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA), 112 }) 113 if err != nil { 114 t.Fatalf("expect no error, got %v", err) 115 } 116 if s == nil { 117 t.Fatalf("expect session to be created, got none") 118 } 119 120 req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil) 121 resp, err := s.Config.HTTPClient.Do(req) 122 if err != nil { 123 t.Fatalf("expect no error, got %v", err) 124 } 125 if e, a := http.StatusOK, resp.StatusCode; e != a { 126 t.Errorf("expect %d status code, got %d", e, a) 127 } 128 } 129 130 func TestNewSession_WithCustomCABundle_HTTPProxyAvailable(t *testing.T) { 131 restoreEnvFn := initSessionTestEnv() 132 defer restoreEnvFn() 133 134 s, err := NewSessionWithOptions(Options{ 135 Config: aws.Config{ 136 HTTPClient: &http.Client{}, 137 Region: aws.String("mock-region"), 138 Credentials: credentials.AnonymousCredentials, 139 }, 140 CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA), 141 }) 142 if err != nil { 143 t.Fatalf("expect no error, got %v", err) 144 } 145 if s == nil { 146 t.Fatalf("expect session to be created, got none") 147 } 148 149 tr := s.Config.HTTPClient.Transport.(*http.Transport) 150 if tr.Proxy == nil { 151 t.Fatalf("expect transport proxy, was nil") 152 } 153 if tr.TLSClientConfig.RootCAs == nil { 154 t.Fatalf("expect TLS config to have root CAs") 155 } 156 } 157 158 func TestNewSession_WithCustomCABundle_OptionPriority(t *testing.T) { 159 restoreEnvFn := initSessionTestEnv() 160 defer restoreEnvFn() 161 162 endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil) 163 if err != nil { 164 t.Fatalf("expect no error, got %v", err) 165 } 166 167 os.Setenv("AWS_CA_BUNDLE", "file-not-exists") 168 169 s, err := NewSessionWithOptions(Options{ 170 Config: aws.Config{ 171 HTTPClient: &http.Client{}, 172 Endpoint: aws.String(endpoint), 173 Region: aws.String("mock-region"), 174 Credentials: credentials.AnonymousCredentials, 175 }, 176 CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA), 177 }) 178 if err != nil { 179 t.Fatalf("expect no error, got %v", err) 180 } 181 if s == nil { 182 t.Fatalf("expect session to be created, got none") 183 } 184 185 req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil) 186 resp, err := s.Config.HTTPClient.Do(req) 187 if err != nil { 188 t.Fatalf("expect no error, got %v", err) 189 } 190 if e, a := http.StatusOK, resp.StatusCode; e != a { 191 t.Errorf("expect %d status code, got %d", e, a) 192 } 193 } 194 195 type mockRoundTripper struct{} 196 197 func (m *mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { 198 return nil, nil 199 } 200 201 func TestNewSession_WithCustomCABundle_UnsupportedTransport(t *testing.T) { 202 restoreEnvFn := initSessionTestEnv() 203 defer restoreEnvFn() 204 205 s, err := NewSessionWithOptions(Options{ 206 Config: aws.Config{ 207 HTTPClient: &http.Client{ 208 Transport: &mockRoundTripper{}, 209 }, 210 }, 211 CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA), 212 }) 213 if err == nil { 214 t.Fatalf("expect error, got none") 215 } 216 if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a { 217 t.Errorf("expect %s error code, got %s", e, a) 218 } 219 if s != nil { 220 t.Errorf("expect nil session, got %v", s) 221 } 222 aerrMsg := err.(awserr.Error).Message() 223 if e, a := "transport unsupported type", aerrMsg; !strings.Contains(a, e) { 224 t.Errorf("expect %s to be in %s", e, a) 225 } 226 } 227 228 func TestNewSession_WithCustomCABundle_TransportSet(t *testing.T) { 229 restoreEnvFn := initSessionTestEnv() 230 defer restoreEnvFn() 231 232 endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil) 233 if err != nil { 234 t.Fatalf("expect no error, got %v", err) 235 } 236 237 s, err := NewSessionWithOptions(Options{ 238 Config: aws.Config{ 239 Endpoint: aws.String(endpoint), 240 Region: aws.String("mock-region"), 241 Credentials: credentials.AnonymousCredentials, 242 HTTPClient: &http.Client{ 243 Transport: &http.Transport{ 244 Proxy: http.ProxyFromEnvironment, 245 Dial: (&net.Dialer{ 246 Timeout: 30 * time.Second, 247 KeepAlive: 30 * time.Second, 248 DualStack: true, 249 }).Dial, 250 TLSHandshakeTimeout: 2 * time.Second, 251 }, 252 }, 253 }, 254 CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA), 255 }) 256 if err != nil { 257 t.Fatalf("expect no error, got %v", err) 258 } 259 if s == nil { 260 t.Fatalf("expect session to be created, got none") 261 } 262 263 req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil) 264 resp, err := s.Config.HTTPClient.Do(req) 265 if err != nil { 266 t.Fatalf("expect no error, got %v", err) 267 } 268 if e, a := http.StatusOK, resp.StatusCode; e != a { 269 t.Errorf("expect %d status code, got %d", e, a) 270 } 271 }