github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/revisions.go (about)

     1  package postgres
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"errors"
     7  	"fmt"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/jackc/pgx/v5"
    13  
    14  	"github.com/authzed/spicedb/pkg/datastore"
    15  	implv1 "github.com/authzed/spicedb/pkg/proto/impl/v1"
    16  )
    17  
    18  const (
    19  	errRevision       = "unable to find revision: %w"
    20  	errCheckRevision  = "unable to check revision: %w"
    21  	errRevisionFormat = "invalid revision format: %w"
    22  
    23  	// querySelectRevision will round the database's timestamp down to the nearest
    24  	// quantization period, and then find the first transaction (and its active xmin)
    25  	// after that. If there are no transactions newer than the quantization period,
    26  	// it just picks the latest transaction. It will also return the amount of
    27  	// nanoseconds until the next optimized revision would be selected server-side,
    28  	// for use with caching.
    29  	//
    30  	//   %[1] Name of xid column
    31  	//   %[2] Relationship tuple transaction table
    32  	//   %[3] Name of timestamp column
    33  	//   %[4] Quantization period (in nanoseconds)
    34  	//   %[5] Name of snapshot column
    35  	querySelectRevision = `
    36  	WITH selected AS (SELECT COALESCE(
    37  		(SELECT %[1]s FROM %[2]s WHERE %[3]s >= TO_TIMESTAMP(FLOOR(EXTRACT(EPOCH FROM NOW() AT TIME ZONE 'utc') * 1000000000 / %[4]d) * %[4]d / 1000000000) AT TIME ZONE 'utc' ORDER BY %[3]s ASC LIMIT 1),
    38  		NULL
    39  	) as xid)
    40  	SELECT selected.xid,
    41  	COALESCE((SELECT %[5]s FROM %[2]s WHERE %[1]s = selected.xid), (SELECT pg_current_snapshot())),
    42  	%[4]d - CAST(EXTRACT(EPOCH FROM NOW() AT TIME ZONE 'utc') * 1000000000 as bigint) %% %[4]d
    43  	FROM selected;`
    44  
    45  	// queryValidTransaction will return a single row with three values:
    46  	//   1) the transaction ID of the minimum valid (i.e. within the GC window) transaction
    47  	//   2) the snapshot associated with the minimum valid transaction
    48  	//   3) the current snapshot that would be used if a new transaction were created now
    49  	//
    50  	// The input values for the format string are:
    51  	//   %[1] Name of xid column
    52  	//   %[2] Relationship tuple transaction table
    53  	//   %[3] Name of timestamp column
    54  	//   %[4] Inverse of GC window (in seconds)
    55  	//   %[5] Name of the snapshot column
    56  	queryValidTransaction = `
    57  	WITH minvalid AS (
    58  		SELECT %[1]s, %[5]s
    59          FROM %[2]s
    60          WHERE 
    61              %[3]s >= NOW() - INTERVAL '%[4]f seconds'
    62            OR
    63               %[3]s = (SELECT MAX(%[3]s) FROM %[2]s)
    64          ORDER BY %[3]s ASC
    65          LIMIT 1
    66  	)
    67  	SELECT minvalid.%[1]s, minvalid.%[5]s, pg_current_snapshot() FROM minvalid;`
    68  
    69  	queryCurrentSnapshot = `SELECT pg_current_snapshot();`
    70  
    71  	queryCurrentTransactionID = `SELECT pg_current_xact_id()::text::integer;`
    72  	queryLatestXID            = `SELECT max(xid) FROM relation_tuple_transaction;`
    73  )
    74  
    75  func (pgd *pgDatastore) optimizedRevisionFunc(ctx context.Context) (datastore.Revision, time.Duration, error) {
    76  	var revision xid8
    77  	var snapshot pgSnapshot
    78  	var validForNanos time.Duration
    79  	if err := pgd.readPool.QueryRow(ctx, pgd.optimizedRevisionQuery).
    80  		Scan(&revision, &snapshot, &validForNanos); err != nil {
    81  		return datastore.NoRevision, 0, fmt.Errorf(errRevision, err)
    82  	}
    83  
    84  	snapshot = snapshot.markComplete(revision.Uint64)
    85  
    86  	return postgresRevision{snapshot}, validForNanos, nil
    87  }
    88  
    89  func (pgd *pgDatastore) HeadRevision(ctx context.Context) (datastore.Revision, error) {
    90  	ctx, span := tracer.Start(ctx, "HeadRevision")
    91  	defer span.End()
    92  
    93  	var snapshot pgSnapshot
    94  	if err := pgd.readPool.QueryRow(ctx, queryCurrentSnapshot).Scan(&snapshot); err != nil {
    95  		if errors.Is(err, pgx.ErrNoRows) {
    96  			return datastore.NoRevision, nil
    97  		}
    98  		return datastore.NoRevision, fmt.Errorf(errRevision, err)
    99  	}
   100  
   101  	return postgresRevision{snapshot}, nil
   102  }
   103  
   104  func (pgd *pgDatastore) CheckRevision(ctx context.Context, revisionRaw datastore.Revision) error {
   105  	revision, ok := revisionRaw.(postgresRevision)
   106  	if !ok {
   107  		return datastore.NewInvalidRevisionErr(revisionRaw, datastore.CouldNotDetermineRevision)
   108  	}
   109  
   110  	var minXid xid8
   111  	var minSnapshot, currentSnapshot pgSnapshot
   112  	if err := pgd.readPool.QueryRow(ctx, pgd.validTransactionQuery).
   113  		Scan(&minXid, &minSnapshot, &currentSnapshot); err != nil {
   114  		return fmt.Errorf(errCheckRevision, err)
   115  	}
   116  
   117  	if revisionRaw.GreaterThan(postgresRevision{currentSnapshot}) {
   118  		return datastore.NewInvalidRevisionErr(revision, datastore.CouldNotDetermineRevision)
   119  	}
   120  	if minSnapshot.markComplete(minXid.Uint64).GreaterThan(revision.snapshot) {
   121  		return datastore.NewInvalidRevisionErr(revision, datastore.RevisionStale)
   122  	}
   123  
   124  	return nil
   125  }
   126  
   127  // RevisionFromString reverses the encoding process performed by MarshalBinary and String.
   128  func (pgd *pgDatastore) RevisionFromString(revisionStr string) (datastore.Revision, error) {
   129  	return ParseRevisionString(revisionStr)
   130  }
   131  
   132  // ParseRevisionString parses a revision string into a Postgres revision.
   133  func ParseRevisionString(revisionStr string) (rev datastore.Revision, err error) {
   134  	rev, err = parseRevisionProto(revisionStr)
   135  	if err != nil {
   136  		decimalRev, decimalErr := parseRevisionDecimal(revisionStr)
   137  		if decimalErr != nil {
   138  			// If decimal ALSO had an error than it was likely just a mangled original input
   139  			return
   140  		}
   141  		return decimalRev, nil
   142  	}
   143  	return
   144  }
   145  
   146  func parseRevisionProto(revisionStr string) (datastore.Revision, error) {
   147  	protoBytes, err := base64.StdEncoding.DecodeString(revisionStr)
   148  	if err != nil {
   149  		return datastore.NoRevision, fmt.Errorf(errRevisionFormat, err)
   150  	}
   151  
   152  	decoded := implv1.PostgresRevision{}
   153  	if err := decoded.UnmarshalVT(protoBytes); err != nil {
   154  		return datastore.NoRevision, fmt.Errorf(errRevisionFormat, err)
   155  	}
   156  
   157  	xminInt := int64(decoded.Xmin)
   158  
   159  	var xips []uint64
   160  	if len(decoded.RelativeXips) > 0 {
   161  		xips = make([]uint64, len(decoded.RelativeXips))
   162  		for i, relativeXip := range decoded.RelativeXips {
   163  			xips[i] = uint64(xminInt + relativeXip)
   164  		}
   165  	}
   166  
   167  	return postgresRevision{
   168  		pgSnapshot{
   169  			xmin:    decoded.Xmin,
   170  			xmax:    uint64(xminInt + decoded.RelativeXmax),
   171  			xipList: xips,
   172  		},
   173  	}, nil
   174  }
   175  
   176  // MaxLegacyXIPDelta is the maximum allowed delta between the xmin and
   177  // xmax revisions IDs on a *legacy* revision stored as a revision decimal.
   178  // This is set to prevent a delta that is too large from blowing out the
   179  // memory usage of the allocated slice, or even causing a panic in the case
   180  // of a VERY large delta (which can be produced by, for example, a CRDB revision
   181  // being given to a Postgres datastore accidentally).
   182  const MaxLegacyXIPDelta = 1000
   183  
   184  // parseRevisionDecimal parses a deprecated decimal.Decimal encoding of the revision
   185  // with an optional xmin component, in the format of revision.xmin, e.g. 100.99.
   186  // Because we're encoding to a snapshot, we want the revision to be considered visible,
   187  // so we set the xmax and xmin for 1 past the encoded revision for the simple cases.
   188  func parseRevisionDecimal(revisionStr string) (datastore.Revision, error) {
   189  	components := strings.Split(revisionStr, ".")
   190  	numComponents := len(components)
   191  	if numComponents != 1 && numComponents != 2 {
   192  		return datastore.NoRevision, fmt.Errorf(
   193  			errRevisionFormat,
   194  			fmt.Errorf("wrong number of components %d != 1 or 2", len(components)),
   195  		)
   196  	}
   197  
   198  	xid, err := strconv.ParseUint(components[0], 10, 64)
   199  	if err != nil {
   200  		return datastore.NoRevision, fmt.Errorf(errRevisionFormat, err)
   201  	}
   202  
   203  	xmax := xid + 1
   204  	xmin := xid + 1
   205  
   206  	if numComponents == 2 {
   207  		xminCandidate, err := strconv.ParseUint(components[1], 10, 64)
   208  		if err != nil {
   209  			return datastore.NoRevision, fmt.Errorf(errRevisionFormat, err)
   210  		}
   211  		if xminCandidate < xid {
   212  			xmin = xminCandidate
   213  		}
   214  	}
   215  
   216  	var xipList []uint64
   217  	if xmax > xmin {
   218  		// Ensure that the delta is not too large to cause memory issues or a panic.
   219  		if xmax-xmin > MaxLegacyXIPDelta {
   220  			return nil, fmt.Errorf("received revision delta in excess of that expected; are you sure you're not passing a ZedToken from an incompatible datastore?")
   221  		}
   222  
   223  		// TODO(jschorr): Remove this deprecated code path once we have per-datastore-marked ZedTokens.
   224  		xipList = make([]uint64, 0, xmax-xmin)
   225  		for i := xmin; i < xid; i++ {
   226  			xipList = append(xipList, i)
   227  		}
   228  	}
   229  
   230  	return postgresRevision{pgSnapshot{
   231  		xmin:    xmin,
   232  		xmax:    xmax,
   233  		xipList: xipList,
   234  	}}, nil
   235  }
   236  
   237  func createNewTransaction(ctx context.Context, tx pgx.Tx) (newXID xid8, newSnapshot pgSnapshot, err error) {
   238  	ctx, span := tracer.Start(ctx, "createNewTransaction")
   239  	defer span.End()
   240  
   241  	cterr := tx.QueryRow(ctx, createTxn).Scan(&newXID, &newSnapshot)
   242  	if cterr != nil {
   243  		err = fmt.Errorf("error when trying to create a new transaction: %w", cterr)
   244  	}
   245  	return
   246  }
   247  
   248  type postgresRevision struct {
   249  	snapshot pgSnapshot
   250  }
   251  
   252  func (pr postgresRevision) Equal(rhsRaw datastore.Revision) bool {
   253  	rhs, ok := rhsRaw.(postgresRevision)
   254  	return ok && pr.snapshot.Equal(rhs.snapshot)
   255  }
   256  
   257  func (pr postgresRevision) GreaterThan(rhsRaw datastore.Revision) bool {
   258  	if rhsRaw == datastore.NoRevision {
   259  		return true
   260  	}
   261  
   262  	rhs, ok := rhsRaw.(postgresRevision)
   263  	return ok && pr.snapshot.GreaterThan(rhs.snapshot)
   264  }
   265  
   266  func (pr postgresRevision) LessThan(rhsRaw datastore.Revision) bool {
   267  	rhs, ok := rhsRaw.(postgresRevision)
   268  	return ok && pr.snapshot.LessThan(rhs.snapshot)
   269  }
   270  
   271  func (pr postgresRevision) DebugString() string {
   272  	return pr.snapshot.String()
   273  }
   274  
   275  func (pr postgresRevision) String() string {
   276  	return base64.StdEncoding.EncodeToString(pr.mustMarshalBinary())
   277  }
   278  
   279  func (pr postgresRevision) mustMarshalBinary() []byte {
   280  	serialized, err := pr.MarshalBinary()
   281  	if err != nil {
   282  		panic(fmt.Sprintf("unexpected error marshaling proto: %s", err))
   283  	}
   284  	return serialized
   285  }
   286  
   287  // MarshalBinary creates a version of the snapshot that uses relative encoding
   288  // for xmax and xip list values to save bytes when encoded as varint protos.
   289  // For example, snapshot 1001:1004:1001,1003 becomes 1000:3:0,2.
   290  func (pr postgresRevision) MarshalBinary() ([]byte, error) {
   291  	xminInt := int64(pr.snapshot.xmin)
   292  	relativeXips := make([]int64, len(pr.snapshot.xipList))
   293  	for i, xip := range pr.snapshot.xipList {
   294  		relativeXips[i] = int64(xip) - xminInt
   295  	}
   296  
   297  	protoRevision := implv1.PostgresRevision{
   298  		Xmin:         pr.snapshot.xmin,
   299  		RelativeXmax: int64(pr.snapshot.xmax) - xminInt,
   300  		RelativeXips: relativeXips,
   301  	}
   302  
   303  	return protoRevision.MarshalVT()
   304  }
   305  
   306  var _ datastore.Revision = postgresRevision{}
   307  
   308  func revisionKeyFunc(rev revisionWithXid) uint64 {
   309  	return rev.tx.Uint64
   310  }