github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/common/pgx.go (about)

     1  package common
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"time"
     9  
    10  	"github.com/exaring/otelpgx"
    11  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry"
    12  	zerologadapter "github.com/jackc/pgx-zerolog"
    13  	"github.com/jackc/pgx/v5"
    14  	"github.com/jackc/pgx/v5/pgconn"
    15  	"github.com/jackc/pgx/v5/pgxpool"
    16  	"github.com/jackc/pgx/v5/tracelog"
    17  	"github.com/rs/zerolog"
    18  	"go.opentelemetry.io/otel/attribute"
    19  	"go.opentelemetry.io/otel/trace"
    20  
    21  	"github.com/authzed/spicedb/internal/datastore/common"
    22  	log "github.com/authzed/spicedb/internal/logging"
    23  	corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
    24  )
    25  
    26  const errUnableToQueryTuples = "unable to query tuples: %w"
    27  
    28  // NewPGXExecutor creates an executor that uses the pgx library to make the specified queries.
    29  func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteQueryFunc {
    30  	return func(ctx context.Context, sql string, args []any) ([]*corev1.RelationTuple, error) {
    31  		span := trace.SpanFromContext(ctx)
    32  		return queryTuples(ctx, sql, args, span, querier)
    33  	}
    34  }
    35  
    36  // queryTuples queries tuples for the given query and transaction.
    37  func queryTuples(ctx context.Context, sqlStatement string, args []any, span trace.Span, tx DBFuncQuerier) ([]*corev1.RelationTuple, error) {
    38  	var tuples []*corev1.RelationTuple
    39  	err := tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
    40  		span.AddEvent("Query issued to database")
    41  
    42  		for rows.Next() {
    43  			nextTuple := &corev1.RelationTuple{
    44  				ResourceAndRelation: &corev1.ObjectAndRelation{},
    45  				Subject:             &corev1.ObjectAndRelation{},
    46  			}
    47  			var caveatName sql.NullString
    48  			var caveatCtx map[string]any
    49  			err := rows.Scan(
    50  				&nextTuple.ResourceAndRelation.Namespace,
    51  				&nextTuple.ResourceAndRelation.ObjectId,
    52  				&nextTuple.ResourceAndRelation.Relation,
    53  				&nextTuple.Subject.Namespace,
    54  				&nextTuple.Subject.ObjectId,
    55  				&nextTuple.Subject.Relation,
    56  				&caveatName,
    57  				&caveatCtx,
    58  			)
    59  			if err != nil {
    60  				return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err))
    61  			}
    62  
    63  			nextTuple.Caveat, err = common.ContextualizedCaveatFrom(caveatName.String, caveatCtx)
    64  			if err != nil {
    65  				return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("unable to fetch caveat context: %w", err))
    66  			}
    67  			tuples = append(tuples, nextTuple)
    68  		}
    69  		if err := rows.Err(); err != nil {
    70  			return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("rows err: %w", err))
    71  		}
    72  
    73  		span.AddEvent("Tuples loaded", trace.WithAttributes(attribute.Int("tupleCount", len(tuples))))
    74  		return nil
    75  	}, sqlStatement, args...)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	return tuples, nil
    81  }
    82  
    83  // ParseConfigWithInstrumentation returns a pgx.ConnConfig that has been instrumented for observability
    84  func ParseConfigWithInstrumentation(url string) (*pgx.ConnConfig, error) {
    85  	connConfig, err := pgx.ParseConfig(url)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	ConfigurePGXLogger(connConfig)
    91  	ConfigureOTELTracer(connConfig)
    92  
    93  	return connConfig, nil
    94  }
    95  
    96  // ConnectWithInstrumentation returns a pgx.Conn that has been instrumented for observability
    97  func ConnectWithInstrumentation(ctx context.Context, url string) (*pgx.Conn, error) {
    98  	connConfig, err := ParseConfigWithInstrumentation(url)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	return pgx.ConnectConfig(ctx, connConfig)
   104  }
   105  
   106  // ConfigurePGXLogger sets zerolog global logger into the connection pool configuration, and maps
   107  // info level events to debug, as they are rather verbose for SpiceDB's info level
   108  func ConfigurePGXLogger(connConfig *pgx.ConnConfig) {
   109  	levelMappingFn := func(logger tracelog.Logger) tracelog.LoggerFunc {
   110  		return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) {
   111  			if level == tracelog.LogLevelInfo {
   112  				level = tracelog.LogLevelDebug
   113  			}
   114  
   115  			truncateLargeSQL(data)
   116  
   117  			// log cancellation and serialization errors at debug level
   118  			if errArg, ok := data["err"]; ok {
   119  				err, ok := errArg.(error)
   120  				if ok && (IsCancellationError(err) || IsSerializationError(err)) {
   121  					logger.Log(ctx, tracelog.LogLevelDebug, msg, data)
   122  					return
   123  				}
   124  			}
   125  
   126  			logger.Log(ctx, level, msg, data)
   127  		}
   128  	}
   129  
   130  	l := zerologadapter.NewLogger(log.Logger, zerologadapter.WithoutPGXModule(), zerologadapter.WithSubDictionary("pgx"),
   131  		zerologadapter.WithContextFunc(func(ctx context.Context, z zerolog.Context) zerolog.Context {
   132  			if logger := log.Ctx(ctx); logger != nil {
   133  				return logger.With()
   134  			}
   135  
   136  			return z
   137  		}))
   138  	addTracer(connConfig, &tracelog.TraceLog{Logger: levelMappingFn(l), LogLevel: tracelog.LogLevelInfo})
   139  }
   140  
   141  // truncateLargeSQL takes arguments of a SQL statement provided via pgx's tracelog.LoggerFunc and
   142  // replaces SQL statements and SQL arguments with placeholders when the statements and/or arguments
   143  // exceed a certain length. This helps de-clutter logs when statements have hundreds to thousands of placeholders.
   144  // The change is done in place.
   145  func truncateLargeSQL(data map[string]any) {
   146  	const (
   147  		maxSQLLen     = 350
   148  		maxSQLArgsLen = 50
   149  	)
   150  
   151  	if sqlData, ok := data["sql"]; ok {
   152  		sqlString, ok := sqlData.(string)
   153  		if ok && len(sqlString) > maxSQLLen {
   154  			data["sql"] = sqlString[:maxSQLLen] + "..."
   155  		}
   156  	}
   157  	if argsData, ok := data["args"]; ok {
   158  		argsSlice, ok := argsData.([]any)
   159  		if ok && len(argsSlice) > maxSQLArgsLen {
   160  			data["args"] = argsSlice[:maxSQLArgsLen]
   161  		}
   162  	}
   163  }
   164  
   165  // IsCancellationError determines if an error returned by pgx has been caused by context cancellation.
   166  func IsCancellationError(err error) bool {
   167  	if errors.Is(err, context.Canceled) ||
   168  		errors.Is(err, context.DeadlineExceeded) ||
   169  		err.Error() == "conn closed" { // conns are sometimes closed async upon cancellation
   170  		return true
   171  	}
   172  	return false
   173  }
   174  
   175  func IsSerializationError(err error) bool {
   176  	var pgerr *pgconn.PgError
   177  	if errors.As(err, &pgerr) &&
   178  		// We need to check unique constraint here because some versions of postgres have an error where
   179  		// unique constraint violations are raised instead of serialization errors.
   180  		// (e.g. https://www.postgresql.org/message-id/flat/CAGPCyEZG76zjv7S31v_xPeLNRuzj-m%3DY2GOY7PEzu7vhB%3DyQog%40mail.gmail.com)
   181  		(pgerr.SQLState() == pgSerializationFailure || pgerr.SQLState() == pgUniqueConstraintViolation || pgerr.SQLState() == pgTransactionAborted) {
   182  		return true
   183  	}
   184  
   185  	if errors.Is(err, pgx.ErrTxCommitRollback) {
   186  		return true
   187  	}
   188  
   189  	return false
   190  }
   191  
   192  // ConfigureOTELTracer adds OTEL tracing to a pgx.ConnConfig
   193  func ConfigureOTELTracer(connConfig *pgx.ConnConfig) {
   194  	addTracer(connConfig, otelpgx.NewTracer(otelpgx.WithTrimSQLInSpanName()))
   195  }
   196  
   197  func addTracer(connConfig *pgx.ConnConfig, tracer pgx.QueryTracer) {
   198  	composedTracer := addComposedTracer(connConfig)
   199  	composedTracer.Tracers = append(composedTracer.Tracers, tracer)
   200  }
   201  
   202  func addComposedTracer(connConfig *pgx.ConnConfig) *ComposedTracer {
   203  	var composedTracer *ComposedTracer
   204  	if connConfig.Tracer == nil {
   205  		composedTracer = &ComposedTracer{}
   206  		connConfig.Tracer = composedTracer
   207  	} else {
   208  		var ok bool
   209  		composedTracer, ok = connConfig.Tracer.(*ComposedTracer)
   210  		if !ok {
   211  			composedTracer.Tracers = append(composedTracer.Tracers, connConfig.Tracer)
   212  			connConfig.Tracer = composedTracer
   213  		}
   214  	}
   215  	return composedTracer
   216  }
   217  
   218  // ComposedTracer allows adding multiple tracers to a pgx.ConnConfig
   219  type ComposedTracer struct {
   220  	Tracers []pgx.QueryTracer
   221  }
   222  
   223  func (m *ComposedTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
   224  	for _, t := range m.Tracers {
   225  		ctx = t.TraceQueryStart(ctx, conn, data)
   226  	}
   227  
   228  	return ctx
   229  }
   230  
   231  func (m *ComposedTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
   232  	for _, t := range m.Tracers {
   233  		t.TraceQueryEnd(ctx, conn, data)
   234  	}
   235  }
   236  
   237  // DBFuncQuerier is satisfied by RetryPool and QuerierFuncs (which can wrap a pgxpool or transaction)
   238  type DBFuncQuerier interface {
   239  	ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, arguments ...any) error
   240  	QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, optionsAndArgs ...any) error
   241  	QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, optionsAndArgs ...any) error
   242  }
   243  
   244  // PoolOptions is the set of configuration used for a pgx connection pool.
   245  type PoolOptions struct {
   246  	ConnMaxIdleTime         *time.Duration
   247  	ConnMaxLifetime         *time.Duration
   248  	ConnMaxLifetimeJitter   *time.Duration
   249  	ConnHealthCheckInterval *time.Duration
   250  	MinOpenConns            *int
   251  	MaxOpenConns            *int
   252  }
   253  
   254  // ConfigurePgx applies PoolOptions to a pgx connection pool confiugration.
   255  func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) {
   256  	if opts.MaxOpenConns != nil {
   257  		pgxConfig.MaxConns = int32(*opts.MaxOpenConns)
   258  	}
   259  
   260  	// Default to keeping the pool maxed out at all times.
   261  	pgxConfig.MinConns = pgxConfig.MaxConns
   262  	if opts.MinOpenConns != nil {
   263  		pgxConfig.MinConns = int32(*opts.MinOpenConns)
   264  	}
   265  
   266  	if pgxConfig.MaxConns > 0 && pgxConfig.MinConns > 0 && pgxConfig.MaxConns < pgxConfig.MinConns {
   267  		log.Warn().Int32("max-connections", pgxConfig.MaxConns).Int32("min-connections", pgxConfig.MinConns).Msg("maximum number of connections configured is less than minimum number of connections; minimum will be used")
   268  	}
   269  
   270  	if opts.ConnMaxIdleTime != nil {
   271  		pgxConfig.MaxConnIdleTime = *opts.ConnMaxIdleTime
   272  	}
   273  
   274  	if opts.ConnMaxLifetime != nil {
   275  		pgxConfig.MaxConnLifetime = *opts.ConnMaxLifetime
   276  	}
   277  
   278  	if opts.ConnHealthCheckInterval != nil {
   279  		pgxConfig.HealthCheckPeriod = *opts.ConnHealthCheckInterval
   280  	}
   281  
   282  	if opts.ConnMaxLifetimeJitter != nil {
   283  		pgxConfig.MaxConnLifetimeJitter = *opts.ConnMaxLifetimeJitter
   284  	} else if opts.ConnMaxLifetime != nil {
   285  		pgxConfig.MaxConnLifetimeJitter = time.Duration(0.2 * float64(*opts.ConnMaxLifetime))
   286  	}
   287  
   288  	ConfigurePGXLogger(pgxConfig.ConnConfig)
   289  	ConfigureOTELTracer(pgxConfig.ConnConfig)
   290  }
   291  
   292  type QuerierFuncs struct {
   293  	d Querier
   294  }
   295  
   296  func (t *QuerierFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, arguments ...any) error {
   297  	tag, err := t.d.Exec(ctx, sql, arguments...)
   298  	return tagFunc(ctx, tag, err)
   299  }
   300  
   301  func (t *QuerierFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, optionsAndArgs ...any) error {
   302  	rows, err := t.d.Query(ctx, sql, optionsAndArgs...)
   303  	if err != nil {
   304  		return err
   305  	}
   306  	defer rows.Close()
   307  	return rowsFunc(ctx, rows)
   308  }
   309  
   310  func (t *QuerierFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, optionsAndArgs ...any) error {
   311  	return rowFunc(ctx, t.d.QueryRow(ctx, sql, optionsAndArgs...))
   312  }
   313  
   314  func QuerierFuncsFor(d Querier) DBFuncQuerier {
   315  	return &QuerierFuncs{d: d}
   316  }
   317  
   318  // SleepOnErr sleeps for a short period of time after an error has occurred.
   319  func SleepOnErr(ctx context.Context, err error, retries uint8) {
   320  	after := retry.BackoffExponentialWithJitter(25*time.Millisecond, 0.5)(ctx, uint(retries+1)) // add one so we always wait at least a little bit
   321  	log.Ctx(ctx).Debug().Err(err).Dur("after", after).Uint8("retry", retries+1).Msg("retrying on database error")
   322  
   323  	select {
   324  	case <-time.After(after):
   325  	case <-ctx.Done():
   326  	}
   327  }