gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/gmtls/tls.go (about) 1 // Copyright (c) 2022 zhaochun 2 // gmgo is licensed under Mulan PSL v2. 3 // You can use this software according to the terms and conditions of the Mulan PSL v2. 4 // You may obtain a copy of Mulan PSL v2 at: 5 // http://license.coscl.org.cn/MulanPSL2 6 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 7 // See the Mulan PSL v2 for more details. 8 9 /* 10 gmtls是基于`golang/go`的`tls`包实现的国密改造版本。 11 对应版权声明: thrid_licenses/github.com/golang/go/LICENSE 12 */ 13 14 // Package gmtls partially implements TLS 1.2, as specified in RFC 5246, 15 // and TLS 1.3, as specified in RFC 8446. 16 package gmtls 17 18 // BUG(agl): The crypto/tls package only implements some countermeasures 19 // against Lucky13 attacks on CBC-mode encryption, and only on SHA1 20 // variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and 21 // https://www.imperialviolet.org/2013/02/04/luckythirteen.html. 22 23 import ( 24 "bytes" 25 "context" 26 "crypto" 27 "crypto/ecdsa" 28 "crypto/ed25519" 29 "crypto/rsa" 30 "encoding/pem" 31 "errors" 32 "fmt" 33 "gitee.com/zhaochuninhefei/gmgo/ecdsa_ext" 34 "gitee.com/zhaochuninhefei/zcgolog/zclog" 35 "net" 36 "os" 37 "strings" 38 39 "gitee.com/zhaochuninhefei/gmgo/sm2" 40 "gitee.com/zhaochuninhefei/gmgo/x509" 41 ) 42 43 // Server 生成tls通信Server 44 // Server returns a new TLS server side connection 45 // using conn as the underlying transport. 46 // The configuration config must be non-nil and must include 47 // at least one certificate or else set GetCertificate. 48 func Server(conn net.Conn, config *Config) *Conn { 49 c := &Conn{ 50 conn: conn, 51 config: config, 52 } 53 // 绑定握手函数 54 c.handshakeFn = c.serverHandshake 55 return c 56 } 57 58 // Client 生成tls通信Client 59 // Client returns a new TLS client side connection 60 // using conn as the underlying transport. 61 // The config cannot be nil: users must set either ServerName or 62 // InsecureSkipVerify in the config. 63 func Client(conn net.Conn, config *Config) *Conn { 64 c := &Conn{ 65 conn: conn, 66 config: config, 67 isClient: true, 68 } 69 // 绑定握手函数 70 c.handshakeFn = c.clientHandshake 71 return c 72 } 73 74 // A listener implements a network listener (net.Listener) for TLS connections. 75 type listener struct { 76 net.Listener 77 config *Config 78 } 79 80 // Accept waits for and returns the next incoming TLS connection. 81 // The returned connection is of type *Conn. 82 func (l *listener) Accept() (net.Conn, error) { 83 c, err := l.Listener.Accept() 84 if err != nil { 85 return nil, err 86 } 87 return Server(c, l.config), nil 88 } 89 90 // NewListener creates a Listener which accepts connections from an inner 91 // Listener and wraps each connection with Server. 92 // The configuration config must be non-nil and must include 93 // at least one certificate or else set GetCertificate. 94 func NewListener(inner net.Listener, config *Config) net.Listener { 95 l := new(listener) 96 l.Listener = inner 97 l.config = config 98 return l 99 } 100 101 // Listen creates a TLS listener accepting connections on the 102 // given network address using net.Listen. 103 // The configuration config must be non-nil and must include 104 // at least one certificate or else set GetCertificate. 105 func Listen(network, laddr string, config *Config) (net.Listener, error) { 106 if config == nil || len(config.Certificates) == 0 && 107 config.GetCertificate == nil && config.GetConfigForClient == nil { 108 return nil, errors.New("gmtls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") 109 } 110 l, err := net.Listen(network, laddr) 111 if err != nil { 112 return nil, err 113 } 114 return NewListener(l, config), nil 115 } 116 117 // type timeoutError struct{} 118 119 // func (timeoutError) Error() string { return "gmtls: DialWithDialer timed out" } 120 // func (timeoutError) Timeout() bool { return true } 121 // func (timeoutError) Temporary() bool { return true } 122 123 // DialWithDialer connects to the given network address using dialer.Dial and 124 // then initiates a TLS handshake, returning the resulting TLS connection. Any 125 // timeout or deadline given in the dialer apply to connection and TLS 126 // handshake as a whole. 127 // 128 // DialWithDialer interprets a nil configuration as equivalent to the zero 129 // configuration; see the documentation of Config for the defaults. 130 // 131 // DialWithDialer uses context.Background internally; to specify the context, 132 // use Dialer.DialContext with NetDialer set to the desired dialer. 133 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { 134 return dial(context.Background(), dialer, network, addr, config) 135 } 136 137 // 客户端拨号,发起tls通信请求 138 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { 139 if netDialer.Timeout != 0 { 140 var cancel context.CancelFunc 141 ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) 142 defer cancel() 143 } 144 145 if !netDialer.Deadline.IsZero() { 146 var cancel context.CancelFunc 147 ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) 148 defer cancel() 149 } 150 151 rawConn, err := netDialer.DialContext(ctx, network, addr) 152 if err != nil { 153 return nil, err 154 } 155 156 colonPos := strings.LastIndex(addr, ":") 157 if colonPos == -1 { 158 colonPos = len(addr) 159 } 160 hostname := addr[:colonPos] 161 162 if config == nil { 163 config = defaultConfig() 164 } 165 // If no ServerName is set, infer the ServerName 166 // from the hostname we're connecting to. 167 if config.ServerName == "" { 168 // Make a copy to avoid polluting argument or default. 169 c := config.Clone() 170 c.ServerName = hostname 171 config = c 172 } 173 174 conn := Client(rawConn, config) 175 // 客户端发起tls握手 176 if err := conn.HandshakeContext(ctx); err != nil { 177 _ = rawConn.Close() 178 return nil, err 179 } 180 return conn, nil 181 } 182 183 // Dial connects to the given network address using net.Dial 184 // and then initiates a TLS handshake, returning the resulting 185 // TLS connection. 186 // Dial interprets a nil configuration as equivalent to 187 // the zero configuration; see the documentation of Config 188 // for the defaults. 189 func Dial(network, addr string, config *Config) (*Conn, error) { 190 return DialWithDialer(new(net.Dialer), network, addr, config) 191 } 192 193 // Dialer dials TLS connections given a configuration and a Dialer for the 194 // underlying connection. 195 type Dialer struct { 196 // NetDialer is the optional dialer to use for the TLS connections' 197 // underlying TCP connections. 198 // A nil NetDialer is equivalent to the net.Dialer zero value. 199 NetDialer *net.Dialer 200 201 // Config is the TLS configuration to use for new connections. 202 // A nil configuration is equivalent to the zero 203 // configuration; see the documentation of Config for the 204 // defaults. 205 Config *Config 206 } 207 208 // Dial connects to the given network address and initiates a TLS 209 // handshake, returning the resulting TLS connection. 210 // 211 // The returned Conn, if any, will always be of type *Conn. 212 // 213 // Dial uses context.Background internally; to specify the context, 214 // use DialContext. 215 func (d *Dialer) Dial(network, addr string) (net.Conn, error) { 216 return d.DialContext(context.Background(), network, addr) 217 } 218 219 func (d *Dialer) netDialer() *net.Dialer { 220 if d.NetDialer != nil { 221 return d.NetDialer 222 } 223 return new(net.Dialer) 224 } 225 226 // DialContext connects to the given network address and initiates a TLS 227 // handshake, returning the resulting TLS connection. 228 // 229 // The provided Context must be non-nil. If the context expires before 230 // the connection is complete, an error is returned. Once successfully 231 // connected, any expiration of the context will not affect the 232 // connection. 233 // 234 // The returned Conn, if any, will always be of type *Conn. 235 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 236 c, err := dial(ctx, d.netDialer(), network, addr, d.Config) 237 if err != nil { 238 // Don't return c (a typed nil) in an interface. 239 return nil, err 240 } 241 return c, nil 242 } 243 244 // LoadX509KeyPair reads and parses a public/private key pair from a pair 245 // of files. The files must contain PEM encoded data. The certificate file 246 // may contain intermediate certificates following the leaf certificate to 247 // form a certificate chain. On successful return, Certificate.Leaf will 248 // be nil because the parsed form of the certificate is not retained. 249 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { 250 certPEMBlock, err := os.ReadFile(certFile) 251 if err != nil { 252 return Certificate{}, err 253 } 254 keyPEMBlock, err := os.ReadFile(keyFile) 255 if err != nil { 256 return Certificate{}, err 257 } 258 return X509KeyPair(certPEMBlock, keyPEMBlock) 259 } 260 261 // X509KeyPair parses a public/private key pair from a pair of 262 // PEM encoded data. On successful return, Certificate.Leaf will be nil because 263 // the parsed form of the certificate is not retained. 264 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { 265 fail := func(err error) (Certificate, error) { return Certificate{}, err } 266 267 var cert Certificate 268 var skippedBlockTypes []string 269 for { 270 var certDERBlock *pem.Block 271 // 将证书PEM字节数组解码为DER字节数组 272 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) 273 if certDERBlock == nil { 274 break 275 } 276 if certDERBlock.Type == "CERTIFICATE" { 277 // 将证书DER字节数组加入证书链的证书列表 278 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) 279 } else { 280 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) 281 } 282 } 283 284 if len(cert.Certificate) == 0 { 285 if len(skippedBlockTypes) == 0 { 286 return fail(errors.New("gmtls: failed to find any PEM data in certificate input")) 287 } 288 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { 289 return fail(errors.New("gmtls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) 290 } 291 return fail(fmt.Errorf("gmtls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 292 } 293 294 skippedBlockTypes = skippedBlockTypes[:0] 295 var keyDERBlock *pem.Block 296 for { 297 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) 298 if keyDERBlock == nil { 299 if len(skippedBlockTypes) == 0 { 300 return fail(errors.New("gmtls: failed to find any PEM data in key input")) 301 } 302 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { 303 return fail(errors.New("gmtls: found a certificate rather than a key in the PEM for the private key")) 304 } 305 return fail(fmt.Errorf("gmtls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 306 } 307 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { 308 break 309 } 310 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) 311 } 312 // 读取证书链中的首个证书(子证书),转为x509.Certificate 313 // We don't need to parse the public key for TLS, but we so do anyway 314 // to check that it looks sane and matches the private key. 315 x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) 316 if err != nil { 317 return fail(err) 318 } 319 320 var signatures []SignatureScheme 321 zclog.Debugf("x509Cert.SignatureAlgorithm: %s", x509Cert.SignatureAlgorithm.String()) 322 switch x509Cert.SignatureAlgorithm { 323 case x509.SM2WithSM3: 324 signatures = append(signatures, SM2WITHSM3) 325 case x509.ECDSAWithSHA256: 326 signatures = append(signatures, ECDSAWithP256AndSHA256) 327 case x509.ECDSAWithSHA384: 328 signatures = append(signatures, ECDSAWithP384AndSHA384) 329 case x509.ECDSAWithSHA512: 330 signatures = append(signatures, ECDSAWithP521AndSHA512) 331 case x509.ECDSAEXTWithSHA256: 332 signatures = append(signatures, ECDSAEXTWithP256AndSHA256) 333 case x509.ECDSAEXTWithSHA384: 334 signatures = append(signatures, ECDSAEXTWithP384AndSHA384) 335 case x509.ECDSAEXTWithSHA512: 336 signatures = append(signatures, ECDSAEXTWithP521AndSHA512) 337 } 338 if len(signatures) > 0 { 339 cert.SupportedSignatureAlgorithms = signatures 340 } 341 zclog.Debugf("cert.SupportedSignatureAlgorithms: %s", cert.SupportedSignatureAlgorithms) 342 343 // 将key的DER字节数组转为私钥 344 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) 345 if err != nil { 346 return fail(err) 347 } 348 // ECDSA_EXT私钥特殊处理 349 if keyDERBlock.Type == "ECDSA_EXT PRIVATE KEY" { 350 if privKey, ok := cert.PrivateKey.(*ecdsa.PrivateKey); ok { 351 cert.PrivateKey = &ecdsa_ext.PrivateKey{ 352 PrivateKey: *privKey, 353 } 354 zclog.Debugln("读取到ECDSA_EXT PRIVATE KEY,并转为ecdsa_ext.PrivateKey") 355 hasEcdsaExt := false 356 for _, algorithm := range cert.SupportedSignatureAlgorithms { 357 if algorithm == ECDSAEXTWithP256AndSHA256 || 358 algorithm == ECDSAEXTWithP384AndSHA384 || 359 algorithm == ECDSAEXTWithP521AndSHA512 { 360 hasEcdsaExt = true 361 break 362 } 363 } 364 if !hasEcdsaExt { 365 // 临时对应,解决SupportedSignatureAlgorithms在ecdsa_ext时可能不正确的问题 366 cert.SupportedSignatureAlgorithms = []SignatureScheme{ECDSAEXTWithP256AndSHA256} 367 zclog.Debugf("临时修改cert.SupportedSignatureAlgorithms为: %s", cert.SupportedSignatureAlgorithms) 368 } 369 } else if _, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey); ok { 370 // ok 371 } else { 372 return fail(errors.New("pem文件类型为`ECDSA_EXT PRIVATE KEY`, 但证书中的私钥类型不是*ecdsa.PrivateKey")) 373 } 374 } 375 // 检查私钥与证书中的公钥是否匹配 376 switch pub := x509Cert.PublicKey.(type) { 377 // 补充SM2分支 378 case *sm2.PublicKey: 379 priv, ok := cert.PrivateKey.(*sm2.PrivateKey) 380 if !ok { 381 return fail(errors.New("gmtls: private key type does not match public key type")) 382 } 383 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 384 return fail(errors.New("gmtls: private key does not match public key")) 385 } 386 case *rsa.PublicKey: 387 priv, ok := cert.PrivateKey.(*rsa.PrivateKey) 388 if !ok { 389 return fail(errors.New("gmtls: private key type does not match public key type")) 390 } 391 if pub.N.Cmp(priv.N) != 0 { 392 return fail(errors.New("gmtls: private key does not match public key")) 393 } 394 case *ecdsa.PublicKey: 395 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) 396 if !ok { 397 privExt, okExt := cert.PrivateKey.(*ecdsa_ext.PrivateKey) 398 if !okExt { 399 return fail(errors.New("gmtls: private key type does not match public key type")) 400 } 401 if pub.X.Cmp(privExt.X) != 0 || pub.Y.Cmp(privExt.Y) != 0 { 402 return fail(errors.New("gmtls: private key does not match public key")) 403 } 404 } else { 405 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 406 return fail(errors.New("gmtls: private key does not match public key")) 407 } 408 } 409 case *ecdsa_ext.PublicKey: 410 priv, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey) 411 if !ok { 412 return fail(errors.New("gmtls: private key type does not match public key type")) 413 } 414 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 415 return fail(errors.New("gmtls: private key does not match public key")) 416 } 417 case ed25519.PublicKey: 418 priv, ok := cert.PrivateKey.(ed25519.PrivateKey) 419 if !ok { 420 return fail(errors.New("gmtls: private key type does not match public key type")) 421 } 422 if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { 423 return fail(errors.New("gmtls: private key does not match public key")) 424 } 425 default: 426 return fail(errors.New("gmtls: unknown public key algorithm")) 427 } 428 429 return cert, nil 430 } 431 432 // 将DER字节数组转为对应的私钥 433 // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates 434 // PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. 435 // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. 436 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { 437 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { 438 return key, nil 439 } 440 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { 441 switch key := key.(type) { 442 // 添加SM2, ecdsa_ext 443 case *sm2.PrivateKey, *rsa.PrivateKey, *ecdsa.PrivateKey, *ecdsa_ext.PrivateKey, ed25519.PrivateKey: 444 return key, nil 445 default: 446 return nil, errors.New("gmtls: found unknown private key type in PKCS#8 wrapping") 447 } 448 } 449 if key, err := x509.ParseECPrivateKey(der); err == nil { 450 return key, nil 451 } 452 453 return nil, errors.New("gmtls: failed to parse private key") 454 } 455 456 // NewServerConfigByClientHello 根据客户端发出的ClientHello的协议与密码套件决定Server的证书链 457 // 当客户端支持tls1.3或gmssl,且客户端支持的密码套件包含 TLS_SM4_GCM_SM3 时,服务端证书采用gmSigCert。 458 // - gmSigCert 国密证书 459 // - genericCert 一般证书 460 //goland:noinspection GoUnusedExportedFunction 461 func NewServerConfigByClientHello(gmSigCert, genericCert *Certificate) (*Config, error) { 462 // 根据ClientHelloInfo中支持的协议,返回服务端证书 463 fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) { 464 gmFlag := false 465 // 检查客户端支持的协议中是否包含TLS1.3或GMSSL 466 for _, v := range info.SupportedVersions { 467 if v == VersionGMSSL || v == VersionTLS13 { 468 for _, curveID := range info.SupportedCurves { 469 if curveID == Curve256Sm2 { 470 gmFlag = true 471 break 472 } 473 } 474 if gmFlag { 475 break 476 } 477 // 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3 478 for _, c := range info.CipherSuites { 479 if c == TLS_SM4_GCM_SM3 { 480 gmFlag = true 481 break 482 } 483 } 484 break 485 } 486 } 487 if gmFlag { 488 return gmSigCert, nil 489 } else { 490 return genericCert, nil 491 } 492 } 493 494 return &Config{ 495 Certificates: nil, 496 GetCertificate: fncGetSignCertKeypair, 497 }, nil 498 } 499 500 //func NewServerConfigByClientHelloCurve(certMap map[string]*Certificate) (*Config, error) { 501 // // 根据ClientHelloInfo中支持的协议,返回服务端证书 502 // fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) { 503 // //info.config.CurvePreferences 504 // 505 // gmFlag := false 506 // // 检查客户端支持的协议中是否包含TLS1.3或GMSSL 507 // for _, v := range info.SupportedVersions { 508 // if v == VersionGMSSL || v == VersionTLS13 { 509 // // 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3 510 // for _, c := range info.CipherSuites { 511 // if c == TLS_SM4_GCM_SM3 { 512 // gmFlag = true 513 // break 514 // } 515 // } 516 // break 517 // } 518 // } 519 // if gmFlag { 520 // return gmSigCert, nil 521 // } else { 522 // return genericCert, nil 523 // } 524 // } 525 // 526 // return &Config{ 527 // Certificates: nil, 528 // GetCertificate: fncGetSignCertKeypair, 529 // }, nil 530 //}