github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/depends/kit/sqlx/database.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"time"
     8  
     9  	"github.com/machinefi/w3bstream/pkg/depends/kit/sqlx/builder"
    10  )
    11  
    12  type DBExecutor interface {
    13  	ExprExecutor
    14  	TableResolver
    15  
    16  	Dialect() builder.Dialect
    17  	D() *Database
    18  	WithSchema(string) DBExecutor
    19  
    20  	Context() context.Context
    21  	WithContext(ctx context.Context) DBExecutor
    22  }
    23  
    24  type WithDBName interface {
    25  	WithDBName(string) driver.Connector
    26  }
    27  
    28  type SqlExecutor interface {
    29  	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    30  	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    31  }
    32  
    33  type ExprExecutor interface {
    34  	SqlExecutor
    35  	Exec(builder.SqlExpr) (sql.Result, error)
    36  	Query(builder.SqlExpr) (*sql.Rows, error)
    37  	QueryAndScan(builder.SqlExpr, interface{}) error
    38  }
    39  
    40  type TableResolver interface {
    41  	T(model builder.Model) *builder.Table
    42  }
    43  
    44  type TxExecutor interface {
    45  	IsTx() bool
    46  	BeginTx(*sql.TxOptions) (DBExecutor, error)
    47  	Begin() (DBExecutor, error)
    48  	Commit() error
    49  	Rollback() error
    50  }
    51  
    52  type Migrator interface {
    53  	Migrate(ctx context.Context, db DBExecutor) error
    54  }
    55  
    56  func NewDatabase(name string) *Database {
    57  	return &Database{Name: name, Tables: builder.Tables{}}
    58  }
    59  
    60  type Database struct {
    61  	Name   string
    62  	Schema string
    63  	Tables builder.Tables
    64  }
    65  
    66  func (d Database) WithSchema(schema string) *Database {
    67  	d.Schema = schema
    68  
    69  	tables := builder.Tables{}
    70  	d.Tables.Range(func(t *builder.Table, _ int) {
    71  		tables.Add(t.WithSchema(schema))
    72  	})
    73  	d.Tables = tables
    74  	return &d
    75  }
    76  
    77  func (d *Database) OpenDB(connector driver.Connector) *DB {
    78  	if c, ok := connector.(WithDBName); ok {
    79  		connector = c.WithDBName(d.Name)
    80  	}
    81  	dialect, ok := connector.(builder.Dialect)
    82  	if !ok {
    83  		panic("connect MUST be a builder.Dialect")
    84  	}
    85  	return &DB{
    86  		Database:    d,
    87  		dialect:     dialect,
    88  		SqlExecutor: sql.OpenDB(connector),
    89  	}
    90  }
    91  
    92  func (d *Database) T(model builder.Model) *builder.Table {
    93  	if t, ok := model.(builder.TableDefinition); ok {
    94  		return t.T()
    95  	}
    96  	if t, ok := model.(*builder.Table); ok {
    97  		return t
    98  	}
    99  	return d.Table(model.TableName())
   100  }
   101  
   102  func (d *Database) Table(name string) *builder.Table { return d.Tables.Table(name) }
   103  
   104  func (d *Database) Register(m builder.Model) *builder.Table {
   105  	t := builder.TableFromModel(m)
   106  	t.Schema = d.Schema
   107  	d.AddTable(t)
   108  	return t
   109  }
   110  
   111  func (d *Database) AddTable(t *builder.Table) { d.Tables.Add(t) }
   112  
   113  type DB struct {
   114  	ctx context.Context
   115  
   116  	dialect builder.Dialect
   117  	*Database
   118  	SqlExecutor
   119  }
   120  
   121  func (d *DB) WithContext(ctx context.Context) DBExecutor {
   122  	e := new(DB)
   123  	*e = *d
   124  	e.ctx = ctx
   125  	return e
   126  }
   127  
   128  func (d *DB) Context() context.Context {
   129  	if d.ctx == nil {
   130  		return context.Background()
   131  	}
   132  	return d.ctx
   133  }
   134  
   135  func (d DB) WithSchema(schema string) DBExecutor {
   136  	d.Database = d.Database.WithSchema(schema)
   137  	return &d
   138  }
   139  
   140  func (d *DB) Dialect() builder.Dialect { return d.dialect }
   141  
   142  func (d *DB) D() *Database { return d.Database }
   143  
   144  func (d *DB) Migrate(ctx context.Context, db DBExecutor) error {
   145  	if migrator, ok := d.dialect.(Migrator); ok {
   146  		return migrator.Migrate(ctx, db)
   147  	}
   148  	return nil
   149  }
   150  
   151  func (d *DB) Exec(e builder.SqlExpr) (sql.Result, error) {
   152  	ex := builder.ResolveExprContext(d.Context(), e)
   153  	if builder.IsNilExpr(ex) {
   154  		return nil, nil
   155  	}
   156  	if err := ex.Err(); err != nil {
   157  		return nil, err
   158  	}
   159  	res, err := d.ExecContext(d.Context(), ex.Query(), ex.Args()...)
   160  	if err != nil {
   161  		if d.dialect.IsErrorConflict(err) {
   162  			return nil, NewSqlError(SqlErrTypeConflict, err.Error())
   163  		}
   164  		return nil, err
   165  	}
   166  	return res, nil
   167  }
   168  
   169  func (d *DB) Query(e builder.SqlExpr) (*sql.Rows, error) {
   170  	ex := builder.ResolveExprContext(d.Context(), e)
   171  	if builder.IsNilExpr(ex) {
   172  		return nil, nil
   173  	}
   174  	if err := ex.Err(); err != nil {
   175  		return nil, err
   176  	}
   177  	return d.QueryContext(d.Context(), ex.Query(), ex.Args()...)
   178  }
   179  
   180  func (d *DB) QueryAndScan(e builder.SqlExpr, v interface{}) error {
   181  	rows, err := d.Query(e)
   182  	if err != nil {
   183  		return err
   184  	}
   185  	return Scan(d.Context(), rows, v)
   186  }
   187  
   188  func (d *DB) IsTx() bool { _, ok := d.SqlExecutor.(*sql.Tx); return ok }
   189  
   190  func (d *DB) Begin() (DBExecutor, error) { return d.BeginTx(nil) }
   191  
   192  func (d *DB) BeginTx(opt *sql.TxOptions) (DBExecutor, error) {
   193  	if d.IsTx() {
   194  		return nil, ErrNotDB
   195  	}
   196  	db, err := d.SqlExecutor.(*sql.DB).BeginTx(d.Context(), opt)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	e := *d
   201  	e.SqlExecutor = db
   202  	e.ctx = d.Context()
   203  	return &e, nil
   204  }
   205  
   206  func (d *DB) Commit() error {
   207  	if !d.IsTx() {
   208  		return ErrNotTx
   209  	}
   210  	if d.Context().Err() == context.Canceled {
   211  		return context.Canceled
   212  	}
   213  	return d.SqlExecutor.(*sql.Tx).Commit()
   214  }
   215  
   216  func (d *DB) Rollback() error {
   217  	if !d.IsTx() {
   218  		return ErrNotTx
   219  	}
   220  	if d.Context().Err() == context.Canceled {
   221  		return context.Canceled
   222  	}
   223  	return d.SqlExecutor.(*sql.Tx).Rollback()
   224  }
   225  
   226  func (d *DB) SetMaxOpenConns(n int) { d.SqlExecutor.(*sql.DB).SetMaxOpenConns(n) }
   227  
   228  func (d *DB) SetMaxIdleConns(n int) { d.SqlExecutor.(*sql.DB).SetMaxIdleConns(n) }
   229  
   230  func (d *DB) SetConnMaxLifetime(du time.Duration) { d.SqlExecutor.(*sql.DB).SetConnMaxLifetime(du) }