go.temporal.io/server@v1.23.0/common/auth/tls_config_helper_test.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package auth
    26  
    27  import (
    28  	"crypto/tls"
    29  	"crypto/x509"
    30  	"encoding/base64"
    31  	"fmt"
    32  	"io"
    33  	"net/http"
    34  	"net/http/httptest"
    35  	"os"
    36  	"testing"
    37  
    38  	"github.com/golang/mock/gomock"
    39  	"github.com/stretchr/testify/assert"
    40  )
    41  
    42  var validBase64CaData, invalidBase64CaData, validBase64Certificate, invalidBase64Certificate, validBase64Key, invalidBase64Key string
    43  
    44  func readFile(path string) string {
    45  	file, err := os.Open("testdata/" + path)
    46  	if err != nil {
    47  		panic(err)
    48  	}
    49  	defer func() {
    50  		if err := file.Close(); err != nil {
    51  			panic(err)
    52  		}
    53  	}()
    54  	data, err := io.ReadAll(file)
    55  	if err != nil {
    56  		panic(err)
    57  	}
    58  	return base64.StdEncoding.EncodeToString(data)
    59  }
    60  
    61  func init() {
    62  	validBase64CaData = readFile("ca.crt")
    63  	invalidBase64CaData = readFile("invalid_ca.crt")
    64  	validBase64Certificate = readFile("localhost.crt")
    65  	invalidBase64Certificate = readFile("invalid_localhost.crt")
    66  	validBase64Key = readFile("localhost.key")
    67  	invalidBase64Key = readFile("invalid_localhost.key")
    68  }
    69  
    70  // test if the input is valid
    71  func Test_NewTLSConfig(t *testing.T) {
    72  	tests := map[string]struct {
    73  		cfg    *TLS
    74  		cfgErr string
    75  	}{
    76  		"emptyConfig": {
    77  			cfg: &TLS{},
    78  		},
    79  		"caData_good": {
    80  			cfg: &TLS{
    81  				Enabled: true,
    82  				CaData:  validBase64CaData,
    83  			},
    84  		},
    85  		"caData_badBase64": {
    86  			cfg:    &TLS{Enabled: true, CaData: "this isn't base64"},
    87  			cfgErr: "illegal base64 data at input byte",
    88  		},
    89  		"caData_badPEM": {
    90  			cfg:    &TLS{Enabled: true, CaData: "dGhpcyBpc24ndCBhIFBFTSBjZXJ0"},
    91  			cfgErr: "unable to parse certs as PEM",
    92  		},
    93  		"clientCert_badbase64cert": {
    94  			cfg: &TLS{
    95  				Enabled:  true,
    96  				CertData: "this ain't base64",
    97  				KeyData:  validBase64Key,
    98  			},
    99  			cfgErr: "illegal base64 data at input byte",
   100  		},
   101  		"clientCert_badbase64key": {
   102  			cfg: &TLS{
   103  				Enabled:  true,
   104  				CertData: validBase64Certificate,
   105  				KeyData:  "this ain't base64",
   106  			},
   107  			cfgErr: "illegal base64 data at input byte",
   108  		},
   109  		"clientCert_missingprivatekey": {
   110  			cfg: &TLS{
   111  				Enabled:  true,
   112  				CertData: validBase64Certificate,
   113  				KeyData:  "",
   114  			},
   115  			cfgErr: "unable to config TLS: cert or key is missing",
   116  		},
   117  		"clientCert_duplicate_cert": {
   118  			cfg: &TLS{
   119  				Enabled:  true,
   120  				CertData: validBase64Certificate,
   121  				CertFile: "/a/b/c",
   122  			},
   123  			cfgErr: "only one of certData or certFile properties should be specified",
   124  		},
   125  		"clientCert_duplicate_key": {
   126  			cfg: &TLS{
   127  				Enabled: true,
   128  				KeyData: validBase64Key,
   129  				KeyFile: "/a/b/c",
   130  			},
   131  			cfgErr: "only one of keyData or keyFile properties should be specified",
   132  		},
   133  		"clientCert_duplicate_ca": {
   134  			cfg: &TLS{
   135  				Enabled: true,
   136  				CaData:  validBase64CaData,
   137  				CaFile:  "/a/b/c",
   138  			},
   139  			cfgErr: "only one of caData or caFile properties should be specified",
   140  		},
   141  	}
   142  
   143  	for name, tc := range tests {
   144  		t.Run(name, func(t *testing.T) {
   145  			ctrl := gomock.NewController(t)
   146  			_, err := NewTLSConfig(tc.cfg)
   147  			if tc.cfgErr != "" {
   148  				assert.ErrorContains(t, err, tc.cfgErr)
   149  			} else {
   150  				assert.NoError(t, err)
   151  			}
   152  
   153  			ctrl.Finish()
   154  		})
   155  	}
   156  }
   157  
   158  func Test_ConnectToTLSServerWithCA(t *testing.T) {
   159  	// setup server
   160  	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   161  		fmt.Fprintln(w, "Hello World")
   162  	})
   163  	ts := httptest.NewUnstartedServer(h)
   164  	certBytes, err := os.ReadFile("./testdata/localhost.crt")
   165  	if err != nil {
   166  		panic(fmt.Errorf("unable to decode certificate %w", err))
   167  	}
   168  	keyBytes, err := os.ReadFile("./testdata/localhost.key")
   169  	if err != nil {
   170  		panic(fmt.Errorf("unable to decode key %w", err))
   171  	}
   172  	cert, err := tls.X509KeyPair(certBytes, keyBytes)
   173  	if err != nil {
   174  		panic(fmt.Errorf("unable to load certificate %w", err))
   175  	}
   176  	ts.TLS = &tls.Config{
   177  		Certificates: []tls.Certificate{cert},
   178  	}
   179  	ts.StartTLS()
   180  
   181  	tests := map[string]struct {
   182  		cfg           *TLS
   183  		connectionErr string
   184  	}{
   185  		"caData_good": {
   186  			cfg: &TLS{
   187  				Enabled: true,
   188  				CaData:  validBase64CaData,
   189  			},
   190  		},
   191  		"caData_signedByWrongCA": {
   192  			cfg: &TLS{
   193  				Enabled:                true,
   194  				EnableHostVerification: true,
   195  				CaData:                 invalidBase64CaData,
   196  			},
   197  			connectionErr: "x509: certificate signed by unknown authority",
   198  		},
   199  		"caData_signedByWrongCAButNotEnableHostVerification": {
   200  			cfg: &TLS{
   201  				Enabled:                true,
   202  				EnableHostVerification: false,
   203  				CaData:                 invalidBase64CaData,
   204  			},
   205  		},
   206  		"caFile_good": {
   207  			cfg: &TLS{
   208  				Enabled:                true,
   209  				EnableHostVerification: true,
   210  				CaFile:                 "testdata/ca.crt",
   211  			},
   212  		},
   213  		"caFile_signedByWrongCA": {
   214  			cfg: &TLS{
   215  				Enabled:                true,
   216  				EnableHostVerification: true,
   217  				CaFile:                 "testdata/invalid_ca.crt",
   218  			},
   219  			connectionErr: "x509: certificate signed by unknown authority",
   220  		},
   221  		"caFile_signedByWrongCANotEnableHostVerification": {
   222  			cfg: &TLS{
   223  				Enabled:                true,
   224  				EnableHostVerification: false,
   225  				CaFile:                 "testdata/invalid_ca.crt",
   226  			},
   227  		},
   228  		"certData_good": {
   229  			cfg: &TLS{
   230  				Enabled:                true,
   231  				EnableHostVerification: true,
   232  				CaData:                 validBase64Certificate,
   233  			},
   234  		},
   235  	}
   236  
   237  	for name, tc := range tests {
   238  		t.Run(name, func(t *testing.T) {
   239  			ctrl := gomock.NewController(t)
   240  			tlsConfig, err := NewTLSConfig(tc.cfg)
   241  			if err != nil {
   242  				panic(err)
   243  			}
   244  			cl := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
   245  			resp, err := cl.Get(ts.URL)
   246  			if tc.connectionErr != "" {
   247  				assert.ErrorContains(t, err, tc.connectionErr)
   248  			} else {
   249  				assert.NoError(t, err)
   250  				assert.Equal(t, 200, resp.StatusCode)
   251  			}
   252  
   253  			ctrl.Finish()
   254  		})
   255  	}
   256  }
   257  
   258  func Test_ConnectToTLSServerWithClientCertificate(t *testing.T) {
   259  	// setup server
   260  	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   261  		fmt.Fprintln(w, "Hello World")
   262  	})
   263  	ts := httptest.NewUnstartedServer(h)
   264  	certBytes, err := os.ReadFile("./testdata/localhost.crt")
   265  	if err != nil {
   266  		panic(fmt.Errorf("unable to decode certificate %w", err))
   267  	}
   268  	keyBytes, err := os.ReadFile("./testdata/localhost.key")
   269  	if err != nil {
   270  		panic(fmt.Errorf("unable to decode key %w", err))
   271  	}
   272  	cert, err := tls.X509KeyPair(certBytes, keyBytes)
   273  	if err != nil {
   274  		panic(fmt.Errorf("unable to load certificate %w", err))
   275  	}
   276  	caBytes, _ := os.ReadFile("testdata/ca.crt")
   277  	caCertPool := x509.NewCertPool()
   278  	caCertPool.AppendCertsFromPEM(caBytes)
   279  	ts.TLS = &tls.Config{
   280  		ClientCAs:    caCertPool,
   281  		Certificates: []tls.Certificate{cert},
   282  		ClientAuth:   tls.RequireAndVerifyClientCert,
   283  	}
   284  	ts.StartTLS()
   285  
   286  	tests := map[string]struct {
   287  		cfg           *TLS
   288  		connectionErr string
   289  	}{
   290  		"clientData_good": {
   291  			cfg: &TLS{
   292  				Enabled:                true,
   293  				EnableHostVerification: true,
   294  				CaData:                 validBase64CaData,
   295  				CertData:               validBase64Certificate,
   296  				KeyData:                validBase64Key,
   297  			},
   298  		},
   299  		"clientData_certNotProvided": {
   300  			cfg: &TLS{
   301  				Enabled:                true,
   302  				EnableHostVerification: true,
   303  				CaData:                 validBase64CaData,
   304  			},
   305  			connectionErr: "certificate required",
   306  		},
   307  		"clientData_certInvalid": {
   308  			cfg: &TLS{
   309  				Enabled:                true,
   310  				EnableHostVerification: true,
   311  				CaData:                 validBase64CaData,
   312  				CertData:               invalidBase64Certificate,
   313  				KeyData:                invalidBase64Key,
   314  			},
   315  			connectionErr: "certificate required",
   316  		},
   317  		"certFile_good": {
   318  			cfg: &TLS{
   319  				Enabled:                true,
   320  				EnableHostVerification: true,
   321  				CaData:                 validBase64CaData,
   322  				CertFile:               "testdata/localhost.crt",
   323  				KeyFile:                "testdata/localhost.key",
   324  			},
   325  		},
   326  		"clientFile_certInvalid": {
   327  			cfg: &TLS{
   328  				Enabled:                true,
   329  				EnableHostVerification: true,
   330  				CaData:                 validBase64CaData,
   331  				CertFile:               "testdata/invalid_localhost.crt",
   332  				KeyFile:                "testdata/invalid_localhost.key",
   333  			},
   334  			connectionErr: "certificate required",
   335  		},
   336  	}
   337  
   338  	for name, tc := range tests {
   339  		t.Run(name, func(t *testing.T) {
   340  			ctrl := gomock.NewController(t)
   341  			tlsConfig, err := NewTLSConfig(tc.cfg)
   342  			if err != nil {
   343  				panic(err)
   344  			}
   345  			cl := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
   346  			resp, err := cl.Get(ts.URL)
   347  			if tc.connectionErr != "" {
   348  				assert.ErrorContains(t, err, tc.connectionErr)
   349  			} else {
   350  				assert.NoError(t, err)
   351  				assert.Equal(t, 200, resp.StatusCode)
   352  			}
   353  
   354  			ctrl.Finish()
   355  		})
   356  	}
   357  }