github.com/go-courier/sqlx/v2@v2.23.13/connectors/mysql/driver.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/go-courier/sqlx/v2"
    11  
    12  	"github.com/go-courier/logr"
    13  	"github.com/pkg/errors"
    14  
    15  	"github.com/go-sql-driver/mysql"
    16  )
    17  
    18  func init() {
    19  	_ = mysql.SetLogger(&logger{})
    20  }
    21  
    22  type logger struct{}
    23  
    24  func (l *logger) Print(args ...interface{}) {
    25  }
    26  
    27  var _ interface {
    28  	driver.Driver
    29  } = (*MySqlLoggingDriver)(nil)
    30  
    31  type MySqlLoggingDriver struct {
    32  	driver mysql.MySQLDriver
    33  }
    34  
    35  func (d *MySqlLoggingDriver) Open(dsn string) (driver.Conn, error) {
    36  	cfg, err := mysql.ParseDSN(dsn)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	cfg.Passwd = strings.Repeat("*", len(cfg.Passwd))
    41  
    42  	conn, err := d.driver.Open(dsn)
    43  	if err != nil {
    44  		return nil, errors.Wrapf(err, "failed to open connection: %s", cfg.FormatDSN())
    45  	}
    46  	return &loggerConn{Conn: conn, cfg: cfg}, nil
    47  }
    48  
    49  func (d *MySqlLoggingDriver) Driver() driver.Driver {
    50  	return d
    51  }
    52  
    53  var _ interface {
    54  	driver.ConnBeginTx
    55  	driver.ExecerContext
    56  	driver.QueryerContext
    57  } = (*loggerConn)(nil)
    58  
    59  type loggerConn struct {
    60  	cfg *mysql.Config
    61  	driver.Conn
    62  }
    63  
    64  func (c *loggerConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    65  	logger := logr.FromContext(ctx)
    66  
    67  	logger.Debug("=========== Beginning Transaction ===========")
    68  	tx, err := c.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
    69  	if err != nil {
    70  		logger.Error(errors.Wrap(err, "failed to begin transaction"))
    71  		return nil, err
    72  	}
    73  	return &loggingTx{Tx: tx, logger: logger}, nil
    74  }
    75  
    76  func (c *loggerConn) Prepare(query string) (driver.Stmt, error) {
    77  	panic(fmt.Errorf("don't use Prepare"))
    78  }
    79  
    80  func (c *loggerConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
    81  	cost := startTimer()
    82  	newCtx, logger := logr.Start(ctx, "Query")
    83  
    84  	defer func() {
    85  		q := c.interpolateParams(query, args)
    86  
    87  		if err != nil {
    88  			if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); !ok {
    89  				logger.Error(errors.Wrapf(err, "query failed: %s", q))
    90  			} else {
    91  				logger.Warn(errors.Wrapf(mysqlErr, "query failed: %s", q))
    92  			}
    93  		} else {
    94  			logger.WithValues("cost", cost().String()).Debug(q.String())
    95  		}
    96  
    97  		logger.End()
    98  	}()
    99  
   100  	rows, err = c.Conn.(driver.QueryerContext).QueryContext(newCtx, query, args)
   101  	return
   102  }
   103  
   104  func (c *loggerConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
   105  	cost := startTimer()
   106  	newCtx, logger := logr.Start(ctx, "Query")
   107  
   108  	defer func() {
   109  		q := c.interpolateParams(query, args)
   110  
   111  		if err != nil {
   112  			if mysqlErr, ok := sqlx.UnwrapAll(err).(*mysql.MySQLError); !ok {
   113  				logger.Error(errors.Wrapf(err, "exec failed: %s", q))
   114  			} else if mysqlErr.Number == DuplicateEntryErrNumber {
   115  				logger.Error(errors.Wrapf(mysqlErr, "exec failed: %s", q))
   116  			} else {
   117  				logger.Warn(errors.Wrapf(mysqlErr, "exec failed: %s", q))
   118  			}
   119  		} else {
   120  			logger.WithValues("cost", cost().String()).Debug(q.String())
   121  		}
   122  
   123  		logger.End()
   124  	}()
   125  
   126  	result, err = c.Conn.(driver.ExecerContext).ExecContext(newCtx, query, args)
   127  	return
   128  }
   129  
   130  func (c *loggerConn) interpolateParams(query string, args []driver.NamedValue) fmt.Stringer {
   131  	return &SqlPrinter{query, args, c.cfg}
   132  }
   133  
   134  type SqlPrinter struct {
   135  	query string
   136  	args  []driver.NamedValue
   137  	cfg   *mysql.Config
   138  }
   139  
   140  func (p *SqlPrinter) String() string {
   141  	if len(p.args) == 0 {
   142  		return p.query
   143  	}
   144  	argValues, err := namedValueToValue(p.args)
   145  	if err != nil {
   146  		return p.query
   147  	}
   148  	sqlForLog, err := interpolateParams(p.query, argValues, p.cfg.Loc, p.cfg.MaxAllowedPacket)
   149  	if err != nil {
   150  		return p.query
   151  	}
   152  
   153  	return sqlForLog
   154  }
   155  
   156  var DuplicateEntryErrNumber uint16 = 1062
   157  
   158  func startTimer() func() time.Duration {
   159  	startTime := time.Now()
   160  	return func() time.Duration {
   161  		return time.Since(startTime)
   162  	}
   163  }
   164  
   165  type loggingTx struct {
   166  	logger logr.Logger
   167  	driver.Tx
   168  }
   169  
   170  func (tx *loggingTx) Commit() error {
   171  	if err := tx.Tx.Commit(); err != nil {
   172  		tx.logger.Debug("failed to commit transaction: %s", err)
   173  		return err
   174  	}
   175  	tx.logger.Debug("=========== Committed Transaction ===========")
   176  	return nil
   177  }
   178  
   179  func (tx *loggingTx) Rollback() error {
   180  	if err := tx.Tx.Rollback(); err != nil {
   181  		tx.logger.Debug("failed to rollback transaction: %s", err)
   182  		return err
   183  	}
   184  	tx.logger.Debug("=========== Rollback Transaction ===========")
   185  	return nil
   186  }