github.com/hellobchain/newcryptosm@v0.0.0-20221019060107-edb949a317e9/tls/tls.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  // Package gmtls add sm2 support
    17  package tls
    18  
    19  import (
    20  	"crypto"
    21  	"crypto/rsa"
    22  	"encoding/pem"
    23  	"errors"
    24  	"fmt"
    25  	"github.com/hellobchain/newcryptosm/ecdsa"
    26  	x5092 "github.com/hellobchain/newcryptosm/x509"
    27  	"io/ioutil"
    28  	"net"
    29  	"strings"
    30  	"time"
    31  )
    32  
    33  // Server returns a new TLS server side connection
    34  // using conn as the underlying transport.
    35  // The configuration config must be non-nil and must include
    36  // at least one certificate or else set GetCertificate.
    37  func Server(conn net.Conn, config *Config) *Conn {
    38  	return &Conn{conn: conn, config: config}
    39  }
    40  
    41  // Client returns a new TLS client side connection
    42  // using conn as the underlying transport.
    43  // The config cannot be nil: users must set either ServerName or
    44  // InsecureSkipVerify in the config.
    45  func Client(conn net.Conn, config *Config) *Conn {
    46  	return &Conn{conn: conn, config: config, isClient: true}
    47  }
    48  
    49  // A listener implements a network listener (net.Listener) for TLS connections.
    50  type listener struct {
    51  	net.Listener
    52  	config *Config
    53  }
    54  
    55  // Accept waits for and returns the next incoming TLS connection.
    56  // The returned connection is of type *Conn.
    57  func (l *listener) Accept() (net.Conn, error) {
    58  	c, err := l.Listener.Accept()
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	return Server(c, l.config), nil
    63  }
    64  
    65  // NewListener creates a Listener which accepts connections from an inner
    66  // Listener and wraps each connection with Server.
    67  // The configuration config must be non-nil and must include
    68  // at least one certificate or else set GetCertificate.
    69  func NewListener(inner net.Listener, config *Config) net.Listener {
    70  	l := new(listener)
    71  	l.Listener = inner
    72  	l.config = config
    73  	return l
    74  }
    75  
    76  // Listen creates a TLS listener accepting connections on the
    77  // given network address using net.Listen.
    78  // The configuration config must be non-nil and must include
    79  // at least one certificate or else set GetCertificate.
    80  func Listen(network, laddr string, config *Config) (net.Listener, error) {
    81  	if config == nil || (len(config.Certificates) == 0 && config.GetCertificate == nil) {
    82  		return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config")
    83  	}
    84  	l, err := net.Listen(network, laddr)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	return NewListener(l, config), nil
    89  }
    90  
    91  type timeoutError struct{}
    92  
    93  func (timeoutError) Error() string   { return "tls: DialWithDialer timed out" }
    94  func (timeoutError) Timeout() bool   { return true }
    95  func (timeoutError) Temporary() bool { return true }
    96  
    97  // DialWithDialer connects to the given network address using dialer.Dial and
    98  // then initiates a TLS handshake, returning the resulting TLS connection. Any
    99  // timeout or deadline given in the dialer apply to connection and TLS
   100  // handshake as a whole.
   101  //
   102  // DialWithDialer interprets a nil configuration as equivalent to the zero
   103  // configuration; see the documentation of Config for the defaults.
   104  func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
   105  	// We want the Timeout and Deadline values from dialer to cover the
   106  	// whole process: TCP connection and TLS handshake. This means that we
   107  	// also need to start our own timers now.
   108  	timeout := dialer.Timeout
   109  
   110  	if !dialer.Deadline.IsZero() {
   111  		//		deadlineTimeout := time.Until(dialer.Deadline)
   112  		deadlineTimeout := dialer.Deadline.Sub(time.Now()) // support go before 1.8
   113  		if timeout == 0 || deadlineTimeout < timeout {
   114  			timeout = deadlineTimeout
   115  		}
   116  	}
   117  
   118  	var errChannel chan error
   119  
   120  	if timeout != 0 {
   121  		errChannel = make(chan error, 2)
   122  		time.AfterFunc(timeout, func() {
   123  			errChannel <- timeoutError{}
   124  		})
   125  	}
   126  
   127  	rawConn, err := dialer.Dial(network, addr)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	colonPos := strings.LastIndex(addr, ":")
   133  	if colonPos == -1 {
   134  		colonPos = len(addr)
   135  	}
   136  	hostname := addr[:colonPos]
   137  
   138  	if config == nil {
   139  		config = defaultConfig()
   140  	}
   141  	// If no ServerName is set, infer the ServerName
   142  	// from the hostname we're connecting to.
   143  	if config.ServerName == "" {
   144  		// Make a copy to avoid polluting argument or default.
   145  		c := config.Clone()
   146  		c.ServerName = hostname
   147  		config = c
   148  	}
   149  
   150  	conn := Client(rawConn, config)
   151  
   152  	if timeout == 0 {
   153  		err = conn.Handshake()
   154  	} else {
   155  		go func() {
   156  			errChannel <- conn.Handshake()
   157  		}()
   158  
   159  		err = <-errChannel
   160  	}
   161  
   162  	if err != nil {
   163  		rawConn.Close()
   164  		return nil, err
   165  	}
   166  
   167  	return conn, nil
   168  }
   169  
   170  // Dial connects to the given network address using net.Dial
   171  // and then initiates a TLS handshake, returning the resulting
   172  // TLS connection.
   173  // Dial interprets a nil configuration as equivalent to
   174  // the zero configuration; see the documentation of Config
   175  // for the defaults.
   176  func Dial(network, addr string, config *Config) (*Conn, error) {
   177  	return DialWithDialer(new(net.Dialer), network, addr, config)
   178  }
   179  
   180  // LoadX509KeyPair reads and parses a public/private key pair from a pair
   181  // of files. The files must contain PEM encoded data. The certificate file
   182  // may contain intermediate certificates following the leaf certificate to
   183  // form a certificate chain. On successful return, Certificate.Leaf will
   184  // be nil because the parsed form of the certificate is not retained.
   185  func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
   186  	certPEMBlock, err := ioutil.ReadFile(certFile)
   187  	if err != nil {
   188  		return Certificate{}, err
   189  	}
   190  	keyPEMBlock, err := ioutil.ReadFile(keyFile)
   191  	if err != nil {
   192  		return Certificate{}, err
   193  	}
   194  	return X509KeyPair(certPEMBlock, keyPEMBlock)
   195  }
   196  
   197  // X509KeyPair parses a public/private key pair from a pair of
   198  // PEM encoded data. On successful return, Certificate.Leaf will be nil because
   199  // the parsed form of the certificate is not retained.
   200  func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
   201  	fail := func(err error) (Certificate, error) { return Certificate{}, err }
   202  
   203  	var cert Certificate
   204  	var skippedBlockTypes []string
   205  	for {
   206  		var certDERBlock *pem.Block
   207  		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
   208  		if certDERBlock == nil {
   209  			break
   210  		}
   211  		if certDERBlock.Type == "CERTIFICATE" {
   212  			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
   213  		} else {
   214  			skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
   215  		}
   216  	}
   217  
   218  	if len(cert.Certificate) == 0 {
   219  		if len(skippedBlockTypes) == 0 {
   220  			return fail(errors.New("tls: failed to find any PEM data in certificate input"))
   221  		}
   222  		if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
   223  			return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
   224  		}
   225  		return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   226  	}
   227  
   228  	skippedBlockTypes = skippedBlockTypes[:0]
   229  	var keyDERBlock *pem.Block
   230  	for {
   231  		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
   232  		if keyDERBlock == nil {
   233  			if len(skippedBlockTypes) == 0 {
   234  				return fail(errors.New("tls: failed to find any PEM data in key input"))
   235  			}
   236  			if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
   237  				return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
   238  			}
   239  			return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   240  		}
   241  		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
   242  			break
   243  		}
   244  		skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
   245  	}
   246  
   247  	var err error
   248  	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
   249  	if err != nil {
   250  		return fail(err)
   251  	}
   252  
   253  	// We don't need to parse the public key for TLS, but we so do anyway
   254  	// to check that it looks sane and matches the private key.
   255  	x509Cert, err := x5092.ParseCertificate(cert.Certificate[0])
   256  	if err != nil {
   257  		return fail(err)
   258  	}
   259  
   260  	switch pub := x509Cert.PublicKey.(type) {
   261  	case *rsa.PublicKey:
   262  		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
   263  		if !ok {
   264  			return fail(errors.New("tls: private key type does not match public key type"))
   265  		}
   266  		if pub.N.Cmp(priv.N) != 0 {
   267  			return fail(errors.New("tls: private key does not match public key"))
   268  		}
   269  	case *ecdsa.PublicKey:
   270  		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
   271  		if !ok {
   272  			return fail(errors.New("tls: ecdsa sm2 private key type does not match public key type"))
   273  		}
   274  		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   275  			return fail(errors.New("tls: ecdsa sm2 private key does not match public key"))
   276  		}
   277  	default:
   278  		return fail(errors.New("tls: unknown public key algorithm"))
   279  	}
   280  	return cert, nil
   281  }
   282  
   283  func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
   284  	if key, err := x5092.ParsePKCS1PrivateKey(der); err == nil {
   285  		return key, nil
   286  	}
   287  	if key, err := x5092.ParsePKCS8PrivateKey(der); err == nil {
   288  		switch key := key.(type) {
   289  		case *rsa.PrivateKey, *ecdsa.PrivateKey:
   290  			return key, nil
   291  		default:
   292  			return nil, errors.New("GM tls: found unknown private key type in PKCS#8 wrapping")
   293  		}
   294  	}
   295  	if key, err := x5092.ParseECPrivateKey(der); err == nil {
   296  		return key, nil
   297  	}
   298  	return nil, errors.New("GM tls: gm failed to parse private key")
   299  }