github.com/aacfactory/fns@v1.2.86-0.20240310083819-80d667fc0a17/transports/ssl/default.go (about) 1 /* 2 * Copyright 2023 Wang Min Xiang 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 */ 17 18 package ssl 19 20 import ( 21 "crypto/ecdsa" 22 "crypto/rsa" 23 "crypto/tls" 24 "crypto/x509" 25 "encoding/pem" 26 "fmt" 27 "github.com/aacfactory/afssl/gmsm/cfca" 28 "github.com/aacfactory/afssl/gmsm/sm2" 29 "github.com/aacfactory/afssl/gmsm/smx509" 30 "github.com/aacfactory/afssl/gmsm/tlcp" 31 "github.com/aacfactory/configures" 32 "github.com/aacfactory/errors" 33 "net" 34 "os" 35 "strings" 36 "time" 37 ) 38 39 type Keypair struct { 40 Cert string `json:"cert"` 41 Key string `json:"key"` 42 Password string `json:"password"` 43 } 44 45 type Keypairs []Keypair 46 47 func (kps Keypairs) Certificates() (tlcps []tlcp.Certificate, standards []tls.Certificate, err error) { 48 if len(kps) == 0 { 49 return 50 } 51 for _, keypair := range kps { 52 cert := strings.TrimSpace(keypair.Cert) 53 key := strings.TrimSpace(keypair.Key) 54 // key 55 if key == "" { 56 err = errors.Warning("fns: keypairs build certificates failed").WithCause(fmt.Errorf("key is undefined")) 57 return 58 } 59 var keyPEM []byte 60 if strings.IndexAny(key, "-----BEGIN") < 0 { 61 keyPEM, err = os.ReadFile(key) 62 if err != nil { 63 err = errors.Warning("fns: keypairs build certificates failed").WithCause(err) 64 return 65 } 66 } else { 67 keyPEM = []byte(key) 68 } 69 keyBlock, _ := pem.Decode(keyPEM) 70 if keyBlock.Type == "CFCA" { 71 password := strings.TrimSpace(keypair.Password) 72 if password == "" { 73 err = errors.Warning("fns: keypairs build certificates failed").WithCause(fmt.Errorf("password is undefined")) 74 return 75 } 76 pass, readPassErr := os.ReadFile(password) 77 if readPassErr != nil { 78 if !os.IsNotExist(readPassErr) { 79 err = errors.Warning("fns: keypairs build certificates failed").WithCause(readPassErr) 80 return 81 } 82 pass = []byte(password) 83 } 84 cfcaCert, cfcaKey, cfcaErr := cfca.Parse(keyPEM, pass) 85 if cfcaErr != nil { 86 err = errors.Warning("fns: keypairs build certificates failed").WithCause(cfcaErr) 87 return 88 } 89 if cert != "" { 90 var certPEM []byte 91 if strings.IndexAny(cert, "-----BEGIN") < 0 { 92 certPEM, err = os.ReadFile(cert) 93 if err != nil { 94 err = errors.Warning("fns: keypairs build certificates failed").WithCause(err) 95 return 96 } 97 } else { 98 certPEM = []byte(cert) 99 } 100 certBlock, _ := pem.Decode(certPEM) 101 if certBlock == nil { 102 err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("x509: failed to decode PEM block containing certificate")) 103 return 104 } 105 rootCert, rootCertErr := smx509.ParseCertificate(certBlock.Bytes) 106 if rootCertErr != nil { 107 err = errors.Warning("fns: keypairs build certificates failed").WithCause(rootCertErr) 108 return 109 } 110 checkSignatureErr := rootCert.CheckSignature(smx509.SignatureAlgorithm(cfcaCert.SignatureAlgorithm), cfcaCert.RawTBSCertificate, cfcaCert.Signature) 111 if checkSignatureErr != nil { 112 err = errors.Warning("fns: keypairs build certificates failed").WithCause(checkSignatureErr) 113 return 114 } 115 } 116 certificate := tlcp.Certificate{ 117 Certificate: [][]byte{keyBlock.Bytes}, 118 PrivateKey: cfcaKey, 119 Leaf: cfcaCert, 120 } 121 switch pub := cfcaCert.PublicKey.(type) { 122 case *rsa.PublicKey: 123 priv, ok := certificate.PrivateKey.(*rsa.PrivateKey) 124 if !ok { 125 err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key type does not match public key type")) 126 return 127 } 128 if pub.N.Cmp(priv.N) != 0 { 129 err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key does not match public key")) 130 return 131 } 132 case *ecdsa.PublicKey: 133 priv, ok := certificate.PrivateKey.(*sm2.PrivateKey) 134 if !ok { 135 err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key type does not match public key type")) 136 return 137 } 138 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { 139 err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key does not match public key")) 140 return 141 } 142 default: 143 err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: unknown public key algorithm")) 144 return 145 } 146 if tlcps == nil { 147 tlcps = make([]tlcp.Certificate, 0, 1) 148 } 149 tlcps = append(tlcps, certificate) 150 continue 151 } 152 keyType, getKeyTypeErr := smx509.GetGMPrivateKeyType(keyBlock.Bytes) 153 if getKeyTypeErr != nil { 154 err = errors.Warning("fns: keypairs build certificates failed").WithCause(getKeyTypeErr) 155 return 156 } 157 // certPEM 158 var certPEM []byte 159 if strings.IndexAny(cert, "-----BEGIN") < 0 { 160 certPEM, err = os.ReadFile(cert) 161 if err != nil { 162 err = errors.Warning("fns: keypairs build certificates failed").WithCause(err) 163 return 164 } 165 } else { 166 certPEM = []byte(cert) 167 } 168 if keyType == smx509.SM2Key { 169 certificate, certificateErr := tlcp.X509KeyPair(certPEM, keyPEM) 170 if certificateErr != nil { 171 err = errors.Warning("fns: keypairs build certificates failed").WithCause(certificateErr) 172 return 173 } 174 if tlcps == nil { 175 tlcps = make([]tlcp.Certificate, 0, 1) 176 } 177 tlcps = append(tlcps, certificate) 178 } else if keyType == smx509.SM9Key { 179 err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("sm9 key is unsupported")) 180 return 181 } else { 182 certificate, certificateErr := tls.X509KeyPair(certPEM, keyPEM) 183 if certificateErr != nil { 184 err = errors.Warning("fns: keypairs build certificates failed").WithCause(certificateErr) 185 return 186 } 187 if standards == nil { 188 standards = make([]tls.Certificate, 0, 1) 189 } 190 standards = append(standards, certificate) 191 } 192 } 193 return 194 } 195 196 type ServerConfig struct { 197 ClientAuth int `json:"clientAuth"` 198 Keypair Keypairs `json:"keypair"` 199 } 200 201 func (config *ServerConfig) Config() (gm *tlcp.Config, standard *tls.Config, err error) { 202 clientAuth := tls.ClientAuthType(config.ClientAuth) 203 if clientAuth < tls.NoClientCert || clientAuth > tls.RequireAndVerifyClientCert { 204 err = errors.Warning("fns: build server side tls config failed").WithCause(fmt.Errorf("clientAuth is invalid")) 205 return 206 } 207 if len(config.Keypair) == 0 { 208 err = errors.Warning("fns: build server side tls config failed").WithCause(fmt.Errorf("keypair is undefined")) 209 return 210 } 211 tlcps, standards, certErr := config.Keypair.Certificates() 212 if certErr != nil { 213 err = errors.Warning("fns: build server side tls config failed").WithCause(certErr) 214 return 215 } 216 if len(tlcps) > 0 { 217 gm = &tlcp.Config{ 218 Certificates: tlcps, 219 ClientAuth: tlcp.ClientAuthType(clientAuth), 220 } 221 } 222 if len(standards) > 0 { 223 standard = &tls.Config{ 224 Certificates: standards, 225 ClientAuth: clientAuth, 226 } 227 } 228 return 229 } 230 231 type ClientConfig struct { 232 InsecureSkipVerify bool `json:"insecureSkipVerify"` 233 Keypair Keypairs `json:"keypair"` 234 } 235 236 func (config *ClientConfig) Config() (gm *tlcp.Config, standard *tls.Config, err error) { 237 if len(config.Keypair) == 0 { 238 err = errors.Warning("fns: build client side tls config failed").WithCause(fmt.Errorf("keypair is undefined")) 239 return 240 } 241 tlcps, standards, certErr := config.Keypair.Certificates() 242 if certErr != nil { 243 err = errors.Warning("fns: build client side tls config failed").WithCause(certErr) 244 return 245 } 246 if len(tlcps) > 0 { 247 gm = &tlcp.Config{ 248 Certificates: tlcps, 249 InsecureSkipVerify: config.InsecureSkipVerify, 250 } 251 } 252 if len(standards) > 0 { 253 standard = &tls.Config{ 254 Certificates: standards, 255 InsecureSkipVerify: config.InsecureSkipVerify, 256 } 257 } 258 return 259 } 260 261 type DefaultConfigOptions struct { 262 CA []string `json:"ca"` 263 Server *ServerConfig `json:"server"` 264 Client *ClientConfig `json:"client"` 265 } 266 267 func (options DefaultConfigOptions) Build() (srvGmTLS *tlcp.Config, cliGmTLS *tlcp.Config, srvStdTLS *tls.Config, cliStdTLS *tls.Config, err error) { 268 if options.Server == nil { 269 err = errors.Warning("fns: build default tls config failed").WithCause(fmt.Errorf("server side config is required")) 270 return 271 } 272 srvGmTLS, srvStdTLS, err = options.Server.Config() 273 if err != nil { 274 err = errors.Warning("fns: build default tls config failed").WithCause(err) 275 return 276 } 277 if options.Client != nil { 278 cliGmTLS, cliStdTLS, err = options.Client.Config() 279 if err != nil { 280 err = errors.Warning("fns: build default tls config failed").WithCause(err) 281 return 282 } 283 } 284 var gmCAS *smx509.CertPool 285 var stCAS *x509.CertPool 286 if len(options.CA) > 0 { 287 caPEMs := make([][]byte, 0, 1) 288 for _, ca := range options.CA { 289 ca = strings.TrimSpace(ca) 290 if ca == "" { 291 continue 292 } 293 var caPEM []byte 294 if strings.IndexAny(ca, "-----BEGIN") < 0 { 295 caPEM, err = os.ReadFile(ca) 296 if err != nil { 297 err = errors.Warning("fns: build default tls config failed").WithCause(err) 298 return 299 } 300 } else { 301 caPEM = []byte(ca) 302 } 303 caPEMs = append(caPEMs, caPEM) 304 } 305 if srvGmTLS != nil { 306 gmCAS = smx509.NewCertPool() 307 for _, caPEM := range caPEMs { 308 gmCAS.AppendCertsFromPEM(caPEM) 309 } 310 srvGmTLS.ClientCAs = gmCAS 311 } 312 if srvStdTLS != nil { 313 stCAS = x509.NewCertPool() 314 for _, caPEM := range caPEMs { 315 stCAS.AppendCertsFromPEM(caPEM) 316 } 317 srvStdTLS.ClientCAs = stCAS 318 } 319 if cliGmTLS != nil { 320 cliGmTLS.RootCAs = gmCAS 321 } 322 if cliStdTLS != nil { 323 cliStdTLS.RootCAs = stCAS 324 } 325 } 326 return 327 } 328 329 func NewDefaultConfig(srv *tls.Config, cli *tls.Config, srvGM *tlcp.Config, cliGM *tlcp.Config) *DefaultConfig { 330 return &DefaultConfig{ 331 srvStdTLS: srv, 332 cliStdTLS: cli, 333 srvGmTLS: srvGM, 334 cliGmTLS: cliGM, 335 } 336 } 337 338 type DefaultConfig struct { 339 srvStdTLS *tls.Config 340 cliStdTLS *tls.Config 341 srvGmTLS *tlcp.Config 342 cliGmTLS *tlcp.Config 343 } 344 345 func (config *DefaultConfig) Construct(options configures.Config) (err error) { 346 opt := DefaultConfigOptions{} 347 optErr := options.As(&opt) 348 if optErr != nil { 349 err = errors.Warning("fns: build default tls config failed").WithCause(optErr) 350 return 351 } 352 config.srvGmTLS, config.cliGmTLS, config.srvStdTLS, config.cliStdTLS, err = opt.Build() 353 return 354 } 355 356 func (config *DefaultConfig) Server() (srvTLS *tls.Config, ln ListenerFunc) { 357 if config.srvGmTLS != nil { 358 if config.srvStdTLS != nil { 359 srvTLS = config.srvStdTLS 360 ln = func(inner net.Listener) (v net.Listener) { 361 v = tlcp.NewProtocolSwitcherListener(inner, config.srvGmTLS.Clone(), config.srvStdTLS.Clone()) 362 return 363 } 364 return 365 } 366 ln = func(inner net.Listener) (v net.Listener) { 367 v = tlcp.NewListener(inner, config.srvGmTLS.Clone()) 368 return 369 } 370 return 371 } 372 if config.srvStdTLS != nil { 373 srvTLS = config.srvStdTLS 374 ln = func(inner net.Listener) (v net.Listener) { 375 v = tls.NewListener(inner, config.srvStdTLS) 376 return 377 } 378 } 379 return 380 } 381 382 func (config *DefaultConfig) Client() (cliTLS *tls.Config, dialer Dialer) { 383 if config.cliStdTLS != nil { 384 cliTLS = config.cliStdTLS 385 return 386 } 387 if config.cliGmTLS != nil { 388 if config.cliStdTLS != nil { 389 cliTLS = config.cliStdTLS 390 } 391 nd := &net.Dialer{ 392 Timeout: 30 * time.Second, 393 Deadline: time.Time{}, 394 LocalAddr: nil, 395 FallbackDelay: 0, 396 KeepAlive: 60 * time.Second, 397 Resolver: nil, 398 Control: nil, 399 ControlContext: nil, 400 } 401 dialer = &tlcp.Dialer{NetDialer: nd, Config: config.cliGmTLS} 402 return 403 } 404 return 405 }