gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/gmtls/tls.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // gmgo is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  gmtls是基于`golang/go`的`tls`包实现的国密改造版本。
    11  对应版权声明: thrid_licenses/github.com/golang/go/LICENSE
    12  */
    13  
    14  // Package gmtls partially implements TLS 1.2, as specified in RFC 5246,
    15  // and TLS 1.3, as specified in RFC 8446.
    16  package gmtls
    17  
    18  // BUG(agl): The crypto/tls package only implements some countermeasures
    19  // against Lucky13 attacks on CBC-mode encryption, and only on SHA1
    20  // variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
    21  // https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"crypto"
    27  	"crypto/ecdsa"
    28  	"crypto/ed25519"
    29  	"crypto/rsa"
    30  	"encoding/pem"
    31  	"errors"
    32  	"fmt"
    33  	"gitee.com/zhaochuninhefei/gmgo/ecdsa_ext"
    34  	"gitee.com/zhaochuninhefei/zcgolog/zclog"
    35  	"net"
    36  	"os"
    37  	"strings"
    38  
    39  	"gitee.com/zhaochuninhefei/gmgo/sm2"
    40  	"gitee.com/zhaochuninhefei/gmgo/x509"
    41  )
    42  
    43  // Server 生成tls通信Server
    44  // Server returns a new TLS server side connection
    45  // using conn as the underlying transport.
    46  // The configuration config must be non-nil and must include
    47  // at least one certificate or else set GetCertificate.
    48  func Server(conn net.Conn, config *Config) *Conn {
    49  	c := &Conn{
    50  		conn:   conn,
    51  		config: config,
    52  	}
    53  	// 绑定握手函数
    54  	c.handshakeFn = c.serverHandshake
    55  	return c
    56  }
    57  
    58  // Client 生成tls通信Client
    59  // Client returns a new TLS client side connection
    60  // using conn as the underlying transport.
    61  // The config cannot be nil: users must set either ServerName or
    62  // InsecureSkipVerify in the config.
    63  func Client(conn net.Conn, config *Config) *Conn {
    64  	c := &Conn{
    65  		conn:     conn,
    66  		config:   config,
    67  		isClient: true,
    68  	}
    69  	// 绑定握手函数
    70  	c.handshakeFn = c.clientHandshake
    71  	return c
    72  }
    73  
    74  // A listener implements a network listener (net.Listener) for TLS connections.
    75  type listener struct {
    76  	net.Listener
    77  	config *Config
    78  }
    79  
    80  // Accept waits for and returns the next incoming TLS connection.
    81  // The returned connection is of type *Conn.
    82  func (l *listener) Accept() (net.Conn, error) {
    83  	c, err := l.Listener.Accept()
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	return Server(c, l.config), nil
    88  }
    89  
    90  // NewListener creates a Listener which accepts connections from an inner
    91  // Listener and wraps each connection with Server.
    92  // The configuration config must be non-nil and must include
    93  // at least one certificate or else set GetCertificate.
    94  func NewListener(inner net.Listener, config *Config) net.Listener {
    95  	l := new(listener)
    96  	l.Listener = inner
    97  	l.config = config
    98  	return l
    99  }
   100  
   101  // Listen creates a TLS listener accepting connections on the
   102  // given network address using net.Listen.
   103  // The configuration config must be non-nil and must include
   104  // at least one certificate or else set GetCertificate.
   105  func Listen(network, laddr string, config *Config) (net.Listener, error) {
   106  	if config == nil || len(config.Certificates) == 0 &&
   107  		config.GetCertificate == nil && config.GetConfigForClient == nil {
   108  		return nil, errors.New("gmtls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
   109  	}
   110  	l, err := net.Listen(network, laddr)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  	return NewListener(l, config), nil
   115  }
   116  
   117  // type timeoutError struct{}
   118  
   119  // func (timeoutError) Error() string   { return "gmtls: DialWithDialer timed out" }
   120  // func (timeoutError) Timeout() bool   { return true }
   121  // func (timeoutError) Temporary() bool { return true }
   122  
   123  // DialWithDialer connects to the given network address using dialer.Dial and
   124  // then initiates a TLS handshake, returning the resulting TLS connection. Any
   125  // timeout or deadline given in the dialer apply to connection and TLS
   126  // handshake as a whole.
   127  //
   128  // DialWithDialer interprets a nil configuration as equivalent to the zero
   129  // configuration; see the documentation of Config for the defaults.
   130  //
   131  // DialWithDialer uses context.Background internally; to specify the context,
   132  // use Dialer.DialContext with NetDialer set to the desired dialer.
   133  func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
   134  	return dial(context.Background(), dialer, network, addr, config)
   135  }
   136  
   137  // 客户端拨号,发起tls通信请求
   138  func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
   139  	if netDialer.Timeout != 0 {
   140  		var cancel context.CancelFunc
   141  		ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
   142  		defer cancel()
   143  	}
   144  
   145  	if !netDialer.Deadline.IsZero() {
   146  		var cancel context.CancelFunc
   147  		ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
   148  		defer cancel()
   149  	}
   150  
   151  	rawConn, err := netDialer.DialContext(ctx, network, addr)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	colonPos := strings.LastIndex(addr, ":")
   157  	if colonPos == -1 {
   158  		colonPos = len(addr)
   159  	}
   160  	hostname := addr[:colonPos]
   161  
   162  	if config == nil {
   163  		config = defaultConfig()
   164  	}
   165  	// If no ServerName is set, infer the ServerName
   166  	// from the hostname we're connecting to.
   167  	if config.ServerName == "" {
   168  		// Make a copy to avoid polluting argument or default.
   169  		c := config.Clone()
   170  		c.ServerName = hostname
   171  		config = c
   172  	}
   173  
   174  	conn := Client(rawConn, config)
   175  	// 客户端发起tls握手
   176  	if err := conn.HandshakeContext(ctx); err != nil {
   177  		_ = rawConn.Close()
   178  		return nil, err
   179  	}
   180  	return conn, nil
   181  }
   182  
   183  // Dial connects to the given network address using net.Dial
   184  // and then initiates a TLS handshake, returning the resulting
   185  // TLS connection.
   186  // Dial interprets a nil configuration as equivalent to
   187  // the zero configuration; see the documentation of Config
   188  // for the defaults.
   189  func Dial(network, addr string, config *Config) (*Conn, error) {
   190  	return DialWithDialer(new(net.Dialer), network, addr, config)
   191  }
   192  
   193  // Dialer dials TLS connections given a configuration and a Dialer for the
   194  // underlying connection.
   195  type Dialer struct {
   196  	// NetDialer is the optional dialer to use for the TLS connections'
   197  	// underlying TCP connections.
   198  	// A nil NetDialer is equivalent to the net.Dialer zero value.
   199  	NetDialer *net.Dialer
   200  
   201  	// Config is the TLS configuration to use for new connections.
   202  	// A nil configuration is equivalent to the zero
   203  	// configuration; see the documentation of Config for the
   204  	// defaults.
   205  	Config *Config
   206  }
   207  
   208  // Dial connects to the given network address and initiates a TLS
   209  // handshake, returning the resulting TLS connection.
   210  //
   211  // The returned Conn, if any, will always be of type *Conn.
   212  //
   213  // Dial uses context.Background internally; to specify the context,
   214  // use DialContext.
   215  func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
   216  	return d.DialContext(context.Background(), network, addr)
   217  }
   218  
   219  func (d *Dialer) netDialer() *net.Dialer {
   220  	if d.NetDialer != nil {
   221  		return d.NetDialer
   222  	}
   223  	return new(net.Dialer)
   224  }
   225  
   226  // DialContext connects to the given network address and initiates a TLS
   227  // handshake, returning the resulting TLS connection.
   228  //
   229  // The provided Context must be non-nil. If the context expires before
   230  // the connection is complete, an error is returned. Once successfully
   231  // connected, any expiration of the context will not affect the
   232  // connection.
   233  //
   234  // The returned Conn, if any, will always be of type *Conn.
   235  func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   236  	c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
   237  	if err != nil {
   238  		// Don't return c (a typed nil) in an interface.
   239  		return nil, err
   240  	}
   241  	return c, nil
   242  }
   243  
   244  // LoadX509KeyPair reads and parses a public/private key pair from a pair
   245  // of files. The files must contain PEM encoded data. The certificate file
   246  // may contain intermediate certificates following the leaf certificate to
   247  // form a certificate chain. On successful return, Certificate.Leaf will
   248  // be nil because the parsed form of the certificate is not retained.
   249  func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
   250  	certPEMBlock, err := os.ReadFile(certFile)
   251  	if err != nil {
   252  		return Certificate{}, err
   253  	}
   254  	keyPEMBlock, err := os.ReadFile(keyFile)
   255  	if err != nil {
   256  		return Certificate{}, err
   257  	}
   258  	return X509KeyPair(certPEMBlock, keyPEMBlock)
   259  }
   260  
   261  // X509KeyPair parses a public/private key pair from a pair of
   262  // PEM encoded data. On successful return, Certificate.Leaf will be nil because
   263  // the parsed form of the certificate is not retained.
   264  func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
   265  	fail := func(err error) (Certificate, error) { return Certificate{}, err }
   266  
   267  	var cert Certificate
   268  	var skippedBlockTypes []string
   269  	for {
   270  		var certDERBlock *pem.Block
   271  		// 将证书PEM字节数组解码为DER字节数组
   272  		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
   273  		if certDERBlock == nil {
   274  			break
   275  		}
   276  		if certDERBlock.Type == "CERTIFICATE" {
   277  			// 将证书DER字节数组加入证书链的证书列表
   278  			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
   279  		} else {
   280  			skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
   281  		}
   282  	}
   283  
   284  	if len(cert.Certificate) == 0 {
   285  		if len(skippedBlockTypes) == 0 {
   286  			return fail(errors.New("gmtls: failed to find any PEM data in certificate input"))
   287  		}
   288  		if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
   289  			return fail(errors.New("gmtls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
   290  		}
   291  		return fail(fmt.Errorf("gmtls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   292  	}
   293  
   294  	skippedBlockTypes = skippedBlockTypes[:0]
   295  	var keyDERBlock *pem.Block
   296  	for {
   297  		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
   298  		if keyDERBlock == nil {
   299  			if len(skippedBlockTypes) == 0 {
   300  				return fail(errors.New("gmtls: failed to find any PEM data in key input"))
   301  			}
   302  			if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
   303  				return fail(errors.New("gmtls: found a certificate rather than a key in the PEM for the private key"))
   304  			}
   305  			return fail(fmt.Errorf("gmtls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   306  		}
   307  		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
   308  			break
   309  		}
   310  		skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
   311  	}
   312  	// 读取证书链中的首个证书(子证书),转为x509.Certificate
   313  	// We don't need to parse the public key for TLS, but we so do anyway
   314  	// to check that it looks sane and matches the private key.
   315  	x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
   316  	if err != nil {
   317  		return fail(err)
   318  	}
   319  
   320  	var signatures []SignatureScheme
   321  	zclog.Debugf("x509Cert.SignatureAlgorithm: %s", x509Cert.SignatureAlgorithm.String())
   322  	switch x509Cert.SignatureAlgorithm {
   323  	case x509.SM2WithSM3:
   324  		signatures = append(signatures, SM2WITHSM3)
   325  	case x509.ECDSAWithSHA256:
   326  		signatures = append(signatures, ECDSAWithP256AndSHA256)
   327  	case x509.ECDSAWithSHA384:
   328  		signatures = append(signatures, ECDSAWithP384AndSHA384)
   329  	case x509.ECDSAWithSHA512:
   330  		signatures = append(signatures, ECDSAWithP521AndSHA512)
   331  	case x509.ECDSAEXTWithSHA256:
   332  		signatures = append(signatures, ECDSAEXTWithP256AndSHA256)
   333  	case x509.ECDSAEXTWithSHA384:
   334  		signatures = append(signatures, ECDSAEXTWithP384AndSHA384)
   335  	case x509.ECDSAEXTWithSHA512:
   336  		signatures = append(signatures, ECDSAEXTWithP521AndSHA512)
   337  	}
   338  	if len(signatures) > 0 {
   339  		cert.SupportedSignatureAlgorithms = signatures
   340  	}
   341  	zclog.Debugf("cert.SupportedSignatureAlgorithms: %s", cert.SupportedSignatureAlgorithms)
   342  
   343  	// 将key的DER字节数组转为私钥
   344  	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
   345  	if err != nil {
   346  		return fail(err)
   347  	}
   348  	// ECDSA_EXT私钥特殊处理
   349  	if keyDERBlock.Type == "ECDSA_EXT PRIVATE KEY" {
   350  		if privKey, ok := cert.PrivateKey.(*ecdsa.PrivateKey); ok {
   351  			cert.PrivateKey = &ecdsa_ext.PrivateKey{
   352  				PrivateKey: *privKey,
   353  			}
   354  			zclog.Debugln("读取到ECDSA_EXT PRIVATE KEY,并转为ecdsa_ext.PrivateKey")
   355  			hasEcdsaExt := false
   356  			for _, algorithm := range cert.SupportedSignatureAlgorithms {
   357  				if algorithm == ECDSAEXTWithP256AndSHA256 ||
   358  					algorithm == ECDSAEXTWithP384AndSHA384 ||
   359  					algorithm == ECDSAEXTWithP521AndSHA512 {
   360  					hasEcdsaExt = true
   361  					break
   362  				}
   363  			}
   364  			if !hasEcdsaExt {
   365  				// 临时对应,解决SupportedSignatureAlgorithms在ecdsa_ext时可能不正确的问题
   366  				cert.SupportedSignatureAlgorithms = []SignatureScheme{ECDSAEXTWithP256AndSHA256}
   367  				zclog.Debugf("临时修改cert.SupportedSignatureAlgorithms为: %s", cert.SupportedSignatureAlgorithms)
   368  			}
   369  		} else if _, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey); ok {
   370  			// ok
   371  		} else {
   372  			return fail(errors.New("pem文件类型为`ECDSA_EXT PRIVATE KEY`, 但证书中的私钥类型不是*ecdsa.PrivateKey"))
   373  		}
   374  	}
   375  	// 检查私钥与证书中的公钥是否匹配
   376  	switch pub := x509Cert.PublicKey.(type) {
   377  	// 补充SM2分支
   378  	case *sm2.PublicKey:
   379  		priv, ok := cert.PrivateKey.(*sm2.PrivateKey)
   380  		if !ok {
   381  			return fail(errors.New("gmtls: private key type does not match public key type"))
   382  		}
   383  		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   384  			return fail(errors.New("gmtls: private key does not match public key"))
   385  		}
   386  	case *rsa.PublicKey:
   387  		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
   388  		if !ok {
   389  			return fail(errors.New("gmtls: private key type does not match public key type"))
   390  		}
   391  		if pub.N.Cmp(priv.N) != 0 {
   392  			return fail(errors.New("gmtls: private key does not match public key"))
   393  		}
   394  	case *ecdsa.PublicKey:
   395  		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
   396  		if !ok {
   397  			privExt, okExt := cert.PrivateKey.(*ecdsa_ext.PrivateKey)
   398  			if !okExt {
   399  				return fail(errors.New("gmtls: private key type does not match public key type"))
   400  			}
   401  			if pub.X.Cmp(privExt.X) != 0 || pub.Y.Cmp(privExt.Y) != 0 {
   402  				return fail(errors.New("gmtls: private key does not match public key"))
   403  			}
   404  		} else {
   405  			if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   406  				return fail(errors.New("gmtls: private key does not match public key"))
   407  			}
   408  		}
   409  	case *ecdsa_ext.PublicKey:
   410  		priv, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey)
   411  		if !ok {
   412  			return fail(errors.New("gmtls: private key type does not match public key type"))
   413  		}
   414  		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   415  			return fail(errors.New("gmtls: private key does not match public key"))
   416  		}
   417  	case ed25519.PublicKey:
   418  		priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
   419  		if !ok {
   420  			return fail(errors.New("gmtls: private key type does not match public key type"))
   421  		}
   422  		if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
   423  			return fail(errors.New("gmtls: private key does not match public key"))
   424  		}
   425  	default:
   426  		return fail(errors.New("gmtls: unknown public key algorithm"))
   427  	}
   428  
   429  	return cert, nil
   430  }
   431  
   432  // 将DER字节数组转为对应的私钥
   433  // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
   434  // PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
   435  // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
   436  func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
   437  	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
   438  		return key, nil
   439  	}
   440  	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
   441  		switch key := key.(type) {
   442  		// 添加SM2, ecdsa_ext
   443  		case *sm2.PrivateKey, *rsa.PrivateKey, *ecdsa.PrivateKey, *ecdsa_ext.PrivateKey, ed25519.PrivateKey:
   444  			return key, nil
   445  		default:
   446  			return nil, errors.New("gmtls: found unknown private key type in PKCS#8 wrapping")
   447  		}
   448  	}
   449  	if key, err := x509.ParseECPrivateKey(der); err == nil {
   450  		return key, nil
   451  	}
   452  
   453  	return nil, errors.New("gmtls: failed to parse private key")
   454  }
   455  
   456  // NewServerConfigByClientHello 根据客户端发出的ClientHello的协议与密码套件决定Server的证书链
   457  //  当客户端支持tls1.3或gmssl,且客户端支持的密码套件包含 TLS_SM4_GCM_SM3 时,服务端证书采用gmSigCert。
   458  //  - gmSigCert 国密证书
   459  //  - genericCert 一般证书
   460  //goland:noinspection GoUnusedExportedFunction
   461  func NewServerConfigByClientHello(gmSigCert, genericCert *Certificate) (*Config, error) {
   462  	// 根据ClientHelloInfo中支持的协议,返回服务端证书
   463  	fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) {
   464  		gmFlag := false
   465  		// 检查客户端支持的协议中是否包含TLS1.3或GMSSL
   466  		for _, v := range info.SupportedVersions {
   467  			if v == VersionGMSSL || v == VersionTLS13 {
   468  				for _, curveID := range info.SupportedCurves {
   469  					if curveID == Curve256Sm2 {
   470  						gmFlag = true
   471  						break
   472  					}
   473  				}
   474  				if gmFlag {
   475  					break
   476  				}
   477  				// 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3
   478  				for _, c := range info.CipherSuites {
   479  					if c == TLS_SM4_GCM_SM3 {
   480  						gmFlag = true
   481  						break
   482  					}
   483  				}
   484  				break
   485  			}
   486  		}
   487  		if gmFlag {
   488  			return gmSigCert, nil
   489  		} else {
   490  			return genericCert, nil
   491  		}
   492  	}
   493  
   494  	return &Config{
   495  		Certificates:   nil,
   496  		GetCertificate: fncGetSignCertKeypair,
   497  	}, nil
   498  }
   499  
   500  //func NewServerConfigByClientHelloCurve(certMap map[string]*Certificate) (*Config, error) {
   501  //	// 根据ClientHelloInfo中支持的协议,返回服务端证书
   502  //	fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) {
   503  //		//info.config.CurvePreferences
   504  //
   505  //		gmFlag := false
   506  //		// 检查客户端支持的协议中是否包含TLS1.3或GMSSL
   507  //		for _, v := range info.SupportedVersions {
   508  //			if v == VersionGMSSL || v == VersionTLS13 {
   509  //				// 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3
   510  //				for _, c := range info.CipherSuites {
   511  //					if c == TLS_SM4_GCM_SM3 {
   512  //						gmFlag = true
   513  //						break
   514  //					}
   515  //				}
   516  //				break
   517  //			}
   518  //		}
   519  //		if gmFlag {
   520  //			return gmSigCert, nil
   521  //		} else {
   522  //			return genericCert, nil
   523  //		}
   524  //	}
   525  //
   526  //	return &Config{
   527  //		Certificates:   nil,
   528  //		GetCertificate: fncGetSignCertKeypair,
   529  //	}, nil
   530  //}