github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/cmd/factory/factory.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package factory
    15  
    16  import (
    17  	"crypto/tls"
    18  	"fmt"
    19  	"net/url"
    20  	"os"
    21  	"path/filepath"
    22  
    23  	"github.com/BurntSushi/toml"
    24  	"github.com/pingcap/log"
    25  	"github.com/pingcap/tiflow/cdc/api"
    26  	apiv2client "github.com/pingcap/tiflow/pkg/api/v2"
    27  	"github.com/pingcap/tiflow/pkg/cmd/util"
    28  	"github.com/pingcap/tiflow/pkg/errors"
    29  	"github.com/pingcap/tiflow/pkg/etcd"
    30  	"github.com/pingcap/tiflow/pkg/security"
    31  	"github.com/spf13/cobra"
    32  	pd "github.com/tikv/pd/client"
    33  	"golang.org/x/term"
    34  	"google.golang.org/grpc"
    35  )
    36  
    37  const (
    38  	defaultCrendentialConfigFile = ".ticdc/credentials"
    39  	// User Credential Environment Variables
    40  	envVarTiCDCUser     = "TICDC_USER"
    41  	envVarTiCDCPassword = "TICDC_PASSWORD"
    42  	// TLS Client Certificate Environment Variables
    43  	envVarTiCDCCAPath   = "TICDC_CA_PATH"
    44  	envVarTiCDCCertPath = "TICDC_CERT_PATH"
    45  	envVarTiCDCKeyPath  = "TICDC_KEY_PATH"
    46  )
    47  
    48  // Factory defines the client-side construction factory.
    49  type Factory interface {
    50  	ClientGetter
    51  	EtcdClient() (*etcd.CDCEtcdClientImpl, error)
    52  	PdClient() (pd.Client, error)
    53  	APIV2Client() (apiv2client.APIV2Interface, error)
    54  }
    55  
    56  // ClientGetter defines the client getter.
    57  type ClientGetter interface {
    58  	ToTLSConfig() (*tls.Config, error)
    59  	ToGRPCDialOption() (grpc.DialOption, error)
    60  	GetPdAddr() string
    61  	GetServerAddr() string
    62  	GetLogLevel() string
    63  	GetCredential() *security.Credential
    64  	GetAuthParameters() url.Values
    65  }
    66  
    67  // ClientAuth specifies the authentication parameters.
    68  type ClientAuth struct {
    69  	// User Credential
    70  	User     string `toml:"ticdc_user,omitempty"`
    71  	Password string `toml:"ticdc_password,omitempty"`
    72  
    73  	// TLS Client Certificate
    74  	CaPath   string `toml:"ca_path,omitempty"`
    75  	CertPath string `toml:"cert_path,omitempty"`
    76  	KeyPath  string `toml:"key_path,omitempty"`
    77  }
    78  
    79  // StoreToDefaultPath stores the client authentication to default path.
    80  func (c *ClientAuth) StoreToDefaultPath() error {
    81  	homeDir, err := os.UserHomeDir()
    82  	if err != nil {
    83  		msg := "failed to get user home directory"
    84  		return fmt.Errorf("%s: %w", msg, err)
    85  	}
    86  
    87  	filename := filepath.Join(homeDir, defaultCrendentialConfigFile)
    88  	err = os.MkdirAll(filepath.Dir(filename), os.ModePerm)
    89  	if err != nil {
    90  		msg := fmt.Sprintf("failed to create directory for creandential file <%s>", filename)
    91  		return fmt.Errorf("%s: %w", msg, err)
    92  	}
    93  	file, err := os.Create(filename)
    94  	if err != nil {
    95  		msg := fmt.Sprintf("failed to create creandential file <%s>", filename)
    96  		return fmt.Errorf("%s: %w", msg, err)
    97  	}
    98  
    99  	err = toml.NewEncoder(file).Encode(c)
   100  	if err != nil {
   101  		msg := fmt.Sprintf("failed to encode client authentication to creandential file <%s>", filename)
   102  		return fmt.Errorf("%s: %w", msg, err)
   103  	}
   104  
   105  	err = file.Close()
   106  	if err != nil {
   107  		msg := fmt.Sprintf("failed to store client authentication to creandential file <%s>", filename)
   108  		return fmt.Errorf("%s: %w", msg, err)
   109  	}
   110  	return nil
   111  }
   112  
   113  // ReadFromDefaultPath reads the client authentication from default path.
   114  func ReadFromDefaultPath() (*ClientAuth, error) {
   115  	homeDir, err := os.UserHomeDir()
   116  	if err != nil {
   117  		msg := "failed to get user home directory"
   118  		return nil, fmt.Errorf("%s: %w", msg, err)
   119  	}
   120  
   121  	res := &ClientAuth{}
   122  	filename := filepath.Join(homeDir, defaultCrendentialConfigFile)
   123  	if _, err := os.Stat(filename); err == nil {
   124  		err = util.StrictDecodeFile(filename, "cdc cli auth config", res)
   125  		if err != nil {
   126  			msg := fmt.Sprintf("failed to parse client authentication from creandential file <%s>", filename)
   127  			return nil, fmt.Errorf("%s: %w", msg, err)
   128  		}
   129  	}
   130  	return res, nil
   131  }
   132  
   133  // ClientFlags specifies the parameters needed to construct the client.
   134  type ClientFlags struct {
   135  	ClientAuth
   136  	pdAddr     string
   137  	serverAddr string
   138  	logLevel   string
   139  }
   140  
   141  var _ ClientGetter = &ClientFlags{}
   142  
   143  // ToTLSConfig returns the configuration of tls.
   144  func (c *ClientFlags) ToTLSConfig() (*tls.Config, error) {
   145  	credential := c.GetCredential()
   146  	tlsConfig, err := credential.ToTLSConfig()
   147  	if err != nil {
   148  		return nil, errors.Annotate(err, "fail to validate TLS settings")
   149  	}
   150  	return tlsConfig, nil
   151  }
   152  
   153  // ToGRPCDialOption returns the option of GRPC dial.
   154  func (c *ClientFlags) ToGRPCDialOption() (grpc.DialOption, error) {
   155  	credential := c.GetCredential()
   156  	grpcTLSOption, err := credential.ToGRPCDialOption()
   157  	if err != nil {
   158  		return nil, errors.Annotate(err, "fail to validate TLS settings")
   159  	}
   160  
   161  	return grpcTLSOption, nil
   162  }
   163  
   164  // GetPdAddr returns pd address.
   165  func (c *ClientFlags) GetPdAddr() string {
   166  	return c.pdAddr
   167  }
   168  
   169  // GetLogLevel returns log level.
   170  func (c *ClientFlags) GetLogLevel() string {
   171  	return c.logLevel
   172  }
   173  
   174  // GetServerAddr returns cdc cluster id.
   175  func (c *ClientFlags) GetServerAddr() string {
   176  	return c.serverAddr
   177  }
   178  
   179  // NewClientFlags creates new client flags.
   180  func NewClientFlags() *ClientFlags {
   181  	return &ClientFlags{}
   182  }
   183  
   184  // AddFlags receives a *cobra.Command reference and binds
   185  // flags related to template printing to it.
   186  func (c *ClientFlags) AddFlags(cmd *cobra.Command) {
   187  	cmd.PersistentFlags().StringVar(&c.serverAddr, "server",
   188  		"", "CDC server address")
   189  	cmd.PersistentFlags().StringVar(&c.pdAddr, "pd", "",
   190  		"PD address, use ',' to separate multiple PDs, "+
   191  			"Parameter --pd is deprecated, please use parameter --server instead.")
   192  	cmd.PersistentFlags().StringVar(&c.CaPath, "ca", "",
   193  		"CA certificate path for TLS connection to CDC server")
   194  	cmd.PersistentFlags().StringVar(&c.CertPath, "cert", "",
   195  		"Certificate path for TLS connection to CDC server")
   196  	cmd.PersistentFlags().StringVar(&c.KeyPath, "key", "",
   197  		"Private key path for TLS connection to CDC server")
   198  	cmd.PersistentFlags().StringVar(&c.logLevel, "log-level", "warn",
   199  		"log level (etc: debug|info|warn|error)")
   200  
   201  	cmd.PersistentFlags().StringVar(&c.User, "user", "", "User name for authentication. "+
   202  		"You can sqpecify it via environment variable TICDC_USER")
   203  	cmd.PersistentFlags().StringVar(&c.Password, "password", "", "Password for authentication. "+
   204  		"You can specify it via environment variable TICDC_PASSWORD")
   205  }
   206  
   207  // GetCredential returns credential.
   208  func (c *ClientFlags) GetCredential() *security.Credential {
   209  	var certAllowedCN []string
   210  
   211  	return &security.Credential{
   212  		CAPath:        c.CaPath,
   213  		CertPath:      c.CertPath,
   214  		KeyPath:       c.KeyPath,
   215  		CertAllowedCN: certAllowedCN,
   216  	}
   217  }
   218  
   219  // CompleteClientAuthParameters completes the authentication parameters.
   220  func (c *ClientFlags) CompleteClientAuthParameters(cmd *cobra.Command) error {
   221  	c.completeTLSClientCertificate(cmd)
   222  	return c.completeUserCredential(cmd)
   223  }
   224  
   225  func (c *ClientFlags) completeUserCredential(cmd *cobra.Command) (err error) {
   226  	authType := "command line"
   227  	defer func() {
   228  		if err == nil {
   229  			if c.User == "" && c.Password != "" {
   230  				err = errors.ErrCredentialNotFound.GenWithStackByArgs("invalid atuhentication: password is specified without user")
   231  			}
   232  		}
   233  		log.Info(fmt.Sprintf("cli authentication type: %s", authType))
   234  	}()
   235  	// If user is specified via command line, password should be specified as well.
   236  	if c.User != "" {
   237  		if c.Password == "" {
   238  			cmd.Print("Enter password: ")
   239  			password, err := term.ReadPassword(int(os.Stdin.Fd()))
   240  			if err != nil {
   241  				return errors.ErrCredentialNotFound.GenWithStackByArgs(c.User, "Error reading password, ", err)
   242  			}
   243  			cmd.Println()
   244  			c.Password = string(password)
   245  		}
   246  		return nil
   247  	}
   248  
   249  	// If user is not specified via command line, try to get it from environment variable.
   250  	authType = "environment variable"
   251  	c.User = os.Getenv(envVarTiCDCUser)
   252  	c.Password = os.Getenv(envVarTiCDCPassword)
   253  	if c.User != "" {
   254  		return nil
   255  	}
   256  
   257  	// If user is not specified via command line or environment variable, try to get it from credential file.
   258  	authType = "credential file"
   259  	res, err := ReadFromDefaultPath()
   260  	if err != nil {
   261  		return errors.WrapError(errors.ErrCredentialNotFound, err)
   262  	}
   263  	if res != nil {
   264  		c.User = res.User
   265  		c.Password = res.Password
   266  	}
   267  	return nil
   268  }
   269  
   270  func (c *ClientFlags) completeTLSClientCertificate(cmd *cobra.Command) {
   271  	authType := "command line"
   272  	defer func() {
   273  		if c.CaPath == "" && c.CertPath == "" && c.KeyPath == "" {
   274  			authType = "disabled"
   275  		}
   276  		log.Info(fmt.Sprintf("cli tls client certificate type: %s", authType))
   277  	}()
   278  	// If one of the client certificate is specified via command line, all of them should be specified.
   279  	if c.CaPath != "" || c.CertPath != "" || c.KeyPath != "" {
   280  		return
   281  	}
   282  
   283  	// If none of the client certificate is specified via command line, try to get it from environment variable.
   284  	authType = "environment variable"
   285  	c.CaPath = os.Getenv(envVarTiCDCCAPath)
   286  	c.CertPath = os.Getenv(envVarTiCDCCertPath)
   287  	c.KeyPath = os.Getenv(envVarTiCDCKeyPath)
   288  	if c.CaPath != "" || c.CertPath != "" || c.KeyPath != "" {
   289  		return
   290  	}
   291  
   292  	// If none of the client certificate is specified via command line or environment variable, try to get it from credential file.
   293  	authType = "credential file"
   294  	res, err := ReadFromDefaultPath()
   295  	if err != nil {
   296  		cmd.Println("failed to read client certificate from default config file: , try to use insecure connection", err)
   297  	}
   298  	if res != nil {
   299  		c.CaPath = res.CaPath
   300  		c.CertPath = res.CertPath
   301  		c.KeyPath = res.KeyPath
   302  	}
   303  }
   304  
   305  // GetAuthParameters returns the authentication parameters.
   306  func (c *ClientFlags) GetAuthParameters() url.Values {
   307  	if c.User == "" {
   308  		return nil
   309  	}
   310  	return url.Values{
   311  		api.APIOpVarTiCDCUser:     {c.User},
   312  		api.APIOpVarTiCDCPassword: {c.Password},
   313  	}
   314  }