github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/kv/postgres/store.go (about)

     1  package postgres
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"hash/fnv"
     9  	"strconv"
    10  
    11  	"github.com/IBM/pgxpoolprometheus"
    12  	"github.com/georgysavva/scany/v2/pgxscan"
    13  	"github.com/jackc/pgx/v5"
    14  	"github.com/jackc/pgx/v5/pgconn"
    15  	"github.com/jackc/pgx/v5/pgxpool"
    16  	"github.com/prometheus/client_golang/prometheus"
    17  	"github.com/treeverse/lakefs/pkg/kv"
    18  	"github.com/treeverse/lakefs/pkg/kv/kvparams"
    19  )
    20  
    21  type Driver struct{}
    22  
    23  type Store struct {
    24  	Pool           *pgxpool.Pool
    25  	Params         *Params
    26  	TableSanitized string
    27  	collector      prometheus.Collector
    28  }
    29  
    30  type EntriesIterator struct {
    31  	ctx          context.Context
    32  	partitionKey []byte
    33  	startKey     []byte
    34  	includeStart bool
    35  	store        *Store
    36  	entries      []kv.Entry
    37  	limit        int
    38  	currEntryIdx int
    39  	err          error
    40  }
    41  
    42  const (
    43  	DriverName = "postgres"
    44  
    45  	DefaultTableName = "kv"
    46  	paramTableName   = "lakefskv_table"
    47  
    48  	// DefaultPartitions Changing the below value means repartitioning and probably a migration.
    49  	// Change it only if you really know what you're doing.
    50  	DefaultPartitions   = 100
    51  	DefaultScanPageSize = 1000
    52  )
    53  
    54  //nolint:gochecknoinits
    55  func init() {
    56  	kv.Register(DriverName, &Driver{})
    57  }
    58  
    59  func (d *Driver) Open(ctx context.Context, kvParams kvparams.Config) (kv.Store, error) {
    60  	if kvParams.Postgres == nil {
    61  		return nil, fmt.Errorf("missing %s settings: %w", DriverName, kv.ErrDriverConfiguration)
    62  	}
    63  	config, err := newPgxpoolConfig(kvParams)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	pool, err := pgxpool.NewWithConfig(ctx, config)
    69  	if err != nil {
    70  		return nil, fmt.Errorf("%w: %s", kv.ErrConnectFailed, err)
    71  	}
    72  	defer func() {
    73  		// if we return before store uses the pool, free it
    74  		if pool != nil {
    75  			pool.Close()
    76  		}
    77  	}()
    78  
    79  	// acquire connection and make sure we reach the database
    80  	conn, err := pool.Acquire(ctx)
    81  	if err != nil {
    82  		return nil, fmt.Errorf("%w: %s", kv.ErrConnectFailed, err)
    83  	}
    84  	defer conn.Release()
    85  	err = conn.Conn().Ping(ctx)
    86  	if err != nil {
    87  		return nil, fmt.Errorf("%w: %s", kv.ErrConnectFailed, err)
    88  	}
    89  
    90  	params := parseStoreConfig(config.ConnConfig.RuntimeParams, kvParams.Postgres)
    91  	err = setupKeyValueDatabase(ctx, conn, params.TableName, params.PartitionsAmount)
    92  	if err != nil {
    93  		return nil, fmt.Errorf("%w: %s", kv.ErrSetupFailed, err)
    94  	}
    95  
    96  	// register collector to publish pgx's pool stats as metrics
    97  	var collector prometheus.Collector
    98  	if params.Metrics {
    99  		collector = pgxpoolprometheus.NewCollector(pool, map[string]string{"db_name": params.TableName})
   100  		err := prometheus.Register(collector)
   101  		if err != nil {
   102  			return nil, err
   103  		}
   104  	}
   105  
   106  	store := &Store{
   107  		Pool:           pool,
   108  		Params:         params,
   109  		TableSanitized: pgx.Identifier{params.TableName}.Sanitize(),
   110  		collector:      collector,
   111  	}
   112  	pool = nil
   113  	return store, nil
   114  }
   115  
   116  func newPgxpoolConfig(kvParams kvparams.Config) (*pgxpool.Config, error) {
   117  	config, err := pgxpool.ParseConfig(kvParams.Postgres.ConnectionString)
   118  	if err != nil {
   119  		return nil, fmt.Errorf("%w: %s", kv.ErrDriverConfiguration, err)
   120  	}
   121  	if kvParams.Postgres.MaxOpenConnections > 0 {
   122  		config.MaxConns = kvParams.Postgres.MaxOpenConnections
   123  	}
   124  	if kvParams.Postgres.MaxIdleConnections > 0 {
   125  		config.MinConns = kvParams.Postgres.MaxIdleConnections
   126  	}
   127  	if kvParams.Postgres.ConnectionMaxLifetime > 0 {
   128  		config.MaxConnLifetime = kvParams.Postgres.ConnectionMaxLifetime
   129  	}
   130  	return config, err
   131  }
   132  
   133  type Params struct {
   134  	TableName          string
   135  	SanitizedTableName string
   136  	PartitionsAmount   int
   137  	ScanPageSize       int
   138  	Metrics            bool
   139  }
   140  
   141  func parseStoreConfig(runtimeParams map[string]string, pgParams *kvparams.Postgres) *Params {
   142  	p := &Params{
   143  		TableName:        DefaultTableName,
   144  		PartitionsAmount: DefaultPartitions,
   145  		ScanPageSize:     DefaultScanPageSize,
   146  		Metrics:          pgParams.Metrics,
   147  	}
   148  	if tableName, ok := runtimeParams[paramTableName]; ok {
   149  		p.TableName = tableName
   150  	}
   151  
   152  	p.SanitizedTableName = pgx.Identifier{p.TableName}.Sanitize()
   153  	if pgParams.ScanPageSize > 0 {
   154  		p.ScanPageSize = pgParams.ScanPageSize
   155  	}
   156  	return p
   157  }
   158  
   159  // setupKeyValueDatabase setup everything required to enable kv over postgres
   160  func setupKeyValueDatabase(ctx context.Context, conn *pgxpool.Conn, table string, partitionsAmount int) (err error) {
   161  	var aid string
   162  	aid, err = generateAdvisoryLockID("lakefs:" + table)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	// This will wait indefinitely until the lock can be acquired.
   168  	_, err = conn.Exec(ctx, `SELECT pg_advisory_lock($1)`, aid)
   169  	if err != nil {
   170  		return fmt.Errorf("try lock failed: %w", err)
   171  	}
   172  	defer func(ctx context.Context) {
   173  		_, unlockErr := conn.Exec(ctx, `SELECT pg_advisory_unlock($1)`, aid)
   174  		// prefer the last error over unlock error
   175  		if err == nil {
   176  			err = unlockErr
   177  		}
   178  	}(ctx)
   179  
   180  	// main kv table
   181  	tableSanitize := pgx.Identifier{table}.Sanitize()
   182  	_, err = conn.Exec(ctx, `CREATE TABLE IF NOT EXISTS `+tableSanitize+` (
   183  		partition_key BYTEA NOT NULL,
   184  		key BYTEA NOT NULL,
   185  		value BYTEA NOT NULL,
   186  		PRIMARY KEY (partition_key, key))
   187  	PARTITION BY HASH (partition_key)`)
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	// partitions
   193  	partitions := getTablePartitions(table, partitionsAmount)
   194  	for i := 0; i < len(partitions); i++ {
   195  		_, err = conn.Exec(ctx, `CREATE TABLE IF NOT EXISTS`+
   196  			pgx.Identifier{partitions[i]}.Sanitize()+` PARTITION OF `+
   197  			tableSanitize+` FOR VALUES WITH (MODULUS `+strconv.Itoa(partitionsAmount)+
   198  			`,REMAINDER `+strconv.Itoa(i)+`)`)
   199  		if err != nil {
   200  			return err
   201  		}
   202  	}
   203  	// view of kv table to help humans select from table (same as table with _v as suffix)
   204  	_, err = conn.Exec(ctx, `CREATE OR REPLACE VIEW `+pgx.Identifier{table + "_v"}.Sanitize()+
   205  		` AS SELECT ENCODE(partition_key, 'escape') AS partition_key, ENCODE(key, 'escape') AS key, value FROM `+tableSanitize)
   206  	return err
   207  }
   208  
   209  func generateAdvisoryLockID(name string) (string, error) {
   210  	h := fnv.New32a()
   211  	if _, err := h.Write([]byte(name)); err != nil {
   212  		return "", err
   213  	}
   214  	aid := fmt.Sprint(h.Sum32())
   215  	return aid, nil
   216  }
   217  
   218  func getTablePartitions(tableName string, partitionsAmount int) []string {
   219  	res := make([]string, 0, partitionsAmount)
   220  	for i := 0; i < partitionsAmount; i++ {
   221  		res = append(res, fmt.Sprintf("%s_%d", tableName, i))
   222  	}
   223  	return res
   224  }
   225  
   226  func (s *Store) Get(ctx context.Context, partitionKey, key []byte) (*kv.ValueWithPredicate, error) {
   227  	if len(partitionKey) == 0 {
   228  		return nil, kv.ErrMissingPartitionKey
   229  	}
   230  	if len(key) == 0 {
   231  		return nil, kv.ErrMissingKey
   232  	}
   233  
   234  	row := s.Pool.QueryRow(ctx, `SELECT value FROM `+s.Params.SanitizedTableName+` WHERE key = $1 AND partition_key = $2`, key, partitionKey)
   235  	var val []byte
   236  	err := row.Scan(&val)
   237  	if errors.Is(err, pgx.ErrNoRows) {
   238  		return nil, kv.ErrNotFound
   239  	}
   240  	if err != nil {
   241  		return nil, fmt.Errorf("postgres get: %w", err)
   242  	}
   243  	return &kv.ValueWithPredicate{
   244  		Value:     val,
   245  		Predicate: kv.Predicate(val),
   246  	}, nil
   247  }
   248  
   249  func (s *Store) Set(ctx context.Context, partitionKey, key, value []byte) error {
   250  	if len(partitionKey) == 0 {
   251  		return kv.ErrMissingPartitionKey
   252  	}
   253  	if len(key) == 0 {
   254  		return kv.ErrMissingKey
   255  	}
   256  	if value == nil {
   257  		return kv.ErrMissingValue
   258  	}
   259  
   260  	_, err := s.Pool.Exec(ctx, `INSERT INTO `+s.Params.SanitizedTableName+`(partition_key,key,value) VALUES($1,$2,$3)
   261  			ON CONFLICT (partition_key,key) DO UPDATE SET value = $3`, partitionKey, key, value)
   262  	if err != nil {
   263  		return fmt.Errorf("postgres set: %w", err)
   264  	}
   265  	return nil
   266  }
   267  
   268  func (s *Store) SetIf(ctx context.Context, partitionKey, key, value []byte, valuePredicate kv.Predicate) error {
   269  	if len(partitionKey) == 0 {
   270  		return kv.ErrMissingPartitionKey
   271  	}
   272  	if len(key) == 0 {
   273  		return kv.ErrMissingKey
   274  	}
   275  	if value == nil {
   276  		return kv.ErrMissingValue
   277  	}
   278  
   279  	var (
   280  		res pgconn.CommandTag
   281  		err error
   282  	)
   283  	switch valuePredicate {
   284  	case nil: // use insert to make sure there was no previous value before
   285  		res, err = s.Pool.Exec(ctx, `INSERT INTO `+s.Params.SanitizedTableName+`(partition_key,key,value) VALUES($1,$2,$3) ON CONFLICT DO NOTHING`, partitionKey, key, value)
   286  
   287  	case kv.PrecondConditionalExists: // update only if exists
   288  		res, err = s.Pool.Exec(ctx, `UPDATE `+s.Params.SanitizedTableName+` SET value=$3 WHERE key=$2 AND partition_key=$1`, partitionKey, key, value)
   289  
   290  	default: // update just in case the previous value was same as predicate value
   291  		res, err = s.Pool.Exec(ctx, `UPDATE `+s.Params.SanitizedTableName+` SET value=$3 WHERE key=$2 AND partition_key=$1 AND value=$4`, partitionKey, key, value, valuePredicate.([]byte))
   292  	}
   293  	if err != nil {
   294  		return fmt.Errorf("postgres setIf: %w", err)
   295  	}
   296  	if res.RowsAffected() != 1 {
   297  		return kv.ErrPredicateFailed
   298  	}
   299  	return nil
   300  }
   301  
   302  func (s *Store) Delete(ctx context.Context, partitionKey, key []byte) error {
   303  	if len(partitionKey) == 0 {
   304  		return kv.ErrMissingPartitionKey
   305  	}
   306  	if len(key) == 0 {
   307  		return kv.ErrMissingKey
   308  	}
   309  	_, err := s.Pool.Exec(ctx, `DELETE FROM `+s.Params.SanitizedTableName+` WHERE partition_key=$1 AND key=$2`, partitionKey, key)
   310  	if err != nil {
   311  		return fmt.Errorf("postgres delete: %w", err)
   312  	}
   313  	return nil
   314  }
   315  
   316  func (s *Store) Scan(ctx context.Context, partitionKey []byte, options kv.ScanOptions) (kv.EntriesIterator, error) {
   317  	if len(partitionKey) == 0 {
   318  		return nil, kv.ErrMissingPartitionKey
   319  	}
   320  
   321  	// limit based on the minimum between ScanPageSize and ScanOptions batch size
   322  	limit := s.Params.ScanPageSize
   323  	if options.BatchSize != 0 && s.Params.ScanPageSize != 0 && options.BatchSize < s.Params.ScanPageSize {
   324  		limit = options.BatchSize
   325  	}
   326  	it := &EntriesIterator{
   327  		ctx:          ctx,
   328  		partitionKey: partitionKey,
   329  		startKey:     options.KeyStart,
   330  		store:        s,
   331  		limit:        limit,
   332  		includeStart: true,
   333  	}
   334  	it.runQuery()
   335  	if it.err != nil {
   336  		return nil, it.err
   337  	}
   338  	return it, nil
   339  }
   340  
   341  func (s *Store) Close() {
   342  	if s.collector != nil {
   343  		prometheus.Unregister(s.collector)
   344  		s.collector = nil
   345  	}
   346  	s.Pool.Close()
   347  }
   348  
   349  // Next reads the next key/value.
   350  func (e *EntriesIterator) Next() bool {
   351  	if e.err != nil || len(e.entries) == 0 {
   352  		return false
   353  	}
   354  	if e.currEntryIdx+1 == len(e.entries) {
   355  		key := e.entries[e.currEntryIdx].Key
   356  		e.startKey = key
   357  		e.includeStart = false
   358  		e.runQuery()
   359  		if e.err != nil || len(e.entries) == 0 {
   360  			return false
   361  		}
   362  	}
   363  	e.currEntryIdx++
   364  	return true
   365  }
   366  
   367  func (e *EntriesIterator) SeekGE(key []byte) {
   368  	if !e.isInRange(key) {
   369  		e.startKey = key
   370  		e.includeStart = true
   371  		e.runQuery()
   372  		return
   373  	}
   374  	for i := range e.entries {
   375  		if bytes.Compare(key, e.entries[i].Key) <= 0 {
   376  			e.currEntryIdx = i - 1
   377  			return
   378  		}
   379  	}
   380  }
   381  
   382  func (e *EntriesIterator) Entry() *kv.Entry {
   383  	if e.entries == nil {
   384  		return nil
   385  	}
   386  	return &e.entries[e.currEntryIdx]
   387  }
   388  
   389  // Err return the last scan error or the cursor error
   390  func (e *EntriesIterator) Err() error {
   391  	return e.err
   392  }
   393  
   394  func (e *EntriesIterator) Close() {
   395  	e.entries = nil
   396  	e.currEntryIdx = -1
   397  	e.err = kv.ErrClosedEntries
   398  }
   399  
   400  func (e *EntriesIterator) runQuery() {
   401  	var (
   402  		rows pgx.Rows
   403  		err  error
   404  	)
   405  	if e.startKey == nil {
   406  		rows, err = e.store.Pool.Query(e.ctx, `SELECT partition_key,key,value FROM `+e.store.Params.SanitizedTableName+` WHERE partition_key=$1 ORDER BY key LIMIT $2`, e.partitionKey, e.limit)
   407  	} else {
   408  		compareOp := ">="
   409  		if !e.includeStart {
   410  			compareOp = ">"
   411  		}
   412  		rows, err = e.store.Pool.Query(e.ctx, `SELECT partition_key,key,value FROM `+e.store.Params.SanitizedTableName+` WHERE partition_key=$1 AND key `+compareOp+` $2 ORDER BY key LIMIT $3`, e.partitionKey, e.startKey, e.limit)
   413  	}
   414  	if err != nil {
   415  		e.err = fmt.Errorf("postgres scan: %w", err)
   416  		return
   417  	}
   418  	defer rows.Close()
   419  	err = pgxscan.ScanAll(&e.entries, rows)
   420  	if err != nil {
   421  		e.err = fmt.Errorf("scanning all entries: %w", err)
   422  		return
   423  	}
   424  	e.currEntryIdx = -1
   425  }
   426  
   427  func (e *EntriesIterator) isInRange(key []byte) bool {
   428  	if len(e.entries) == 0 {
   429  		return false
   430  	}
   431  	minKey := e.entries[0].Key
   432  	maxKey := e.entries[len(e.entries)-1].Key
   433  	return minKey != nil && maxKey != nil && bytes.Compare(key, minKey) >= 0 && bytes.Compare(key, maxKey) <= 0
   434  }