github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/security/credential.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package security
    15  
    16  import (
    17  	"crypto/tls"
    18  	"crypto/x509"
    19  	"database/sql/driver"
    20  	"encoding/json"
    21  	"encoding/pem"
    22  	"os"
    23  	"strings"
    24  
    25  	"github.com/pingcap/tiflow/pkg/errors"
    26  	pd "github.com/tikv/pd/client"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/credentials"
    29  )
    30  
    31  // Credential holds necessary path parameter to build a tls.Config
    32  type Credential struct {
    33  	CAPath        string   `toml:"ca-path" json:"ca-path"`
    34  	CertPath      string   `toml:"cert-path" json:"cert-path"`
    35  	KeyPath       string   `toml:"key-path" json:"key-path"`
    36  	CertAllowedCN []string `toml:"cert-allowed-cn" json:"cert-allowed-cn"`
    37  
    38  	// MTLS indicates whether use mTLS, by default it will affect all connections,
    39  	// cludings:
    40  	// 1) connections between TiCDC and TiKV
    41  	// 2) connections between TiCDC and PD
    42  	// 3) http server of TiCDC which is used for open API
    43  	// 4) p2p server of TiCDC which is used sending messages between TiCDC nodes
    44  	// Todo: just enable mTLS for 3) and 4) by default
    45  	MTLS bool `toml:"mtls" json:"mtls"`
    46  
    47  	ClientUserRequired bool     `toml:"client-user-required" json:"client-user-required"`
    48  	ClientAllowedUser  []string `toml:"client-allowed-user" json:"client-allowed-user"`
    49  }
    50  
    51  // Value implements the driver.Valuer interface
    52  func (s Credential) Value() (driver.Value, error) {
    53  	return json.Marshal(s)
    54  }
    55  
    56  // Scan implements the sql.Scanner interface
    57  func (s *Credential) Scan(value interface{}) error {
    58  	b, ok := value.([]byte)
    59  	if !ok {
    60  		return errors.New("type assertion to []byte failed")
    61  	}
    62  
    63  	return json.Unmarshal(b, s)
    64  }
    65  
    66  // IsTLSEnabled checks whether TLS is enabled or not.
    67  func (s *Credential) IsTLSEnabled() bool {
    68  	return len(s.CAPath) != 0 && len(s.CertPath) != 0 && len(s.KeyPath) != 0
    69  }
    70  
    71  // IsEmpty checks whether Credential is empty or not.
    72  func (s *Credential) IsEmpty() bool {
    73  	return len(s.CAPath) == 0 && len(s.CertPath) == 0 && len(s.KeyPath) == 0
    74  }
    75  
    76  // PDSecurityOption creates a new pd SecurityOption from Security
    77  func (s *Credential) PDSecurityOption() pd.SecurityOption {
    78  	return pd.SecurityOption{
    79  		CAPath:   s.CAPath,
    80  		CertPath: s.CertPath,
    81  		KeyPath:  s.KeyPath,
    82  	}
    83  }
    84  
    85  // ToGRPCDialOption constructs a gRPC dial option.
    86  func (s *Credential) ToGRPCDialOption() (grpc.DialOption, error) {
    87  	tlsCfg, err := s.ToTLSConfig()
    88  	if err != nil || tlsCfg == nil {
    89  		return grpc.WithInsecure(), err
    90  	}
    91  	return grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)), nil
    92  }
    93  
    94  // ToTLSConfig generates tls's config from *Security
    95  func (s *Credential) ToTLSConfig() (*tls.Config, error) {
    96  	cfg, err := ToTLSConfigWithVerify(s.CAPath, s.CertPath, s.KeyPath, nil, s.MTLS)
    97  	return cfg, errors.WrapError(errors.ErrToTLSConfigFailed, err)
    98  }
    99  
   100  // ToTLSConfigWithVerify generates tls's config from *Security and requires
   101  // the remote common name to be verified.
   102  func (s *Credential) ToTLSConfigWithVerify() (*tls.Config, error) {
   103  	cfg, err := ToTLSConfigWithVerify(s.CAPath, s.CertPath, s.KeyPath, s.CertAllowedCN, s.MTLS)
   104  	return cfg, errors.WrapError(errors.ErrToTLSConfigFailed, err)
   105  }
   106  
   107  func (s *Credential) getSelfCommonName() (string, error) {
   108  	if s.CertPath == "" {
   109  		return "", nil
   110  	}
   111  	data, err := os.ReadFile(s.CertPath)
   112  	if err != nil {
   113  		return "", errors.WrapError(errors.ErrToTLSConfigFailed, err)
   114  	}
   115  	block, _ := pem.Decode(data)
   116  	if block == nil || block.Type != "CERTIFICATE" {
   117  		return "", errors.ErrToTLSConfigFailed.
   118  			GenWithStack("failed to decode PEM block to certificate")
   119  	}
   120  	certificate, err := x509.ParseCertificate(block.Bytes)
   121  	if err != nil {
   122  		return "", errors.WrapError(errors.ErrToTLSConfigFailed, err)
   123  	}
   124  	return certificate.Subject.CommonName, nil
   125  }
   126  
   127  // AddSelfCommonName add Common Name in certificate that specified by s.CertPath
   128  // to s.CertAllowedCN
   129  func (s *Credential) AddSelfCommonName() error {
   130  	cn, err := s.getSelfCommonName()
   131  	if err != nil {
   132  		return err
   133  	}
   134  	if cn == "" {
   135  		return nil
   136  	}
   137  	s.CertAllowedCN = append(s.CertAllowedCN, cn)
   138  	return nil
   139  }
   140  
   141  // ToTLSConfigWithVerify constructs a `*tls.Config` from the CA, certification and key
   142  // paths, and add verify for CN.
   143  //
   144  // If the CA path is empty, returns nil.
   145  func ToTLSConfigWithVerify(
   146  	caPath, certPath, keyPath string, verifyCN []string, mTLS bool,
   147  ) (*tls.Config, error) {
   148  	if len(caPath) == 0 {
   149  		return nil, nil
   150  	}
   151  
   152  	// Create a certificate pool from CA
   153  	certPool := x509.NewCertPool()
   154  	ca, err := os.ReadFile(caPath)
   155  	if err != nil {
   156  		return nil, errors.Annotate(err, "could not read ca certificate")
   157  	}
   158  
   159  	// Append the certificates from the CA
   160  	if !certPool.AppendCertsFromPEM(ca) {
   161  		return nil, errors.New("failed to append ca certs")
   162  	}
   163  
   164  	tlsCfg := &tls.Config{
   165  		RootCAs:    certPool,
   166  		ClientCAs:  certPool,
   167  		NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
   168  		MinVersion: tls.VersionTLS12,
   169  	}
   170  
   171  	if mTLS {
   172  		tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
   173  	}
   174  
   175  	if len(certPath) != 0 && len(keyPath) != 0 {
   176  		loadCert := func() (*tls.Certificate, error) {
   177  			cert, err := tls.LoadX509KeyPair(certPath, keyPath)
   178  			if err != nil {
   179  				return nil, errors.Annotate(err, "could not load client key pair")
   180  			}
   181  			return &cert, nil
   182  		}
   183  		tlsCfg.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   184  			return loadCert()
   185  		}
   186  		tlsCfg.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   187  			return loadCert()
   188  		}
   189  	}
   190  
   191  	addVerifyPeerCertificate(tlsCfg, verifyCN)
   192  	return tlsCfg, nil
   193  }
   194  
   195  func addVerifyPeerCertificate(tlsCfg *tls.Config, verifyCN []string) {
   196  	if len(verifyCN) != 0 {
   197  		checkCN := make(map[string]struct{})
   198  		for _, cn := range verifyCN {
   199  			cn = strings.TrimSpace(cn)
   200  			checkCN[cn] = struct{}{}
   201  		}
   202  		tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
   203  		tlsCfg.VerifyPeerCertificate = func(
   204  			rawCerts [][]byte, verifiedChains [][]*x509.Certificate,
   205  		) error {
   206  			cns := make([]string, 0, len(verifiedChains))
   207  			for _, chains := range verifiedChains {
   208  				for _, chain := range chains {
   209  					cns = append(cns, chain.Subject.CommonName)
   210  					if _, match := checkCN[chain.Subject.CommonName]; match {
   211  						return nil
   212  					}
   213  				}
   214  			}
   215  			return errors.Errorf("client certificate authentication failed. "+
   216  				"The Common Name from the client certificate %v was not found "+
   217  				"in the configuration cluster-verify-cn with value: %s", cns, verifyCN)
   218  		}
   219  	}
   220  }