github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/adapter.go (about)

     1  package sqx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"log"
     7  	"strings"
     8  
     9  	"github.com/bingoohuang/gg/pkg/sqlparse/sqlparser"
    10  	"go.uber.org/multierr"
    11  )
    12  
    13  type DBTypeAware interface {
    14  	GetDBType() sqlparser.DBType
    15  }
    16  
    17  func (s Sqx) GetDBType() sqlparser.DBType { return s.DBType }
    18  
    19  type Sqx struct {
    20  	DB      SqxDB
    21  	DBType  sqlparser.DBType
    22  	CloseFn func() error
    23  }
    24  
    25  func LogSqlResultDesc(desc string, lastResult sql.Result) {
    26  	lastInsertId, _ := lastResult.LastInsertId()
    27  	rowsAffected, _ := lastResult.RowsAffected()
    28  	log.Printf("%sresult lastInsertId: %d, rowsAffected: %d", quoteDesc(desc), lastInsertId, rowsAffected)
    29  }
    30  
    31  func quoteDesc(desc string) string {
    32  	if desc != "" && !strings.HasPrefix(desc, "[") {
    33  		desc = "[" + desc + "] "
    34  	}
    35  	return desc
    36  }
    37  
    38  func logQueryError(nolog bool, desc string, result sql.Result, err error) {
    39  	if nolog {
    40  		return
    41  	}
    42  
    43  	if err != nil {
    44  		log.Printf("%squery error: %v", quoteDesc(desc), err)
    45  	} else if result != nil {
    46  		LogSqlResultDesc(desc, result)
    47  	}
    48  }
    49  
    50  func logRows(desc string, rows int) {
    51  	log.Printf("%squery %d rows", quoteDesc(desc), rows)
    52  }
    53  
    54  func logQuery(desc, query string, args []interface{}) {
    55  	log.Printf("%squery [%s] with args: %v", quoteDesc(desc), query, args)
    56  }
    57  
    58  type Executable interface {
    59  	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    60  }
    61  
    62  type ExecFn func(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    63  
    64  func (f ExecFn) Exec(query string, args ...interface{}) (sql.Result, error) {
    65  	return f(context.Background(), query, args...)
    66  }
    67  
    68  func (f ExecFn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
    69  	return f(ctx, query, args...)
    70  }
    71  
    72  type Queryable interface {
    73  	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    74  }
    75  
    76  type QueryFn func(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    77  
    78  func (f QueryFn) Query(query string, args ...interface{}) (*sql.Rows, error) {
    79  	return f(context.Background(), query, args...)
    80  }
    81  
    82  func (f QueryFn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
    83  	return f(ctx, query, args...)
    84  }
    85  
    86  func (s *Sqx) Close() error {
    87  	if s.CloseFn != nil {
    88  		return s.CloseFn()
    89  	}
    90  
    91  	return nil
    92  }
    93  
    94  func (s *Sqx) DoQuery(arg *QueryArgs) error {
    95  	return arg.DoQuery(s.typedDB())
    96  }
    97  
    98  func (s *Sqx) DoExec(arg *QueryArgs) (int64, error) {
    99  	return arg.DoExec(s.typedDB())
   100  }
   101  
   102  func (s *Sqx) DoExecRaw(arg *QueryArgs) (sql.Result, error) {
   103  	return arg.DoExecRaw(s.typedDB())
   104  }
   105  
   106  func (a *QueryArgs) DoExecRaw(db SqxDB) (sql.Result, error) {
   107  	ns := NewSQL(a.Query, a.Args...)
   108  	ns.Ctx = a.Ctx
   109  	return ns.WithConvertOptions(a.Options).UpdateRaw(db)
   110  }
   111  
   112  func (a *QueryArgs) DoExec(db SqxDB) (int64, error) {
   113  	ns := NewSQL(a.Query, a.Args...)
   114  	ns.Ctx = a.Ctx
   115  	return ns.WithConvertOptions(a.Options).Update(db)
   116  }
   117  
   118  func (a *QueryArgs) DoQuery(db SqxDB) error {
   119  	ns := NewSQL(a.Query, a.Args...)
   120  	ns.Ctx = a.Ctx
   121  	return ns.WithConvertOptions(a.Options).Query(db, a.Dest)
   122  }
   123  
   124  func NewSqx(db *sql.DB) *Sqx {
   125  	return &Sqx{DB: db, DBType: sqlparser.ToDBType(DriverName(db.Driver())), CloseFn: db.Close}
   126  }
   127  
   128  type QueryArgs struct {
   129  	Desc    string
   130  	Dest    interface{}
   131  	Query   string
   132  	Args    []interface{}
   133  	Limit   int
   134  	Options []sqlparser.ConvertOption
   135  	Ctx     context.Context
   136  }
   137  
   138  func (s *Sqx) SelectDesc(desc string, dest interface{}, query string, args ...interface{}) error {
   139  	return s.DoQuery(&QueryArgs{Desc: desc, Dest: dest, Query: query, Args: args})
   140  }
   141  
   142  func (s *Sqx) Select(dest interface{}, query string, args ...interface{}) error {
   143  	return s.DoQuery(&QueryArgs{Dest: dest, Query: query, Args: args})
   144  }
   145  
   146  func (s *Sqx) GetDesc(desc string, dest interface{}, query string, args ...interface{}) error {
   147  	return s.DoQuery(&QueryArgs{Desc: desc, Dest: dest, Query: query, Args: args, Limit: 1})
   148  }
   149  
   150  func (s *Sqx) Get(dest interface{}, query string, args ...interface{}) error {
   151  	return s.DoQuery(&QueryArgs{Dest: dest, Query: query, Args: args, Limit: 1})
   152  }
   153  
   154  func (s *Sqx) Upsert(insertQuery, updateQuery string, args ...interface{}) (ur UpsertResult, err error) {
   155  	return Upsert(s, insertQuery, updateQuery, args...)
   156  }
   157  
   158  func (s *Sqx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   159  	return s.DoExecRaw(&QueryArgs{Ctx: ctx, Query: query, Args: args, Limit: 1})
   160  }
   161  
   162  func (s *Sqx) Exec(query string, args ...interface{}) (sql.Result, error) {
   163  	return s.ExecContext(context.Background(), query, args...)
   164  }
   165  
   166  func (s *Sqx) Query(query string, args ...interface{}) (*sql.Rows, error) {
   167  	return s.QueryContext(context.Background(), query, args...)
   168  }
   169  
   170  type ContextKey int
   171  
   172  const (
   173  	AdaptedKey ContextKey = iota + 1
   174  )
   175  
   176  func (s *Sqx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   177  	s2 := SQL{Ctx: ctx, Q: query, Vars: args}
   178  
   179  	if adapted, ok := ctx.Value(AdaptedKey).(bool); !ok || !adapted {
   180  		if err := s2.adaptQuery(s); err != nil {
   181  			return nil, err
   182  		}
   183  	}
   184  
   185  	rows, err := s.DB.QueryContext(ctx, s2.Q, s2.Vars...)
   186  	if err != nil {
   187  		logQueryError(false, "", nil, err)
   188  	}
   189  	return rows, err
   190  }
   191  
   192  type BeginTx interface {
   193  	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
   194  }
   195  
   196  func (s *Sqx) Tx(f func(sqx *Sqx) error) error {
   197  	return s.TxContext(context.Background(), f)
   198  }
   199  
   200  func (s *Sqx) TxContext(ctx context.Context, f func(sqx *Sqx) error) error {
   201  	btx, ok := s.DB.(BeginTx)
   202  	if !ok {
   203  		panic("can't begin transaction")
   204  	}
   205  
   206  	tx, err := btx.BeginTx(ctx, nil)
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	if err := f(&Sqx{DB: tx, DBType: s.DBType}); err != nil {
   212  		return multierr.Append(err, tx.Rollback())
   213  	}
   214  
   215  	return tx.Commit()
   216  }
   217  
   218  type DBRaw struct {
   219  	SqxDB
   220  	DBType sqlparser.DBType
   221  }
   222  
   223  func (t DBRaw) GetDBType() sqlparser.DBType { return t.DBType }
   224  
   225  func (s Sqx) typedDB() SqxDB {
   226  	return &DBRaw{
   227  		SqxDB:  s.DB,
   228  		DBType: s.DBType,
   229  	}
   230  }
   231  
   232  func VarsStr(keys ...string) []interface{} {
   233  	args := make([]interface{}, len(keys))
   234  	for i := 0; i < len(keys); i++ {
   235  		args[i] = keys[i]
   236  	}
   237  
   238  	return args
   239  }
   240  
   241  func Vars(keys ...interface{}) []interface{} {
   242  	args := make([]interface{}, len(keys))
   243  	for i := 0; i < len(keys); i++ {
   244  		args[i] = keys[i]
   245  	}
   246  
   247  	return args
   248  }
   249  
   250  type UpsertResult int
   251  
   252  const (
   253  	UpsertError UpsertResult = iota
   254  	UpsertInserted
   255  	UpsertUpdated
   256  )
   257  
   258  func Upsert(executable Executable, insertQuery, updateQuery string, args ...interface{}) (ur UpsertResult, err error) {
   259  	return UpsertContext(context.Background(), executable, insertQuery, updateQuery, args...)
   260  }
   261  
   262  func UpsertContext(ctx context.Context, executable Executable, insertQuery, updateQuery string, args ...interface{}) (ur UpsertResult, err error) {
   263  	_, err1 := executable.ExecContext(ctx, insertQuery, args...)
   264  	if err1 == nil {
   265  		return UpsertInserted, nil
   266  	}
   267  
   268  	_, err2 := executable.ExecContext(ctx, updateQuery, args...)
   269  	if err2 == nil {
   270  		return UpsertUpdated, nil
   271  	}
   272  
   273  	return UpsertError, multierr.Append(err1, err2)
   274  }