github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/stmt.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"time"
     7  
     8  	"github.com/lingyao2333/mo-zero/core/logx"
     9  	"github.com/lingyao2333/mo-zero/core/syncx"
    10  	"github.com/lingyao2333/mo-zero/core/timex"
    11  )
    12  
    13  const defaultSlowThreshold = time.Millisecond * 500
    14  
    15  var (
    16  	slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
    17  	logSql        = syncx.ForAtomicBool(true)
    18  	logSlowSql    = syncx.ForAtomicBool(true)
    19  )
    20  
    21  // DisableLog disables logging of sql statements, includes info and slow logs.
    22  func DisableLog() {
    23  	logSql.Set(false)
    24  	logSlowSql.Set(false)
    25  }
    26  
    27  // DisableStmtLog disables info logging of sql statements, but keeps slow logs.
    28  func DisableStmtLog() {
    29  	logSql.Set(false)
    30  }
    31  
    32  // SetSlowThreshold sets the slow threshold.
    33  func SetSlowThreshold(threshold time.Duration) {
    34  	slowThreshold.Set(threshold)
    35  }
    36  
    37  func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
    38  	guard := newGuard("exec")
    39  	if err := guard.start(q, args...); err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	result, err := conn.ExecContext(ctx, q, args...)
    44  	guard.finish(ctx, err)
    45  
    46  	return result, err
    47  }
    48  
    49  func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
    50  	guard := newGuard("execStmt")
    51  	if err := guard.start(q, args...); err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	result, err := conn.ExecContext(ctx, args...)
    56  	guard.finish(ctx, err)
    57  
    58  	return result, err
    59  }
    60  
    61  func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
    62  	q string, args ...interface{}) error {
    63  	guard := newGuard("query")
    64  	if err := guard.start(q, args...); err != nil {
    65  		return err
    66  	}
    67  
    68  	rows, err := conn.QueryContext(ctx, q, args...)
    69  	guard.finish(ctx, err)
    70  	if err != nil {
    71  		return err
    72  	}
    73  	defer rows.Close()
    74  
    75  	return scanner(rows)
    76  }
    77  
    78  func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
    79  	q string, args ...interface{}) error {
    80  	guard := newGuard("queryStmt")
    81  	if err := guard.start(q, args...); err != nil {
    82  		return err
    83  	}
    84  
    85  	rows, err := conn.QueryContext(ctx, args...)
    86  	guard.finish(ctx, err)
    87  	if err != nil {
    88  		return err
    89  	}
    90  	defer rows.Close()
    91  
    92  	return scanner(rows)
    93  }
    94  
    95  type (
    96  	sqlGuard interface {
    97  		start(q string, args ...interface{}) error
    98  		finish(ctx context.Context, err error)
    99  	}
   100  
   101  	nilGuard struct{}
   102  
   103  	realSqlGuard struct {
   104  		command   string
   105  		stmt      string
   106  		startTime time.Duration
   107  	}
   108  )
   109  
   110  func newGuard(command string) sqlGuard {
   111  	if logSql.True() || logSlowSql.True() {
   112  		return &realSqlGuard{
   113  			command: command,
   114  		}
   115  	}
   116  
   117  	return nilGuard{}
   118  }
   119  
   120  func (n nilGuard) start(_ string, _ ...interface{}) error {
   121  	return nil
   122  }
   123  
   124  func (n nilGuard) finish(_ context.Context, _ error) {
   125  }
   126  
   127  func (e *realSqlGuard) finish(ctx context.Context, err error) {
   128  	duration := timex.Since(e.startTime)
   129  	if duration > slowThreshold.Load() {
   130  		logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] %s: slowcall - %s", e.command, e.stmt)
   131  	} else if logSql.True() {
   132  		logx.WithContext(ctx).WithDuration(duration).Infof("sql %s: %s", e.command, e.stmt)
   133  	}
   134  
   135  	if err != nil {
   136  		logSqlError(ctx, e.stmt, err)
   137  	}
   138  
   139  	metricReqDur.Observe(int64(duration/time.Millisecond), e.command)
   140  }
   141  
   142  func (e *realSqlGuard) start(q string, args ...interface{}) error {
   143  	stmt, err := format(q, args...)
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	e.stmt = stmt
   149  	e.startTime = timex.Now()
   150  
   151  	return nil
   152  }