github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/common/pgx.go (about) 1 package common 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "time" 9 10 "github.com/exaring/otelpgx" 11 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry" 12 zerologadapter "github.com/jackc/pgx-zerolog" 13 "github.com/jackc/pgx/v5" 14 "github.com/jackc/pgx/v5/pgconn" 15 "github.com/jackc/pgx/v5/pgxpool" 16 "github.com/jackc/pgx/v5/tracelog" 17 "github.com/rs/zerolog" 18 "go.opentelemetry.io/otel/attribute" 19 "go.opentelemetry.io/otel/trace" 20 21 "github.com/authzed/spicedb/internal/datastore/common" 22 log "github.com/authzed/spicedb/internal/logging" 23 corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" 24 ) 25 26 const errUnableToQueryTuples = "unable to query tuples: %w" 27 28 // NewPGXExecutor creates an executor that uses the pgx library to make the specified queries. 29 func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteQueryFunc { 30 return func(ctx context.Context, sql string, args []any) ([]*corev1.RelationTuple, error) { 31 span := trace.SpanFromContext(ctx) 32 return queryTuples(ctx, sql, args, span, querier) 33 } 34 } 35 36 // queryTuples queries tuples for the given query and transaction. 37 func queryTuples(ctx context.Context, sqlStatement string, args []any, span trace.Span, tx DBFuncQuerier) ([]*corev1.RelationTuple, error) { 38 var tuples []*corev1.RelationTuple 39 err := tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { 40 span.AddEvent("Query issued to database") 41 42 for rows.Next() { 43 nextTuple := &corev1.RelationTuple{ 44 ResourceAndRelation: &corev1.ObjectAndRelation{}, 45 Subject: &corev1.ObjectAndRelation{}, 46 } 47 var caveatName sql.NullString 48 var caveatCtx map[string]any 49 err := rows.Scan( 50 &nextTuple.ResourceAndRelation.Namespace, 51 &nextTuple.ResourceAndRelation.ObjectId, 52 &nextTuple.ResourceAndRelation.Relation, 53 &nextTuple.Subject.Namespace, 54 &nextTuple.Subject.ObjectId, 55 &nextTuple.Subject.Relation, 56 &caveatName, 57 &caveatCtx, 58 ) 59 if err != nil { 60 return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err)) 61 } 62 63 nextTuple.Caveat, err = common.ContextualizedCaveatFrom(caveatName.String, caveatCtx) 64 if err != nil { 65 return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("unable to fetch caveat context: %w", err)) 66 } 67 tuples = append(tuples, nextTuple) 68 } 69 if err := rows.Err(); err != nil { 70 return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("rows err: %w", err)) 71 } 72 73 span.AddEvent("Tuples loaded", trace.WithAttributes(attribute.Int("tupleCount", len(tuples)))) 74 return nil 75 }, sqlStatement, args...) 76 if err != nil { 77 return nil, err 78 } 79 80 return tuples, nil 81 } 82 83 // ParseConfigWithInstrumentation returns a pgx.ConnConfig that has been instrumented for observability 84 func ParseConfigWithInstrumentation(url string) (*pgx.ConnConfig, error) { 85 connConfig, err := pgx.ParseConfig(url) 86 if err != nil { 87 return nil, err 88 } 89 90 ConfigurePGXLogger(connConfig) 91 ConfigureOTELTracer(connConfig) 92 93 return connConfig, nil 94 } 95 96 // ConnectWithInstrumentation returns a pgx.Conn that has been instrumented for observability 97 func ConnectWithInstrumentation(ctx context.Context, url string) (*pgx.Conn, error) { 98 connConfig, err := ParseConfigWithInstrumentation(url) 99 if err != nil { 100 return nil, err 101 } 102 103 return pgx.ConnectConfig(ctx, connConfig) 104 } 105 106 // ConfigurePGXLogger sets zerolog global logger into the connection pool configuration, and maps 107 // info level events to debug, as they are rather verbose for SpiceDB's info level 108 func ConfigurePGXLogger(connConfig *pgx.ConnConfig) { 109 levelMappingFn := func(logger tracelog.Logger) tracelog.LoggerFunc { 110 return func(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]interface{}) { 111 if level == tracelog.LogLevelInfo { 112 level = tracelog.LogLevelDebug 113 } 114 115 truncateLargeSQL(data) 116 117 // log cancellation and serialization errors at debug level 118 if errArg, ok := data["err"]; ok { 119 err, ok := errArg.(error) 120 if ok && (IsCancellationError(err) || IsSerializationError(err)) { 121 logger.Log(ctx, tracelog.LogLevelDebug, msg, data) 122 return 123 } 124 } 125 126 logger.Log(ctx, level, msg, data) 127 } 128 } 129 130 l := zerologadapter.NewLogger(log.Logger, zerologadapter.WithoutPGXModule(), zerologadapter.WithSubDictionary("pgx"), 131 zerologadapter.WithContextFunc(func(ctx context.Context, z zerolog.Context) zerolog.Context { 132 if logger := log.Ctx(ctx); logger != nil { 133 return logger.With() 134 } 135 136 return z 137 })) 138 addTracer(connConfig, &tracelog.TraceLog{Logger: levelMappingFn(l), LogLevel: tracelog.LogLevelInfo}) 139 } 140 141 // truncateLargeSQL takes arguments of a SQL statement provided via pgx's tracelog.LoggerFunc and 142 // replaces SQL statements and SQL arguments with placeholders when the statements and/or arguments 143 // exceed a certain length. This helps de-clutter logs when statements have hundreds to thousands of placeholders. 144 // The change is done in place. 145 func truncateLargeSQL(data map[string]any) { 146 const ( 147 maxSQLLen = 350 148 maxSQLArgsLen = 50 149 ) 150 151 if sqlData, ok := data["sql"]; ok { 152 sqlString, ok := sqlData.(string) 153 if ok && len(sqlString) > maxSQLLen { 154 data["sql"] = sqlString[:maxSQLLen] + "..." 155 } 156 } 157 if argsData, ok := data["args"]; ok { 158 argsSlice, ok := argsData.([]any) 159 if ok && len(argsSlice) > maxSQLArgsLen { 160 data["args"] = argsSlice[:maxSQLArgsLen] 161 } 162 } 163 } 164 165 // IsCancellationError determines if an error returned by pgx has been caused by context cancellation. 166 func IsCancellationError(err error) bool { 167 if errors.Is(err, context.Canceled) || 168 errors.Is(err, context.DeadlineExceeded) || 169 err.Error() == "conn closed" { // conns are sometimes closed async upon cancellation 170 return true 171 } 172 return false 173 } 174 175 func IsSerializationError(err error) bool { 176 var pgerr *pgconn.PgError 177 if errors.As(err, &pgerr) && 178 // We need to check unique constraint here because some versions of postgres have an error where 179 // unique constraint violations are raised instead of serialization errors. 180 // (e.g. https://www.postgresql.org/message-id/flat/CAGPCyEZG76zjv7S31v_xPeLNRuzj-m%3DY2GOY7PEzu7vhB%3DyQog%40mail.gmail.com) 181 (pgerr.SQLState() == pgSerializationFailure || pgerr.SQLState() == pgUniqueConstraintViolation || pgerr.SQLState() == pgTransactionAborted) { 182 return true 183 } 184 185 if errors.Is(err, pgx.ErrTxCommitRollback) { 186 return true 187 } 188 189 return false 190 } 191 192 // ConfigureOTELTracer adds OTEL tracing to a pgx.ConnConfig 193 func ConfigureOTELTracer(connConfig *pgx.ConnConfig) { 194 addTracer(connConfig, otelpgx.NewTracer(otelpgx.WithTrimSQLInSpanName())) 195 } 196 197 func addTracer(connConfig *pgx.ConnConfig, tracer pgx.QueryTracer) { 198 composedTracer := addComposedTracer(connConfig) 199 composedTracer.Tracers = append(composedTracer.Tracers, tracer) 200 } 201 202 func addComposedTracer(connConfig *pgx.ConnConfig) *ComposedTracer { 203 var composedTracer *ComposedTracer 204 if connConfig.Tracer == nil { 205 composedTracer = &ComposedTracer{} 206 connConfig.Tracer = composedTracer 207 } else { 208 var ok bool 209 composedTracer, ok = connConfig.Tracer.(*ComposedTracer) 210 if !ok { 211 composedTracer.Tracers = append(composedTracer.Tracers, connConfig.Tracer) 212 connConfig.Tracer = composedTracer 213 } 214 } 215 return composedTracer 216 } 217 218 // ComposedTracer allows adding multiple tracers to a pgx.ConnConfig 219 type ComposedTracer struct { 220 Tracers []pgx.QueryTracer 221 } 222 223 func (m *ComposedTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { 224 for _, t := range m.Tracers { 225 ctx = t.TraceQueryStart(ctx, conn, data) 226 } 227 228 return ctx 229 } 230 231 func (m *ComposedTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { 232 for _, t := range m.Tracers { 233 t.TraceQueryEnd(ctx, conn, data) 234 } 235 } 236 237 // DBFuncQuerier is satisfied by RetryPool and QuerierFuncs (which can wrap a pgxpool or transaction) 238 type DBFuncQuerier interface { 239 ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, arguments ...any) error 240 QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, optionsAndArgs ...any) error 241 QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, optionsAndArgs ...any) error 242 } 243 244 // PoolOptions is the set of configuration used for a pgx connection pool. 245 type PoolOptions struct { 246 ConnMaxIdleTime *time.Duration 247 ConnMaxLifetime *time.Duration 248 ConnMaxLifetimeJitter *time.Duration 249 ConnHealthCheckInterval *time.Duration 250 MinOpenConns *int 251 MaxOpenConns *int 252 } 253 254 // ConfigurePgx applies PoolOptions to a pgx connection pool confiugration. 255 func (opts PoolOptions) ConfigurePgx(pgxConfig *pgxpool.Config) { 256 if opts.MaxOpenConns != nil { 257 pgxConfig.MaxConns = int32(*opts.MaxOpenConns) 258 } 259 260 // Default to keeping the pool maxed out at all times. 261 pgxConfig.MinConns = pgxConfig.MaxConns 262 if opts.MinOpenConns != nil { 263 pgxConfig.MinConns = int32(*opts.MinOpenConns) 264 } 265 266 if pgxConfig.MaxConns > 0 && pgxConfig.MinConns > 0 && pgxConfig.MaxConns < pgxConfig.MinConns { 267 log.Warn().Int32("max-connections", pgxConfig.MaxConns).Int32("min-connections", pgxConfig.MinConns).Msg("maximum number of connections configured is less than minimum number of connections; minimum will be used") 268 } 269 270 if opts.ConnMaxIdleTime != nil { 271 pgxConfig.MaxConnIdleTime = *opts.ConnMaxIdleTime 272 } 273 274 if opts.ConnMaxLifetime != nil { 275 pgxConfig.MaxConnLifetime = *opts.ConnMaxLifetime 276 } 277 278 if opts.ConnHealthCheckInterval != nil { 279 pgxConfig.HealthCheckPeriod = *opts.ConnHealthCheckInterval 280 } 281 282 if opts.ConnMaxLifetimeJitter != nil { 283 pgxConfig.MaxConnLifetimeJitter = *opts.ConnMaxLifetimeJitter 284 } else if opts.ConnMaxLifetime != nil { 285 pgxConfig.MaxConnLifetimeJitter = time.Duration(0.2 * float64(*opts.ConnMaxLifetime)) 286 } 287 288 ConfigurePGXLogger(pgxConfig.ConnConfig) 289 ConfigureOTELTracer(pgxConfig.ConnConfig) 290 } 291 292 type QuerierFuncs struct { 293 d Querier 294 } 295 296 func (t *QuerierFuncs) ExecFunc(ctx context.Context, tagFunc func(ctx context.Context, tag pgconn.CommandTag, err error) error, sql string, arguments ...any) error { 297 tag, err := t.d.Exec(ctx, sql, arguments...) 298 return tagFunc(ctx, tag, err) 299 } 300 301 func (t *QuerierFuncs) QueryFunc(ctx context.Context, rowsFunc func(ctx context.Context, rows pgx.Rows) error, sql string, optionsAndArgs ...any) error { 302 rows, err := t.d.Query(ctx, sql, optionsAndArgs...) 303 if err != nil { 304 return err 305 } 306 defer rows.Close() 307 return rowsFunc(ctx, rows) 308 } 309 310 func (t *QuerierFuncs) QueryRowFunc(ctx context.Context, rowFunc func(ctx context.Context, row pgx.Row) error, sql string, optionsAndArgs ...any) error { 311 return rowFunc(ctx, t.d.QueryRow(ctx, sql, optionsAndArgs...)) 312 } 313 314 func QuerierFuncsFor(d Querier) DBFuncQuerier { 315 return &QuerierFuncs{d: d} 316 } 317 318 // SleepOnErr sleeps for a short period of time after an error has occurred. 319 func SleepOnErr(ctx context.Context, err error, retries uint8) { 320 after := retry.BackoffExponentialWithJitter(25*time.Millisecond, 0.5)(ctx, uint(retries+1)) // add one so we always wait at least a little bit 321 log.Ctx(ctx).Debug().Err(err).Dur("after", after).Uint8("retry", retries+1).Msg("retrying on database error") 322 323 select { 324 case <-time.After(after): 325 case <-ctx.Done(): 326 } 327 }