github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/sqlx/mysql_logger_driver/statement.go (about) 1 package mysql_logger_driver 2 3 import ( 4 "database/sql/driver" 5 "strconv" 6 "strings" 7 "time" 8 9 "github.com/fatih/color" 10 "github.com/go-sql-driver/mysql" 11 "github.com/sirupsen/logrus" 12 13 "github.com/artisanhe/tools/duration" 14 ) 15 16 type loggerStmt struct { 17 cfg *mysql.Config 18 query string 19 stmt driver.Stmt 20 } 21 22 func (s *loggerStmt) Close() error { 23 if err := s.stmt.Close(); err != nil { 24 logrus.Errorf("failed to close statement: %s", err) 25 return err 26 } 27 return nil 28 } 29 30 var DuplicateEntryErrNumber uint16 = 1062 31 32 func (s *loggerStmt) Exec(args []driver.Value) (driver.Result, error) { 33 cost := duration.NewDuration() 34 if len(args) != 0 { 35 sqlForLog, err := s.interpolateParams(s.query, args) 36 if err != nil { 37 logrus.Warnf("failed exec %s: %s", err, color.RedString(s.query)) 38 return nil, err 39 } 40 s.query = sqlForLog 41 } 42 result, err := s.stmt.Exec(args) 43 if err != nil { 44 if mysqlErr, ok := err.(*mysql.MySQLError); !ok { 45 logrus.Errorf("failed exec %s: %s", err, color.RedString(s.query)) 46 } else if mysqlErr.Number == DuplicateEntryErrNumber { 47 logrus.Warnf("failed exec %s: %s", err, color.RedString(s.query)) 48 } else { 49 logrus.Errorf("failed exec %s: %s", err, color.RedString(s.query)) 50 } 51 return nil, err 52 } 53 cost.ToLogger().Debugf(color.YellowString(s.query)) 54 return result, nil 55 } 56 57 func (s *loggerStmt) Query(args []driver.Value) (driver.Rows, error) { 58 cost := duration.NewDuration() 59 if len(args) != 0 { 60 sqlForLog, err := s.interpolateParams(s.query, args) 61 if err != nil { 62 if mysqlErr, ok := err.(*mysql.MySQLError); !ok { 63 logrus.Errorf("failed exec %s: %s", err, color.RedString(s.query)) 64 } else { 65 logrus.Warnf("failed exec %s: %s", mysqlErr, color.RedString(s.query)) 66 } 67 return nil, err 68 } 69 s.query = sqlForLog 70 } 71 rows, err := s.stmt.Query(args) 72 if err != nil { 73 logrus.Warnf("failed query %s: %s", err, color.RedString(s.query)) 74 return nil, err 75 } 76 cost.ToLogger().Debugf(color.GreenString(s.query)) 77 return rows, nil 78 } 79 80 func (s *loggerStmt) NumInput() int { 81 i := s.stmt.NumInput() 82 return i 83 } 84 85 func (s *loggerStmt) interpolateParams(query string, args []driver.Value) (string, error) { 86 if strings.Count(query, "?") != len(args) { 87 return "", driver.ErrSkip 88 } 89 90 buf := []byte{} 91 buf = buf[:0] 92 argPos := 0 93 94 for i := 0; i < len(query); i++ { 95 q := strings.IndexByte(query[i:], '?') 96 if q == -1 { 97 buf = append(buf, query[i:]...) 98 break 99 } 100 buf = append(buf, query[i:i+q]...) 101 i += q 102 103 arg := args[argPos] 104 argPos++ 105 106 if arg == nil { 107 buf = append(buf, "NULL"...) 108 continue 109 } 110 111 switch v := arg.(type) { 112 case int64: 113 buf = strconv.AppendInt(buf, v, 10) 114 case float64: 115 buf = strconv.AppendFloat(buf, v, 'g', -1, 64) 116 case bool: 117 if v { 118 buf = append(buf, '1') 119 } else { 120 buf = append(buf, '0') 121 } 122 case time.Time: 123 if v.IsZero() { 124 buf = append(buf, "'0000-00-00'"...) 125 } else { 126 v := v.In(s.cfg.Loc) 127 v = v.Add(time.Nanosecond * 500) // Write round under microsecond 128 year := v.Year() 129 year100 := year / 100 130 year1 := year % 100 131 month := v.Month() 132 day := v.Day() 133 hour := v.Hour() 134 minute := v.Minute() 135 second := v.Second() 136 micro := v.Nanosecond() / 1000 137 138 buf = append(buf, []byte{ 139 '\'', 140 digits10[year100], digits01[year100], 141 digits10[year1], digits01[year1], 142 '-', 143 digits10[month], digits01[month], 144 '-', 145 digits10[day], digits01[day], 146 ' ', 147 digits10[hour], digits01[hour], 148 ':', 149 digits10[minute], digits01[minute], 150 ':', 151 digits10[second], digits01[second], 152 }...) 153 154 if micro != 0 { 155 micro10000 := micro / 10000 156 micro100 := micro / 100 % 100 157 micro1 := micro % 100 158 buf = append(buf, []byte{ 159 '.', 160 digits10[micro10000], digits01[micro10000], 161 digits10[micro100], digits01[micro100], 162 digits10[micro1], digits01[micro1], 163 }...) 164 } 165 buf = append(buf, '\'') 166 } 167 case []byte: 168 if v == nil { 169 buf = append(buf, "NULL"...) 170 } else { 171 buf = append(buf, "_binary'"...) 172 buf = escapeBytesBackslash(buf, v) 173 buf = append(buf, '\'') 174 } 175 case string: 176 buf = append(buf, '\'') 177 buf = escapeBytesBackslash(buf, []byte(v)) 178 buf = append(buf, '\'') 179 default: 180 return "", driver.ErrSkip 181 } 182 183 if len(buf)+4 > s.cfg.MaxAllowedPacket { 184 return "", driver.ErrSkip 185 } 186 } 187 if argPos != len(args) { 188 return "", driver.ErrSkip 189 } 190 return string(buf), nil 191 } 192 193 // copy from mysql driver 194 195 const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" 196 const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" 197 198 func escapeBytesBackslash(buf, v []byte) []byte { 199 pos := len(buf) 200 buf = reserveBuffer(buf, len(v)*2) 201 202 for _, c := range v { 203 switch c { 204 case '\x00': 205 buf[pos] = '\\' 206 buf[pos+1] = '0' 207 pos += 2 208 case '\n': 209 buf[pos] = '\\' 210 buf[pos+1] = 'n' 211 pos += 2 212 case '\r': 213 buf[pos] = '\\' 214 buf[pos+1] = 'r' 215 pos += 2 216 case '\x1a': 217 buf[pos] = '\\' 218 buf[pos+1] = 'Z' 219 pos += 2 220 case '\'': 221 buf[pos] = '\\' 222 buf[pos+1] = '\'' 223 pos += 2 224 case '"': 225 buf[pos] = '\\' 226 buf[pos+1] = '"' 227 pos += 2 228 case '\\': 229 buf[pos] = '\\' 230 buf[pos+1] = '\\' 231 pos += 2 232 default: 233 buf[pos] = c 234 pos++ 235 } 236 } 237 238 return buf[:pos] 239 } 240 241 func reserveBuffer(buf []byte, appendSize int) []byte { 242 newSize := len(buf) + appendSize 243 if cap(buf) < newSize { 244 // Grow buffer exponentially 245 newBuf := make([]byte, len(buf)*2+appendSize) 246 copy(newBuf, buf) 247 buf = newBuf 248 } 249 return buf[:newSize] 250 }