github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/common/interceptor.go (about) 1 package common 2 3 import ( 4 "context" 5 6 "github.com/jackc/pgx/v5" 7 "github.com/jackc/pgx/v5/pgconn" 8 ) 9 10 // Querier holds common methods for connections and pools, equivalent to 11 // Querier (which is deprecated for pgx v5) 12 type Querier interface { 13 Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) 14 Query(ctx context.Context, sql string, optionsAndArgs ...any) (pgx.Rows, error) 15 QueryRow(ctx context.Context, sql string, optionsAndArgs ...any) pgx.Row 16 } 17 18 // ConnPooler is an interface to pgx.Pool methods used by postgres-based datastores 19 type ConnPooler interface { 20 Querier 21 Begin(ctx context.Context) (pgx.Tx, error) 22 BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) 23 CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) 24 Close() 25 } 26 27 // QueryInterceptor exposes a mechanism to intercept all methods exposed in Querier 28 // This can be used as a sort of middleware layer for pgx queries 29 type QueryInterceptor interface { 30 // InterceptExec is the method to intercept Querier.Exec. The implementation is responsible to invoke the 31 // delegate with the provided arguments 32 InterceptExec(ctx context.Context, delegate Querier, sql string, arguments ...any) (pgconn.CommandTag, error) 33 34 // InterceptQuery is the method to intercept Querier.Query. The implementation is responsible to invoke the 35 // delegate with the provided arguments 36 InterceptQuery(ctx context.Context, delegate Querier, sql string, args ...any) (pgx.Rows, error) 37 38 // InterceptQueryRow is the method to intercept Querier.QueryRow. The implementation is responsible to invoke the 39 // delegate with the provided arguments 40 InterceptQueryRow(ctx context.Context, delegate Querier, sql string, optionsAndArgs ...any) pgx.Row 41 } 42 43 type querierInterceptor struct { 44 delegate Querier 45 interceptor QueryInterceptor 46 } 47 48 func newQuerierInterceptor(delegate Querier, interceptor QueryInterceptor) Querier { 49 if interceptor == nil { 50 return delegate 51 } 52 return querierInterceptor{delegate: delegate, interceptor: interceptor} 53 } 54 55 func (q querierInterceptor) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { 56 return q.interceptor.InterceptExec(ctx, q.delegate, sql, arguments...) 57 } 58 59 func (q querierInterceptor) Query(ctx context.Context, sql string, arguments ...any) (pgx.Rows, error) { 60 return q.interceptor.InterceptQuery(ctx, q.delegate, sql, arguments...) 61 } 62 63 func (q querierInterceptor) QueryRow(ctx context.Context, sql string, arguments ...any) pgx.Row { 64 return q.interceptor.InterceptQueryRow(ctx, q.delegate, sql, arguments...) 65 } 66 67 func newTxInterceptor(interceptor QueryInterceptor) interceptTxFunc { 68 return func(tx pgx.Tx) pgx.Tx { 69 if interceptor == nil { 70 return tx 71 } 72 return txInterceptor{delegate: tx, interceptor: interceptor} 73 } 74 } 75 76 type txInterceptor struct { 77 delegate pgx.Tx 78 interceptor QueryInterceptor 79 } 80 81 type interceptTxFunc func(tx pgx.Tx) pgx.Tx 82 83 func (t txInterceptor) Begin(ctx context.Context) (pgx.Tx, error) { 84 return t.delegate.Begin(ctx) 85 } 86 87 func (t txInterceptor) BeginFunc(ctx context.Context, f func(pgx.Tx) error) (err error) { 88 return pgx.BeginFunc(ctx, t.delegate, f) 89 } 90 91 func (t txInterceptor) Commit(ctx context.Context) error { 92 return t.delegate.Commit(ctx) 93 } 94 95 func (t txInterceptor) Rollback(ctx context.Context) error { 96 return t.delegate.Rollback(ctx) 97 } 98 99 func (t txInterceptor) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { 100 return t.delegate.CopyFrom(ctx, tableName, columnNames, rowSrc) 101 } 102 103 func (t txInterceptor) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { 104 return t.delegate.SendBatch(ctx, b) 105 } 106 107 func (t txInterceptor) LargeObjects() pgx.LargeObjects { 108 return t.delegate.LargeObjects() 109 } 110 111 func (t txInterceptor) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { 112 return t.delegate.Prepare(ctx, name, sql) 113 } 114 115 func (t txInterceptor) Exec(ctx context.Context, sql string, args ...any) (commandTag pgconn.CommandTag, err error) { 116 return t.interceptor.InterceptExec(ctx, t.delegate, sql, args...) 117 } 118 119 func (t txInterceptor) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { 120 return t.interceptor.InterceptQuery(ctx, t.delegate, sql, args...) 121 } 122 123 func (t txInterceptor) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { 124 return t.interceptor.InterceptQueryRow(ctx, t.delegate, sql, args...) 125 } 126 127 func (t txInterceptor) Conn() *pgx.Conn { 128 return t.delegate.Conn() 129 } 130 131 func MustNewInterceptorPooler(pooler ConnPooler, interceptor QueryInterceptor) ConnPooler { 132 if pooler == nil { 133 panic("unexpected nil ConnPooler") 134 } 135 if interceptor == nil { 136 return pooler 137 } 138 return InterceptorPooler{ 139 delegate: pooler, 140 interceptingQuerier: newQuerierInterceptor(pooler, interceptor), 141 txInterceptor: newTxInterceptor(interceptor), 142 } 143 } 144 145 type InterceptorPooler struct { 146 delegate ConnPooler 147 interceptingQuerier Querier 148 txInterceptor interceptTxFunc 149 } 150 151 func (i InterceptorPooler) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { 152 return i.interceptingQuerier.Exec(ctx, sql, arguments...) 153 } 154 155 func (i InterceptorPooler) Query(ctx context.Context, sql string, optionsAndArgs ...any) (pgx.Rows, error) { 156 return i.interceptingQuerier.Query(ctx, sql, optionsAndArgs...) 157 } 158 159 func (i InterceptorPooler) QueryRow(ctx context.Context, sql string, optionsAndArgs ...any) pgx.Row { 160 return i.interceptingQuerier.QueryRow(ctx, sql, optionsAndArgs...) 161 } 162 163 func (i InterceptorPooler) Begin(ctx context.Context) (pgx.Tx, error) { 164 tx, err := i.delegate.Begin(ctx) 165 if err != nil { 166 return nil, err 167 } 168 return i.txInterceptor(tx), nil 169 } 170 171 func (i InterceptorPooler) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { 172 tx, err := i.delegate.BeginTx(ctx, txOptions) 173 if err != nil { 174 return nil, err 175 } 176 return i.txInterceptor(tx), nil 177 } 178 179 func (i InterceptorPooler) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { 180 return i.delegate.CopyFrom(ctx, tableName, columnNames, rowSrc) 181 } 182 183 func (i InterceptorPooler) Close() { 184 i.delegate.Close() 185 }