github.com/kunlun-qilian/sqlx/v3@v3.0.0/db.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"time"
     7  
     8  	"github.com/pkg/errors"
     9  
    10  	"github.com/kunlun-qilian/sqlx/v3/builder"
    11  )
    12  
    13  var ErrNotTx = errors.New("db is not *sql.Tx")
    14  var ErrNotDB = errors.New("db is not *sql.DB")
    15  
    16  type SqlExecutor interface {
    17  	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    18  	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    19  }
    20  
    21  type SqlxExecutor interface {
    22  	SqlExecutor
    23  	ExecExpr(expr builder.SqlExpr) (sql.Result, error)
    24  	QueryExpr(expr builder.SqlExpr) (*sql.Rows, error)
    25  
    26  	QueryExprAndScan(expr builder.SqlExpr, v interface{}) error
    27  }
    28  
    29  type Migrator interface {
    30  	Migrate(ctx context.Context, db DBExecutor) error
    31  }
    32  
    33  type TableResolver interface {
    34  	// T return table of the connecting database
    35  	T(model builder.Model) *builder.Table
    36  }
    37  
    38  type DBExecutor interface {
    39  	SqlxExecutor
    40  	TableResolver
    41  
    42  	// Dialect of databases
    43  	Dialect() builder.Dialect
    44  	// D return database which is connecting
    45  	D() *Database
    46  	// WithSchema switch database schema
    47  	WithSchema(schema string) DBExecutor
    48  
    49  	Context() context.Context
    50  	WithContext(ctx context.Context) DBExecutor
    51  }
    52  
    53  type MaybeTxExecutor interface {
    54  	IsTx() bool
    55  	BeginTx(*sql.TxOptions) (DBExecutor, error)
    56  	Begin() (DBExecutor, error)
    57  	Commit() error
    58  	Rollback() error
    59  }
    60  
    61  type DB struct {
    62  	dialect builder.Dialect
    63  	*Database
    64  	SqlExecutor
    65  	ctx context.Context
    66  }
    67  
    68  func (d *DB) WithContext(ctx context.Context) DBExecutor {
    69  	dd := new(DB)
    70  	*dd = *d
    71  	dd.ctx = ctx
    72  	return dd
    73  }
    74  
    75  func (d *DB) Context() context.Context {
    76  	if d.ctx != nil {
    77  		return d.ctx
    78  	}
    79  	return context.Background()
    80  }
    81  
    82  func (d DB) WithSchema(schema string) DBExecutor {
    83  	d.Database = d.Database.WithSchema(schema)
    84  	return &d
    85  }
    86  
    87  func (d *DB) Dialect() builder.Dialect {
    88  	return d.dialect
    89  }
    90  
    91  func (d *DB) Migrate(ctx context.Context, db DBExecutor) error {
    92  	if migrator, ok := d.dialect.(Migrator); ok {
    93  		return migrator.Migrate(ctx, db)
    94  	}
    95  	return nil
    96  }
    97  
    98  func (d *DB) D() *Database {
    99  	return d.Database
   100  }
   101  
   102  func (d *DB) ExecExpr(expr builder.SqlExpr) (sql.Result, error) {
   103  	e := builder.ResolveExprContext(d.Context(), expr)
   104  	if builder.IsNilExpr(e) {
   105  		return nil, nil
   106  	}
   107  	if err := e.Err(); err != nil {
   108  		return nil, err
   109  	}
   110  	result, err := d.ExecContext(d.Context(), e.Query(), e.Args()...)
   111  	if err != nil {
   112  		if d.dialect.IsErrorConflict(err) {
   113  			return nil, NewSqlError(sqlErrTypeConflict, err.Error())
   114  		}
   115  		return nil, err
   116  	}
   117  	return result, nil
   118  }
   119  
   120  func (d *DB) QueryExpr(expr builder.SqlExpr) (*sql.Rows, error) {
   121  	e := builder.ResolveExprContext(d.Context(), expr)
   122  	if builder.IsNilExpr(e) {
   123  		return nil, nil
   124  	}
   125  	if err := e.Err(); err != nil {
   126  		return nil, err
   127  	}
   128  	return d.QueryContext(d.Context(), e.Query(), e.Args()...)
   129  }
   130  
   131  func (d *DB) QueryExprAndScan(expr builder.SqlExpr, v interface{}) error {
   132  	rows, err := d.QueryExpr(expr)
   133  	if err != nil {
   134  		return err
   135  	}
   136  	return Scan(rows, v)
   137  }
   138  
   139  func (d *DB) IsTx() bool {
   140  	_, ok := d.SqlExecutor.(*sql.Tx)
   141  	return ok
   142  }
   143  
   144  func (d *DB) Begin() (DBExecutor, error) {
   145  	return d.BeginTx(nil)
   146  }
   147  
   148  func (d *DB) BeginTx(opt *sql.TxOptions) (DBExecutor, error) {
   149  	if d.IsTx() {
   150  		return nil, ErrNotDB
   151  	}
   152  	db, err := d.SqlExecutor.(*sql.DB).BeginTx(d.Context(), opt)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	return &DB{
   157  		Database:    d.Database,
   158  		dialect:     d.dialect,
   159  		SqlExecutor: db,
   160  		ctx:         d.Context(),
   161  	}, nil
   162  }
   163  
   164  func (d *DB) Commit() error {
   165  	if !d.IsTx() {
   166  		return ErrNotTx
   167  	}
   168  	if d.Context().Err() == context.Canceled {
   169  		return context.Canceled
   170  	}
   171  	return d.SqlExecutor.(*sql.Tx).Commit()
   172  }
   173  
   174  func (d *DB) Rollback() error {
   175  	if !d.IsTx() {
   176  		return ErrNotTx
   177  	}
   178  	if d.Context().Err() == context.Canceled {
   179  		return context.Canceled
   180  	}
   181  	return d.SqlExecutor.(*sql.Tx).Rollback()
   182  }
   183  
   184  func (d *DB) SetMaxOpenConns(n int) {
   185  	d.SqlExecutor.(*sql.DB).SetMaxOpenConns(n)
   186  }
   187  
   188  func (d *DB) SetMaxIdleConns(n int) {
   189  	d.SqlExecutor.(*sql.DB).SetMaxIdleConns(n)
   190  }
   191  
   192  func (d *DB) SetConnMaxLifetime(t time.Duration) {
   193  	d.SqlExecutor.(*sql.DB).SetConnMaxLifetime(t)
   194  }