github.com/Psiphon-Labs/tls-tris@v0.0.0-20230824155421-58bf6d336a9a/tls.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package tls partially implements TLS 1.2, as specified in RFC 5246.
     6  package tls
     7  
     8  // BUG(agl): The crypto/tls package only implements some countermeasures
     9  // against Lucky13 attacks on CBC-mode encryption, and only on SHA1
    10  // variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
    11  // https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
    12  
    13  import (
    14  	"crypto"
    15  	"crypto/ecdsa"
    16  	"crypto/rsa"
    17  	"crypto/x509"
    18  	"encoding/pem"
    19  	"errors"
    20  	"fmt"
    21  	"io/ioutil"
    22  	"net"
    23  	"strings"
    24  	"time"
    25  )
    26  
    27  // Server returns a new TLS server side connection
    28  // using conn as the underlying transport.
    29  // The configuration config must be non-nil and must include
    30  // at least one certificate or else set GetCertificate.
    31  func Server(conn net.Conn, config *Config) *Conn {
    32  
    33  	// [Psiphon]
    34  	// Initialize traffic recording to facilitate playback in the case of
    35  	// passthrough.
    36  	if config.PassthroughAddress != "" {
    37  		conn = newRecorderConn(conn)
    38  	}
    39  
    40  	return &Conn{conn: conn, config: config}
    41  }
    42  
    43  // Client returns a new TLS client side connection
    44  // using conn as the underlying transport.
    45  // The config cannot be nil: users must set either ServerName or
    46  // InsecureSkipVerify in the config.
    47  func Client(conn net.Conn, config *Config) *Conn {
    48  	return &Conn{conn: conn, config: config, isClient: true}
    49  }
    50  
    51  // A listener implements a network listener (net.Listener) for TLS connections.
    52  type listener struct {
    53  	net.Listener
    54  	config *Config
    55  }
    56  
    57  // Accept waits for and returns the next incoming TLS connection.
    58  // The returned connection is of type *Conn.
    59  func (l *listener) Accept() (net.Conn, error) {
    60  	c, err := l.Listener.Accept()
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	return Server(c, l.config), nil
    65  }
    66  
    67  // NewListener creates a Listener which accepts connections from an inner
    68  // Listener and wraps each connection with Server.
    69  // The configuration config must be non-nil and must include
    70  // at least one certificate or else set GetCertificate.
    71  func NewListener(inner net.Listener, config *Config) net.Listener {
    72  	l := new(listener)
    73  	l.Listener = inner
    74  	l.config = config
    75  	return l
    76  }
    77  
    78  // Listen creates a TLS listener accepting connections on the
    79  // given network address using net.Listen.
    80  // The configuration config must be non-nil and must include
    81  // at least one certificate or else set GetCertificate.
    82  func Listen(network, laddr string, config *Config) (net.Listener, error) {
    83  	if config == nil || (len(config.Certificates) == 0 && config.GetCertificate == nil) {
    84  		return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
    85  	}
    86  	l, err := net.Listen(network, laddr)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  	return NewListener(l, config), nil
    91  }
    92  
    93  type timeoutError struct{}
    94  
    95  func (timeoutError) Error() string   { return "tls: DialWithDialer timed out" }
    96  func (timeoutError) Timeout() bool   { return true }
    97  func (timeoutError) Temporary() bool { return true }
    98  
    99  // DialWithDialer connects to the given network address using dialer.Dial and
   100  // then initiates a TLS handshake, returning the resulting TLS connection. Any
   101  // timeout or deadline given in the dialer apply to connection and TLS
   102  // handshake as a whole.
   103  //
   104  // DialWithDialer interprets a nil configuration as equivalent to the zero
   105  // configuration; see the documentation of Config for the defaults.
   106  func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
   107  	// We want the Timeout and Deadline values from dialer to cover the
   108  	// whole process: TCP connection and TLS handshake. This means that we
   109  	// also need to start our own timers now.
   110  	timeout := dialer.Timeout
   111  
   112  	if !dialer.Deadline.IsZero() {
   113  		deadlineTimeout := time.Until(dialer.Deadline)
   114  		if timeout == 0 || deadlineTimeout < timeout {
   115  			timeout = deadlineTimeout
   116  		}
   117  	}
   118  
   119  	var errChannel chan error
   120  
   121  	if timeout != 0 {
   122  		errChannel = make(chan error, 2)
   123  		time.AfterFunc(timeout, func() {
   124  			errChannel <- timeoutError{}
   125  		})
   126  	}
   127  
   128  	rawConn, err := dialer.Dial(network, addr)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	colonPos := strings.LastIndex(addr, ":")
   134  	if colonPos == -1 {
   135  		colonPos = len(addr)
   136  	}
   137  	hostname := addr[:colonPos]
   138  
   139  	if config == nil {
   140  		config = defaultConfig()
   141  	}
   142  	// If no ServerName is set, infer the ServerName
   143  	// from the hostname we're connecting to.
   144  	if config.ServerName == "" {
   145  		// Make a copy to avoid polluting argument or default.
   146  		c := config.Clone()
   147  		c.ServerName = hostname
   148  		config = c
   149  	}
   150  
   151  	conn := Client(rawConn, config)
   152  
   153  	if timeout == 0 {
   154  		err = conn.Handshake()
   155  	} else {
   156  		go func() {
   157  			errChannel <- conn.Handshake()
   158  		}()
   159  
   160  		err = <-errChannel
   161  	}
   162  
   163  	if err != nil {
   164  		rawConn.Close()
   165  		return nil, err
   166  	}
   167  
   168  	return conn, nil
   169  }
   170  
   171  // Dial connects to the given network address using net.Dial
   172  // and then initiates a TLS handshake, returning the resulting
   173  // TLS connection.
   174  // Dial interprets a nil configuration as equivalent to
   175  // the zero configuration; see the documentation of Config
   176  // for the defaults.
   177  func Dial(network, addr string, config *Config) (*Conn, error) {
   178  	return DialWithDialer(new(net.Dialer), network, addr, config)
   179  }
   180  
   181  // LoadX509KeyPair reads and parses a public/private key pair from a pair
   182  // of files. The files must contain PEM encoded data. The certificate file
   183  // may contain intermediate certificates following the leaf certificate to
   184  // form a certificate chain. On successful return, Certificate.Leaf will
   185  // be nil because the parsed form of the certificate is not retained.
   186  func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
   187  	certPEMBlock, err := ioutil.ReadFile(certFile)
   188  	if err != nil {
   189  		return Certificate{}, err
   190  	}
   191  	keyPEMBlock, err := ioutil.ReadFile(keyFile)
   192  	if err != nil {
   193  		return Certificate{}, err
   194  	}
   195  	return X509KeyPair(certPEMBlock, keyPEMBlock)
   196  }
   197  
   198  // X509KeyPair parses a public/private key pair from a pair of
   199  // PEM encoded data. On successful return, Certificate.Leaf will be nil because
   200  // the parsed form of the certificate is not retained.
   201  func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
   202  	fail := func(err error) (Certificate, error) { return Certificate{}, err }
   203  
   204  	var cert Certificate
   205  	var skippedBlockTypes []string
   206  	for {
   207  		var certDERBlock *pem.Block
   208  		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
   209  		if certDERBlock == nil {
   210  			break
   211  		}
   212  		if certDERBlock.Type == "CERTIFICATE" {
   213  			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
   214  		} else {
   215  			skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
   216  		}
   217  	}
   218  
   219  	if len(cert.Certificate) == 0 {
   220  		if len(skippedBlockTypes) == 0 {
   221  			return fail(errors.New("tls: failed to find any PEM data in certificate input"))
   222  		}
   223  		if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
   224  			return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
   225  		}
   226  		return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   227  	}
   228  
   229  	skippedBlockTypes = skippedBlockTypes[:0]
   230  	var keyDERBlock *pem.Block
   231  	for {
   232  		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
   233  		if keyDERBlock == nil {
   234  			if len(skippedBlockTypes) == 0 {
   235  				return fail(errors.New("tls: failed to find any PEM data in key input"))
   236  			}
   237  			if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
   238  				return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
   239  			}
   240  			return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   241  		}
   242  		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
   243  			break
   244  		}
   245  		skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
   246  	}
   247  
   248  	var err error
   249  	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
   250  	if err != nil {
   251  		return fail(err)
   252  	}
   253  
   254  	// We don't need to parse the public key for TLS, but we so do anyway
   255  	// to check that it looks sane and matches the private key.
   256  	x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
   257  	if err != nil {
   258  		return fail(err)
   259  	}
   260  
   261  	switch pub := x509Cert.PublicKey.(type) {
   262  	case *rsa.PublicKey:
   263  		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
   264  		if !ok {
   265  			return fail(errors.New("tls: private key type does not match public key type"))
   266  		}
   267  		if pub.N.Cmp(priv.N) != 0 {
   268  			return fail(errors.New("tls: private key does not match public key"))
   269  		}
   270  	case *ecdsa.PublicKey:
   271  		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
   272  		if !ok {
   273  			return fail(errors.New("tls: private key type does not match public key type"))
   274  		}
   275  		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   276  			return fail(errors.New("tls: private key does not match public key"))
   277  		}
   278  	default:
   279  		return fail(errors.New("tls: unknown public key algorithm"))
   280  	}
   281  
   282  	return cert, nil
   283  }
   284  
   285  // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
   286  // PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
   287  // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
   288  func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
   289  	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
   290  		return key, nil
   291  	}
   292  	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
   293  		switch key := key.(type) {
   294  		case *rsa.PrivateKey, *ecdsa.PrivateKey:
   295  			return key, nil
   296  		default:
   297  			return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
   298  		}
   299  	}
   300  	if key, err := x509.ParseECPrivateKey(der); err == nil {
   301  		return key, nil
   302  	}
   303  
   304  	return nil, errors.New("tls: failed to parse private key")
   305  }