github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/db.go (about)

     1  package adapter
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"runtime"
     7  
     8  	"github.com/octohelm/storage/pkg/sqlbuilder"
     9  )
    10  
    11  func Wrap(d *sql.DB, convertErr func(err error) error) DB {
    12  	return &db{
    13  		DB: d,
    14  		option: option{
    15  			convertErr: convertErr,
    16  		},
    17  	}
    18  }
    19  
    20  type option struct {
    21  	convertErr func(err error) error
    22  }
    23  
    24  type db struct {
    25  	option
    26  	*sql.DB
    27  }
    28  
    29  func (d *db) Exec(ctx context.Context, expr sqlbuilder.SqlExpr) (sql.Result, error) {
    30  	e := sqlbuilder.ResolveExprContext(ctx, expr)
    31  	if sqlbuilder.IsNilExpr(e) {
    32  		return nil, nil
    33  	}
    34  	if err := e.Err(); err != nil {
    35  		return nil, d.convertErr(err)
    36  	}
    37  
    38  	if sqlDo := SqlDoFromContext(ctx); sqlDo != nil {
    39  		result, err := sqlDo.ExecContext(ctx, e.Query(), e.Args()...)
    40  		if err != nil {
    41  			return nil, d.convertErr(err)
    42  		}
    43  		return result, nil
    44  	}
    45  
    46  	result, err := d.ExecContext(ctx, e.Query(), e.Args()...)
    47  	if err != nil {
    48  		return nil, d.convertErr(err)
    49  	}
    50  	return result, nil
    51  }
    52  
    53  func (d *db) Query(ctx context.Context, expr sqlbuilder.SqlExpr) (*sql.Rows, error) {
    54  	e := sqlbuilder.ResolveExprContext(ctx, expr)
    55  	if sqlbuilder.IsNilExpr(e) {
    56  		return nil, nil
    57  	}
    58  	if err := e.Err(); err != nil {
    59  		return nil, err
    60  	}
    61  	if sqlDo := SqlDoFromContext(ctx); sqlDo != nil {
    62  		return sqlDo.QueryContext(ctx, e.Query(), e.Args()...)
    63  	}
    64  	return d.QueryContext(ctx, e.Query(), e.Args()...)
    65  }
    66  
    67  func (d *db) Transaction(ctx context.Context, action func(ctx context.Context) error) (err error) {
    68  	var inScopeOfTxnCreated = false
    69  	var txn *sql.Tx
    70  
    71  	if sqlDo := SqlDoFromContext(ctx); sqlDo != nil {
    72  		if tx, ok := sqlDo.(*sql.Tx); ok {
    73  			txn = tx
    74  		}
    75  	}
    76  
    77  	if txn == nil {
    78  		tx, err := d.BeginTx(ctx, nil)
    79  		if err != nil {
    80  			return err
    81  		}
    82  		inScopeOfTxnCreated = true
    83  		txn = tx
    84  	}
    85  
    86  	defer func() {
    87  		if p := recover(); p != nil {
    88  			// make sure rollack
    89  			_ = txn.Rollback()
    90  
    91  			switch e := p.(type) {
    92  			case runtime.Error:
    93  				panic(e)
    94  			case error:
    95  				err = e
    96  			default:
    97  				panic(e)
    98  			}
    99  		} else if inScopeOfTxnCreated {
   100  			if err != nil {
   101  				_ = txn.Rollback()
   102  			} else {
   103  				err = txn.Commit()
   104  			}
   105  		}
   106  	}()
   107  
   108  	return action(ContextWithSqlDo(ctx, txn))
   109  }