github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/stmt.go (about) 1 package sqlx 2 3 import ( 4 "context" 5 "database/sql" 6 "time" 7 8 "github.com/lingyao2333/mo-zero/core/logx" 9 "github.com/lingyao2333/mo-zero/core/syncx" 10 "github.com/lingyao2333/mo-zero/core/timex" 11 ) 12 13 const defaultSlowThreshold = time.Millisecond * 500 14 15 var ( 16 slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) 17 logSql = syncx.ForAtomicBool(true) 18 logSlowSql = syncx.ForAtomicBool(true) 19 ) 20 21 // DisableLog disables logging of sql statements, includes info and slow logs. 22 func DisableLog() { 23 logSql.Set(false) 24 logSlowSql.Set(false) 25 } 26 27 // DisableStmtLog disables info logging of sql statements, but keeps slow logs. 28 func DisableStmtLog() { 29 logSql.Set(false) 30 } 31 32 // SetSlowThreshold sets the slow threshold. 33 func SetSlowThreshold(threshold time.Duration) { 34 slowThreshold.Set(threshold) 35 } 36 37 func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) { 38 guard := newGuard("exec") 39 if err := guard.start(q, args...); err != nil { 40 return nil, err 41 } 42 43 result, err := conn.ExecContext(ctx, q, args...) 44 guard.finish(ctx, err) 45 46 return result, err 47 } 48 49 func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) { 50 guard := newGuard("execStmt") 51 if err := guard.start(q, args...); err != nil { 52 return nil, err 53 } 54 55 result, err := conn.ExecContext(ctx, args...) 56 guard.finish(ctx, err) 57 58 return result, err 59 } 60 61 func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error, 62 q string, args ...interface{}) error { 63 guard := newGuard("query") 64 if err := guard.start(q, args...); err != nil { 65 return err 66 } 67 68 rows, err := conn.QueryContext(ctx, q, args...) 69 guard.finish(ctx, err) 70 if err != nil { 71 return err 72 } 73 defer rows.Close() 74 75 return scanner(rows) 76 } 77 78 func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error, 79 q string, args ...interface{}) error { 80 guard := newGuard("queryStmt") 81 if err := guard.start(q, args...); err != nil { 82 return err 83 } 84 85 rows, err := conn.QueryContext(ctx, args...) 86 guard.finish(ctx, err) 87 if err != nil { 88 return err 89 } 90 defer rows.Close() 91 92 return scanner(rows) 93 } 94 95 type ( 96 sqlGuard interface { 97 start(q string, args ...interface{}) error 98 finish(ctx context.Context, err error) 99 } 100 101 nilGuard struct{} 102 103 realSqlGuard struct { 104 command string 105 stmt string 106 startTime time.Duration 107 } 108 ) 109 110 func newGuard(command string) sqlGuard { 111 if logSql.True() || logSlowSql.True() { 112 return &realSqlGuard{ 113 command: command, 114 } 115 } 116 117 return nilGuard{} 118 } 119 120 func (n nilGuard) start(_ string, _ ...interface{}) error { 121 return nil 122 } 123 124 func (n nilGuard) finish(_ context.Context, _ error) { 125 } 126 127 func (e *realSqlGuard) finish(ctx context.Context, err error) { 128 duration := timex.Since(e.startTime) 129 if duration > slowThreshold.Load() { 130 logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt) 131 } else if logSql.True() { 132 logx.WithContext(ctx).WithDuration(duration).Infof("sql %s: %s", e.command, e.stmt) 133 } 134 135 if err != nil { 136 logSqlError(ctx, e.stmt, err) 137 } 138 139 metricReqDur.Observe(int64(duration/time.Millisecond), e.command) 140 } 141 142 func (e *realSqlGuard) start(q string, args ...interface{}) error { 143 stmt, err := format(q, args...) 144 if err != nil { 145 return err 146 } 147 148 e.stmt = stmt 149 e.startTime = timex.Now() 150 151 return nil 152 }