github.com/aavshr/aws-sdk-go@v1.41.3/aws/session/client_tls_cert_test.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package session
     5  
     6  import (
     7  	"crypto/x509"
     8  	"io"
     9  	"net/http"
    10  	"os"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/aavshr/aws-sdk-go/awstesting"
    16  )
    17  
    18  func TestNewSession_WithClientTLSCert(t *testing.T) {
    19  	type testCase struct {
    20  		// Params
    21  		setup     func(certFilename, keyFilename string) (Options, func(), error)
    22  		ExpectErr string
    23  	}
    24  
    25  	cases := map[string]testCase{
    26  		"env": {
    27  			setup: func(certFilename, keyFilename string) (Options, func(), error) {
    28  				os.Setenv(useClientTLSCert[0], certFilename)
    29  				os.Setenv(useClientTLSKey[0], keyFilename)
    30  				return Options{}, func() {}, nil
    31  			},
    32  		},
    33  		"env file not found": {
    34  			setup: func(certFilename, keyFilename string) (Options, func(), error) {
    35  				os.Setenv(useClientTLSCert[0], "some-cert-file-not-exists")
    36  				os.Setenv(useClientTLSKey[0], "some-key-file-not-exists")
    37  				return Options{}, func() {}, nil
    38  			},
    39  			ExpectErr: "LoadClientTLSCertError",
    40  		},
    41  		"env cert file only": {
    42  			setup: func(certFilename, keyFilename string) (Options, func(), error) {
    43  				os.Setenv(useClientTLSCert[0], certFilename)
    44  				return Options{}, func() {}, nil
    45  			},
    46  			ExpectErr: "must both be provided",
    47  		},
    48  		"env key file only": {
    49  			setup: func(certFilename, keyFilename string) (Options, func(), error) {
    50  				os.Setenv(useClientTLSKey[0], keyFilename)
    51  				return Options{}, func() {}, nil
    52  			},
    53  			ExpectErr: "must both be provided",
    54  		},
    55  
    56  		"session options": {
    57  			setup: func(certFilename, keyFilename string) (Options, func(), error) {
    58  				certFile, err := os.Open(certFilename)
    59  				if err != nil {
    60  					return Options{}, nil, err
    61  				}
    62  				keyFile, err := os.Open(keyFilename)
    63  				if err != nil {
    64  					return Options{}, nil, err
    65  				}
    66  
    67  				return Options{
    68  						ClientTLSCert: certFile,
    69  						ClientTLSKey:  keyFile,
    70  					}, func() {
    71  						certFile.Close()
    72  						keyFile.Close()
    73  					}, nil
    74  			},
    75  		},
    76  		"session cert load error": {
    77  			setup: func(certFilename, keyFilename string) (Options, func(), error) {
    78  				certFile, err := os.Open(certFilename)
    79  				if err != nil {
    80  					return Options{}, nil, err
    81  				}
    82  				keyFile, err := os.Open(keyFilename)
    83  				if err != nil {
    84  					return Options{}, nil, err
    85  				}
    86  
    87  				stat, _ := certFile.Stat()
    88  				return Options{
    89  						ClientTLSCert: io.LimitReader(certFile, stat.Size()/2),
    90  						ClientTLSKey:  keyFile,
    91  					}, func() {
    92  						certFile.Close()
    93  						keyFile.Close()
    94  					}, nil
    95  			},
    96  			ExpectErr: "unable to load x509 key pair",
    97  		},
    98  		"session key load error": {
    99  			setup: func(certFilename, keyFilename string) (Options, func(), error) {
   100  				certFile, err := os.Open(certFilename)
   101  				if err != nil {
   102  					return Options{}, nil, err
   103  				}
   104  				keyFile, err := os.Open(keyFilename)
   105  				if err != nil {
   106  					return Options{}, nil, err
   107  				}
   108  
   109  				stat, _ := keyFile.Stat()
   110  				return Options{
   111  						ClientTLSCert: certFile,
   112  						ClientTLSKey:  io.LimitReader(keyFile, stat.Size()/2),
   113  					}, func() {
   114  						certFile.Close()
   115  						keyFile.Close()
   116  					}, nil
   117  			},
   118  			ExpectErr: "unable to load x509 key pair",
   119  		},
   120  	}
   121  
   122  	for name, c := range cases {
   123  		t.Run(name, func(t *testing.T) {
   124  			// Asserts
   125  			restoreEnvFn := initSessionTestEnv()
   126  			defer restoreEnvFn()
   127  
   128  			certFilename, keyFilename, err := awstesting.CreateClientTLSCertFiles()
   129  			if err != nil {
   130  				t.Fatalf("failed to create client certificate files, %v", err)
   131  			}
   132  			defer func() {
   133  				if err := awstesting.CleanupTLSBundleFiles(certFilename, keyFilename); err != nil {
   134  					t.Errorf("failed to cleanup client TLS cert files, %v", err)
   135  				}
   136  			}()
   137  
   138  			opts, cleanup, err := c.setup(certFilename, keyFilename)
   139  			if err != nil {
   140  				t.Fatalf("test case failed setup, %v", err)
   141  			}
   142  			if cleanup != nil {
   143  				defer cleanup()
   144  			}
   145  
   146  			server, err := awstesting.NewTLSClientCertServer(http.HandlerFunc(
   147  				func(w http.ResponseWriter, r *http.Request) {
   148  					w.WriteHeader(200)
   149  				}))
   150  			if err != nil {
   151  				t.Fatalf("failed to load session, %v", err)
   152  			}
   153  			server.StartTLS()
   154  			defer server.Close()
   155  
   156  			// Give server change to start
   157  			time.Sleep(time.Second)
   158  
   159  			// Load SDK session with options configured.
   160  			sess, err := NewSessionWithOptions(opts)
   161  			if len(c.ExpectErr) != 0 {
   162  				if err == nil {
   163  					t.Fatalf("expect error, got none")
   164  				}
   165  				if e, a := c.ExpectErr, err.Error(); !strings.Contains(a, e) {
   166  					t.Fatalf("expect error to contain %v, got %v", e, a)
   167  				}
   168  				return
   169  			}
   170  			if err != nil {
   171  				t.Fatalf("expect no error, got %v", err)
   172  			}
   173  
   174  			// Clients need to add ca bundle for test service.
   175  			p := x509.NewCertPool()
   176  			p.AddCert(server.Certificate())
   177  			client := sess.Config.HTTPClient
   178  			client.Transport.(*http.Transport).TLSClientConfig.RootCAs = p
   179  
   180  			// Send request
   181  			req, _ := http.NewRequest("GET", server.URL, nil)
   182  			resp, err := client.Do(req)
   183  			if err != nil {
   184  				t.Fatalf("failed to send request, %v", err)
   185  			}
   186  
   187  			if e, a := 200, resp.StatusCode; e != a {
   188  				t.Errorf("expect %v status code, got %v", e, a)
   189  			}
   190  		})
   191  	}
   192  }