github.com/aavshr/aws-sdk-go@v1.41.3/aws/session/custom_ca_bundle_test.go (about)

     1  package session
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/aavshr/aws-sdk-go/aws"
    14  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    15  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    16  	"github.com/aavshr/aws-sdk-go/awstesting"
    17  )
    18  
    19  var TLSBundleCertFile string
    20  var TLSBundleKeyFile string
    21  var TLSBundleCAFile string
    22  
    23  func TestMain(m *testing.M) {
    24  	var err error
    25  
    26  	TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile, err = awstesting.CreateTLSBundleFiles()
    27  	if err != nil {
    28  		panic(err)
    29  	}
    30  
    31  	fmt.Println("TestMain", TLSBundleCertFile, TLSBundleKeyFile)
    32  
    33  	code := m.Run()
    34  
    35  	err = awstesting.CleanupTLSBundleFiles(TLSBundleCertFile, TLSBundleKeyFile, TLSBundleCAFile)
    36  	if err != nil {
    37  		panic(err)
    38  	}
    39  
    40  	os.Exit(code)
    41  }
    42  
    43  func TestNewSession_WithCustomCABundle_Env(t *testing.T) {
    44  	restoreEnvFn := initSessionTestEnv()
    45  	defer restoreEnvFn()
    46  
    47  	endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
    48  	if err != nil {
    49  		t.Fatalf("expect no error, got %v", err)
    50  	}
    51  
    52  	os.Setenv("AWS_CA_BUNDLE", TLSBundleCAFile)
    53  
    54  	s, err := NewSession(&aws.Config{
    55  		HTTPClient:  &http.Client{},
    56  		Endpoint:    aws.String(endpoint),
    57  		Region:      aws.String("mock-region"),
    58  		Credentials: credentials.AnonymousCredentials,
    59  	})
    60  	if err != nil {
    61  		t.Fatalf("expect no error, got %v", err)
    62  	}
    63  	if s == nil {
    64  		t.Fatalf("expect session to be created, got none")
    65  	}
    66  
    67  	req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
    68  	resp, err := s.Config.HTTPClient.Do(req)
    69  	if err != nil {
    70  		t.Fatalf("expect no error, got %v", err)
    71  	}
    72  	if e, a := http.StatusOK, resp.StatusCode; e != a {
    73  		t.Errorf("expect %d status code, got %d", e, a)
    74  	}
    75  }
    76  
    77  func TestNewSession_WithCustomCABundle_EnvNotExists(t *testing.T) {
    78  	restoreEnvFn := initSessionTestEnv()
    79  	defer restoreEnvFn()
    80  
    81  	os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
    82  
    83  	s, err := NewSession()
    84  	if err == nil {
    85  		t.Fatalf("expect error, got none")
    86  	}
    87  	if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a {
    88  		t.Errorf("expect %s error code, got %s", e, a)
    89  	}
    90  	if s != nil {
    91  		t.Errorf("expect nil session, got %v", s)
    92  	}
    93  }
    94  
    95  func TestNewSession_WithCustomCABundle_Option(t *testing.T) {
    96  	restoreEnvFn := initSessionTestEnv()
    97  	defer restoreEnvFn()
    98  
    99  	endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
   100  	if err != nil {
   101  		t.Fatalf("expect no error, got %v", err)
   102  	}
   103  
   104  	s, err := NewSessionWithOptions(Options{
   105  		Config: aws.Config{
   106  			HTTPClient:  &http.Client{},
   107  			Endpoint:    aws.String(endpoint),
   108  			Region:      aws.String("mock-region"),
   109  			Credentials: credentials.AnonymousCredentials,
   110  		},
   111  		CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
   112  	})
   113  	if err != nil {
   114  		t.Fatalf("expect no error, got %v", err)
   115  	}
   116  	if s == nil {
   117  		t.Fatalf("expect session to be created, got none")
   118  	}
   119  
   120  	req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
   121  	resp, err := s.Config.HTTPClient.Do(req)
   122  	if err != nil {
   123  		t.Fatalf("expect no error, got %v", err)
   124  	}
   125  	if e, a := http.StatusOK, resp.StatusCode; e != a {
   126  		t.Errorf("expect %d status code, got %d", e, a)
   127  	}
   128  }
   129  
   130  func TestNewSession_WithCustomCABundle_HTTPProxyAvailable(t *testing.T) {
   131  	restoreEnvFn := initSessionTestEnv()
   132  	defer restoreEnvFn()
   133  
   134  	s, err := NewSessionWithOptions(Options{
   135  		Config: aws.Config{
   136  			HTTPClient:  &http.Client{},
   137  			Region:      aws.String("mock-region"),
   138  			Credentials: credentials.AnonymousCredentials,
   139  		},
   140  		CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
   141  	})
   142  	if err != nil {
   143  		t.Fatalf("expect no error, got %v", err)
   144  	}
   145  	if s == nil {
   146  		t.Fatalf("expect session to be created, got none")
   147  	}
   148  
   149  	tr := s.Config.HTTPClient.Transport.(*http.Transport)
   150  	if tr.Proxy == nil {
   151  		t.Fatalf("expect transport proxy, was nil")
   152  	}
   153  	if tr.TLSClientConfig.RootCAs == nil {
   154  		t.Fatalf("expect TLS config to have root CAs")
   155  	}
   156  }
   157  
   158  func TestNewSession_WithCustomCABundle_OptionPriority(t *testing.T) {
   159  	restoreEnvFn := initSessionTestEnv()
   160  	defer restoreEnvFn()
   161  
   162  	endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
   163  	if err != nil {
   164  		t.Fatalf("expect no error, got %v", err)
   165  	}
   166  
   167  	os.Setenv("AWS_CA_BUNDLE", "file-not-exists")
   168  
   169  	s, err := NewSessionWithOptions(Options{
   170  		Config: aws.Config{
   171  			HTTPClient:  &http.Client{},
   172  			Endpoint:    aws.String(endpoint),
   173  			Region:      aws.String("mock-region"),
   174  			Credentials: credentials.AnonymousCredentials,
   175  		},
   176  		CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
   177  	})
   178  	if err != nil {
   179  		t.Fatalf("expect no error, got %v", err)
   180  	}
   181  	if s == nil {
   182  		t.Fatalf("expect session to be created, got none")
   183  	}
   184  
   185  	req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
   186  	resp, err := s.Config.HTTPClient.Do(req)
   187  	if err != nil {
   188  		t.Fatalf("expect no error, got %v", err)
   189  	}
   190  	if e, a := http.StatusOK, resp.StatusCode; e != a {
   191  		t.Errorf("expect %d status code, got %d", e, a)
   192  	}
   193  }
   194  
   195  type mockRoundTripper struct{}
   196  
   197  func (m *mockRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
   198  	return nil, nil
   199  }
   200  
   201  func TestNewSession_WithCustomCABundle_UnsupportedTransport(t *testing.T) {
   202  	restoreEnvFn := initSessionTestEnv()
   203  	defer restoreEnvFn()
   204  
   205  	s, err := NewSessionWithOptions(Options{
   206  		Config: aws.Config{
   207  			HTTPClient: &http.Client{
   208  				Transport: &mockRoundTripper{},
   209  			},
   210  		},
   211  		CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
   212  	})
   213  	if err == nil {
   214  		t.Fatalf("expect error, got none")
   215  	}
   216  	if e, a := "LoadCustomCABundleError", err.(awserr.Error).Code(); e != a {
   217  		t.Errorf("expect %s error code, got %s", e, a)
   218  	}
   219  	if s != nil {
   220  		t.Errorf("expect nil session, got %v", s)
   221  	}
   222  	aerrMsg := err.(awserr.Error).Message()
   223  	if e, a := "transport unsupported type", aerrMsg; !strings.Contains(a, e) {
   224  		t.Errorf("expect %s to be in %s", e, a)
   225  	}
   226  }
   227  
   228  func TestNewSession_WithCustomCABundle_TransportSet(t *testing.T) {
   229  	restoreEnvFn := initSessionTestEnv()
   230  	defer restoreEnvFn()
   231  
   232  	endpoint, err := awstesting.CreateTLSServer(TLSBundleCertFile, TLSBundleKeyFile, nil)
   233  	if err != nil {
   234  		t.Fatalf("expect no error, got %v", err)
   235  	}
   236  
   237  	s, err := NewSessionWithOptions(Options{
   238  		Config: aws.Config{
   239  			Endpoint:    aws.String(endpoint),
   240  			Region:      aws.String("mock-region"),
   241  			Credentials: credentials.AnonymousCredentials,
   242  			HTTPClient: &http.Client{
   243  				Transport: &http.Transport{
   244  					Proxy: http.ProxyFromEnvironment,
   245  					Dial: (&net.Dialer{
   246  						Timeout:   30 * time.Second,
   247  						KeepAlive: 30 * time.Second,
   248  						DualStack: true,
   249  					}).Dial,
   250  					TLSHandshakeTimeout: 2 * time.Second,
   251  				},
   252  			},
   253  		},
   254  		CustomCABundle: bytes.NewReader(awstesting.TLSBundleCA),
   255  	})
   256  	if err != nil {
   257  		t.Fatalf("expect no error, got %v", err)
   258  	}
   259  	if s == nil {
   260  		t.Fatalf("expect session to be created, got none")
   261  	}
   262  
   263  	req, _ := http.NewRequest("GET", *s.Config.Endpoint, nil)
   264  	resp, err := s.Config.HTTPClient.Do(req)
   265  	if err != nil {
   266  		t.Fatalf("expect no error, got %v", err)
   267  	}
   268  	if e, a := http.StatusOK, resp.StatusCode; e != a {
   269  		t.Errorf("expect %d status code, got %d", e, a)
   270  	}
   271  }