github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/tlsutils/tlsutils.go (about) 1 /* 2 Copyright 2021 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 // Package tlsutils contains utilities for TLS configuration and formats. 18 package tlsutils 19 20 import ( 21 "context" 22 "crypto/tls" 23 "crypto/x509" 24 "encoding/pem" 25 "net" 26 "strings" 27 28 "github.com/gravitational/trace" 29 ) 30 31 // ParseCertificatePEM parses PEM-encoded x509 certificate. 32 func ParseCertificatePEM(bytes []byte) (*x509.Certificate, error) { 33 block, _ := pem.Decode(bytes) 34 if block == nil { 35 return nil, trace.BadParameter("expected PEM-encoded block") 36 } 37 cert, err := x509.ParseCertificate(block.Bytes) 38 if err != nil { 39 return nil, trace.BadParameter(err.Error()) 40 } 41 return cert, nil 42 } 43 44 // ContextDialer represents network dialer interface that uses context 45 type ContextDialer interface { 46 // DialContext is a function that dials the specified address 47 DialContext(ctx context.Context, network, addr string) (net.Conn, error) 48 } 49 50 // TLSDial dials and establishes TLS connection using custom dialer 51 // is similar to tls.DialWithDialer 52 // Note: function taken from lib/utils/tlsdial.go 53 func TLSDial(ctx context.Context, dialer ContextDialer, network, addr string, tlsConfig *tls.Config) (*tls.Conn, error) { 54 if tlsConfig == nil { 55 return nil, trace.BadParameter("tls config must be specified") 56 } 57 58 plainConn, err := dialer.DialContext(ctx, network, addr) 59 if err != nil { 60 return nil, trace.Wrap(err) 61 } 62 63 colonPos := strings.LastIndex(addr, ":") 64 if colonPos == -1 { 65 colonPos = len(addr) 66 } 67 hostname := addr[:colonPos] 68 69 // If no ServerName is set, infer the ServerName 70 // from the hostname we're connecting to. 71 if tlsConfig.ServerName == "" { 72 // Make a copy to avoid polluting argument or default. 73 tlsConfig = tlsConfig.Clone() 74 tlsConfig.ServerName = hostname 75 } 76 77 conn := tls.Client(plainConn, tlsConfig) 78 err = conn.HandshakeContext(ctx) 79 if err != nil { 80 plainConn.Close() 81 return nil, trace.Wrap(err) 82 } 83 84 if tlsConfig.InsecureSkipVerify { 85 return conn, nil 86 } 87 88 if err := conn.VerifyHostname(tlsConfig.ServerName); err != nil { 89 plainConn.Close() 90 return nil, trace.Wrap(err) 91 } 92 93 return conn, nil 94 }