github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/tls.go (about) 1 /* 2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 16 // add sm2 support 17 package gmtls 18 19 import ( 20 "crypto" 21 "crypto/ecdsa" 22 "crypto/rsa" 23 "crypto/x509" 24 "encoding/pem" 25 "errors" 26 "fmt" 27 "io/ioutil" 28 "net" 29 "strings" 30 "time" 31 32 "github.com/Hyperledger-TWGC/tjfoc-gm/sm2" 33 X "github.com/Hyperledger-TWGC/tjfoc-gm/x509" 34 ) 35 36 // Server returns a new TLS server side connection 37 // using conn as the underlying transport. 38 // The configuration config must be non-nil and must include 39 // at least one certificate or else set GetCertificate. 40 func Server(conn net.Conn, config *Config) *Conn { 41 return &Conn{conn: conn, config: config} 42 } 43 44 // Client returns a new TLS client side connection 45 // using conn as the underlying transport. 46 // The config cannot be nil: users must set either ServerName or 47 // InsecureSkipVerify in the config. 48 func Client(conn net.Conn, config *Config) *Conn { 49 return &Conn{conn: conn, config: config, isClient: true} 50 } 51 52 // A listener implements a network listener (net.Listener) for TLS connections. 53 type listener struct { 54 net.Listener 55 config *Config 56 } 57 58 // Accept waits for and returns the next incoming TLS connection. 59 // The returned connection is of type *Conn. 60 func (l *listener) Accept() (net.Conn, error) { 61 c, err := l.Listener.Accept() 62 if err != nil { 63 return nil, err 64 } 65 return Server(c, l.config), nil 66 } 67 68 // NewListener creates a Listener which accepts connections from an inner 69 // Listener and wraps each connection with Server. 70 // The configuration config must be non-nil and must include 71 // at least one certificate or else set GetCertificate. 72 func NewListener(inner net.Listener, config *Config) net.Listener { 73 l := new(listener) 74 l.Listener = inner 75 l.config = config 76 return l 77 } 78 79 // Listen creates a TLS listener accepting connections on the 80 // given network address using net.Listen. 81 // The configuration config must be non-nil and must include 82 // at least one certificate or else set GetCertificate. 83 func Listen(network, laddr string, config *Config) (net.Listener, error) { 84 if config == nil || (len(config.Certificates) == 0 && config.GetCertificate == nil) { 85 return nil, errors.New("tls: neither Certificates nor GetCertificate set in Config") 86 } 87 l, err := net.Listen(network, laddr) 88 if err != nil { 89 return nil, err 90 } 91 return NewListener(l, config), nil 92 } 93 94 type timeoutError struct{} 95 96 func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } 97 func (timeoutError) Timeout() bool { return true } 98 func (timeoutError) Temporary() bool { return true } 99 100 // DialWithDialer connects to the given network address using dialer.Dial and 101 // then initiates a TLS handshake, returning the resulting TLS connection. Any 102 // timeout or deadline given in the dialer apply to connection and TLS 103 // handshake as a whole. 104 // 105 // DialWithDialer interprets a nil configuration as equivalent to the zero 106 // configuration; see the documentation of Config for the defaults. 107 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { 108 // We want the Timeout and Deadline values from dialer to cover the 109 // whole process: TCP connection and TLS handshake. This means that we 110 // also need to start our own timers now. 111 timeout := dialer.Timeout 112 113 if !dialer.Deadline.IsZero() { 114 // deadlineTimeout := time.Until(dialer.Deadline) 115 deadlineTimeout := dialer.Deadline.Sub(time.Now()) // support go before 1.8 116 if timeout == 0 || deadlineTimeout < timeout { 117 timeout = deadlineTimeout 118 } 119 } 120 121 var errChannel chan error 122 123 if timeout != 0 { 124 errChannel = make(chan error, 2) 125 time.AfterFunc(timeout, func() { 126 errChannel <- timeoutError{} 127 }) 128 } 129 130 rawConn, err := dialer.Dial(network, addr) 131 if err != nil { 132 return nil, err 133 } 134 135 colonPos := strings.LastIndex(addr, ":") 136 if colonPos == -1 { 137 colonPos = len(addr) 138 } 139 hostname := addr[:colonPos] 140 141 if config == nil { 142 config = defaultConfig() 143 } 144 // If no ServerName is set, infer the ServerName 145 // from the hostname we're connecting to. 146 if config.ServerName == "" { 147 // Make a copy to avoid polluting argument or default. 148 c := config.Clone() 149 c.ServerName = hostname 150 config = c 151 } 152 153 conn := Client(rawConn, config) 154 155 if timeout == 0 { 156 err = conn.Handshake() 157 } else { 158 go func() { 159 errChannel <- conn.Handshake() 160 }() 161 162 err = <-errChannel 163 } 164 165 if err != nil { 166 rawConn.Close() 167 return nil, err 168 } 169 170 return conn, nil 171 } 172 173 // Dial connects to the given network address using net.Dial 174 // and then initiates a TLS handshake, returning the resulting 175 // TLS connection. 176 // Dial interprets a nil configuration as equivalent to 177 // the zero configuration; see the documentation of Config 178 // for the defaults. 179 func Dial(network, addr string, config *Config) (*Conn, error) { 180 return DialWithDialer(new(net.Dialer), network, addr, config) 181 } 182 183 // LoadX509KeyPair reads and parses a public/private key pair from a pair 184 // of files. The files must contain PEM encoded data. The certificate file 185 // may contain intermediate certificates following the leaf certificate to 186 // form a certificate chain. On successful return, Certificate.Leaf will 187 // be nil because the parsed form of the certificate is not retained. 188 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) { 189 certPEMBlock, err := ioutil.ReadFile(certFile) 190 if err != nil { 191 return Certificate{}, err 192 } 193 keyPEMBlock, err := ioutil.ReadFile(keyFile) 194 if err != nil { 195 return Certificate{}, err 196 } 197 return X509KeyPair(certPEMBlock, keyPEMBlock) 198 } 199 200 // X509KeyPair parses a public/private key pair from a pair of 201 // PEM encoded data. On successful return, Certificate.Leaf will be nil because 202 // the parsed form of the certificate is not retained. 203 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { 204 fail := func(err error) (Certificate, error) { return Certificate{}, err } 205 206 var cert Certificate 207 var skippedBlockTypes []string 208 for { 209 var certDERBlock *pem.Block 210 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) 211 if certDERBlock == nil { 212 break 213 } 214 if certDERBlock.Type == "CERTIFICATE" { 215 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) 216 } else { 217 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type) 218 } 219 } 220 221 if len(cert.Certificate) == 0 { 222 if len(skippedBlockTypes) == 0 { 223 return fail(errors.New("tls: failed to find any PEM data in certificate input")) 224 } 225 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") { 226 return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched")) 227 } 228 return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 229 } 230 231 skippedBlockTypes = skippedBlockTypes[:0] 232 var keyDERBlock *pem.Block 233 for { 234 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) 235 if keyDERBlock == nil { 236 if len(skippedBlockTypes) == 0 { 237 return fail(errors.New("tls: failed to find any PEM data in key input")) 238 } 239 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" { 240 return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key")) 241 } 242 return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes)) 243 } 244 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { 245 break 246 } 247 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type) 248 } 249 250 var err error 251 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) 252 if err != nil { 253 return fail(err) 254 } 255 256 // We don't need to parse the public key for TLS, but we so do anyway 257 // to check that it looks sane and matches the private key. 258 x509Cert, err := X.ParseCertificate(cert.Certificate[0]) 259 if err != nil { 260 return fail(err) 261 } 262 263 switch pub := x509Cert.PublicKey.(type) { 264 case *rsa.PublicKey: 265 priv, ok := cert.PrivateKey.(*rsa.PrivateKey) 266 if !ok { 267 return fail(errors.New("tls: private key type does not match public key type")) 268 } 269 if pub.N.Cmp(priv.N) != 0 { 270 return fail(errors.New("tls: private key does not match public key")) 271 } 272 case *ecdsa.PublicKey: 273 pub, _ = x509Cert.PublicKey.(*ecdsa.PublicKey) 274 switch pub.Curve { 275 case sm2.P256Sm2(): 276 priv, ok := cert.PrivateKey.(*sm2.PrivateKey) 277 if !ok { 278 return fail(errors.New("tls: sm2 private key type does not match public key type")) 279 } 280 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 281 return fail(errors.New("tls: sm2 private key does not match public key")) 282 } 283 default: 284 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) 285 if !ok { 286 return fail(errors.New("tls: private key type does not match public key type")) 287 } 288 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 289 return fail(errors.New("tls: private key does not match public key")) 290 } 291 } 292 default: 293 return fail(errors.New("tls: unknown public key algorithm")) 294 } 295 return cert, nil 296 } 297 298 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { 299 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { 300 return key, nil 301 } 302 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { 303 switch key := key.(type) { 304 case *rsa.PrivateKey, *ecdsa.PrivateKey: 305 return key, nil 306 default: 307 return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping") 308 } 309 } 310 if key, err := X.ParsePKCS8UnecryptedPrivateKey(der); err == nil { 311 return key, nil 312 } 313 return nil, errors.New("tls: failed to parse private key") 314 }