github.com/acoshift/pgsql@v0.15.3/pgctx/pgctx.go (about) 1 package pgctx 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "net/http" 8 9 "github.com/acoshift/pgsql" 10 ) 11 12 type DB interface { 13 Queryer 14 pgsql.BeginTxer 15 } 16 17 // Queryer interface 18 type Queryer interface { 19 QueryRowContext(context.Context, string, ...any) *sql.Row 20 QueryContext(context.Context, string, ...any) (*sql.Rows, error) 21 ExecContext(context.Context, string, ...any) (sql.Result, error) 22 PrepareContext(context.Context, string) (*sql.Stmt, error) 23 } 24 25 func NewKeyContext(ctx context.Context, key any, db DB) context.Context { 26 return context.WithValue(ctx, ctxKeyDB{key}, db) 27 } 28 29 // NewContext creates new context 30 func NewContext(ctx context.Context, db DB) context.Context { 31 return NewKeyContext(ctx, nil, db) 32 } 33 34 func KeyMiddleware(key any, db DB) func(h http.Handler) http.Handler { 35 return func(h http.Handler) http.Handler { 36 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 r = r.WithContext(NewKeyContext(r.Context(), key, db)) 38 h.ServeHTTP(w, r) 39 }) 40 } 41 } 42 43 // Middleware injects db into request's context 44 func Middleware(db DB) func(h http.Handler) http.Handler { 45 return KeyMiddleware(nil, db) 46 } 47 48 // With creates new empty key context with db from keyed context 49 func With(ctx context.Context, key any) context.Context { 50 db := ctx.Value(ctxKeyDB{key}) 51 return context.WithValue(ctx, ctxKeyDB{}, db) 52 } 53 54 func GetDB(ctx context.Context) DB { 55 return ctx.Value(ctxKeyDB{}).(DB) 56 } 57 58 func GetDBKey(ctx context.Context, key any) DB { 59 return ctx.Value(ctxKeyDB{key}).(DB) 60 } 61 62 func GetTx(ctx context.Context) *sql.Tx { 63 return ctx.Value(ctxKeyQueryer{}).(*wrapTx).Tx // panic if not in tx 64 } 65 66 type wrapTx struct { 67 *sql.Tx 68 onCommitted []func(ctx context.Context) 69 } 70 71 var _ Queryer = &wrapTx{} 72 73 // RunInTxOptions starts sql tx if not started 74 func RunInTxOptions(ctx context.Context, opt *pgsql.TxOptions, f func(ctx context.Context) error) error { 75 if IsInTx(ctx) { 76 return f(ctx) 77 } 78 79 db := ctx.Value(ctxKeyDB{}).(pgsql.BeginTxer) 80 var pTx wrapTx 81 abort := false 82 err := pgsql.RunInTxContext(ctx, db, opt, func(tx *sql.Tx) error { 83 pTx = wrapTx{Tx: tx} 84 ctx := context.WithValue(ctx, ctxKeyQueryer{}, &pTx) 85 err := f(ctx) 86 if errors.Is(err, pgsql.ErrAbortTx) { 87 abort = true 88 } 89 return err 90 }) 91 if err != nil { 92 return err 93 } 94 if !abort && len(pTx.onCommitted) > 0 { 95 for _, f := range pTx.onCommitted { 96 f(ctx) 97 } 98 } 99 return nil 100 } 101 102 // RunInTx calls RunInTxOptions with default options 103 func RunInTx(ctx context.Context, f func(ctx context.Context) error) error { 104 return RunInTxOptions(ctx, nil, f) 105 } 106 107 // RunInReadOnlyTx calls RunInTxOptions with read only options 108 func RunInReadOnlyTx(ctx context.Context, f func(ctx context.Context) error) error { 109 var opts pgsql.TxOptions 110 opts.TxOptions.ReadOnly = true 111 return RunInTxOptions(ctx, &opts, f) 112 } 113 114 // IsInTx checks is context inside RunInTx 115 func IsInTx(ctx context.Context) bool { 116 _, ok := ctx.Value(ctxKeyQueryer{}).(*wrapTx) 117 return ok 118 } 119 120 // Committed calls f after committed or immediate if not in tx 121 func Committed(ctx context.Context, f func(ctx context.Context)) { 122 if f == nil { 123 return 124 } 125 126 if !IsInTx(ctx) { 127 f(ctx) 128 return 129 } 130 131 pTx := ctx.Value(ctxKeyQueryer{}).(*wrapTx) 132 pTx.onCommitted = append(pTx.onCommitted, f) 133 } 134 135 type ( 136 ctxKeyDB struct { 137 key any 138 } 139 ctxKeyQueryer struct{} 140 ) 141 142 func q(ctx context.Context) Queryer { 143 if q, ok := ctx.Value(ctxKeyQueryer{}).(Queryer); ok { 144 return q 145 } 146 return ctx.Value(ctxKeyDB{}).(Queryer) 147 } 148 149 // QueryRow calls db.QueryRowContext 150 func QueryRow(ctx context.Context, query string, args ...any) *pgsql.Row { 151 return &pgsql.Row{q(ctx).QueryRowContext(ctx, query, args...)} 152 } 153 154 // Query calls db.QueryContext 155 func Query(ctx context.Context, query string, args ...any) (*pgsql.Rows, error) { 156 rows, err := q(ctx).QueryContext(ctx, query, args...) 157 if err != nil { 158 return nil, err 159 } 160 return &pgsql.Rows{rows}, nil 161 } 162 163 // Exec calls db.ExecContext 164 func Exec(ctx context.Context, query string, args ...any) (sql.Result, error) { 165 return q(ctx).ExecContext(ctx, query, args...) 166 } 167 168 // Iter calls pgsql.IterContext 169 func Iter(ctx context.Context, iter pgsql.Iterator, query string, args ...any) error { 170 return pgsql.IterContext(ctx, q(ctx), iter, query, args...) 171 } 172 173 // Prepare calls db.PrepareContext 174 func Prepare(ctx context.Context, query string) (*sql.Stmt, error) { 175 return q(ctx).PrepareContext(ctx, query) 176 }