gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmtls/tls.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // core-gm is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  gmtls是基于`golang/go`的`tls`包实现的国密改造版本。
    11  对应版权声明: thrid_licenses/github.com/golang/go/LICENSE
    12  */
    13  
    14  // Package gmtls partially implements TLS 1.2, as specified in RFC 5246,
    15  // and TLS 1.3, as specified in RFC 8446.
    16  package gmtls
    17  
    18  // BUG(agl): The crypto/tls package only implements some countermeasures
    19  // against Lucky13 attacks on CBC-mode encryption, and only on SHA1
    20  // variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and
    21  // https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"crypto"
    27  	"crypto/ecdsa"
    28  	"crypto/ed25519"
    29  	"crypto/rsa"
    30  	"encoding/pem"
    31  	"errors"
    32  	"fmt"
    33  	"gitee.com/ks-custle/core-gm/ecdsa_ext"
    34  	"net"
    35  	"os"
    36  	"strings"
    37  
    38  	"gitee.com/ks-custle/core-gm/sm2"
    39  	"gitee.com/ks-custle/core-gm/x509"
    40  )
    41  
    42  // Server 生成tls通信Server
    43  // Server returns a new TLS server side connection
    44  // using conn as the underlying transport.
    45  // The configuration config must be non-nil and must include
    46  // at least one certificate or else set GetCertificate.
    47  func Server(conn net.Conn, config *Config) *Conn {
    48  	c := &Conn{
    49  		conn:   conn,
    50  		config: config,
    51  	}
    52  	// 绑定握手函数
    53  	c.handshakeFn = c.serverHandshake
    54  	return c
    55  }
    56  
    57  // Client 生成tls通信Client
    58  // Client returns a new TLS client side connection
    59  // using conn as the underlying transport.
    60  // The config cannot be nil: users must set either ServerName or
    61  // InsecureSkipVerify in the config.
    62  func Client(conn net.Conn, config *Config) *Conn {
    63  	c := &Conn{
    64  		conn:     conn,
    65  		config:   config,
    66  		isClient: true,
    67  	}
    68  	// 绑定握手函数
    69  	c.handshakeFn = c.clientHandshake
    70  	return c
    71  }
    72  
    73  // A listener implements a network listener (net.Listener) for TLS connections.
    74  type listener struct {
    75  	net.Listener
    76  	config *Config
    77  }
    78  
    79  // Accept waits for and returns the next incoming TLS connection.
    80  // The returned connection is of type *Conn.
    81  func (l *listener) Accept() (net.Conn, error) {
    82  	c, err := l.Listener.Accept()
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	return Server(c, l.config), nil
    87  }
    88  
    89  // NewListener creates a Listener which accepts connections from an inner
    90  // Listener and wraps each connection with Server.
    91  // The configuration config must be non-nil and must include
    92  // at least one certificate or else set GetCertificate.
    93  func NewListener(inner net.Listener, config *Config) net.Listener {
    94  	l := new(listener)
    95  	l.Listener = inner
    96  	l.config = config
    97  	return l
    98  }
    99  
   100  // Listen creates a TLS listener accepting connections on the
   101  // given network address using net.Listen.
   102  // The configuration config must be non-nil and must include
   103  // at least one certificate or else set GetCertificate.
   104  func Listen(network, laddr string, config *Config) (net.Listener, error) {
   105  	if config == nil || len(config.Certificates) == 0 &&
   106  		config.GetCertificate == nil && config.GetConfigForClient == nil {
   107  		return nil, errors.New("gmtls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
   108  	}
   109  	l, err := net.Listen(network, laddr)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	return NewListener(l, config), nil
   114  }
   115  
   116  // type timeoutError struct{}
   117  
   118  // func (timeoutError) Error() string   { return "gmtls: DialWithDialer timed out" }
   119  // func (timeoutError) Timeout() bool   { return true }
   120  // func (timeoutError) Temporary() bool { return true }
   121  
   122  // DialWithDialer connects to the given network address using dialer.Dial and
   123  // then initiates a TLS handshake, returning the resulting TLS connection. Any
   124  // timeout or deadline given in the dialer apply to connection and TLS
   125  // handshake as a whole.
   126  //
   127  // DialWithDialer interprets a nil configuration as equivalent to the zero
   128  // configuration; see the documentation of Config for the defaults.
   129  //
   130  // DialWithDialer uses context.Background internally; to specify the context,
   131  // use Dialer.DialContext with NetDialer set to the desired dialer.
   132  func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
   133  	return dial(context.Background(), dialer, network, addr, config)
   134  }
   135  
   136  // 客户端拨号,发起tls通信请求
   137  func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
   138  	if netDialer.Timeout != 0 {
   139  		var cancel context.CancelFunc
   140  		ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
   141  		defer cancel()
   142  	}
   143  
   144  	if !netDialer.Deadline.IsZero() {
   145  		var cancel context.CancelFunc
   146  		ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
   147  		defer cancel()
   148  	}
   149  
   150  	rawConn, err := netDialer.DialContext(ctx, network, addr)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	colonPos := strings.LastIndex(addr, ":")
   156  	if colonPos == -1 {
   157  		colonPos = len(addr)
   158  	}
   159  	hostname := addr[:colonPos]
   160  
   161  	if config == nil {
   162  		config = defaultConfig()
   163  	}
   164  	// If no ServerName is set, infer the ServerName
   165  	// from the hostname we're connecting to.
   166  	if config.ServerName == "" {
   167  		// Make a copy to avoid polluting argument or default.
   168  		c := config.Clone()
   169  		c.ServerName = hostname
   170  		config = c
   171  	}
   172  
   173  	conn := Client(rawConn, config)
   174  	// 客户端发起tls握手
   175  	if err := conn.HandshakeContext(ctx); err != nil {
   176  		_ = rawConn.Close()
   177  		return nil, err
   178  	}
   179  	return conn, nil
   180  }
   181  
   182  // Dial connects to the given network address using net.Dial
   183  // and then initiates a TLS handshake, returning the resulting
   184  // TLS connection.
   185  // Dial interprets a nil configuration as equivalent to
   186  // the zero configuration; see the documentation of Config
   187  // for the defaults.
   188  func Dial(network, addr string, config *Config) (*Conn, error) {
   189  	return DialWithDialer(new(net.Dialer), network, addr, config)
   190  }
   191  
   192  // Dialer dials TLS connections given a configuration and a Dialer for the
   193  // underlying connection.
   194  type Dialer struct {
   195  	// NetDialer is the optional dialer to use for the TLS connections'
   196  	// underlying TCP connections.
   197  	// A nil NetDialer is equivalent to the net.Dialer zero value.
   198  	NetDialer *net.Dialer
   199  
   200  	// Config is the TLS configuration to use for new connections.
   201  	// A nil configuration is equivalent to the zero
   202  	// configuration; see the documentation of Config for the
   203  	// defaults.
   204  	Config *Config
   205  }
   206  
   207  // Dial connects to the given network address and initiates a TLS
   208  // handshake, returning the resulting TLS connection.
   209  //
   210  // The returned Conn, if any, will always be of type *Conn.
   211  //
   212  // Dial uses context.Background internally; to specify the context,
   213  // use DialContext.
   214  func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
   215  	return d.DialContext(context.Background(), network, addr)
   216  }
   217  
   218  func (d *Dialer) netDialer() *net.Dialer {
   219  	if d.NetDialer != nil {
   220  		return d.NetDialer
   221  	}
   222  	return new(net.Dialer)
   223  }
   224  
   225  // DialContext connects to the given network address and initiates a TLS
   226  // handshake, returning the resulting TLS connection.
   227  //
   228  // The provided Context must be non-nil. If the context expires before
   229  // the connection is complete, an error is returned. Once successfully
   230  // connected, any expiration of the context will not affect the
   231  // connection.
   232  //
   233  // The returned Conn, if any, will always be of type *Conn.
   234  func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
   235  	c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
   236  	if err != nil {
   237  		// Don't return c (a typed nil) in an interface.
   238  		return nil, err
   239  	}
   240  	return c, nil
   241  }
   242  
   243  // LoadX509KeyPair reads and parses a public/private key pair from a pair
   244  // of files. The files must contain PEM encoded data. The certificate file
   245  // may contain intermediate certificates following the leaf certificate to
   246  // form a certificate chain. On successful return, Certificate.Leaf will
   247  // be nil because the parsed form of the certificate is not retained.
   248  func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
   249  	certPEMBlock, err := os.ReadFile(certFile)
   250  	if err != nil {
   251  		return Certificate{}, err
   252  	}
   253  	keyPEMBlock, err := os.ReadFile(keyFile)
   254  	if err != nil {
   255  		return Certificate{}, err
   256  	}
   257  	return X509KeyPair(certPEMBlock, keyPEMBlock)
   258  }
   259  
   260  // X509KeyPair parses a public/private key pair from a pair of
   261  // PEM encoded data. On successful return, Certificate.Leaf will be nil because
   262  // the parsed form of the certificate is not retained.
   263  func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
   264  	fail := func(err error) (Certificate, error) { return Certificate{}, err }
   265  
   266  	var cert Certificate
   267  	var skippedBlockTypes []string
   268  	for {
   269  		var certDERBlock *pem.Block
   270  		// 将证书PEM字节数组解码为DER字节数组
   271  		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
   272  		if certDERBlock == nil {
   273  			break
   274  		}
   275  		if certDERBlock.Type == "CERTIFICATE" {
   276  			// 将证书DER字节数组加入证书链的证书列表
   277  			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
   278  		} else {
   279  			skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
   280  		}
   281  	}
   282  
   283  	if len(cert.Certificate) == 0 {
   284  		if len(skippedBlockTypes) == 0 {
   285  			return fail(errors.New("gmtls: failed to find any PEM data in certificate input"))
   286  		}
   287  		if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
   288  			return fail(errors.New("gmtls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
   289  		}
   290  		return fail(fmt.Errorf("gmtls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   291  	}
   292  
   293  	skippedBlockTypes = skippedBlockTypes[:0]
   294  	var keyDERBlock *pem.Block
   295  	for {
   296  		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
   297  		if keyDERBlock == nil {
   298  			if len(skippedBlockTypes) == 0 {
   299  				return fail(errors.New("gmtls: failed to find any PEM data in key input"))
   300  			}
   301  			if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
   302  				return fail(errors.New("gmtls: found a certificate rather than a key in the PEM for the private key"))
   303  			}
   304  			return fail(fmt.Errorf("gmtls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
   305  		}
   306  		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
   307  			break
   308  		}
   309  		skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
   310  	}
   311  	// 读取证书链中的首个证书(子证书),转为x509.Certificate
   312  	// We don't need to parse the public key for TLS, but we so do anyway
   313  	// to check that it looks sane and matches the private key.
   314  	x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
   315  	if err != nil {
   316  		return fail(err)
   317  	}
   318  
   319  	var signatures []SignatureScheme
   320  	fmt.Println("x509Cert.SignatureAlgorithm: %s", x509Cert.SignatureAlgorithm.String())
   321  	switch x509Cert.SignatureAlgorithm {
   322  	case x509.SM2WithSM3:
   323  		signatures = append(signatures, SM2WITHSM3)
   324  	case x509.ECDSAWithSHA256:
   325  		signatures = append(signatures, ECDSAWithP256AndSHA256)
   326  	case x509.ECDSAWithSHA384:
   327  		signatures = append(signatures, ECDSAWithP384AndSHA384)
   328  	case x509.ECDSAWithSHA512:
   329  		signatures = append(signatures, ECDSAWithP521AndSHA512)
   330  	case x509.ECDSAEXTWithSHA256:
   331  		signatures = append(signatures, ECDSAEXTWithP256AndSHA256)
   332  	case x509.ECDSAEXTWithSHA384:
   333  		signatures = append(signatures, ECDSAEXTWithP384AndSHA384)
   334  	case x509.ECDSAEXTWithSHA512:
   335  		signatures = append(signatures, ECDSAEXTWithP521AndSHA512)
   336  	}
   337  	if len(signatures) > 0 {
   338  		cert.SupportedSignatureAlgorithms = signatures
   339  	}
   340  	fmt.Println("cert.SupportedSignatureAlgorithms: %s", cert.SupportedSignatureAlgorithms)
   341  
   342  	// 将key的DER字节数组转为私钥
   343  	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
   344  	if err != nil {
   345  		return fail(err)
   346  	}
   347  	// ECDSA_EXT私钥特殊处理
   348  	if keyDERBlock.Type == "ECDSA_EXT PRIVATE KEY" {
   349  		if privKey, ok := cert.PrivateKey.(*ecdsa.PrivateKey); ok {
   350  			cert.PrivateKey = &ecdsa_ext.PrivateKey{
   351  				PrivateKey: *privKey,
   352  			}
   353  			fmt.Println("读取到ECDSA_EXT PRIVATE KEY,并转为ecdsa_ext.PrivateKey")
   354  			hasEcdsaExt := false
   355  			for _, algorithm := range cert.SupportedSignatureAlgorithms {
   356  				if algorithm == ECDSAEXTWithP256AndSHA256 ||
   357  					algorithm == ECDSAEXTWithP384AndSHA384 ||
   358  					algorithm == ECDSAEXTWithP521AndSHA512 {
   359  					hasEcdsaExt = true
   360  					break
   361  				}
   362  			}
   363  			if !hasEcdsaExt {
   364  				// 临时对应,解决SupportedSignatureAlgorithms在ecdsa_ext时可能不正确的问题
   365  				cert.SupportedSignatureAlgorithms = []SignatureScheme{ECDSAEXTWithP256AndSHA256}
   366  				fmt.Println("临时修改cert.SupportedSignatureAlgorithms为: %s", cert.SupportedSignatureAlgorithms)
   367  			}
   368  		} else if _, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey); ok {
   369  			// ok
   370  		} else {
   371  			return fail(errors.New("pem文件类型为`ECDSA_EXT PRIVATE KEY`, 但证书中的私钥类型不是*ecdsa.PrivateKey"))
   372  		}
   373  	}
   374  	// 检查私钥与证书中的公钥是否匹配
   375  	switch pub := x509Cert.PublicKey.(type) {
   376  	// 补充SM2分支
   377  	case *sm2.PublicKey:
   378  		priv, ok := cert.PrivateKey.(*sm2.PrivateKey)
   379  		if !ok {
   380  			return fail(errors.New("gmtls: private key type does not match public key type"))
   381  		}
   382  		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   383  			return fail(errors.New("gmtls: private key does not match public key"))
   384  		}
   385  	case *rsa.PublicKey:
   386  		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
   387  		if !ok {
   388  			return fail(errors.New("gmtls: private key type does not match public key type"))
   389  		}
   390  		if pub.N.Cmp(priv.N) != 0 {
   391  			return fail(errors.New("gmtls: private key does not match public key"))
   392  		}
   393  	case *ecdsa.PublicKey:
   394  		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
   395  		if !ok {
   396  			privExt, okExt := cert.PrivateKey.(*ecdsa_ext.PrivateKey)
   397  			if !okExt {
   398  				return fail(errors.New("gmtls: private key type does not match public key type"))
   399  			}
   400  			if pub.X.Cmp(privExt.X) != 0 || pub.Y.Cmp(privExt.Y) != 0 {
   401  				return fail(errors.New("gmtls: private key does not match public key"))
   402  			}
   403  		} else {
   404  			if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   405  				return fail(errors.New("gmtls: private key does not match public key"))
   406  			}
   407  		}
   408  	case *ecdsa_ext.PublicKey:
   409  		priv, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey)
   410  		if !ok {
   411  			return fail(errors.New("gmtls: private key type does not match public key type"))
   412  		}
   413  		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   414  			return fail(errors.New("gmtls: private key does not match public key"))
   415  		}
   416  	case ed25519.PublicKey:
   417  		priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
   418  		if !ok {
   419  			return fail(errors.New("gmtls: private key type does not match public key type"))
   420  		}
   421  		if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
   422  			return fail(errors.New("gmtls: private key does not match public key"))
   423  		}
   424  	default:
   425  		return fail(errors.New("gmtls: unknown public key algorithm"))
   426  	}
   427  
   428  	return cert, nil
   429  }
   430  
   431  // 将DER字节数组转为对应的私钥
   432  // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
   433  // PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys.
   434  // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
   435  func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
   436  	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
   437  		return key, nil
   438  	}
   439  	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
   440  		switch key := key.(type) {
   441  		// 添加SM2, ecdsa_ext
   442  		case *sm2.PrivateKey, *rsa.PrivateKey, *ecdsa.PrivateKey, *ecdsa_ext.PrivateKey, ed25519.PrivateKey:
   443  			return key, nil
   444  		default:
   445  			return nil, errors.New("gmtls: found unknown private key type in PKCS#8 wrapping")
   446  		}
   447  	}
   448  	if key, err := x509.ParseECPrivateKey(der); err == nil {
   449  		return key, nil
   450  	}
   451  
   452  	return nil, errors.New("gmtls: failed to parse private key")
   453  }
   454  
   455  // NewServerConfigByClientHello 根据客户端发出的ClientHello的协议与密码套件决定Server的证书链
   456  //
   457  //	当客户端支持tls1.3或gmssl,且客户端支持的密码套件包含 TLS_SM4_GCM_SM3 时,服务端证书采用gmSigCert。
   458  //	- gmSigCert 国密证书
   459  //	- genericCert 一般证书
   460  //
   461  //goland:noinspection GoUnusedExportedFunction
   462  func NewServerConfigByClientHello(gmSigCert, genericCert *Certificate) (*Config, error) {
   463  	// 根据ClientHelloInfo中支持的协议,返回服务端证书
   464  	fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) {
   465  		gmFlag := false
   466  		// 检查客户端支持的协议中是否包含TLS1.3或GMSSL
   467  		for _, v := range info.SupportedVersions {
   468  			if v == VersionGMSSL || v == VersionTLS13 {
   469  				for _, curveID := range info.SupportedCurves {
   470  					if curveID == Curve256Sm2 {
   471  						gmFlag = true
   472  						break
   473  					}
   474  				}
   475  				if gmFlag {
   476  					break
   477  				}
   478  				// 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3
   479  				for _, c := range info.CipherSuites {
   480  					if c == TLS_SM4_GCM_SM3 {
   481  						gmFlag = true
   482  						break
   483  					}
   484  				}
   485  				break
   486  			}
   487  		}
   488  		if gmFlag {
   489  			return gmSigCert, nil
   490  		} else {
   491  			return genericCert, nil
   492  		}
   493  	}
   494  
   495  	return &Config{
   496  		Certificates:   nil,
   497  		GetCertificate: fncGetSignCertKeypair,
   498  	}, nil
   499  }
   500  
   501  //func NewServerConfigByClientHelloCurve(certMap map[string]*Certificate) (*Config, error) {
   502  //	// 根据ClientHelloInfo中支持的协议,返回服务端证书
   503  //	fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) {
   504  //		//info.config.CurvePreferences
   505  //
   506  //		gmFlag := false
   507  //		// 检查客户端支持的协议中是否包含TLS1.3或GMSSL
   508  //		for _, v := range info.SupportedVersions {
   509  //			if v == VersionGMSSL || v == VersionTLS13 {
   510  //				// 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3
   511  //				for _, c := range info.CipherSuites {
   512  //					if c == TLS_SM4_GCM_SM3 {
   513  //						gmFlag = true
   514  //						break
   515  //					}
   516  //				}
   517  //				break
   518  //			}
   519  //		}
   520  //		if gmFlag {
   521  //			return gmSigCert, nil
   522  //		} else {
   523  //			return genericCert, nil
   524  //		}
   525  //	}
   526  //
   527  //	return &Config{
   528  //		Certificates:   nil,
   529  //		GetCertificate: fncGetSignCertKeypair,
   530  //	}, nil
   531  //}