gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmtls/tls.go (about) 1 // Copyright (c) 2022 zhaochun 2 // core-gm is licensed under Mulan PSL v2. 3 // You can use this software according to the terms and conditions of the Mulan PSL v2. 4 // You may obtain a copy of Mulan PSL v2 at: 5 // http://license.coscl.org.cn/MulanPSL2 6 // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 7 // See the Mulan PSL v2 for more details. 8 9 /* 10 gmtls是基于`golang/go`的`tls`包实现的国密改造版本。 11 对应版权声明: thrid_licenses/github.com/golang/go/LICENSE 12 */ 13 14 // Package gmtls partially implements TLS 1.2, as specified in RFC 5246, 15 // and TLS 1.3, as specified in RFC 8446. 16 package gmtls 17 18 // BUG(agl): The crypto/tls package only implements some countermeasures 19 // against Lucky13 attacks on CBC-mode encryption, and only on SHA1 20 // variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and 21 // https://www.imperialviolet.org/2013/02/04/luckythirteen.html. 22 23 import ( 24 "bytes" 25 "context" 26 "crypto" 27 "crypto/ecdsa" 28 "crypto/ed25519" 29 "crypto/rsa" 30 "encoding/pem" 31 "errors" 32 "fmt" 33 "gitee.com/ks-custle/core-gm/ecdsa_ext" 34 "net" 35 "os" 36 "strings" 37 38 "gitee.com/ks-custle/core-gm/sm2" 39 "gitee.com/ks-custle/core-gm/x509" 40 ) 41 42 // Server 生成tls通信Server 43 // Server returns a new TLS server side connection 44 // using conn as the underlying transport. 45 // The configuration config must be non-nil and must include 46 // at least one certificate or else set GetCertificate. 47 func Server(conn net.Conn, config *Config) *Conn { 48 c := &Conn{ 49 conn: conn, 50 config: config, 51 } 52 // 绑定握手函数 53 c.handshakeFn = c.serverHandshake 54 return c 55 } 56 57 // Client 生成tls通信Client 58 // Client returns a new TLS client side connection 59 // using conn as the underlying transport. 60 // The config cannot be nil: users must set either ServerName or 61 // InsecureSkipVerify in the config. 62 func Client(conn net.Conn, config *Config) *Conn { 63 c := &Conn{ 64 conn: conn, 65 config: config, 66 isClient: true, 67 } 68 // 绑定握手函数 69 c.handshakeFn = c.clientHandshake 70 return c 71 } 72 73 // A listener implements a network listener (net.Listener) for TLS connections. 74 type listener struct { 75 net.Listener 76 config *Config 77 } 78 79 // Accept waits for and returns the next incoming TLS connection. 80 // The returned connection is of type *Conn. 81 func (l *listener) Accept() (net.Conn, error) { 82 c, err := l.Listener.Accept() 83 if err != nil { 84 return nil, err 85 } 86 return Server(c, l.config), nil 87 } 88 89 // NewListener creates a Listener which accepts connections from an inner 90 // Listener and wraps each connection with Server. 91 // The configuration config must be non-nil and must include 92 // at least one certificate or else set GetCertificate. 93 func NewListener(inner net.Listener, config *Config) net.Listener { 94 l := new(listener) 95 l.Listener = inner 96 l.config = config 97 return l 98 } 99 100 // Listen creates a TLS listener accepting connections on the 101 // given network address using net.Listen. 102 // The configuration config must be non-nil and must include 103 // at least one certificate or else set GetCertificate. 104 func Listen(network, laddr string, config *Config) (net.Listener, error) { 105 if config == nil || len(config.Certificates) == 0 && 106 config.GetCertificate == nil && config.GetConfigForClient == nil { 107 return nil, errors.New("gmtls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") 108 } 109 l, err := net.Listen(network, laddr) 110 if err != nil { 111 return nil, err 112 } 113 return NewListener(l, config), nil 114 } 115 116 // type timeoutError struct{} 117 118 // func (timeoutError) Error() string { return "gmtls: DialWithDialer timed out" } 119 // func (timeoutError) Timeout() bool { return true } 120 // func (timeoutError) Temporary() bool { return true } 121 122 // DialWithDialer connects to the given network address using dialer.Dial and 123 // then initiates a TLS handshake, returning the resulting TLS connection. Any 124 // timeout or deadline given in the dialer apply to connection and TLS 125 // handshake as a whole. 126 // 127 // DialWithDialer interprets a nil configuration as equivalent to the zero 128 // configuration; see the documentation of Config for the defaults. 129 // 130 // DialWithDialer uses context.Background internally; to specify the context, 131 // use Dialer.DialContext with NetDialer set to the desired dialer. 132 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { 133 return dial(context.Background(), dialer, network, addr, config) 134 } 135 136 // 客户端拨号,发起tls通信请求 137 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { 138 if netDialer.Timeout != 0 { 139 var cancel context.CancelFunc 140 ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) 141 defer cancel() 142 } 143 144 if !netDialer.Deadline.IsZero() { 145 var cancel context.CancelFunc 146 ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) 147 defer cancel() 148 } 149 150 rawConn, err := netDialer.DialContext(ctx, network, addr) 151 if err != nil { 152 return nil, err 153 } 154 155 colonPos := strings.LastIndex(addr, ":") 156 if colonPos == -1 { 157 colonPos = len(addr) 158 } 159 hostname := addr[:colonPos] 160 161 if config == nil { 162 config = defaultConfig() 163 } 164 // If no ServerName is set, infer the ServerName 165 // from the hostname we're connecting to. 166 if config.ServerName == "" { 167 // Make a copy to avoid polluting argument or default. 168 c := config.Clone() 169 c.ServerName = hostname 170 config = c 171 } 172 173 conn := Client(rawConn, config) 174 // 客户端发起tls握手 175 if err := conn.HandshakeContext(ctx); err != nil { 176 _ = rawConn.Close() 177 return nil, err 178 } 179 return conn, nil 180 } 181 182 // Dial connects to the given network address using net.Dial 183 // and then initiates a TLS handshake, returning the resulting 184 // TLS connection. 185 // Dial interprets a nil configuration as equivalent to 186 // the zero configuration; see the documentation of Config 187 // for the defaults. 188 func Dial(network, addr string, config *Config) (*Conn, error) { 189 return DialWithDialer(new(net.Dialer), network, addr, config) 190 } 191 192 // Dialer dials TLS connections given a configuration and a Dialer for the 193 // underlying connection. 194 type Dialer struct { 195 // NetDialer is the optional dialer to use for the TLS connections' 196 // underlying TCP connections. 197 // A nil NetDialer is equivalent to the net.Dialer zero value. 198 NetDialer *net.Dialer 199 200 // Config is the TLS configuration to use for new connections. 201 // A nil configuration is equivalent to the zero 202 // configuration; see the documentation of Config for the 203 // defaults. 204 Config *Config 205 } 206 207 // Dial connects to the given network address and initiates a TLS 208 // handshake, returning the resulting TLS connection. 209 // 210 // The returned Conn, if any, will always be of type *Conn. 211 // 212 // Dial uses context.Background internally; to specify the context, 213 // use DialContext. 214 func (d *Dialer) Dial(network, addr string) (net.Conn, error) { 215 return d.DialContext(context.Background(), network, addr) 216 } 217 218 func (d *Dialer) netDialer() *net.Dialer { 219 if d.NetDialer != nil { 220 return d.NetDialer 221 } 222 return new(net.Dialer) 223 } 224 225 // DialContext connects to the given network address and initiates a TLS 226 // handshake, returning the resulting TLS connection. 227 // 228 // The provided Context must be non-nil. If the context expires before 229 // the connection is complete, an error is returned. Once successfully 230 // connected, any expiration of the context will not affect the 231 // connection. 232 // 233 // The returned Conn, if any, will always be of type *Conn. 234 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 235 c, err := dial(ctx, d.netDialer(), network, addr, d.Config) 236 if err != nil { 237 // Don't return c (a typed nil) in an interface. 238 return nil, err 239 } 240 return c, nil 241 } 242 243 // LoadX509KeyPair reads and parses a public/private key pair from a pair 244 // of files. The files must contain PEM encoded data. The certificate file 245 // may contain intermediate certificates following the leaf certificate to 246 // form a certificate chain. On successful return, Certificate.Leaf will 247 // be nil because the parsed form of the certificate is not retained. 248 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { 249 certPEMBlock, err := os.ReadFile(certFile) 250 if err != nil { 251 return Certificate{}, err 252 } 253 keyPEMBlock, err := os.ReadFile(keyFile) 254 if err != nil { 255 return Certificate{}, err 256 } 257 return X509KeyPair(certPEMBlock, keyPEMBlock) 258 } 259 260 // X509KeyPair parses a public/private key pair from a pair of 261 // PEM encoded data. On successful return, Certificate.Leaf will be nil because 262 // the parsed form of the certificate is not retained. 263 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { 264 fail := func(err error) (Certificate, error) { return Certificate{}, err } 265 266 var cert Certificate 267 var skippedBlockTypes []string 268 for { 269 var certDERBlock *pem.Block 270 // 将证书PEM字节数组解码为DER字节数组 271 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) 272 if certDERBlock == nil { 273 break 274 } 275 if certDERBlock.Type == "CERTIFICATE" { 276 // 将证书DER字节数组加入证书链的证书列表 277 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) 278 } else { 279 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) 280 } 281 } 282 283 if len(cert.Certificate) == 0 { 284 if len(skippedBlockTypes) == 0 { 285 return fail(errors.New("gmtls: failed to find any PEM data in certificate input")) 286 } 287 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { 288 return fail(errors.New("gmtls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) 289 } 290 return fail(fmt.Errorf("gmtls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 291 } 292 293 skippedBlockTypes = skippedBlockTypes[:0] 294 var keyDERBlock *pem.Block 295 for { 296 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) 297 if keyDERBlock == nil { 298 if len(skippedBlockTypes) == 0 { 299 return fail(errors.New("gmtls: failed to find any PEM data in key input")) 300 } 301 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { 302 return fail(errors.New("gmtls: found a certificate rather than a key in the PEM for the private key")) 303 } 304 return fail(fmt.Errorf("gmtls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 305 } 306 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { 307 break 308 } 309 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) 310 } 311 // 读取证书链中的首个证书(子证书),转为x509.Certificate 312 // We don't need to parse the public key for TLS, but we so do anyway 313 // to check that it looks sane and matches the private key. 314 x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) 315 if err != nil { 316 return fail(err) 317 } 318 319 var signatures []SignatureScheme 320 fmt.Println("x509Cert.SignatureAlgorithm: %s", x509Cert.SignatureAlgorithm.String()) 321 switch x509Cert.SignatureAlgorithm { 322 case x509.SM2WithSM3: 323 signatures = append(signatures, SM2WITHSM3) 324 case x509.ECDSAWithSHA256: 325 signatures = append(signatures, ECDSAWithP256AndSHA256) 326 case x509.ECDSAWithSHA384: 327 signatures = append(signatures, ECDSAWithP384AndSHA384) 328 case x509.ECDSAWithSHA512: 329 signatures = append(signatures, ECDSAWithP521AndSHA512) 330 case x509.ECDSAEXTWithSHA256: 331 signatures = append(signatures, ECDSAEXTWithP256AndSHA256) 332 case x509.ECDSAEXTWithSHA384: 333 signatures = append(signatures, ECDSAEXTWithP384AndSHA384) 334 case x509.ECDSAEXTWithSHA512: 335 signatures = append(signatures, ECDSAEXTWithP521AndSHA512) 336 } 337 if len(signatures) > 0 { 338 cert.SupportedSignatureAlgorithms = signatures 339 } 340 fmt.Println("cert.SupportedSignatureAlgorithms: %s", cert.SupportedSignatureAlgorithms) 341 342 // 将key的DER字节数组转为私钥 343 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) 344 if err != nil { 345 return fail(err) 346 } 347 // ECDSA_EXT私钥特殊处理 348 if keyDERBlock.Type == "ECDSA_EXT PRIVATE KEY" { 349 if privKey, ok := cert.PrivateKey.(*ecdsa.PrivateKey); ok { 350 cert.PrivateKey = &ecdsa_ext.PrivateKey{ 351 PrivateKey: *privKey, 352 } 353 fmt.Println("读取到ECDSA_EXT PRIVATE KEY,并转为ecdsa_ext.PrivateKey") 354 hasEcdsaExt := false 355 for _, algorithm := range cert.SupportedSignatureAlgorithms { 356 if algorithm == ECDSAEXTWithP256AndSHA256 || 357 algorithm == ECDSAEXTWithP384AndSHA384 || 358 algorithm == ECDSAEXTWithP521AndSHA512 { 359 hasEcdsaExt = true 360 break 361 } 362 } 363 if !hasEcdsaExt { 364 // 临时对应,解决SupportedSignatureAlgorithms在ecdsa_ext时可能不正确的问题 365 cert.SupportedSignatureAlgorithms = []SignatureScheme{ECDSAEXTWithP256AndSHA256} 366 fmt.Println("临时修改cert.SupportedSignatureAlgorithms为: %s", cert.SupportedSignatureAlgorithms) 367 } 368 } else if _, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey); ok { 369 // ok 370 } else { 371 return fail(errors.New("pem文件类型为`ECDSA_EXT PRIVATE KEY`, 但证书中的私钥类型不是*ecdsa.PrivateKey")) 372 } 373 } 374 // 检查私钥与证书中的公钥是否匹配 375 switch pub := x509Cert.PublicKey.(type) { 376 // 补充SM2分支 377 case *sm2.PublicKey: 378 priv, ok := cert.PrivateKey.(*sm2.PrivateKey) 379 if !ok { 380 return fail(errors.New("gmtls: private key type does not match public key type")) 381 } 382 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 383 return fail(errors.New("gmtls: private key does not match public key")) 384 } 385 case *rsa.PublicKey: 386 priv, ok := cert.PrivateKey.(*rsa.PrivateKey) 387 if !ok { 388 return fail(errors.New("gmtls: private key type does not match public key type")) 389 } 390 if pub.N.Cmp(priv.N) != 0 { 391 return fail(errors.New("gmtls: private key does not match public key")) 392 } 393 case *ecdsa.PublicKey: 394 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) 395 if !ok { 396 privExt, okExt := cert.PrivateKey.(*ecdsa_ext.PrivateKey) 397 if !okExt { 398 return fail(errors.New("gmtls: private key type does not match public key type")) 399 } 400 if pub.X.Cmp(privExt.X) != 0 || pub.Y.Cmp(privExt.Y) != 0 { 401 return fail(errors.New("gmtls: private key does not match public key")) 402 } 403 } else { 404 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 405 return fail(errors.New("gmtls: private key does not match public key")) 406 } 407 } 408 case *ecdsa_ext.PublicKey: 409 priv, ok := cert.PrivateKey.(*ecdsa_ext.PrivateKey) 410 if !ok { 411 return fail(errors.New("gmtls: private key type does not match public key type")) 412 } 413 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 414 return fail(errors.New("gmtls: private key does not match public key")) 415 } 416 case ed25519.PublicKey: 417 priv, ok := cert.PrivateKey.(ed25519.PrivateKey) 418 if !ok { 419 return fail(errors.New("gmtls: private key type does not match public key type")) 420 } 421 if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { 422 return fail(errors.New("gmtls: private key does not match public key")) 423 } 424 default: 425 return fail(errors.New("gmtls: unknown public key algorithm")) 426 } 427 428 return cert, nil 429 } 430 431 // 将DER字节数组转为对应的私钥 432 // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates 433 // PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. 434 // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. 435 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { 436 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { 437 return key, nil 438 } 439 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { 440 switch key := key.(type) { 441 // 添加SM2, ecdsa_ext 442 case *sm2.PrivateKey, *rsa.PrivateKey, *ecdsa.PrivateKey, *ecdsa_ext.PrivateKey, ed25519.PrivateKey: 443 return key, nil 444 default: 445 return nil, errors.New("gmtls: found unknown private key type in PKCS#8 wrapping") 446 } 447 } 448 if key, err := x509.ParseECPrivateKey(der); err == nil { 449 return key, nil 450 } 451 452 return nil, errors.New("gmtls: failed to parse private key") 453 } 454 455 // NewServerConfigByClientHello 根据客户端发出的ClientHello的协议与密码套件决定Server的证书链 456 // 457 // 当客户端支持tls1.3或gmssl,且客户端支持的密码套件包含 TLS_SM4_GCM_SM3 时,服务端证书采用gmSigCert。 458 // - gmSigCert 国密证书 459 // - genericCert 一般证书 460 // 461 //goland:noinspection GoUnusedExportedFunction 462 func NewServerConfigByClientHello(gmSigCert, genericCert *Certificate) (*Config, error) { 463 // 根据ClientHelloInfo中支持的协议,返回服务端证书 464 fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) { 465 gmFlag := false 466 // 检查客户端支持的协议中是否包含TLS1.3或GMSSL 467 for _, v := range info.SupportedVersions { 468 if v == VersionGMSSL || v == VersionTLS13 { 469 for _, curveID := range info.SupportedCurves { 470 if curveID == Curve256Sm2 { 471 gmFlag = true 472 break 473 } 474 } 475 if gmFlag { 476 break 477 } 478 // 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3 479 for _, c := range info.CipherSuites { 480 if c == TLS_SM4_GCM_SM3 { 481 gmFlag = true 482 break 483 } 484 } 485 break 486 } 487 } 488 if gmFlag { 489 return gmSigCert, nil 490 } else { 491 return genericCert, nil 492 } 493 } 494 495 return &Config{ 496 Certificates: nil, 497 GetCertificate: fncGetSignCertKeypair, 498 }, nil 499 } 500 501 //func NewServerConfigByClientHelloCurve(certMap map[string]*Certificate) (*Config, error) { 502 // // 根据ClientHelloInfo中支持的协议,返回服务端证书 503 // fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) { 504 // //info.config.CurvePreferences 505 // 506 // gmFlag := false 507 // // 检查客户端支持的协议中是否包含TLS1.3或GMSSL 508 // for _, v := range info.SupportedVersions { 509 // if v == VersionGMSSL || v == VersionTLS13 { 510 // // 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3 511 // for _, c := range info.CipherSuites { 512 // if c == TLS_SM4_GCM_SM3 { 513 // gmFlag = true 514 // break 515 // } 516 // } 517 // break 518 // } 519 // } 520 // if gmFlag { 521 // return gmSigCert, nil 522 // } else { 523 // return genericCert, nil 524 // } 525 // } 526 // 527 // return &Config{ 528 // Certificates: nil, 529 // GetCertificate: fncGetSignCertKeypair, 530 // }, nil 531 //}