github.com/polarismesh/polaris@v1.17.8/store/mysql/base_db.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package sqldb
    19  
    20  import (
    21  	"context"
    22  	"database/sql"
    23  	"fmt"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/polarismesh/polaris/plugin"
    28  )
    29  
    30  // db抛出的异常,需要重试的字符串组
    31  var errMsg = []string{"Deadlock", "bad connection", "invalid connection"}
    32  
    33  // BaseDB 对sql.DB的封装
    34  type BaseDB struct {
    35  	*sql.DB
    36  	cfg            *dbConfig
    37  	isolationLevel sql.IsolationLevel
    38  	parsePwd       plugin.ParsePassword
    39  }
    40  
    41  // dbConfig store的配置
    42  type dbConfig struct {
    43  	dbType           string
    44  	dbUser           string
    45  	dbPwd            string
    46  	dbAddr           string
    47  	dbName           string
    48  	maxOpenConns     int
    49  	maxIdleConns     int
    50  	connMaxLifetime  int
    51  	txIsolationLevel int
    52  }
    53  
    54  // NewBaseDB 新建一个BaseDB
    55  func NewBaseDB(cfg *dbConfig, parsePwd plugin.ParsePassword) (*BaseDB, error) {
    56  	baseDb := &BaseDB{cfg: cfg, parsePwd: parsePwd}
    57  	if cfg.txIsolationLevel > 0 {
    58  		baseDb.isolationLevel = sql.IsolationLevel(cfg.txIsolationLevel)
    59  		log.Infof("[Store][database] use isolation level: %s", baseDb.isolationLevel.String())
    60  	}
    61  
    62  	if err := baseDb.openDatabase(); err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	return baseDb, nil
    67  }
    68  
    69  // openDatabase 与数据库进行连接
    70  func (b *BaseDB) openDatabase() error {
    71  	c := b.cfg
    72  
    73  	// 使用密码解析插件
    74  	if b.parsePwd != nil {
    75  		pwd, err := b.parsePwd.ParsePassword(c.dbPwd)
    76  		if err != nil {
    77  			log.Errorf("[Store][database][ParsePwdPlugin] parse password err: %s", err.Error())
    78  			return err
    79  		}
    80  		c.dbPwd = pwd
    81  	}
    82  
    83  	dns := fmt.Sprintf("%s:%s@tcp(%s)/%s", c.dbUser, c.dbPwd, c.dbAddr, c.dbName)
    84  
    85  	db, err := sql.Open(c.dbType, dns)
    86  	if err != nil {
    87  		log.Errorf("[Store][database] sql open err: %s", err.Error())
    88  		return err
    89  	}
    90  	if pingErr := db.Ping(); pingErr != nil {
    91  		log.Errorf("[Store][database] database ping err: %s", pingErr.Error())
    92  		return pingErr
    93  	}
    94  	if c.maxOpenConns > 0 {
    95  		log.Infof("[Store][database] db set max open conns: %d", c.maxOpenConns)
    96  		db.SetMaxOpenConns(c.maxOpenConns)
    97  	}
    98  	if c.maxIdleConns > 0 {
    99  		log.Infof("[Store][database] db set max idle conns: %d", c.maxIdleConns)
   100  		db.SetMaxIdleConns(c.maxIdleConns)
   101  	}
   102  	if c.connMaxLifetime > 0 {
   103  		log.Infof("[Store][database] db set conn max life time: %d", c.connMaxLifetime)
   104  		db.SetConnMaxLifetime(time.Second * time.Duration(c.connMaxLifetime))
   105  	}
   106  
   107  	b.DB = db
   108  	return nil
   109  }
   110  
   111  // Exec 重写db.Exec函数 提供重试功能
   112  func (b *BaseDB) Exec(query string, args ...interface{}) (sql.Result, error) {
   113  	var result sql.Result
   114  	var err error
   115  	Retry("exec "+query, func() error {
   116  		result, err = b.DB.Exec(query, args...)
   117  		return err
   118  	})
   119  
   120  	return result, err
   121  }
   122  
   123  // Query 重写db.Query函数
   124  func (b *BaseDB) Query(query string, args ...interface{}) (*sql.Rows, error) {
   125  	var rows *sql.Rows
   126  	var err error
   127  	Retry("query "+query, func() error {
   128  		rows, err = b.DB.Query(query, args...)
   129  		return err
   130  	})
   131  
   132  	return rows, err
   133  }
   134  
   135  // Begin 重写db.Begin
   136  func (b *BaseDB) Begin() (*BaseTx, error) {
   137  	var tx *sql.Tx
   138  	var err error
   139  	var option *sql.TxOptions
   140  	if b.isolationLevel > 0 {
   141  		option = &sql.TxOptions{Isolation: sql.IsolationLevel(b.isolationLevel)}
   142  	}
   143  	Retry("begin", func() error {
   144  		tx, err = b.DB.BeginTx(context.Background(), option)
   145  		return err
   146  	})
   147  
   148  	return &BaseTx{Tx: tx}, err
   149  }
   150  
   151  // BaseTx 对sql.Tx的封装
   152  type BaseTx struct {
   153  	*sql.Tx
   154  }
   155  
   156  // Retry 重试主函数
   157  // 最多重试20次,每次等待5ms*重试次数
   158  func Retry(label string, handle func() error) {
   159  	var err error
   160  	maxTryTimes := 20
   161  	for i := 1; i <= maxTryTimes; i++ {
   162  		err = handle()
   163  		if err == nil {
   164  			return
   165  		}
   166  
   167  		repeated := false // 是否重试
   168  		for _, msg := range errMsg {
   169  			if strings.Contains(err.Error(), msg) {
   170  				log.Warnf("[Store][database][%s] get error msg: %s. Repeated doing(%d)", label, err.Error(), i)
   171  				time.Sleep(time.Millisecond * 5 * time.Duration(i))
   172  				repeated = true
   173  				break
   174  			}
   175  		}
   176  		if !repeated {
   177  			return
   178  		}
   179  	}
   180  }
   181  
   182  // RetryTransaction 事务重试
   183  func RetryTransaction(label string, handle func() error) error {
   184  	var err error
   185  	Retry(label, func() error {
   186  		err = handle()
   187  		return err
   188  	})
   189  	return err
   190  }
   191  
   192  func (b *BaseDB) processWithTransaction(label string, handle func(*BaseTx) error) error {
   193  	tx, err := b.Begin()
   194  	if err != nil {
   195  		log.Errorf("[Store][database] %s begin tx err: %s", label, err.Error())
   196  		return err
   197  	}
   198  
   199  	defer func() {
   200  		_ = tx.Rollback()
   201  	}()
   202  	return handle(tx)
   203  }