github.com/acoshift/pgsql@v0.15.3/pgctx/pgctx.go (about)

     1  package pgctx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"net/http"
     8  
     9  	"github.com/acoshift/pgsql"
    10  )
    11  
    12  type DB interface {
    13  	Queryer
    14  	pgsql.BeginTxer
    15  }
    16  
    17  // Queryer interface
    18  type Queryer interface {
    19  	QueryRowContext(context.Context, string, ...any) *sql.Row
    20  	QueryContext(context.Context, string, ...any) (*sql.Rows, error)
    21  	ExecContext(context.Context, string, ...any) (sql.Result, error)
    22  	PrepareContext(context.Context, string) (*sql.Stmt, error)
    23  }
    24  
    25  func NewKeyContext(ctx context.Context, key any, db DB) context.Context {
    26  	return context.WithValue(ctx, ctxKeyDB{key}, db)
    27  }
    28  
    29  // NewContext creates new context
    30  func NewContext(ctx context.Context, db DB) context.Context {
    31  	return NewKeyContext(ctx, nil, db)
    32  }
    33  
    34  func KeyMiddleware(key any, db DB) func(h http.Handler) http.Handler {
    35  	return func(h http.Handler) http.Handler {
    36  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    37  			r = r.WithContext(NewKeyContext(r.Context(), key, db))
    38  			h.ServeHTTP(w, r)
    39  		})
    40  	}
    41  }
    42  
    43  // Middleware injects db into request's context
    44  func Middleware(db DB) func(h http.Handler) http.Handler {
    45  	return KeyMiddleware(nil, db)
    46  }
    47  
    48  // With creates new empty key context with db from keyed context
    49  func With(ctx context.Context, key any) context.Context {
    50  	db := ctx.Value(ctxKeyDB{key})
    51  	return context.WithValue(ctx, ctxKeyDB{}, db)
    52  }
    53  
    54  func GetDB(ctx context.Context) DB {
    55  	return ctx.Value(ctxKeyDB{}).(DB)
    56  }
    57  
    58  func GetDBKey(ctx context.Context, key any) DB {
    59  	return ctx.Value(ctxKeyDB{key}).(DB)
    60  }
    61  
    62  func GetTx(ctx context.Context) *sql.Tx {
    63  	return ctx.Value(ctxKeyQueryer{}).(*wrapTx).Tx // panic if not in tx
    64  }
    65  
    66  type wrapTx struct {
    67  	*sql.Tx
    68  	onCommitted []func(ctx context.Context)
    69  }
    70  
    71  var _ Queryer = &wrapTx{}
    72  
    73  // RunInTxOptions starts sql tx if not started
    74  func RunInTxOptions(ctx context.Context, opt *pgsql.TxOptions, f func(ctx context.Context) error) error {
    75  	if IsInTx(ctx) {
    76  		return f(ctx)
    77  	}
    78  
    79  	db := ctx.Value(ctxKeyDB{}).(pgsql.BeginTxer)
    80  	var pTx wrapTx
    81  	abort := false
    82  	err := pgsql.RunInTxContext(ctx, db, opt, func(tx *sql.Tx) error {
    83  		pTx = wrapTx{Tx: tx}
    84  		ctx := context.WithValue(ctx, ctxKeyQueryer{}, &pTx)
    85  		err := f(ctx)
    86  		if errors.Is(err, pgsql.ErrAbortTx) {
    87  			abort = true
    88  		}
    89  		return err
    90  	})
    91  	if err != nil {
    92  		return err
    93  	}
    94  	if !abort && len(pTx.onCommitted) > 0 {
    95  		for _, f := range pTx.onCommitted {
    96  			f(ctx)
    97  		}
    98  	}
    99  	return nil
   100  }
   101  
   102  // RunInTx calls RunInTxOptions with default options
   103  func RunInTx(ctx context.Context, f func(ctx context.Context) error) error {
   104  	return RunInTxOptions(ctx, nil, f)
   105  }
   106  
   107  // RunInReadOnlyTx calls RunInTxOptions with read only options
   108  func RunInReadOnlyTx(ctx context.Context, f func(ctx context.Context) error) error {
   109  	var opts pgsql.TxOptions
   110  	opts.TxOptions.ReadOnly = true
   111  	return RunInTxOptions(ctx, &opts, f)
   112  }
   113  
   114  // IsInTx checks is context inside RunInTx
   115  func IsInTx(ctx context.Context) bool {
   116  	_, ok := ctx.Value(ctxKeyQueryer{}).(*wrapTx)
   117  	return ok
   118  }
   119  
   120  // Committed calls f after committed or immediate if not in tx
   121  func Committed(ctx context.Context, f func(ctx context.Context)) {
   122  	if f == nil {
   123  		return
   124  	}
   125  
   126  	if !IsInTx(ctx) {
   127  		f(ctx)
   128  		return
   129  	}
   130  
   131  	pTx := ctx.Value(ctxKeyQueryer{}).(*wrapTx)
   132  	pTx.onCommitted = append(pTx.onCommitted, f)
   133  }
   134  
   135  type (
   136  	ctxKeyDB struct {
   137  		key any
   138  	}
   139  	ctxKeyQueryer struct{}
   140  )
   141  
   142  func q(ctx context.Context) Queryer {
   143  	if q, ok := ctx.Value(ctxKeyQueryer{}).(Queryer); ok {
   144  		return q
   145  	}
   146  	return ctx.Value(ctxKeyDB{}).(Queryer)
   147  }
   148  
   149  // QueryRow calls db.QueryRowContext
   150  func QueryRow(ctx context.Context, query string, args ...any) *pgsql.Row {
   151  	return &pgsql.Row{q(ctx).QueryRowContext(ctx, query, args...)}
   152  }
   153  
   154  // Query calls db.QueryContext
   155  func Query(ctx context.Context, query string, args ...any) (*pgsql.Rows, error) {
   156  	rows, err := q(ctx).QueryContext(ctx, query, args...)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	return &pgsql.Rows{rows}, nil
   161  }
   162  
   163  // Exec calls db.ExecContext
   164  func Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
   165  	return q(ctx).ExecContext(ctx, query, args...)
   166  }
   167  
   168  // Iter calls pgsql.IterContext
   169  func Iter(ctx context.Context, iter pgsql.Iterator, query string, args ...any) error {
   170  	return pgsql.IterContext(ctx, q(ctx), iter, query, args...)
   171  }
   172  
   173  // Prepare calls db.PrepareContext
   174  func Prepare(ctx context.Context, query string) (*sql.Stmt, error) {
   175  	return q(ctx).PrepareContext(ctx, query)
   176  }