github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/tls.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  // add sm2 support
    17  package gmtls
    18  
    19  import (
    20  	"crypto"
    21  	"crypto/ecdsa"
    22  	"crypto/rsa"
    23  	"crypto/x509"
    24  	"encoding/pem"
    25  	"errors"
    26  	"fmt"
    27  	"io/ioutil"
    28  	"net"
    29  	"strings"
    30  	"time"
    31  
    32  	"github.com/Hyperledger-TWGC/tjfoc-gm/sm2"
    33  	X "github.com/Hyperledger-TWGC/tjfoc-gm/x509"
    34  )
    35  
    36  // Server returns a new TLS server side connection
    37  // using conn as the underlying transport.
    38  // The configuration config must be non-nil and must include
    39  // at least one certificate or else set GetCertificate.
    40  func Server(conn net.Conn, config *Config) *Conn {
    41  	return &Conn{conn: conn, config: config}
    42  }
    43  
    44  // Client returns a new TLS client side connection
    45  // using conn as the underlying transport.
    46  // The config cannot be nil: users must set either ServerName or
    47  // InsecureSkipVerify in the config.
    48  func Client(conn net.Conn, config *Config) *Conn {
    49  	return &Conn{conn: conn, config: config, isClient: true}
    50  }
    51  
    52  // A listener implements a network listener (net.Listener) for TLS connections.
    53  type listener struct {
    54  	net.Listener
    55  	config *Config
    56  }
    57  
    58  // Accept waits for and returns the next incoming TLS connection.
    59  // The returned connection is of type *Conn.
    60  func (l *listener) Accept() (net.Conn, error) {
    61  	c, err := l.Listener.Accept()
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	return Server(c, l.config), nil
    66  }
    67  
    68  // NewListener creates a Listener which accepts connections from an inner
    69  // Listener and wraps each connection with Server.
    70  // The configuration config must be non-nil and must include
    71  // at least one certificate or else set GetCertificate.
    72  func NewListener(inner net.Listener, config *Config) net.Listener {
    73  	l := new(listener)
    74  	l.Listener = inner
    75  	l.config = config
    76  	return l
    77  }
    78  
    79  // Listen creates a TLS listener accepting connections on the
    80  // given network address using net.Listen.
    81  // The configuration config must be non-nil and must include
    82  // at least one certificate or else set GetCertificate.
    83  func Listen(network, laddr string, config *Config) (net.Listener, error) {
    84  	if config == nil || (len(config.Certificates) == 0 && config.GetCertificate == nil) {
    85  		return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
    86  	}
    87  	l, err := net.Listen(network, laddr)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	return NewListener(l, config), nil
    92  }
    93  
    94  type timeoutError struct{}
    95  
    96  func (timeoutError) Error() string   { return "tls: DialWithDialer timed out" }
    97  func (timeoutError) Timeout() bool   { return true }
    98  func (timeoutError) Temporary() bool { return true }
    99  
   100  // DialWithDialer connects to the given network address using dialer.Dial and
   101  // then initiates a TLS handshake, returning the resulting TLS connection. Any
   102  // timeout or deadline given in the dialer apply to connection and TLS
   103  // handshake as a whole.
   104  //
   105  // DialWithDialer interprets a nil configuration as equivalent to the zero
   106  // configuration; see the documentation of Config for the defaults.
   107  func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
   108  	// We want the Timeout and Deadline values from dialer to cover the
   109  	// whole process: TCP connection and TLS handshake. This means that we
   110  	// also need to start our own timers now.
   111  	timeout := dialer.Timeout
   112  
   113  	if !dialer.Deadline.IsZero() {
   114  		//		deadlineTimeout := time.Until(dialer.Deadline)
   115  		deadlineTimeout := dialer.Deadline.Sub(time.Now()) // support go before 1.8
   116  		if timeout == 0 || deadlineTimeout < timeout {
   117  			timeout = deadlineTimeout
   118  		}
   119  	}
   120  
   121  	var errChannel chan error
   122  
   123  	if timeout != 0 {
   124  		errChannel = make(chan error, 2)
   125  		time.AfterFunc(timeout, func() {
   126  			errChannel <- timeoutError{}
   127  		})
   128  	}
   129  
   130  	rawConn, err := dialer.Dial(network, addr)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	colonPos := strings.LastIndex(addr, ":")
   136  	if colonPos == -1 {
   137  		colonPos = len(addr)
   138  	}
   139  	hostname := addr[:colonPos]
   140  
   141  	if config == nil {
   142  		config = defaultConfig()
   143  	}
   144  	// If no ServerName is set, infer the ServerName
   145  	// from the hostname we're connecting to.
   146  	if config.ServerName == "" {
   147  		// Make a copy to avoid polluting argument or default.
   148  		c := config.Clone()
   149  		c.ServerName = hostname
   150  		config = c
   151  	}
   152  
   153  	conn := Client(rawConn, config)
   154  
   155  	if timeout == 0 {
   156  		err = conn.Handshake()
   157  	} else {
   158  		go func() {
   159  			errChannel <- conn.Handshake()
   160  		}()
   161  
   162  		err = <-errChannel
   163  	}
   164  
   165  	if err != nil {
   166  		rawConn.Close()
   167  		return nil, err
   168  	}
   169  
   170  	return conn, nil
   171  }
   172  
   173  // Dial connects to the given network address using net.Dial
   174  // and then initiates a TLS handshake, returning the resulting
   175  // TLS connection.
   176  // Dial interprets a nil configuration as equivalent to
   177  // the zero configuration; see the documentation of Config
   178  // for the defaults.
   179  func Dial(network, addr string, config *Config) (*Conn, error) {
   180  	return DialWithDialer(new(net.Dialer), network, addr, config)
   181  }
   182  
   183  // LoadX509KeyPair reads and parses a public/private key pair from a pair
   184  // of files. The files must contain PEM encoded data. The certificate file
   185  // may contain intermediate certificates following the leaf certificate to
   186  // form a certificate chain. On successful return, Certificate.Leaf will
   187  // be nil because the parsed form of the certificate is not retained.
   188  func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
   189  	certPEMBlock, err := ioutil.ReadFile(certFile)
   190  	if err != nil {
   191  		return Certificate{}, err
   192  	}
   193  	keyPEMBlock, err := ioutil.ReadFile(keyFile)
   194  	if err != nil {
   195  		return Certificate{}, err
   196  	}
   197  	return X509KeyPair(certPEMBlock, keyPEMBlock)
   198  }
   199  
   200  // X509KeyPair parses a public/private key pair from a pair of
   201  // PEM encoded data. On successful return, Certificate.Leaf will be nil because
   202  // the parsed form of the certificate is not retained.
   203  func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
   204  	fail := func(err error) (Certificate, error) { return Certificate{}, err }
   205  
   206  	var cert Certificate
   207  	var skippedBlockTypes []string
   208  	for {
   209  		var certDERBlock *pem.Block
   210  		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
   211  		if certDERBlock == nil {
   212  			break
   213  		}
   214  		if certDERBlock.Type == "CERTIFICATE" {
   215  			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
   216  		} else {
   217  			skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
   218  		}
   219  	}
   220  
   221  	if len(cert.Certificate) == 0 {
   222  		if len(skippedBlockTypes) == 0 {
   223  			return fail(errors.New("tls: failed to find any PEM data in certificate input"))
   224  		}
   225  		if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
   226  			return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
   227  		}
   228  		return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   229  	}
   230  
   231  	skippedBlockTypes = skippedBlockTypes[:0]
   232  	var keyDERBlock *pem.Block
   233  	for {
   234  		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
   235  		if keyDERBlock == nil {
   236  			if len(skippedBlockTypes) == 0 {
   237  				return fail(errors.New("tls: failed to find any PEM data in key input"))
   238  			}
   239  			if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
   240  				return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
   241  			}
   242  			return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   243  		}
   244  		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
   245  			break
   246  		}
   247  		skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
   248  	}
   249  
   250  	var err error
   251  	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
   252  	if err != nil {
   253  		return fail(err)
   254  	}
   255  
   256  	// We don't need to parse the public key for TLS, but we so do anyway
   257  	// to check that it looks sane and matches the private key.
   258  	x509Cert, err := X.ParseCertificate(cert.Certificate[0])
   259  	if err != nil {
   260  		return fail(err)
   261  	}
   262  
   263  	switch pub := x509Cert.PublicKey.(type) {
   264  	case *rsa.PublicKey:
   265  		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
   266  		if !ok {
   267  			return fail(errors.New("tls: private key type does not match public key type"))
   268  		}
   269  		if pub.N.Cmp(priv.N) != 0 {
   270  			return fail(errors.New("tls: private key does not match public key"))
   271  		}
   272  	case *ecdsa.PublicKey:
   273  		pub, _ = x509Cert.PublicKey.(*ecdsa.PublicKey)
   274  		switch pub.Curve {
   275  		case sm2.P256Sm2():
   276  			priv, ok := cert.PrivateKey.(*sm2.PrivateKey)
   277  			if !ok {
   278  				return fail(errors.New("tls: sm2 private key type does not match public key type"))
   279  			}
   280  			if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   281  				return fail(errors.New("tls: sm2 private key does not match public key"))
   282  			}
   283  		default:
   284  			priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
   285  			if !ok {
   286  				return fail(errors.New("tls: private key type does not match public key type"))
   287  			}
   288  			if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   289  				return fail(errors.New("tls: private key does not match public key"))
   290  			}
   291  		}
   292  	default:
   293  		return fail(errors.New("tls: unknown public key algorithm"))
   294  	}
   295  	return cert, nil
   296  }
   297  
   298  func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
   299  	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
   300  		return key, nil
   301  	}
   302  	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
   303  		switch key := key.(type) {
   304  		case *rsa.PrivateKey, *ecdsa.PrivateKey:
   305  			return key, nil
   306  		default:
   307  			return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
   308  		}
   309  	}
   310  	if key, err := X.ParsePKCS8UnecryptedPrivateKey(der); err == nil {
   311  		return key, nil
   312  	}
   313  	return nil, errors.New("tls: failed to parse private key")
   314  }