github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/common/interceptor.go (about)

     1  package common
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/jackc/pgx/v5"
     7  	"github.com/jackc/pgx/v5/pgconn"
     8  )
     9  
    10  // Querier holds common methods for connections and pools, equivalent to
    11  // Querier (which is deprecated for pgx v5)
    12  type Querier interface {
    13  	Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
    14  	Query(ctx context.Context, sql string, optionsAndArgs ...any) (pgx.Rows, error)
    15  	QueryRow(ctx context.Context, sql string, optionsAndArgs ...any) pgx.Row
    16  }
    17  
    18  // ConnPooler is an interface to pgx.Pool methods used by postgres-based datastores
    19  type ConnPooler interface {
    20  	Querier
    21  	Begin(ctx context.Context) (pgx.Tx, error)
    22  	BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error)
    23  	CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
    24  	Close()
    25  }
    26  
    27  // QueryInterceptor exposes a mechanism to intercept all methods exposed in Querier
    28  // This can be used as a sort of middleware layer for pgx queries
    29  type QueryInterceptor interface {
    30  	// InterceptExec is the method to intercept Querier.Exec. The implementation is responsible to invoke the
    31  	// delegate with the provided arguments
    32  	InterceptExec(ctx context.Context, delegate Querier, sql string, arguments ...any) (pgconn.CommandTag, error)
    33  
    34  	// InterceptQuery is the method to intercept Querier.Query. The implementation is responsible to invoke the
    35  	// delegate with the provided arguments
    36  	InterceptQuery(ctx context.Context, delegate Querier, sql string, args ...any) (pgx.Rows, error)
    37  
    38  	// InterceptQueryRow is the method to intercept Querier.QueryRow. The implementation is responsible to invoke the
    39  	// delegate with the provided arguments
    40  	InterceptQueryRow(ctx context.Context, delegate Querier, sql string, optionsAndArgs ...any) pgx.Row
    41  }
    42  
    43  type querierInterceptor struct {
    44  	delegate    Querier
    45  	interceptor QueryInterceptor
    46  }
    47  
    48  func newQuerierInterceptor(delegate Querier, interceptor QueryInterceptor) Querier {
    49  	if interceptor == nil {
    50  		return delegate
    51  	}
    52  	return querierInterceptor{delegate: delegate, interceptor: interceptor}
    53  }
    54  
    55  func (q querierInterceptor) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
    56  	return q.interceptor.InterceptExec(ctx, q.delegate, sql, arguments...)
    57  }
    58  
    59  func (q querierInterceptor) Query(ctx context.Context, sql string, arguments ...any) (pgx.Rows, error) {
    60  	return q.interceptor.InterceptQuery(ctx, q.delegate, sql, arguments...)
    61  }
    62  
    63  func (q querierInterceptor) QueryRow(ctx context.Context, sql string, arguments ...any) pgx.Row {
    64  	return q.interceptor.InterceptQueryRow(ctx, q.delegate, sql, arguments...)
    65  }
    66  
    67  func newTxInterceptor(interceptor QueryInterceptor) interceptTxFunc {
    68  	return func(tx pgx.Tx) pgx.Tx {
    69  		if interceptor == nil {
    70  			return tx
    71  		}
    72  		return txInterceptor{delegate: tx, interceptor: interceptor}
    73  	}
    74  }
    75  
    76  type txInterceptor struct {
    77  	delegate    pgx.Tx
    78  	interceptor QueryInterceptor
    79  }
    80  
    81  type interceptTxFunc func(tx pgx.Tx) pgx.Tx
    82  
    83  func (t txInterceptor) Begin(ctx context.Context) (pgx.Tx, error) {
    84  	return t.delegate.Begin(ctx)
    85  }
    86  
    87  func (t txInterceptor) BeginFunc(ctx context.Context, f func(pgx.Tx) error) (err error) {
    88  	return pgx.BeginFunc(ctx, t.delegate, f)
    89  }
    90  
    91  func (t txInterceptor) Commit(ctx context.Context) error {
    92  	return t.delegate.Commit(ctx)
    93  }
    94  
    95  func (t txInterceptor) Rollback(ctx context.Context) error {
    96  	return t.delegate.Rollback(ctx)
    97  }
    98  
    99  func (t txInterceptor) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
   100  	return t.delegate.CopyFrom(ctx, tableName, columnNames, rowSrc)
   101  }
   102  
   103  func (t txInterceptor) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
   104  	return t.delegate.SendBatch(ctx, b)
   105  }
   106  
   107  func (t txInterceptor) LargeObjects() pgx.LargeObjects {
   108  	return t.delegate.LargeObjects()
   109  }
   110  
   111  func (t txInterceptor) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
   112  	return t.delegate.Prepare(ctx, name, sql)
   113  }
   114  
   115  func (t txInterceptor) Exec(ctx context.Context, sql string, args ...any) (commandTag pgconn.CommandTag, err error) {
   116  	return t.interceptor.InterceptExec(ctx, t.delegate, sql, args...)
   117  }
   118  
   119  func (t txInterceptor) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
   120  	return t.interceptor.InterceptQuery(ctx, t.delegate, sql, args...)
   121  }
   122  
   123  func (t txInterceptor) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row {
   124  	return t.interceptor.InterceptQueryRow(ctx, t.delegate, sql, args...)
   125  }
   126  
   127  func (t txInterceptor) Conn() *pgx.Conn {
   128  	return t.delegate.Conn()
   129  }
   130  
   131  func MustNewInterceptorPooler(pooler ConnPooler, interceptor QueryInterceptor) ConnPooler {
   132  	if pooler == nil {
   133  		panic("unexpected nil ConnPooler")
   134  	}
   135  	if interceptor == nil {
   136  		return pooler
   137  	}
   138  	return InterceptorPooler{
   139  		delegate:            pooler,
   140  		interceptingQuerier: newQuerierInterceptor(pooler, interceptor),
   141  		txInterceptor:       newTxInterceptor(interceptor),
   142  	}
   143  }
   144  
   145  type InterceptorPooler struct {
   146  	delegate            ConnPooler
   147  	interceptingQuerier Querier
   148  	txInterceptor       interceptTxFunc
   149  }
   150  
   151  func (i InterceptorPooler) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
   152  	return i.interceptingQuerier.Exec(ctx, sql, arguments...)
   153  }
   154  
   155  func (i InterceptorPooler) Query(ctx context.Context, sql string, optionsAndArgs ...any) (pgx.Rows, error) {
   156  	return i.interceptingQuerier.Query(ctx, sql, optionsAndArgs...)
   157  }
   158  
   159  func (i InterceptorPooler) QueryRow(ctx context.Context, sql string, optionsAndArgs ...any) pgx.Row {
   160  	return i.interceptingQuerier.QueryRow(ctx, sql, optionsAndArgs...)
   161  }
   162  
   163  func (i InterceptorPooler) Begin(ctx context.Context) (pgx.Tx, error) {
   164  	tx, err := i.delegate.Begin(ctx)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	return i.txInterceptor(tx), nil
   169  }
   170  
   171  func (i InterceptorPooler) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) {
   172  	tx, err := i.delegate.BeginTx(ctx, txOptions)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	return i.txInterceptor(tx), nil
   177  }
   178  
   179  func (i InterceptorPooler) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
   180  	return i.delegate.CopyFrom(ctx, tableName, columnNames, rowSrc)
   181  }
   182  
   183  func (i InterceptorPooler) Close() {
   184  	i.delegate.Close()
   185  }