github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/tlsutils/tlsutils.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package tlsutils contains utilities for TLS configuration and formats.
    18  package tlsutils
    19  
    20  import (
    21  	"context"
    22  	"crypto/tls"
    23  	"crypto/x509"
    24  	"encoding/pem"
    25  	"net"
    26  	"strings"
    27  
    28  	"github.com/gravitational/trace"
    29  )
    30  
    31  // ParseCertificatePEM parses PEM-encoded x509 certificate.
    32  func ParseCertificatePEM(bytes []byte) (*x509.Certificate, error) {
    33  	block, _ := pem.Decode(bytes)
    34  	if block == nil {
    35  		return nil, trace.BadParameter("expected PEM-encoded block")
    36  	}
    37  	cert, err := x509.ParseCertificate(block.Bytes)
    38  	if err != nil {
    39  		return nil, trace.BadParameter(err.Error())
    40  	}
    41  	return cert, nil
    42  }
    43  
    44  // ContextDialer represents network dialer interface that uses context
    45  type ContextDialer interface {
    46  	// DialContext is a function that dials the specified address
    47  	DialContext(ctx context.Context, network, addr string) (net.Conn, error)
    48  }
    49  
    50  // TLSDial dials and establishes TLS connection using custom dialer
    51  // is similar to tls.DialWithDialer
    52  // Note: function taken from lib/utils/tlsdial.go
    53  func TLSDial(ctx context.Context, dialer ContextDialer, network, addr string, tlsConfig *tls.Config) (*tls.Conn, error) {
    54  	if tlsConfig == nil {
    55  		return nil, trace.BadParameter("tls config must be specified")
    56  	}
    57  
    58  	plainConn, err := dialer.DialContext(ctx, network, addr)
    59  	if err != nil {
    60  		return nil, trace.Wrap(err)
    61  	}
    62  
    63  	colonPos := strings.LastIndex(addr, ":")
    64  	if colonPos == -1 {
    65  		colonPos = len(addr)
    66  	}
    67  	hostname := addr[:colonPos]
    68  
    69  	// If no ServerName is set, infer the ServerName
    70  	// from the hostname we're connecting to.
    71  	if tlsConfig.ServerName == "" {
    72  		// Make a copy to avoid polluting argument or default.
    73  		tlsConfig = tlsConfig.Clone()
    74  		tlsConfig.ServerName = hostname
    75  	}
    76  
    77  	conn := tls.Client(plainConn, tlsConfig)
    78  	err = conn.HandshakeContext(ctx)
    79  	if err != nil {
    80  		plainConn.Close()
    81  		return nil, trace.Wrap(err)
    82  	}
    83  
    84  	if tlsConfig.InsecureSkipVerify {
    85  		return conn, nil
    86  	}
    87  
    88  	if err := conn.VerifyHostname(tlsConfig.ServerName); err != nil {
    89  		plainConn.Close()
    90  		return nil, trace.Wrap(err)
    91  	}
    92  
    93  	return conn, nil
    94  }