github.com/aacfactory/fns@v1.2.86-0.20240310083819-80d667fc0a17/transports/ssl/default.go (about)

     1  /*
     2   * Copyright 2023 Wang Min Xiang
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   * 	http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   */
    17  
    18  package ssl
    19  
    20  import (
    21  	"crypto/ecdsa"
    22  	"crypto/rsa"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"encoding/pem"
    26  	"fmt"
    27  	"github.com/aacfactory/afssl/gmsm/cfca"
    28  	"github.com/aacfactory/afssl/gmsm/sm2"
    29  	"github.com/aacfactory/afssl/gmsm/smx509"
    30  	"github.com/aacfactory/afssl/gmsm/tlcp"
    31  	"github.com/aacfactory/configures"
    32  	"github.com/aacfactory/errors"
    33  	"net"
    34  	"os"
    35  	"strings"
    36  	"time"
    37  )
    38  
    39  type Keypair struct {
    40  	Cert     string `json:"cert"`
    41  	Key      string `json:"key"`
    42  	Password string `json:"password"`
    43  }
    44  
    45  type Keypairs []Keypair
    46  
    47  func (kps Keypairs) Certificates() (tlcps []tlcp.Certificate, standards []tls.Certificate, err error) {
    48  	if len(kps) == 0 {
    49  		return
    50  	}
    51  	for _, keypair := range kps {
    52  		cert := strings.TrimSpace(keypair.Cert)
    53  		key := strings.TrimSpace(keypair.Key)
    54  		// key
    55  		if key == "" {
    56  			err = errors.Warning("fns: keypairs build certificates failed").WithCause(fmt.Errorf("key is undefined"))
    57  			return
    58  		}
    59  		var keyPEM []byte
    60  		if strings.IndexAny(key, "-----BEGIN") < 0 {
    61  			keyPEM, err = os.ReadFile(key)
    62  			if err != nil {
    63  				err = errors.Warning("fns: keypairs build certificates failed").WithCause(err)
    64  				return
    65  			}
    66  		} else {
    67  			keyPEM = []byte(key)
    68  		}
    69  		keyBlock, _ := pem.Decode(keyPEM)
    70  		if keyBlock.Type == "CFCA" {
    71  			password := strings.TrimSpace(keypair.Password)
    72  			if password == "" {
    73  				err = errors.Warning("fns: keypairs build certificates failed").WithCause(fmt.Errorf("password is undefined"))
    74  				return
    75  			}
    76  			pass, readPassErr := os.ReadFile(password)
    77  			if readPassErr != nil {
    78  				if !os.IsNotExist(readPassErr) {
    79  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(readPassErr)
    80  					return
    81  				}
    82  				pass = []byte(password)
    83  			}
    84  			cfcaCert, cfcaKey, cfcaErr := cfca.Parse(keyPEM, pass)
    85  			if cfcaErr != nil {
    86  				err = errors.Warning("fns: keypairs build certificates failed").WithCause(cfcaErr)
    87  				return
    88  			}
    89  			if cert != "" {
    90  				var certPEM []byte
    91  				if strings.IndexAny(cert, "-----BEGIN") < 0 {
    92  					certPEM, err = os.ReadFile(cert)
    93  					if err != nil {
    94  						err = errors.Warning("fns: keypairs build certificates failed").WithCause(err)
    95  						return
    96  					}
    97  				} else {
    98  					certPEM = []byte(cert)
    99  				}
   100  				certBlock, _ := pem.Decode(certPEM)
   101  				if certBlock == nil {
   102  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("x509: failed to decode PEM block containing certificate"))
   103  					return
   104  				}
   105  				rootCert, rootCertErr := smx509.ParseCertificate(certBlock.Bytes)
   106  				if rootCertErr != nil {
   107  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(rootCertErr)
   108  					return
   109  				}
   110  				checkSignatureErr := rootCert.CheckSignature(smx509.SignatureAlgorithm(cfcaCert.SignatureAlgorithm), cfcaCert.RawTBSCertificate, cfcaCert.Signature)
   111  				if checkSignatureErr != nil {
   112  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(checkSignatureErr)
   113  					return
   114  				}
   115  			}
   116  			certificate := tlcp.Certificate{
   117  				Certificate: [][]byte{keyBlock.Bytes},
   118  				PrivateKey:  cfcaKey,
   119  				Leaf:        cfcaCert,
   120  			}
   121  			switch pub := cfcaCert.PublicKey.(type) {
   122  			case *rsa.PublicKey:
   123  				priv, ok := certificate.PrivateKey.(*rsa.PrivateKey)
   124  				if !ok {
   125  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key type does not match public key type"))
   126  					return
   127  				}
   128  				if pub.N.Cmp(priv.N) != 0 {
   129  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key does not match public key"))
   130  					return
   131  				}
   132  			case *ecdsa.PublicKey:
   133  				priv, ok := certificate.PrivateKey.(*sm2.PrivateKey)
   134  				if !ok {
   135  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key type does not match public key type"))
   136  					return
   137  				}
   138  				if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
   139  					err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: private key does not match public key"))
   140  					return
   141  				}
   142  			default:
   143  				err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("tlcp: unknown public key algorithm"))
   144  				return
   145  			}
   146  			if tlcps == nil {
   147  				tlcps = make([]tlcp.Certificate, 0, 1)
   148  			}
   149  			tlcps = append(tlcps, certificate)
   150  			continue
   151  		}
   152  		keyType, getKeyTypeErr := smx509.GetGMPrivateKeyType(keyBlock.Bytes)
   153  		if getKeyTypeErr != nil {
   154  			err = errors.Warning("fns: keypairs build certificates failed").WithCause(getKeyTypeErr)
   155  			return
   156  		}
   157  		// certPEM
   158  		var certPEM []byte
   159  		if strings.IndexAny(cert, "-----BEGIN") < 0 {
   160  			certPEM, err = os.ReadFile(cert)
   161  			if err != nil {
   162  				err = errors.Warning("fns: keypairs build certificates failed").WithCause(err)
   163  				return
   164  			}
   165  		} else {
   166  			certPEM = []byte(cert)
   167  		}
   168  		if keyType == smx509.SM2Key {
   169  			certificate, certificateErr := tlcp.X509KeyPair(certPEM, keyPEM)
   170  			if certificateErr != nil {
   171  				err = errors.Warning("fns: keypairs build certificates failed").WithCause(certificateErr)
   172  				return
   173  			}
   174  			if tlcps == nil {
   175  				tlcps = make([]tlcp.Certificate, 0, 1)
   176  			}
   177  			tlcps = append(tlcps, certificate)
   178  		} else if keyType == smx509.SM9Key {
   179  			err = errors.Warning("fns: keypairs build certificates failed").WithCause(errors.Warning("sm9 key is unsupported"))
   180  			return
   181  		} else {
   182  			certificate, certificateErr := tls.X509KeyPair(certPEM, keyPEM)
   183  			if certificateErr != nil {
   184  				err = errors.Warning("fns: keypairs build certificates failed").WithCause(certificateErr)
   185  				return
   186  			}
   187  			if standards == nil {
   188  				standards = make([]tls.Certificate, 0, 1)
   189  			}
   190  			standards = append(standards, certificate)
   191  		}
   192  	}
   193  	return
   194  }
   195  
   196  type ServerConfig struct {
   197  	ClientAuth int      `json:"clientAuth"`
   198  	Keypair    Keypairs `json:"keypair"`
   199  }
   200  
   201  func (config *ServerConfig) Config() (gm *tlcp.Config, standard *tls.Config, err error) {
   202  	clientAuth := tls.ClientAuthType(config.ClientAuth)
   203  	if clientAuth < tls.NoClientCert || clientAuth > tls.RequireAndVerifyClientCert {
   204  		err = errors.Warning("fns: build server side tls config failed").WithCause(fmt.Errorf("clientAuth is invalid"))
   205  		return
   206  	}
   207  	if len(config.Keypair) == 0 {
   208  		err = errors.Warning("fns: build server side tls config failed").WithCause(fmt.Errorf("keypair is undefined"))
   209  		return
   210  	}
   211  	tlcps, standards, certErr := config.Keypair.Certificates()
   212  	if certErr != nil {
   213  		err = errors.Warning("fns: build server side tls config failed").WithCause(certErr)
   214  		return
   215  	}
   216  	if len(tlcps) > 0 {
   217  		gm = &tlcp.Config{
   218  			Certificates: tlcps,
   219  			ClientAuth:   tlcp.ClientAuthType(clientAuth),
   220  		}
   221  	}
   222  	if len(standards) > 0 {
   223  		standard = &tls.Config{
   224  			Certificates: standards,
   225  			ClientAuth:   clientAuth,
   226  		}
   227  	}
   228  	return
   229  }
   230  
   231  type ClientConfig struct {
   232  	InsecureSkipVerify bool     `json:"insecureSkipVerify"`
   233  	Keypair            Keypairs `json:"keypair"`
   234  }
   235  
   236  func (config *ClientConfig) Config() (gm *tlcp.Config, standard *tls.Config, err error) {
   237  	if len(config.Keypair) == 0 {
   238  		err = errors.Warning("fns: build client side tls config failed").WithCause(fmt.Errorf("keypair is undefined"))
   239  		return
   240  	}
   241  	tlcps, standards, certErr := config.Keypair.Certificates()
   242  	if certErr != nil {
   243  		err = errors.Warning("fns: build client side tls config failed").WithCause(certErr)
   244  		return
   245  	}
   246  	if len(tlcps) > 0 {
   247  		gm = &tlcp.Config{
   248  			Certificates:       tlcps,
   249  			InsecureSkipVerify: config.InsecureSkipVerify,
   250  		}
   251  	}
   252  	if len(standards) > 0 {
   253  		standard = &tls.Config{
   254  			Certificates:       standards,
   255  			InsecureSkipVerify: config.InsecureSkipVerify,
   256  		}
   257  	}
   258  	return
   259  }
   260  
   261  type DefaultConfigOptions struct {
   262  	CA     []string      `json:"ca"`
   263  	Server *ServerConfig `json:"server"`
   264  	Client *ClientConfig `json:"client"`
   265  }
   266  
   267  func (options DefaultConfigOptions) Build() (srvGmTLS *tlcp.Config, cliGmTLS *tlcp.Config, srvStdTLS *tls.Config, cliStdTLS *tls.Config, err error) {
   268  	if options.Server == nil {
   269  		err = errors.Warning("fns: build default tls config failed").WithCause(fmt.Errorf("server side config is required"))
   270  		return
   271  	}
   272  	srvGmTLS, srvStdTLS, err = options.Server.Config()
   273  	if err != nil {
   274  		err = errors.Warning("fns: build default tls config failed").WithCause(err)
   275  		return
   276  	}
   277  	if options.Client != nil {
   278  		cliGmTLS, cliStdTLS, err = options.Client.Config()
   279  		if err != nil {
   280  			err = errors.Warning("fns: build default tls config failed").WithCause(err)
   281  			return
   282  		}
   283  	}
   284  	var gmCAS *smx509.CertPool
   285  	var stCAS *x509.CertPool
   286  	if len(options.CA) > 0 {
   287  		caPEMs := make([][]byte, 0, 1)
   288  		for _, ca := range options.CA {
   289  			ca = strings.TrimSpace(ca)
   290  			if ca == "" {
   291  				continue
   292  			}
   293  			var caPEM []byte
   294  			if strings.IndexAny(ca, "-----BEGIN") < 0 {
   295  				caPEM, err = os.ReadFile(ca)
   296  				if err != nil {
   297  					err = errors.Warning("fns: build default tls config failed").WithCause(err)
   298  					return
   299  				}
   300  			} else {
   301  				caPEM = []byte(ca)
   302  			}
   303  			caPEMs = append(caPEMs, caPEM)
   304  		}
   305  		if srvGmTLS != nil {
   306  			gmCAS = smx509.NewCertPool()
   307  			for _, caPEM := range caPEMs {
   308  				gmCAS.AppendCertsFromPEM(caPEM)
   309  			}
   310  			srvGmTLS.ClientCAs = gmCAS
   311  		}
   312  		if srvStdTLS != nil {
   313  			stCAS = x509.NewCertPool()
   314  			for _, caPEM := range caPEMs {
   315  				stCAS.AppendCertsFromPEM(caPEM)
   316  			}
   317  			srvStdTLS.ClientCAs = stCAS
   318  		}
   319  		if cliGmTLS != nil {
   320  			cliGmTLS.RootCAs = gmCAS
   321  		}
   322  		if cliStdTLS != nil {
   323  			cliStdTLS.RootCAs = stCAS
   324  		}
   325  	}
   326  	return
   327  }
   328  
   329  func NewDefaultConfig(srv *tls.Config, cli *tls.Config, srvGM *tlcp.Config, cliGM *tlcp.Config) *DefaultConfig {
   330  	return &DefaultConfig{
   331  		srvStdTLS: srv,
   332  		cliStdTLS: cli,
   333  		srvGmTLS:  srvGM,
   334  		cliGmTLS:  cliGM,
   335  	}
   336  }
   337  
   338  type DefaultConfig struct {
   339  	srvStdTLS *tls.Config
   340  	cliStdTLS *tls.Config
   341  	srvGmTLS  *tlcp.Config
   342  	cliGmTLS  *tlcp.Config
   343  }
   344  
   345  func (config *DefaultConfig) Construct(options configures.Config) (err error) {
   346  	opt := DefaultConfigOptions{}
   347  	optErr := options.As(&opt)
   348  	if optErr != nil {
   349  		err = errors.Warning("fns: build default tls config failed").WithCause(optErr)
   350  		return
   351  	}
   352  	config.srvGmTLS, config.cliGmTLS, config.srvStdTLS, config.cliStdTLS, err = opt.Build()
   353  	return
   354  }
   355  
   356  func (config *DefaultConfig) Server() (srvTLS *tls.Config, ln ListenerFunc) {
   357  	if config.srvGmTLS != nil {
   358  		if config.srvStdTLS != nil {
   359  			srvTLS = config.srvStdTLS
   360  			ln = func(inner net.Listener) (v net.Listener) {
   361  				v = tlcp.NewProtocolSwitcherListener(inner, config.srvGmTLS.Clone(), config.srvStdTLS.Clone())
   362  				return
   363  			}
   364  			return
   365  		}
   366  		ln = func(inner net.Listener) (v net.Listener) {
   367  			v = tlcp.NewListener(inner, config.srvGmTLS.Clone())
   368  			return
   369  		}
   370  		return
   371  	}
   372  	if config.srvStdTLS != nil {
   373  		srvTLS = config.srvStdTLS
   374  		ln = func(inner net.Listener) (v net.Listener) {
   375  			v = tls.NewListener(inner, config.srvStdTLS)
   376  			return
   377  		}
   378  	}
   379  	return
   380  }
   381  
   382  func (config *DefaultConfig) Client() (cliTLS *tls.Config, dialer Dialer) {
   383  	if config.cliStdTLS != nil {
   384  		cliTLS = config.cliStdTLS
   385  		return
   386  	}
   387  	if config.cliGmTLS != nil {
   388  		if config.cliStdTLS != nil {
   389  			cliTLS = config.cliStdTLS
   390  		}
   391  		nd := &net.Dialer{
   392  			Timeout:        30 * time.Second,
   393  			Deadline:       time.Time{},
   394  			LocalAddr:      nil,
   395  			FallbackDelay:  0,
   396  			KeepAlive:      60 * time.Second,
   397  			Resolver:       nil,
   398  			Control:        nil,
   399  			ControlContext: nil,
   400  		}
   401  		dialer = &tlcp.Dialer{NetDialer: nd, Config: config.cliGmTLS}
   402  		return
   403  	}
   404  	return
   405  }