github.com/aldelo/common@v1.5.1/tlsconfig/tlsconfig.go (about) 1 package tlsconfig 2 3 /* 4 * Copyright 2020-2023 Aldelo, LP 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 import ( 20 "crypto/tls" 21 "crypto/x509" 22 "fmt" 23 "io/ioutil" 24 "strings" 25 ) 26 27 type TlsConfig struct{} 28 29 // GetServerTlsConfig returns *tls.config configured for server TLS or mTLS based on parameters 30 // 31 // serverCertPemPath = (required) path and file name to the server cert pem (unencrypted version) 32 // serverKeyPemPath = (required) path and file name to the server key pem (unencrypted version) 33 // clientCaCertPath = (optional) one or more client ca cert path and file name, in case tls.config is for mTLS 34 func (t *TlsConfig) GetServerTlsConfig(serverCertPemPath string, 35 serverKeyPemPath string, 36 clientCaCertPemPath []string) (*tls.Config, error) { 37 if len(strings.TrimSpace(serverCertPemPath)) == 0 || len(strings.TrimSpace(serverKeyPemPath)) == 0 { 38 return nil, fmt.Errorf("Server TLS Config Requires Server Certificate and Key Pem Path") 39 } 40 41 // create server cert 42 serverCert, err := tls.LoadX509KeyPair(serverCertPemPath, serverKeyPemPath) 43 44 if err != nil { 45 return nil, fmt.Errorf("Load X509 Key Pair Failed: %s", err.Error()) 46 } 47 48 // if client ca cert pem defined, prep for mTLS 49 certPool, _ := x509.SystemCertPool() 50 if certPool == nil { 51 certPool = x509.NewCertPool() 52 } 53 54 certPoolCount := 0 55 56 clientCaCertPemPath = _stringSliceExtractUnique(clientCaCertPemPath) 57 58 if len(clientCaCertPemPath) > 0 { 59 for _, v := range clientCaCertPemPath { 60 if len(strings.TrimSpace(v)) > 0 { 61 if clientCa, e := ioutil.ReadFile(v); e != nil { 62 return nil, fmt.Errorf("Read Client CA Pem Failed: (%s) %s", v, e.Error()) 63 } else { 64 if !certPool.AppendCertsFromPEM(clientCa) { 65 // fail to add client ca to cert pool 66 return nil, fmt.Errorf("Append Client CA From Pem Failed: %s", v) 67 } else { 68 // client ca pem appended to ca pool 69 certPoolCount++ 70 } 71 } 72 } 73 } 74 } 75 76 config := &tls.Config{ 77 Certificates: []tls.Certificate{ 78 serverCert, 79 }, 80 MinVersion: tls.VersionTLS12, 81 CurvePreferences: []tls.CurveID{ 82 tls.CurveP521, 83 tls.CurveP384, 84 tls.CurveP256, 85 }, 86 PreferServerCipherSuites: true, 87 CipherSuites: []uint16{ 88 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 89 tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 90 tls.TLS_RSA_WITH_AES_256_GCM_SHA384, 91 tls.TLS_RSA_WITH_AES_256_CBC_SHA, 92 }, 93 } 94 95 if certPoolCount > 0 { 96 config.ClientAuth = tls.RequireAndVerifyClientCert 97 config.ClientCAs = certPool 98 } else { 99 config.ClientAuth = tls.NoClientCert 100 } 101 102 return config, nil 103 } 104 105 // GetClientTlsConfig returns *tls.config configured for server TLS or mTLS based on parameters 106 // 107 // serverCaCertPath = (required) one or more server ca cert path and file name, required for both server TLS or mTLS 108 // clientCertPemPath = (optional) for mTLS setup, path and file name to the client cert pem (unencrypted version) 109 // clientKeyPemPath = (optional) for mTLS setup, path and file name to the client key pem (unencrypted version) 110 func (t *TlsConfig) GetClientTlsConfig(serverCaCertPemPath []string, 111 clientCertPemPath string, 112 clientKeyPemPath string) (*tls.Config, error) { 113 if len(serverCaCertPemPath) == 0 { 114 return nil, fmt.Errorf("Client TLS Config Requires Server CA Certificate Pem Path") 115 } 116 117 certPool, _ := x509.SystemCertPool() 118 if certPool == nil { 119 certPool = x509.NewCertPool() 120 } 121 122 // get unique server ca cert pem files 123 serverCaCertPemPath = _stringSliceExtractUnique(serverCaCertPemPath) 124 125 for _, v := range serverCaCertPemPath { 126 if len(strings.TrimSpace(v)) > 0 { 127 if serverCa, e := ioutil.ReadFile(v); e != nil { 128 return nil, fmt.Errorf("Read Server CA Pem Failed: (%s) %s", v, e.Error()) 129 } else { 130 if !certPool.AppendCertsFromPEM(serverCa) { 131 // fail to add server ca to cert pool 132 return nil, fmt.Errorf("Append Server CA From Pem Failed: %s", v) 133 } 134 } 135 } 136 } 137 138 config := &tls.Config{ 139 RootCAs: certPool, 140 } 141 142 // for mTls set client cert 143 if len(strings.TrimSpace(clientCertPemPath)) > 0 && len(strings.TrimSpace(clientKeyPemPath)) > 0 { 144 if clientCert, e := tls.LoadX509KeyPair(clientCertPemPath, clientKeyPemPath); e != nil { 145 return nil, fmt.Errorf("Load X509 Key Pair Failed: %s", e.Error()) 146 } else { 147 config.Certificates = []tls.Certificate{ 148 clientCert, 149 } 150 } 151 } 152 153 return config, nil 154 } 155 156 // --------------------------------------------------------------------------------------------------------------------- 157 // private functions to avoid using common namespace (which causes conflict with rest namespace - circular reference) 158 // --------------------------------------------------------------------------------------------------------------------- 159 160 // _stringSliceContains checks if value is contained within the strSlice 161 func _stringSliceContains(strSlice *[]string, value string) bool { 162 if strSlice == nil { 163 return false 164 } else { 165 for _, v := range *strSlice { 166 if strings.ToLower(v) == strings.ToLower(value) { 167 return true 168 } 169 } 170 171 return false 172 } 173 } 174 175 // _stringSliceExtractUnique returns unique string slice elements 176 func _stringSliceExtractUnique(strSlice []string) (result []string) { 177 if strSlice == nil { 178 return []string{} 179 } else if len(strSlice) <= 1 { 180 return strSlice 181 } else { 182 for _, v := range strSlice { 183 if !_stringSliceContains(&result, v) { 184 result = append(result, v) 185 } 186 } 187 188 return result 189 } 190 }