github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/gmtls/tls.go (about) 1 // Copyright 2022 s1ren@github.com/hxx258456. 2 3 /* 4 gmtls是基于`golang/go`的`tls`包实现的国密改造版本。 5 对应版权声明: thrid_licenses/github.com/golang/go/LICENSE 6 */ 7 8 // Package gmtls partially implements TLS 1.2, as specified in RFC 5246, 9 // and TLS 1.3, as specified in RFC 8446. 10 package gmtls 11 12 // BUG(agl): The crypto/tls package only implements some countermeasures 13 // against Lucky13 attacks on CBC-mode encryption, and only on SHA1 14 // variants. See http://www.isg.rhul.ac.uk/tls/TLStiming.pdf and 15 // https://www.imperialviolet.org/2013/02/04/luckythirteen.html. 16 17 import ( 18 "bytes" 19 "context" 20 "crypto" 21 "crypto/ecdsa" 22 "crypto/ed25519" 23 "crypto/rsa" 24 "encoding/pem" 25 "errors" 26 "fmt" 27 "net" 28 "os" 29 "strings" 30 31 "github.com/hxx258456/ccgo/sm2" 32 "github.com/hxx258456/ccgo/x509" 33 ) 34 35 // Server 生成tls通信Server 36 // Server returns a new TLS server side connection 37 // using conn as the underlying transport. 38 // The configuration config must be non-nil and must include 39 // at least one certificate or else set GetCertificate. 40 func Server(conn net.Conn, config *Config) *Conn { 41 c := &Conn{ 42 conn: conn, 43 config: config, 44 } 45 // 绑定握手函数 46 c.handshakeFn = c.serverHandshake 47 return c 48 } 49 50 // Client 生成tls通信Client 51 // Client returns a new TLS client side connection 52 // using conn as the underlying transport. 53 // The config cannot be nil: users must set either ServerName or 54 // InsecureSkipVerify in the config. 55 func Client(conn net.Conn, config *Config) *Conn { 56 c := &Conn{ 57 conn: conn, 58 config: config, 59 isClient: true, 60 } 61 // 绑定握手函数 62 c.handshakeFn = c.clientHandshake 63 return c 64 } 65 66 // A listener implements a network listener (net.Listener) for TLS connections. 67 type listener struct { 68 net.Listener 69 config *Config 70 } 71 72 // Accept waits for and returns the next incoming TLS connection. 73 // The returned connection is of type *Conn. 74 func (l *listener) Accept() (net.Conn, error) { 75 c, err := l.Listener.Accept() 76 if err != nil { 77 return nil, err 78 } 79 return Server(c, l.config), nil 80 } 81 82 // NewListener creates a Listener which accepts connections from an inner 83 // Listener and wraps each connection with Server. 84 // The configuration config must be non-nil and must include 85 // at least one certificate or else set GetCertificate. 86 func NewListener(inner net.Listener, config *Config) net.Listener { 87 l := new(listener) 88 l.Listener = inner 89 l.config = config 90 return l 91 } 92 93 // Listen creates a TLS listener accepting connections on the 94 // given network address using net.Listen. 95 // The configuration config must be non-nil and must include 96 // at least one certificate or else set GetCertificate. 97 func Listen(network, laddr string, config *Config) (net.Listener, error) { 98 if config == nil || len(config.Certificates) == 0 && 99 config.GetCertificate == nil && config.GetConfigForClient == nil { 100 return nil, errors.New("gmtls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config") 101 } 102 l, err := net.Listen(network, laddr) 103 if err != nil { 104 return nil, err 105 } 106 return NewListener(l, config), nil 107 } 108 109 // type timeoutError struct{} 110 111 // func (timeoutError) Error() string { return "gmtls: DialWithDialer timed out" } 112 // func (timeoutError) Timeout() bool { return true } 113 // func (timeoutError) Temporary() bool { return true } 114 115 // DialWithDialer connects to the given network address using dialer.Dial and 116 // then initiates a TLS handshake, returning the resulting TLS connection. Any 117 // timeout or deadline given in the dialer apply to connection and TLS 118 // handshake as a whole. 119 // 120 // DialWithDialer interprets a nil configuration as equivalent to the zero 121 // configuration; see the documentation of Config for the defaults. 122 // 123 // DialWithDialer uses context.Background internally; to specify the context, 124 // use Dialer.DialContext with NetDialer set to the desired dialer. 125 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { 126 return dial(context.Background(), dialer, network, addr, config) 127 } 128 129 // 客户端拨号,发起tls通信请求 130 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { 131 if netDialer.Timeout != 0 { 132 var cancel context.CancelFunc 133 ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) 134 defer cancel() 135 } 136 137 if !netDialer.Deadline.IsZero() { 138 var cancel context.CancelFunc 139 ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) 140 defer cancel() 141 } 142 143 rawConn, err := netDialer.DialContext(ctx, network, addr) 144 if err != nil { 145 return nil, err 146 } 147 148 colonPos := strings.LastIndex(addr, ":") 149 if colonPos == -1 { 150 colonPos = len(addr) 151 } 152 hostname := addr[:colonPos] 153 154 if config == nil { 155 config = defaultConfig() 156 } 157 // If no ServerName is set, infer the ServerName 158 // from the hostname we're connecting to. 159 if config.ServerName == "" { 160 // Make a copy to avoid polluting argument or default. 161 c := config.Clone() 162 c.ServerName = hostname 163 config = c 164 } 165 166 conn := Client(rawConn, config) 167 // 客户端发起tls握手 168 if err := conn.HandshakeContext(ctx); err != nil { 169 _ = rawConn.Close() 170 return nil, err 171 } 172 return conn, nil 173 } 174 175 // Dial connects to the given network address using net.Dial 176 // and then initiates a TLS handshake, returning the resulting 177 // TLS connection. 178 // Dial interprets a nil configuration as equivalent to 179 // the zero configuration; see the documentation of Config 180 // for the defaults. 181 func Dial(network, addr string, config *Config) (*Conn, error) { 182 return DialWithDialer(new(net.Dialer), network, addr, config) 183 } 184 185 // Dialer dials TLS connections given a configuration and a Dialer for the 186 // underlying connection. 187 type Dialer struct { 188 // NetDialer is the optional dialer to use for the TLS connections' 189 // underlying TCP connections. 190 // A nil NetDialer is equivalent to the net.Dialer zero value. 191 NetDialer *net.Dialer 192 193 // Config is the TLS configuration to use for new connections. 194 // A nil configuration is equivalent to the zero 195 // configuration; see the documentation of Config for the 196 // defaults. 197 Config *Config 198 } 199 200 // Dial connects to the given network address and initiates a TLS 201 // handshake, returning the resulting TLS connection. 202 // 203 // The returned Conn, if any, will always be of type *Conn. 204 // 205 // Dial uses context.Background internally; to specify the context, 206 // use DialContext. 207 func (d *Dialer) Dial(network, addr string) (net.Conn, error) { 208 return d.DialContext(context.Background(), network, addr) 209 } 210 211 func (d *Dialer) netDialer() *net.Dialer { 212 if d.NetDialer != nil { 213 return d.NetDialer 214 } 215 return new(net.Dialer) 216 } 217 218 // DialContext connects to the given network address and initiates a TLS 219 // handshake, returning the resulting TLS connection. 220 // 221 // The provided Context must be non-nil. If the context expires before 222 // the connection is complete, an error is returned. Once successfully 223 // connected, any expiration of the context will not affect the 224 // connection. 225 // 226 // The returned Conn, if any, will always be of type *Conn. 227 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { 228 c, err := dial(ctx, d.netDialer(), network, addr, d.Config) 229 if err != nil { 230 // Don't return c (a typed nil) in an interface. 231 return nil, err 232 } 233 return c, nil 234 } 235 236 // LoadX509KeyPair reads and parses a public/private key pair from a pair 237 // of files. The files must contain PEM encoded data. The certificate file 238 // may contain intermediate certificates following the leaf certificate to 239 // form a certificate chain. On successful return, Certificate.Leaf will 240 // be nil because the parsed form of the certificate is not retained. 241 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { 242 certPEMBlock, err := os.ReadFile(certFile) 243 if err != nil { 244 return Certificate{}, err 245 } 246 keyPEMBlock, err := os.ReadFile(keyFile) 247 if err != nil { 248 return Certificate{}, err 249 } 250 return X509KeyPair(certPEMBlock, keyPEMBlock) 251 } 252 253 // X509KeyPair parses a public/private key pair from a pair of 254 // PEM encoded data. On successful return, Certificate.Leaf will be nil because 255 // the parsed form of the certificate is not retained. 256 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { 257 fail := func(err error) (Certificate, error) { return Certificate{}, err } 258 259 var cert Certificate 260 var skippedBlockTypes []string 261 for { 262 var certDERBlock *pem.Block 263 // 将证书PEM字节数组解码为DER字节数组 264 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) 265 if certDERBlock == nil { 266 break 267 } 268 if certDERBlock.Type == "CERTIFICATE" { 269 // 将证书DER字节数组加入证书链的证书列表 270 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) 271 } else { 272 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) 273 } 274 } 275 276 if len(cert.Certificate) == 0 { 277 if len(skippedBlockTypes) == 0 { 278 return fail(errors.New("gmtls: failed to find any PEM data in certificate input")) 279 } 280 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { 281 return fail(errors.New("gmtls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) 282 } 283 return fail(fmt.Errorf("gmtls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 284 } 285 286 skippedBlockTypes = skippedBlockTypes[:0] 287 var keyDERBlock *pem.Block 288 for { 289 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) 290 if keyDERBlock == nil { 291 if len(skippedBlockTypes) == 0 { 292 return fail(errors.New("gmtls: failed to find any PEM data in key input")) 293 } 294 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { 295 return fail(errors.New("gmtls: found a certificate rather than a key in the PEM for the private key")) 296 } 297 return fail(fmt.Errorf("gmtls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 298 } 299 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { 300 break 301 } 302 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) 303 } 304 // 读取证书链中的首个证书(子证书),转为x509.Certificate 305 // We don't need to parse the public key for TLS, but we so do anyway 306 // to check that it looks sane and matches the private key. 307 x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) 308 if err != nil { 309 return fail(err) 310 } 311 // 将key的DER字节数组转为私钥 312 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) 313 if err != nil { 314 return fail(err) 315 } 316 // 检查私钥与证书中的公钥是否匹配 317 switch pub := x509Cert.PublicKey.(type) { 318 // 补充SM2分支 319 case *sm2.PublicKey: 320 priv, ok := cert.PrivateKey.(*sm2.PrivateKey) 321 if !ok { 322 return fail(errors.New("gmtls: private key type does not match public key type")) 323 } 324 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 325 return fail(errors.New("gmtls: private key does not match public key")) 326 } 327 case *rsa.PublicKey: 328 priv, ok := cert.PrivateKey.(*rsa.PrivateKey) 329 if !ok { 330 return fail(errors.New("gmtls: private key type does not match public key type")) 331 } 332 if pub.N.Cmp(priv.N) != 0 { 333 return fail(errors.New("gmtls: private key does not match public key")) 334 } 335 case *ecdsa.PublicKey: 336 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) 337 if !ok { 338 return fail(errors.New("gmtls: private key type does not match public key type")) 339 } 340 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 341 return fail(errors.New("gmtls: private key does not match public key")) 342 } 343 case ed25519.PublicKey: 344 priv, ok := cert.PrivateKey.(ed25519.PrivateKey) 345 if !ok { 346 return fail(errors.New("gmtls: private key type does not match public key type")) 347 } 348 if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { 349 return fail(errors.New("gmtls: private key does not match public key")) 350 } 351 default: 352 return fail(errors.New("gmtls: unknown public key algorithm")) 353 } 354 355 return cert, nil 356 } 357 358 // 将DER字节数组转为对应的私钥 359 // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates 360 // PKCS #1 private keys by default, while OpenSSL 1.0.0 generates PKCS #8 keys. 361 // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. 362 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { 363 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { 364 return key, nil 365 } 366 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { 367 switch key := key.(type) { 368 // 添加SM2 369 case *sm2.PrivateKey, *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: 370 return key, nil 371 default: 372 return nil, errors.New("gmtls: found unknown private key type in PKCS#8 wrapping") 373 } 374 } 375 if key, err := x509.ParseECPrivateKey(der); err == nil { 376 return key, nil 377 } 378 379 return nil, errors.New("gmtls: failed to parse private key") 380 } 381 382 // NewServerConfigByClientHello 根据客户端发出的ClientHello的协议与密码套件决定Server的证书链 383 // 当客户端支持tls1.3或gmssl,且客户端支持的密码套件包含 TLS_SM4_GCM_SM3 时,服务端证书采用gmSigCert。 384 // - gmSigCert 国密证书链 385 // - genericCert 一般证书链 386 func NewServerConfigByClientHello(gmSigCert, genericCert *Certificate) (*Config, error) { 387 // 根据ClientHelloInfo中支持的协议,返回服务端证书 388 fncGetSignCertKeypair := func(info *ClientHelloInfo) (*Certificate, error) { 389 gmFlag := false 390 // 检查客户端支持的协议中是否包含TLS1.3或GMSSL 391 for _, v := range info.SupportedVersions { 392 if v == VersionGMSSL || v == VersionTLS13 { 393 // 检查客户端支持的密码套件是否包含 TLS_SM4_GCM_SM3 394 for _, c := range info.CipherSuites { 395 if c == TLS_SM4_GCM_SM3 { 396 gmFlag = true 397 break 398 } 399 } 400 break 401 } 402 } 403 if gmFlag { 404 return gmSigCert, nil 405 } else { 406 return genericCert, nil 407 } 408 } 409 410 return &Config{ 411 Certificates: nil, 412 GetCertificate: fncGetSignCertKeypair, 413 }, nil 414 }