github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/sqlconn.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  
     7  	"github.com/lingyao2333/mo-zero/core/breaker"
     8  	"github.com/lingyao2333/mo-zero/core/logx"
     9  )
    10  
    11  // spanName is used to identify the span name for the SQL execution.
    12  const spanName = "sql"
    13  
    14  // ErrNotFound is an alias of sql.ErrNoRows
    15  var ErrNotFound = sql.ErrNoRows
    16  
    17  type (
    18  	// Session stands for raw connections or transaction sessions
    19  	Session interface {
    20  		Exec(query string, args ...interface{}) (sql.Result, error)
    21  		ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    22  		Prepare(query string) (StmtSession, error)
    23  		PrepareCtx(ctx context.Context, query string) (StmtSession, error)
    24  		QueryRow(v interface{}, query string, args ...interface{}) error
    25  		QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
    26  		QueryRowPartial(v interface{}, query string, args ...interface{}) error
    27  		QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
    28  		QueryRows(v interface{}, query string, args ...interface{}) error
    29  		QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
    30  		QueryRowsPartial(v interface{}, query string, args ...interface{}) error
    31  		QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
    32  	}
    33  
    34  	// SqlConn only stands for raw connections, so Transact method can be called.
    35  	SqlConn interface {
    36  		Session
    37  		// RawDB is for other ORM to operate with, use it with caution.
    38  		// Notice: don't close it.
    39  		RawDB() (*sql.DB, error)
    40  		Transact(fn func(Session) error) error
    41  		TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
    42  	}
    43  
    44  	// SqlOption defines the method to customize a sql connection.
    45  	SqlOption func(*commonSqlConn)
    46  
    47  	// StmtSession interface represents a session that can be used to execute statements.
    48  	StmtSession interface {
    49  		Close() error
    50  		Exec(args ...interface{}) (sql.Result, error)
    51  		ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
    52  		QueryRow(v interface{}, args ...interface{}) error
    53  		QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
    54  		QueryRowPartial(v interface{}, args ...interface{}) error
    55  		QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
    56  		QueryRows(v interface{}, args ...interface{}) error
    57  		QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
    58  		QueryRowsPartial(v interface{}, args ...interface{}) error
    59  		QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
    60  	}
    61  
    62  	// thread-safe
    63  	// Because CORBA doesn't support PREPARE, so we need to combine the
    64  	// query arguments into one string and do underlying query without arguments
    65  	commonSqlConn struct {
    66  		connProv connProvider
    67  		onError  func(error)
    68  		beginTx  beginnable
    69  		brk      breaker.Breaker
    70  		accept   func(error) bool
    71  	}
    72  
    73  	connProvider func() (*sql.DB, error)
    74  
    75  	sessionConn interface {
    76  		Exec(query string, args ...interface{}) (sql.Result, error)
    77  		ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    78  		Query(query string, args ...interface{}) (*sql.Rows, error)
    79  		QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    80  	}
    81  
    82  	statement struct {
    83  		query string
    84  		stmt  *sql.Stmt
    85  	}
    86  
    87  	stmtConn interface {
    88  		Exec(args ...interface{}) (sql.Result, error)
    89  		ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
    90  		Query(args ...interface{}) (*sql.Rows, error)
    91  		QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
    92  	}
    93  )
    94  
    95  // NewSqlConn returns a SqlConn with given driver name and datasource.
    96  func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
    97  	conn := &commonSqlConn{
    98  		connProv: func() (*sql.DB, error) {
    99  			return getSqlConn(driverName, datasource)
   100  		},
   101  		onError: func(err error) {
   102  			logInstanceError(datasource, err)
   103  		},
   104  		beginTx: begin,
   105  		brk:     breaker.NewBreaker(),
   106  	}
   107  	for _, opt := range opts {
   108  		opt(conn)
   109  	}
   110  
   111  	return conn
   112  }
   113  
   114  // NewSqlConnFromDB returns a SqlConn with the given sql.DB.
   115  // Use it with caution, it's provided for other ORM to interact with.
   116  func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
   117  	conn := &commonSqlConn{
   118  		connProv: func() (*sql.DB, error) {
   119  			return db, nil
   120  		},
   121  		onError: func(err error) {
   122  			logx.Errorf("Error on getting sql instance: %v", err)
   123  		},
   124  		beginTx: begin,
   125  		brk:     breaker.NewBreaker(),
   126  	}
   127  	for _, opt := range opts {
   128  		opt(conn)
   129  	}
   130  
   131  	return conn
   132  }
   133  
   134  func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
   135  	return db.ExecCtx(context.Background(), q, args...)
   136  }
   137  
   138  func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
   139  	result sql.Result, err error) {
   140  	ctx, span := startSpan(ctx, "Exec")
   141  	defer func() {
   142  		endSpan(span, err)
   143  	}()
   144  
   145  	err = db.brk.DoWithAcceptable(func() error {
   146  		var conn *sql.DB
   147  		conn, err = db.connProv()
   148  		if err != nil {
   149  			db.onError(err)
   150  			return err
   151  		}
   152  
   153  		result, err = exec(ctx, conn, q, args...)
   154  		return err
   155  	}, db.acceptable)
   156  	if err == breaker.ErrServiceUnavailable {
   157  		metricReqErr.Inc("Exec", "breaker")
   158  	}
   159  
   160  	return
   161  }
   162  
   163  func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
   164  	return db.PrepareCtx(context.Background(), query)
   165  }
   166  
   167  func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
   168  	ctx, span := startSpan(ctx, "Prepare")
   169  	defer func() {
   170  		endSpan(span, err)
   171  	}()
   172  
   173  	err = db.brk.DoWithAcceptable(func() error {
   174  		var conn *sql.DB
   175  		conn, err = db.connProv()
   176  		if err != nil {
   177  			db.onError(err)
   178  			return err
   179  		}
   180  
   181  		st, err := conn.PrepareContext(ctx, query)
   182  		if err != nil {
   183  			return err
   184  		}
   185  
   186  		stmt = statement{
   187  			query: query,
   188  			stmt:  st,
   189  		}
   190  		return nil
   191  	}, db.acceptable)
   192  	if err == breaker.ErrServiceUnavailable {
   193  		metricReqErr.Inc("Prepare", "breaker")
   194  	}
   195  
   196  	return
   197  }
   198  
   199  func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
   200  	return db.QueryRowCtx(context.Background(), v, q, args...)
   201  }
   202  
   203  func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
   204  	args ...interface{}) (err error) {
   205  	ctx, span := startSpan(ctx, "QueryRow")
   206  	defer func() {
   207  		endSpan(span, err)
   208  	}()
   209  
   210  	return db.queryRows(ctx, func(rows *sql.Rows) error {
   211  		return unmarshalRow(v, rows, true)
   212  	}, q, args...)
   213  }
   214  
   215  func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
   216  	return db.QueryRowPartialCtx(context.Background(), v, q, args...)
   217  }
   218  
   219  func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
   220  	q string, args ...interface{}) (err error) {
   221  	ctx, span := startSpan(ctx, "QueryRowPartial")
   222  	defer func() {
   223  		endSpan(span, err)
   224  	}()
   225  
   226  	return db.queryRows(ctx, func(rows *sql.Rows) error {
   227  		return unmarshalRow(v, rows, false)
   228  	}, q, args...)
   229  }
   230  
   231  func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
   232  	return db.QueryRowsCtx(context.Background(), v, q, args...)
   233  }
   234  
   235  func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
   236  	args ...interface{}) (err error) {
   237  	ctx, span := startSpan(ctx, "QueryRows")
   238  	defer func() {
   239  		endSpan(span, err)
   240  	}()
   241  
   242  	return db.queryRows(ctx, func(rows *sql.Rows) error {
   243  		return unmarshalRows(v, rows, true)
   244  	}, q, args...)
   245  }
   246  
   247  func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
   248  	return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
   249  }
   250  
   251  func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
   252  	q string, args ...interface{}) (err error) {
   253  	ctx, span := startSpan(ctx, "QueryRowsPartial")
   254  	defer func() {
   255  		endSpan(span, err)
   256  	}()
   257  
   258  	return db.queryRows(ctx, func(rows *sql.Rows) error {
   259  		return unmarshalRows(v, rows, false)
   260  	}, q, args...)
   261  }
   262  
   263  func (db *commonSqlConn) RawDB() (*sql.DB, error) {
   264  	return db.connProv()
   265  }
   266  
   267  func (db *commonSqlConn) Transact(fn func(Session) error) error {
   268  	return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
   269  		return fn(session)
   270  	})
   271  }
   272  
   273  func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) (err error) {
   274  	ctx, span := startSpan(ctx, "Transact")
   275  	defer func() {
   276  		endSpan(span, err)
   277  	}()
   278  
   279  	err = db.brk.DoWithAcceptable(func() error {
   280  		return transact(ctx, db, db.beginTx, fn)
   281  	}, db.acceptable)
   282  	if err == breaker.ErrServiceUnavailable {
   283  		metricReqErr.Inc("Transact", "breaker")
   284  	}
   285  
   286  	return
   287  }
   288  
   289  func (db *commonSqlConn) acceptable(err error) bool {
   290  	ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
   291  	if db.accept == nil {
   292  		return ok
   293  	}
   294  
   295  	return ok || db.accept(err)
   296  }
   297  
   298  func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
   299  	q string, args ...interface{}) (err error) {
   300  	var qerr error
   301  	err = db.brk.DoWithAcceptable(func() error {
   302  		conn, err := db.connProv()
   303  		if err != nil {
   304  			db.onError(err)
   305  			return err
   306  		}
   307  
   308  		return query(ctx, conn, func(rows *sql.Rows) error {
   309  			qerr = scanner(rows)
   310  			return qerr
   311  		}, q, args...)
   312  	}, func(err error) bool {
   313  		return qerr == err || db.acceptable(err)
   314  	})
   315  	if err == breaker.ErrServiceUnavailable {
   316  		metricReqErr.Inc("queryRows", "breaker")
   317  	}
   318  
   319  	return
   320  }
   321  
   322  func (s statement) Close() error {
   323  	return s.stmt.Close()
   324  }
   325  
   326  func (s statement) Exec(args ...interface{}) (sql.Result, error) {
   327  	return s.ExecCtx(context.Background(), args...)
   328  }
   329  
   330  func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (result sql.Result, err error) {
   331  	ctx, span := startSpan(ctx, "Exec")
   332  	defer func() {
   333  		endSpan(span, err)
   334  	}()
   335  
   336  	return execStmt(ctx, s.stmt, s.query, args...)
   337  }
   338  
   339  func (s statement) QueryRow(v interface{}, args ...interface{}) error {
   340  	return s.QueryRowCtx(context.Background(), v, args...)
   341  }
   342  
   343  func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
   344  	ctx, span := startSpan(ctx, "QueryRow")
   345  	defer func() {
   346  		endSpan(span, err)
   347  	}()
   348  
   349  	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
   350  		return unmarshalRow(v, rows, true)
   351  	}, s.query, args...)
   352  }
   353  
   354  func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
   355  	return s.QueryRowPartialCtx(context.Background(), v, args...)
   356  }
   357  
   358  func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
   359  	ctx, span := startSpan(ctx, "QueryRowPartial")
   360  	defer func() {
   361  		endSpan(span, err)
   362  	}()
   363  
   364  	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
   365  		return unmarshalRow(v, rows, false)
   366  	}, s.query, args...)
   367  }
   368  
   369  func (s statement) QueryRows(v interface{}, args ...interface{}) error {
   370  	return s.QueryRowsCtx(context.Background(), v, args...)
   371  }
   372  
   373  func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
   374  	ctx, span := startSpan(ctx, "QueryRows")
   375  	defer func() {
   376  		endSpan(span, err)
   377  	}()
   378  
   379  	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
   380  		return unmarshalRows(v, rows, true)
   381  	}, s.query, args...)
   382  }
   383  
   384  func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
   385  	return s.QueryRowsPartialCtx(context.Background(), v, args...)
   386  }
   387  
   388  func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) (err error) {
   389  	ctx, span := startSpan(ctx, "QueryRowsPartial")
   390  	defer func() {
   391  		endSpan(span, err)
   392  	}()
   393  
   394  	return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
   395  		return unmarshalRows(v, rows, false)
   396  	}, s.query, args...)
   397  }