github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/conn/basedb.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package conn
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"database/sql"
    20  	"fmt"
    21  	"net"
    22  	"net/url"
    23  	"strconv"
    24  	"strings"
    25  	"sync"
    26  	"sync/atomic"
    27  
    28  	"github.com/go-sql-driver/mysql"
    29  	"github.com/pingcap/errors"
    30  	"github.com/pingcap/failpoint"
    31  	"github.com/pingcap/tidb/pkg/util"
    32  	"github.com/pingcap/tiflow/dm/config/dbconfig"
    33  	tcontext "github.com/pingcap/tiflow/dm/pkg/context"
    34  	"github.com/pingcap/tiflow/dm/pkg/log"
    35  	"github.com/pingcap/tiflow/dm/pkg/retry"
    36  	"github.com/pingcap/tiflow/dm/pkg/terror"
    37  	"github.com/pingcap/tiflow/dm/pkg/utils"
    38  	"go.uber.org/zap"
    39  )
    40  
    41  var customID int64
    42  
    43  var netTimeout = DefaultDBTimeout
    44  
    45  // DBProvider providers BaseDB instance.
    46  type DBProvider interface {
    47  	Apply(config ScopedDBConfig) (*BaseDB, error)
    48  }
    49  
    50  // DefaultDBProviderImpl is default DBProvider implement.
    51  type DefaultDBProviderImpl struct{}
    52  
    53  // DefaultDBProvider is global instance of DBProvider.
    54  var DefaultDBProvider DBProvider
    55  
    56  func init() {
    57  	DefaultDBProvider = &DefaultDBProviderImpl{}
    58  }
    59  
    60  type ScopedDBConfig struct {
    61  	*dbconfig.DBConfig
    62  	Scope terror.ErrScope
    63  }
    64  
    65  func UpstreamDBConfig(cfg *dbconfig.DBConfig) ScopedDBConfig {
    66  	return ScopedDBConfig{
    67  		DBConfig: cfg,
    68  		Scope:    terror.ScopeUpstream,
    69  	}
    70  }
    71  
    72  func DownstreamDBConfig(cfg *dbconfig.DBConfig) ScopedDBConfig {
    73  	return ScopedDBConfig{
    74  		DBConfig: cfg,
    75  		Scope:    terror.ScopeDownstream,
    76  	}
    77  }
    78  
    79  func GetUpstreamDB(cfg *dbconfig.DBConfig) (*BaseDB, error) {
    80  	return DefaultDBProvider.Apply(UpstreamDBConfig(cfg))
    81  }
    82  
    83  func GetDownstreamDB(cfg *dbconfig.DBConfig) (*BaseDB, error) {
    84  	return DefaultDBProvider.Apply(DownstreamDBConfig(cfg))
    85  }
    86  
    87  // Apply will build BaseDB with DBConfig.
    88  func (d *DefaultDBProviderImpl) Apply(config ScopedDBConfig) (*BaseDB, error) {
    89  	// maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection.
    90  	// https://github.com/go-sql-driver/mysql#maxallowedpacket
    91  	hostPort := net.JoinHostPort(config.Host, strconv.Itoa(config.Port))
    92  	net := "tcp"
    93  	if config.Net != "" {
    94  		net = config.Net
    95  	}
    96  	dsn := fmt.Sprintf("%s:%s@%s(%s)/?charset=utf8mb4&interpolateParams=true&maxAllowedPacket=0",
    97  		config.User, config.Password, net, hostPort)
    98  
    99  	doFuncInClose := func() {}
   100  	if config.Security != nil {
   101  		if loadErr := config.Security.LoadTLSContent(); loadErr != nil {
   102  			return nil, terror.ErrCtlLoadTLSCfg.Delegate(loadErr)
   103  		}
   104  		tlsConfig, err := util.NewTLSConfig(
   105  			util.WithCAContent(config.Security.SSLCABytes),
   106  			util.WithCertAndKeyContent(config.Security.SSLCertBytes, config.Security.SSLKeyBytes),
   107  			util.WithVerifyCommonName(config.Security.CertAllowedCN),
   108  			util.WithMinTLSVersion(tls.VersionTLS10),
   109  		)
   110  		if err != nil {
   111  			return nil, terror.ErrConnInvalidTLSConfig.Delegate(err)
   112  		}
   113  
   114  		if tlsConfig != nil {
   115  			name := "dm" + strconv.FormatInt(atomic.AddInt64(&customID, 1), 10)
   116  			err = mysql.RegisterTLSConfig(name, tlsConfig)
   117  			if err != nil {
   118  				return nil, terror.ErrConnRegistryTLSConfig.Delegate(err)
   119  			}
   120  			dsn += "&tls=" + name
   121  
   122  			doFuncInClose = func() {
   123  				mysql.DeregisterTLSConfig(name)
   124  			}
   125  		}
   126  	}
   127  
   128  	var maxIdleConns int
   129  	rawCfg := config.RawDBCfg
   130  	if rawCfg != nil {
   131  		if rawCfg.ReadTimeout != "" {
   132  			dsn += fmt.Sprintf("&readTimeout=%s", rawCfg.ReadTimeout)
   133  		}
   134  		if rawCfg.WriteTimeout != "" {
   135  			dsn += fmt.Sprintf("&writeTimeout=%s", rawCfg.WriteTimeout)
   136  		}
   137  		maxIdleConns = rawCfg.MaxIdleConns
   138  	}
   139  
   140  	var setFK bool
   141  	for key, val := range config.Session {
   142  		// for num such as 1/"1", format as key='1'
   143  		// for string, format as key='string'
   144  		// both are valid for mysql and tidb
   145  		if strings.ToLower(key) == "foreign_key_checks" {
   146  			setFK = true
   147  		}
   148  		dsn += fmt.Sprintf("&%s='%s'", key, url.QueryEscape(val))
   149  	}
   150  
   151  	if !setFK {
   152  		dsn += "&foreign_key_checks=0"
   153  	}
   154  
   155  	db, err := sql.Open("mysql", dsn)
   156  	if err != nil {
   157  		return nil, terror.DBErrorAdapt(err, config.Scope, terror.ErrDBDriverError)
   158  	}
   159  
   160  	ctx, cancel := context.WithTimeout(context.Background(), netTimeout)
   161  	defer cancel()
   162  	err = db.PingContext(ctx)
   163  	failpoint.Inject("failDBPing", func(_ failpoint.Value) {
   164  		err = errors.New("injected error")
   165  	})
   166  	if err != nil {
   167  		db.Close()
   168  		doFuncInClose()
   169  		return nil, terror.DBErrorAdapt(err, config.Scope, terror.ErrDBDriverError)
   170  	}
   171  
   172  	db.SetMaxIdleConns(maxIdleConns)
   173  
   174  	return NewBaseDB(db, config.Scope, doFuncInClose), nil
   175  }
   176  
   177  // BaseDB wraps *sql.DB, control the BaseConn.
   178  type BaseDB struct {
   179  	DB *sql.DB
   180  
   181  	mu sync.Mutex // protects following fields
   182  	// hold all db connections generated from this BaseDB
   183  	conns map[*BaseConn]struct{}
   184  
   185  	Retry retry.Strategy
   186  
   187  	Scope terror.ErrScope
   188  	// this function will do when close the BaseDB
   189  	doFuncInClose []func()
   190  
   191  	// only use in unit test
   192  	doNotClose bool
   193  }
   194  
   195  // NewBaseDB returns *BaseDB object for test.
   196  func NewBaseDB(db *sql.DB, scope terror.ErrScope, doFuncInClose ...func()) *BaseDB {
   197  	conns := make(map[*BaseConn]struct{})
   198  	return &BaseDB{
   199  		DB:            db,
   200  		conns:         conns,
   201  		Retry:         &retry.FiniteRetryStrategy{},
   202  		Scope:         scope,
   203  		doFuncInClose: doFuncInClose,
   204  	}
   205  }
   206  
   207  // NewBaseDBForTest returns *BaseDB object for test.
   208  func NewBaseDBForTest(db *sql.DB, doFuncInClose ...func()) *BaseDB {
   209  	conns := make(map[*BaseConn]struct{})
   210  	return &BaseDB{
   211  		DB:            db,
   212  		conns:         conns,
   213  		Retry:         &retry.FiniteRetryStrategy{},
   214  		Scope:         terror.ScopeNotSet,
   215  		doFuncInClose: doFuncInClose,
   216  	}
   217  }
   218  
   219  // NewMockDB returns *BaseDB object for mock.
   220  func NewMockDB(db *sql.DB, doFuncInClose ...func()) *BaseDB {
   221  	baseDB := NewBaseDBForTest(db, doFuncInClose...)
   222  	baseDB.doNotClose = true
   223  	return baseDB
   224  }
   225  
   226  // GetBaseConn retrieves *BaseConn which has own retryStrategy.
   227  func (d *BaseDB) GetBaseConn(ctx context.Context) (*BaseConn, error) {
   228  	ctx, cancel := context.WithTimeout(ctx, netTimeout)
   229  	defer cancel()
   230  	conn, err := d.DB.Conn(ctx)
   231  	if err != nil {
   232  		return nil, terror.DBErrorAdapt(err, d.Scope, terror.ErrDBDriverError)
   233  	}
   234  	err = conn.PingContext(ctx)
   235  	if err != nil {
   236  		return nil, terror.DBErrorAdapt(err, d.Scope, terror.ErrDBDriverError)
   237  	}
   238  	baseConn := NewBaseConn(conn, d.Scope, d.Retry)
   239  	d.mu.Lock()
   240  	defer d.mu.Unlock()
   241  	d.conns[baseConn] = struct{}{}
   242  	return baseConn, nil
   243  }
   244  
   245  // TODO: retry can be done inside the BaseDB.
   246  func (d *BaseDB) ExecContext(tctx *tcontext.Context, query string, args ...interface{}) (sql.Result, error) {
   247  	if tctx.L().Core().Enabled(zap.DebugLevel) {
   248  		tctx.L().Debug("exec context",
   249  			zap.String("query", utils.TruncateString(query, -1)),
   250  			zap.String("argument", utils.TruncateInterface(args, -1)))
   251  	}
   252  	return d.DB.ExecContext(tctx.Ctx, query, args...)
   253  }
   254  
   255  // TODO: retry can be done inside the BaseDB.
   256  func (d *BaseDB) QueryContext(tctx *tcontext.Context, query string, args ...interface{}) (*sql.Rows, error) {
   257  	if tctx.L().Core().Enabled(zap.DebugLevel) {
   258  		tctx.L().Debug("query context",
   259  			zap.String("query", utils.TruncateString(query, -1)),
   260  			zap.String("argument", utils.TruncateInterface(args, -1)))
   261  	}
   262  	return d.DB.QueryContext(tctx.Ctx, query, args...)
   263  }
   264  
   265  func (d *BaseDB) DoTxWithRetry(tctx *tcontext.Context, queries []string, args [][]interface{}, retryer retry.Retryer) error {
   266  	workFunc := func(tctx *tcontext.Context) (interface{}, error) {
   267  		var (
   268  			err error
   269  			tx  *sql.Tx
   270  		)
   271  		tx, err = d.DB.BeginTx(tctx.Ctx, nil)
   272  		if err != nil {
   273  			return nil, errors.Trace(err)
   274  		}
   275  		defer func() {
   276  			if err != nil {
   277  				if rollbackErr := tx.Rollback(); rollbackErr != nil {
   278  					tctx.L().Warn("failed to rollback", zap.Error(errors.Trace(rollbackErr)))
   279  				}
   280  			} else {
   281  				err = tx.Commit()
   282  			}
   283  		}()
   284  		for i := range queries {
   285  			q := queries[i]
   286  			if tctx.L().Core().Enabled(zap.DebugLevel) {
   287  				tctx.L().Debug("exec in tx",
   288  					zap.String("query", utils.TruncateString(q, -1)),
   289  					zap.String("argument", utils.TruncateInterface(args[i], -1)))
   290  			}
   291  			if _, err = tx.ExecContext(tctx.Ctx, q, args[i]...); err != nil {
   292  				return nil, errors.Trace(err)
   293  			}
   294  		}
   295  		return nil, errors.Trace(err)
   296  	}
   297  
   298  	_, _, err := retryer.Apply(tctx, workFunc)
   299  	return err
   300  }
   301  
   302  // CloseConn release BaseConn resource from BaseDB, and returns the connection to the connection pool,
   303  // has the same meaning of sql.Conn.Close.
   304  func (d *BaseDB) CloseConn(conn *BaseConn) error {
   305  	d.mu.Lock()
   306  	defer d.mu.Unlock()
   307  	delete(d.conns, conn)
   308  	return conn.close()
   309  }
   310  
   311  // CloseConnWithoutErr release BaseConn resource from BaseDB, and returns the connection to the connection pool,
   312  // has the same meaning of sql.Conn.Close, and log warning on error.
   313  func (d *BaseDB) CloseConnWithoutErr(conn *BaseConn) {
   314  	if err := d.CloseConn(conn); err != nil {
   315  		log.L().Warn("close db connection failed", zap.Error(err))
   316  	}
   317  }
   318  
   319  // ForceCloseConn release BaseConn resource from BaseDB, and close BaseConn completely(not return to the connection pool).
   320  func (d *BaseDB) ForceCloseConn(conn *BaseConn) error {
   321  	d.mu.Lock()
   322  	defer d.mu.Unlock()
   323  	delete(d.conns, conn)
   324  	return conn.forceClose()
   325  }
   326  
   327  // ForceCloseConnWithoutErr close the connection completely(not return to the conn pool),
   328  // and output a warning log if meets an error.
   329  func (d *BaseDB) ForceCloseConnWithoutErr(conn *BaseConn) {
   330  	if err1 := d.ForceCloseConn(conn); err1 != nil {
   331  		log.L().Warn("close db connection failed", zap.Error(err1))
   332  	}
   333  }
   334  
   335  // Close release *BaseDB resource.
   336  func (d *BaseDB) Close() error {
   337  	if d == nil || d.DB == nil || d.doNotClose {
   338  		return nil
   339  	}
   340  	var err error
   341  	d.mu.Lock()
   342  	defer d.mu.Unlock()
   343  	for conn := range d.conns {
   344  		terr := conn.forceClose()
   345  		if err == nil {
   346  			err = terr
   347  		}
   348  	}
   349  	terr := d.DB.Close()
   350  	for _, f := range d.doFuncInClose {
   351  		f()
   352  	}
   353  
   354  	if err == nil {
   355  		return terr
   356  	}
   357  
   358  	return err
   359  }