github.com/aldelo/common@v1.5.1/tlsconfig/tlsconfig.go (about)

     1  package tlsconfig
     2  
     3  /*
     4   * Copyright 2020-2023 Aldelo, LP
     5   *
     6   * Licensed under the Apache License, Version 2.0 (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   *     http://www.apache.org/licenses/LICENSE-2.0
    11   *
    12   * Unless required by applicable law or agreed to in writing, software
    13   * distributed under the License is distributed on an "AS IS" BASIS,
    14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15   * See the License for the specific language governing permissions and
    16   * limitations under the License.
    17   */
    18  
    19  import (
    20  	"crypto/tls"
    21  	"crypto/x509"
    22  	"fmt"
    23  	"io/ioutil"
    24  	"strings"
    25  )
    26  
    27  type TlsConfig struct{}
    28  
    29  // GetServerTlsConfig returns *tls.config configured for server TLS or mTLS based on parameters
    30  //
    31  // serverCertPemPath = (required) path and file name to the server cert pem (unencrypted version)
    32  // serverKeyPemPath = (required) path and file name to the server key pem (unencrypted version)
    33  // clientCaCertPath = (optional) one or more client ca cert path and file name, in case tls.config is for mTLS
    34  func (t *TlsConfig) GetServerTlsConfig(serverCertPemPath string,
    35  	serverKeyPemPath string,
    36  	clientCaCertPemPath []string) (*tls.Config, error) {
    37  	if len(strings.TrimSpace(serverCertPemPath)) == 0 || len(strings.TrimSpace(serverKeyPemPath)) == 0 {
    38  		return nil, fmt.Errorf("Server TLS Config Requires Server Certificate and Key Pem Path")
    39  	}
    40  
    41  	// create server cert
    42  	serverCert, err := tls.LoadX509KeyPair(serverCertPemPath, serverKeyPemPath)
    43  
    44  	if err != nil {
    45  		return nil, fmt.Errorf("Load X509 Key Pair Failed: %s", err.Error())
    46  	}
    47  
    48  	// if client ca cert pem defined, prep for mTLS
    49  	certPool, _ := x509.SystemCertPool()
    50  	if certPool == nil {
    51  		certPool = x509.NewCertPool()
    52  	}
    53  
    54  	certPoolCount := 0
    55  
    56  	clientCaCertPemPath = _stringSliceExtractUnique(clientCaCertPemPath)
    57  
    58  	if len(clientCaCertPemPath) > 0 {
    59  		for _, v := range clientCaCertPemPath {
    60  			if len(strings.TrimSpace(v)) > 0 {
    61  				if clientCa, e := ioutil.ReadFile(v); e != nil {
    62  					return nil, fmt.Errorf("Read Client CA Pem Failed: (%s) %s", v, e.Error())
    63  				} else {
    64  					if !certPool.AppendCertsFromPEM(clientCa) {
    65  						// fail to add client ca to cert pool
    66  						return nil, fmt.Errorf("Append Client CA From Pem Failed: %s", v)
    67  					} else {
    68  						// client ca pem appended to ca pool
    69  						certPoolCount++
    70  					}
    71  				}
    72  			}
    73  		}
    74  	}
    75  
    76  	config := &tls.Config{
    77  		Certificates: []tls.Certificate{
    78  			serverCert,
    79  		},
    80  		MinVersion: tls.VersionTLS12,
    81  		CurvePreferences: []tls.CurveID{
    82  			tls.CurveP521,
    83  			tls.CurveP384,
    84  			tls.CurveP256,
    85  		},
    86  		PreferServerCipherSuites: true,
    87  		CipherSuites: []uint16{
    88  			tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
    89  			tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
    90  			tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
    91  			tls.TLS_RSA_WITH_AES_256_CBC_SHA,
    92  		},
    93  	}
    94  
    95  	if certPoolCount > 0 {
    96  		config.ClientAuth = tls.RequireAndVerifyClientCert
    97  		config.ClientCAs = certPool
    98  	} else {
    99  		config.ClientAuth = tls.NoClientCert
   100  	}
   101  
   102  	return config, nil
   103  }
   104  
   105  // GetClientTlsConfig returns *tls.config configured for server TLS or mTLS based on parameters
   106  //
   107  // serverCaCertPath = (required) one or more server ca cert path and file name, required for both server TLS or mTLS
   108  // clientCertPemPath = (optional) for mTLS setup, path and file name to the client cert pem (unencrypted version)
   109  // clientKeyPemPath = (optional) for mTLS setup, path and file name to the client key pem (unencrypted version)
   110  func (t *TlsConfig) GetClientTlsConfig(serverCaCertPemPath []string,
   111  	clientCertPemPath string,
   112  	clientKeyPemPath string) (*tls.Config, error) {
   113  	if len(serverCaCertPemPath) == 0 {
   114  		return nil, fmt.Errorf("Client TLS Config Requires Server CA Certificate Pem Path")
   115  	}
   116  
   117  	certPool, _ := x509.SystemCertPool()
   118  	if certPool == nil {
   119  		certPool = x509.NewCertPool()
   120  	}
   121  
   122  	// get unique server ca cert pem files
   123  	serverCaCertPemPath = _stringSliceExtractUnique(serverCaCertPemPath)
   124  
   125  	for _, v := range serverCaCertPemPath {
   126  		if len(strings.TrimSpace(v)) > 0 {
   127  			if serverCa, e := ioutil.ReadFile(v); e != nil {
   128  				return nil, fmt.Errorf("Read Server CA Pem Failed: (%s) %s", v, e.Error())
   129  			} else {
   130  				if !certPool.AppendCertsFromPEM(serverCa) {
   131  					// fail to add server ca to cert pool
   132  					return nil, fmt.Errorf("Append Server CA From Pem Failed: %s", v)
   133  				}
   134  			}
   135  		}
   136  	}
   137  
   138  	config := &tls.Config{
   139  		RootCAs: certPool,
   140  	}
   141  
   142  	// for mTls set client cert
   143  	if len(strings.TrimSpace(clientCertPemPath)) > 0 && len(strings.TrimSpace(clientKeyPemPath)) > 0 {
   144  		if clientCert, e := tls.LoadX509KeyPair(clientCertPemPath, clientKeyPemPath); e != nil {
   145  			return nil, fmt.Errorf("Load X509 Key Pair Failed: %s", e.Error())
   146  		} else {
   147  			config.Certificates = []tls.Certificate{
   148  				clientCert,
   149  			}
   150  		}
   151  	}
   152  
   153  	return config, nil
   154  }
   155  
   156  // ---------------------------------------------------------------------------------------------------------------------
   157  // private functions to avoid using common namespace (which causes conflict with rest namespace - circular reference)
   158  // ---------------------------------------------------------------------------------------------------------------------
   159  
   160  // _stringSliceContains checks if value is contained within the strSlice
   161  func _stringSliceContains(strSlice *[]string, value string) bool {
   162  	if strSlice == nil {
   163  		return false
   164  	} else {
   165  		for _, v := range *strSlice {
   166  			if strings.ToLower(v) == strings.ToLower(value) {
   167  				return true
   168  			}
   169  		}
   170  
   171  		return false
   172  	}
   173  }
   174  
   175  // _stringSliceExtractUnique returns unique string slice elements
   176  func _stringSliceExtractUnique(strSlice []string) (result []string) {
   177  	if strSlice == nil {
   178  		return []string{}
   179  	} else if len(strSlice) <= 1 {
   180  		return strSlice
   181  	} else {
   182  		for _, v := range strSlice {
   183  			if !_stringSliceContains(&result, v) {
   184  				result = append(result, v)
   185  			}
   186  		}
   187  
   188  		return result
   189  	}
   190  }