github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/tls/tls.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package tls partially implements TLS 1.2, as specified in RFC 5246.
     6  package tls
     7  
     8  import (
     9  	"crypto"
    10  	"crypto/ecdsa"
    11  	"crypto/rsa"
    12  	"encoding/pem"
    13  	"errors"
    14  	"io/ioutil"
    15  	"net"
    16  	"strings"
    17  	"time"
    18  
    19  	"github.com/zmap/zcrypto/x509"
    20  )
    21  
    22  // Server returns a new TLS server side connection
    23  // using conn as the underlying transport.
    24  // The configuration config must be non-nil and must have
    25  // at least one certificate.
    26  func Server(conn net.Conn, config *Config) *Conn {
    27  	return &Conn{conn: conn, config: config}
    28  }
    29  
    30  // Client returns a new TLS client side connection
    31  // using conn as the underlying transport.
    32  // The config cannot be nil: users must set either ServerName or
    33  // InsecureSkipVerify in the config.
    34  func Client(conn net.Conn, config *Config) *Conn {
    35  	return &Conn{conn: conn, config: config, isClient: true}
    36  }
    37  
    38  // A listener implements a network listener (net.Listener) for TLS connections.
    39  type listener struct {
    40  	net.Listener
    41  	config *Config
    42  }
    43  
    44  // Accept waits for and returns the next incoming TLS connection.
    45  // The returned connection c is a *tls.Conn.
    46  func (l *listener) Accept() (c net.Conn, err error) {
    47  	c, err = l.Listener.Accept()
    48  	if err != nil {
    49  		return
    50  	}
    51  	c = Server(c, l.config)
    52  	return
    53  }
    54  
    55  // NewListener creates a Listener which accepts connections from an inner
    56  // Listener and wraps each connection with Server.
    57  // The configuration config must be non-nil and must have
    58  // at least one certificate.
    59  func NewListener(inner net.Listener, config *Config) net.Listener {
    60  	l := new(listener)
    61  	l.Listener = inner
    62  	l.config = config
    63  	return l
    64  }
    65  
    66  // Listen creates a TLS listener accepting connections on the
    67  // given network address using net.Listen.
    68  // The configuration config must be non-nil and must have
    69  // at least one certificate.
    70  func Listen(network, laddr string, config *Config) (net.Listener, error) {
    71  	if config == nil || len(config.Certificates) == 0 {
    72  		return nil, errors.New("tls.Listen: no certificates in configuration")
    73  	}
    74  	l, err := net.Listen(network, laddr)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	return NewListener(l, config), nil
    79  }
    80  
    81  type timeoutError struct{}
    82  
    83  func (timeoutError) Error() string   { return "tls: DialWithDialer timed out" }
    84  func (timeoutError) Timeout() bool   { return true }
    85  func (timeoutError) Temporary() bool { return true }
    86  
    87  // DialWithDialer connects to the given network address using dialer.Dial and
    88  // then initiates a TLS handshake, returning the resulting TLS connection. Any
    89  // timeout or deadline given in the dialer apply to connection and TLS
    90  // handshake as a whole.
    91  //
    92  // DialWithDialer interprets a nil configuration as equivalent to the zero
    93  // configuration; see the documentation of Config for the defaults.
    94  func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
    95  	// We want the Timeout and Deadline values from dialer to cover the
    96  	// whole process: TCP connection and TLS handshake. This means that we
    97  	// also need to start our own timers now.
    98  	timeout := dialer.Timeout
    99  
   100  	if !dialer.Deadline.IsZero() {
   101  		deadlineTimeout := dialer.Deadline.Sub(time.Now())
   102  		if timeout == 0 || deadlineTimeout < timeout {
   103  			timeout = deadlineTimeout
   104  		}
   105  	}
   106  
   107  	var errChannel chan error
   108  
   109  	if timeout != 0 {
   110  		errChannel = make(chan error, 2)
   111  		time.AfterFunc(timeout, func() {
   112  			errChannel <- timeoutError{}
   113  		})
   114  	}
   115  
   116  	rawConn, err := dialer.Dial(network, addr)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	colonPos := strings.LastIndex(addr, ":")
   122  	if colonPos == -1 {
   123  		colonPos = len(addr)
   124  	}
   125  	hostname := addr[:colonPos]
   126  
   127  	if config == nil {
   128  		config = defaultConfig()
   129  	}
   130  	// If no ServerName is set, infer the ServerName
   131  	// from the hostname we're connecting to.
   132  	if config.ServerName == "" {
   133  		// Make a copy to avoid polluting argument or default.
   134  		c := *config
   135  		c.ServerName = hostname
   136  		config = &c
   137  	}
   138  
   139  	conn := Client(rawConn, config)
   140  
   141  	if timeout == 0 {
   142  		err = conn.Handshake()
   143  	} else {
   144  		go func() {
   145  			errChannel <- conn.Handshake()
   146  		}()
   147  
   148  		err = <-errChannel
   149  	}
   150  
   151  	if err != nil {
   152  		rawConn.Close()
   153  		return nil, err
   154  	}
   155  
   156  	return conn, nil
   157  }
   158  
   159  // Dial connects to the given network address using net.Dial
   160  // and then initiates a TLS handshake, returning the resulting
   161  // TLS connection.
   162  // Dial interprets a nil configuration as equivalent to
   163  // the zero configuration; see the documentation of Config
   164  // for the defaults.
   165  func Dial(network, addr string, config *Config) (*Conn, error) {
   166  	return DialWithDialer(new(net.Dialer), network, addr, config)
   167  }
   168  
   169  // LoadX509KeyPair reads and parses a public/private key pair from a pair of
   170  // files. The files must contain PEM encoded data.
   171  func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
   172  	certPEMBlock, err := ioutil.ReadFile(certFile)
   173  	if err != nil {
   174  		return
   175  	}
   176  	keyPEMBlock, err := ioutil.ReadFile(keyFile)
   177  	if err != nil {
   178  		return
   179  	}
   180  	return X509KeyPair(certPEMBlock, keyPEMBlock)
   181  }
   182  
   183  // X509KeyPair parses a public/private key pair from a pair of
   184  // PEM encoded data.
   185  func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error) {
   186  	var certDERBlock *pem.Block
   187  	for {
   188  		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
   189  		if certDERBlock == nil {
   190  			break
   191  		}
   192  		if certDERBlock.Type == "CERTIFICATE" {
   193  			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
   194  		}
   195  	}
   196  
   197  	if len(cert.Certificate) == 0 {
   198  		err = errors.New("crypto/tls: failed to parse certificate PEM data")
   199  		return
   200  	}
   201  
   202  	var keyDERBlock *pem.Block
   203  	for {
   204  		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
   205  		if keyDERBlock == nil {
   206  			err = errors.New("crypto/tls: failed to parse key PEM data")
   207  			return
   208  		}
   209  		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
   210  			break
   211  		}
   212  	}
   213  
   214  	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
   215  	if err != nil {
   216  		return
   217  	}
   218  
   219  	// We don't need to parse the public key for TLS, but we so do anyway
   220  	// to check that it looks sane and matches the private key.
   221  	x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
   222  	if err != nil {
   223  		return
   224  	}
   225  
   226  	switch pub := x509Cert.PublicKey.(type) {
   227  	case *rsa.PublicKey:
   228  		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
   229  		if !ok {
   230  			err = errors.New("crypto/tls: private key type does not match public key type")
   231  			return
   232  		}
   233  		if pub.N.Cmp(priv.N) != 0 {
   234  			err = errors.New("crypto/tls: private key does not match public key")
   235  			return
   236  		}
   237  	case *x509.AugmentedECDSA:
   238  		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
   239  		if !ok {
   240  			err = errors.New("crypto/tls: private key type does not match pub.Public key type")
   241  			return
   242  
   243  		}
   244  		if pub.Pub.X.Cmp(priv.X) != 0 || pub.Pub.Y.Cmp(priv.Y) != 0 {
   245  			err = errors.New("crypto/tls: private key does not match pub.Public key")
   246  			return
   247  		}
   248  	case *ecdsa.PublicKey:
   249  		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
   250  		if !ok {
   251  			err = errors.New("crypto/tls: private key type does not match public key type")
   252  			return
   253  
   254  		}
   255  		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   256  			err = errors.New("crypto/tls: private key does not match public key")
   257  			return
   258  		}
   259  	default:
   260  		err = errors.New("crypto/tls: unknown public key algorithm")
   261  		return
   262  	}
   263  
   264  	return
   265  }
   266  
   267  // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
   268  // PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
   269  // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
   270  func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
   271  	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
   272  		return key, nil
   273  	}
   274  	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
   275  		switch key := key.(type) {
   276  		case *rsa.PrivateKey, *ecdsa.PrivateKey:
   277  			return key, nil
   278  		default:
   279  			return nil, errors.New("crypto/tls: found unknown private key type in PKCS#8 wrapping")
   280  		}
   281  	}
   282  	if key, err := x509.ParseECPrivateKey(der); err == nil {
   283  		return key, nil
   284  	}
   285  
   286  	return nil, errors.New("crypto/tls: failed to parse private key")
   287  }