github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sqlx/mysql_logger_driver/statement.go (about)

     1  package mysql_logger_driver
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"strconv"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/fatih/color"
    10  	"github.com/go-sql-driver/mysql"
    11  	"github.com/sirupsen/logrus"
    12  
    13  	"github.com/artisanhe/tools/duration"
    14  )
    15  
    16  type loggerStmt struct {
    17  	cfg   *mysql.Config
    18  	query string
    19  	stmt  driver.Stmt
    20  }
    21  
    22  func (s *loggerStmt) Close() error {
    23  	if err := s.stmt.Close(); err != nil {
    24  		logrus.Errorf("failed to close statement: %s", err)
    25  		return err
    26  	}
    27  	return nil
    28  }
    29  
    30  var DuplicateEntryErrNumber uint16 = 1062
    31  
    32  func (s *loggerStmt) Exec(args []driver.Value) (driver.Result, error) {
    33  	cost := duration.NewDuration()
    34  	if len(args) != 0 {
    35  		sqlForLog, err := s.interpolateParams(s.query, args)
    36  		if err != nil {
    37  			logrus.Warnf("failed exec %s: %s", err, color.RedString(s.query))
    38  			return nil, err
    39  		}
    40  		s.query = sqlForLog
    41  	}
    42  	result, err := s.stmt.Exec(args)
    43  	if err != nil {
    44  		if mysqlErr, ok := err.(*mysql.MySQLError); !ok {
    45  			logrus.Errorf("failed exec %s: %s", err, color.RedString(s.query))
    46  		} else if mysqlErr.Number == DuplicateEntryErrNumber {
    47  			logrus.Warnf("failed exec %s: %s", err, color.RedString(s.query))
    48  		} else {
    49  			logrus.Errorf("failed exec %s: %s", err, color.RedString(s.query))
    50  		}
    51  		return nil, err
    52  	}
    53  	cost.ToLogger().Debugf(color.YellowString(s.query))
    54  	return result, nil
    55  }
    56  
    57  func (s *loggerStmt) Query(args []driver.Value) (driver.Rows, error) {
    58  	cost := duration.NewDuration()
    59  	if len(args) != 0 {
    60  		sqlForLog, err := s.interpolateParams(s.query, args)
    61  		if err != nil {
    62  			if mysqlErr, ok := err.(*mysql.MySQLError); !ok {
    63  				logrus.Errorf("failed exec %s: %s", err, color.RedString(s.query))
    64  			} else {
    65  				logrus.Warnf("failed exec %s: %s", mysqlErr, color.RedString(s.query))
    66  			}
    67  			return nil, err
    68  		}
    69  		s.query = sqlForLog
    70  	}
    71  	rows, err := s.stmt.Query(args)
    72  	if err != nil {
    73  		logrus.Warnf("failed query %s: %s", err, color.RedString(s.query))
    74  		return nil, err
    75  	}
    76  	cost.ToLogger().Debugf(color.GreenString(s.query))
    77  	return rows, nil
    78  }
    79  
    80  func (s *loggerStmt) NumInput() int {
    81  	i := s.stmt.NumInput()
    82  	return i
    83  }
    84  
    85  func (s *loggerStmt) interpolateParams(query string, args []driver.Value) (string, error) {
    86  	if strings.Count(query, "?") != len(args) {
    87  		return "", driver.ErrSkip
    88  	}
    89  
    90  	buf := []byte{}
    91  	buf = buf[:0]
    92  	argPos := 0
    93  
    94  	for i := 0; i < len(query); i++ {
    95  		q := strings.IndexByte(query[i:], '?')
    96  		if q == -1 {
    97  			buf = append(buf, query[i:]...)
    98  			break
    99  		}
   100  		buf = append(buf, query[i:i+q]...)
   101  		i += q
   102  
   103  		arg := args[argPos]
   104  		argPos++
   105  
   106  		if arg == nil {
   107  			buf = append(buf, "NULL"...)
   108  			continue
   109  		}
   110  
   111  		switch v := arg.(type) {
   112  		case int64:
   113  			buf = strconv.AppendInt(buf, v, 10)
   114  		case float64:
   115  			buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
   116  		case bool:
   117  			if v {
   118  				buf = append(buf, '1')
   119  			} else {
   120  				buf = append(buf, '0')
   121  			}
   122  		case time.Time:
   123  			if v.IsZero() {
   124  				buf = append(buf, "'0000-00-00'"...)
   125  			} else {
   126  				v := v.In(s.cfg.Loc)
   127  				v = v.Add(time.Nanosecond * 500) // Write round under microsecond
   128  				year := v.Year()
   129  				year100 := year / 100
   130  				year1 := year % 100
   131  				month := v.Month()
   132  				day := v.Day()
   133  				hour := v.Hour()
   134  				minute := v.Minute()
   135  				second := v.Second()
   136  				micro := v.Nanosecond() / 1000
   137  
   138  				buf = append(buf, []byte{
   139  					'\'',
   140  					digits10[year100], digits01[year100],
   141  					digits10[year1], digits01[year1],
   142  					'-',
   143  					digits10[month], digits01[month],
   144  					'-',
   145  					digits10[day], digits01[day],
   146  					' ',
   147  					digits10[hour], digits01[hour],
   148  					':',
   149  					digits10[minute], digits01[minute],
   150  					':',
   151  					digits10[second], digits01[second],
   152  				}...)
   153  
   154  				if micro != 0 {
   155  					micro10000 := micro / 10000
   156  					micro100 := micro / 100 % 100
   157  					micro1 := micro % 100
   158  					buf = append(buf, []byte{
   159  						'.',
   160  						digits10[micro10000], digits01[micro10000],
   161  						digits10[micro100], digits01[micro100],
   162  						digits10[micro1], digits01[micro1],
   163  					}...)
   164  				}
   165  				buf = append(buf, '\'')
   166  			}
   167  		case []byte:
   168  			if v == nil {
   169  				buf = append(buf, "NULL"...)
   170  			} else {
   171  				buf = append(buf, "_binary'"...)
   172  				buf = escapeBytesBackslash(buf, v)
   173  				buf = append(buf, '\'')
   174  			}
   175  		case string:
   176  			buf = append(buf, '\'')
   177  			buf = escapeBytesBackslash(buf, []byte(v))
   178  			buf = append(buf, '\'')
   179  		default:
   180  			return "", driver.ErrSkip
   181  		}
   182  
   183  		if len(buf)+4 > s.cfg.MaxAllowedPacket {
   184  			return "", driver.ErrSkip
   185  		}
   186  	}
   187  	if argPos != len(args) {
   188  		return "", driver.ErrSkip
   189  	}
   190  	return string(buf), nil
   191  }
   192  
   193  // copy from mysql driver
   194  
   195  const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
   196  const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
   197  
   198  func escapeBytesBackslash(buf, v []byte) []byte {
   199  	pos := len(buf)
   200  	buf = reserveBuffer(buf, len(v)*2)
   201  
   202  	for _, c := range v {
   203  		switch c {
   204  		case '\x00':
   205  			buf[pos] = '\\'
   206  			buf[pos+1] = '0'
   207  			pos += 2
   208  		case '\n':
   209  			buf[pos] = '\\'
   210  			buf[pos+1] = 'n'
   211  			pos += 2
   212  		case '\r':
   213  			buf[pos] = '\\'
   214  			buf[pos+1] = 'r'
   215  			pos += 2
   216  		case '\x1a':
   217  			buf[pos] = '\\'
   218  			buf[pos+1] = 'Z'
   219  			pos += 2
   220  		case '\'':
   221  			buf[pos] = '\\'
   222  			buf[pos+1] = '\''
   223  			pos += 2
   224  		case '"':
   225  			buf[pos] = '\\'
   226  			buf[pos+1] = '"'
   227  			pos += 2
   228  		case '\\':
   229  			buf[pos] = '\\'
   230  			buf[pos+1] = '\\'
   231  			pos += 2
   232  		default:
   233  			buf[pos] = c
   234  			pos++
   235  		}
   236  	}
   237  
   238  	return buf[:pos]
   239  }
   240  
   241  func reserveBuffer(buf []byte, appendSize int) []byte {
   242  	newSize := len(buf) + appendSize
   243  	if cap(buf) < newSize {
   244  		// Grow buffer exponentially
   245  		newBuf := make([]byte, len(buf)*2+appendSize)
   246  		copy(newBuf, buf)
   247  		buf = newBuf
   248  	}
   249  	return buf[:newSize]
   250  }