github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/snapshot.go (about)

     1  package postgres
     2  
     3  import (
     4  	"fmt"
     5  	"slices"
     6  	"strconv"
     7  	"strings"
     8  
     9  	"github.com/jackc/pgx/v5/pgtype"
    10  )
    11  
    12  // RegisterTypes registers pgSnapshot and xid8 with a pgtype.ConnInfo.
    13  func RegisterTypes(m *pgtype.Map) {
    14  	m.RegisterType(&pgtype.Type{
    15  		Name:  "snapshot",
    16  		OID:   5038,
    17  		Codec: SnapshotCodec{},
    18  	})
    19  	m.RegisterType(&pgtype.Type{
    20  		Name:  "xid",
    21  		OID:   5069,
    22  		Codec: Uint64Codec{},
    23  	})
    24  	m.RegisterDefaultPgType(pgSnapshot{}, "snapshot")
    25  	m.RegisterDefaultPgType(xid8{}, "xid")
    26  }
    27  
    28  type SnapshotCodec struct {
    29  	pgtype.TextCodec
    30  }
    31  
    32  func (SnapshotCodec) DecodeValue(tm *pgtype.Map, oid uint32, format int16, src []byte) (interface{}, error) {
    33  	if src == nil {
    34  		return nil, nil
    35  	}
    36  
    37  	var target pgSnapshot
    38  	scanPlan := tm.PlanScan(oid, format, &target)
    39  	if scanPlan == nil {
    40  		return nil, fmt.Errorf("PlanScan did not find a plan")
    41  	}
    42  
    43  	err := scanPlan.Scan(src, &target)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return target, nil
    49  }
    50  
    51  type pgSnapshot struct {
    52  	xmin, xmax uint64
    53  	xipList    []uint64 // Must always be sorted
    54  }
    55  
    56  var (
    57  	_ pgtype.TextScanner = &pgSnapshot{}
    58  	_ pgtype.TextValuer  = &pgSnapshot{}
    59  )
    60  
    61  func (s *pgSnapshot) ScanText(v pgtype.Text) error {
    62  	if !v.Valid {
    63  		return fmt.Errorf("cannot scan NULL into pgSnapshot")
    64  	}
    65  
    66  	components := strings.SplitN(v.String, ":", 3)
    67  	if len(components) != 3 {
    68  		return fmt.Errorf("wrong number of snapshot components: %s", v.String)
    69  	}
    70  
    71  	var err error
    72  	s.xmin, err = strconv.ParseUint(components[0], 10, 64)
    73  	if err != nil {
    74  		return fmt.Errorf("unable to parse xmin: %s", components[0])
    75  	}
    76  
    77  	s.xmax, err = strconv.ParseUint(components[1], 10, 64)
    78  	if err != nil {
    79  		return fmt.Errorf("unable to parse xmax: %s", components[1])
    80  	}
    81  
    82  	if components[2] != "" {
    83  		xipStrings := strings.Split(components[2], ",")
    84  		s.xipList = make([]uint64, len(xipStrings))
    85  		for i, xipStr := range xipStrings {
    86  			s.xipList[i], err = strconv.ParseUint(xipStr, 10, 64)
    87  			if err != nil {
    88  				return fmt.Errorf("unable to parse xip: %s", xipStr)
    89  			}
    90  		}
    91  
    92  		// Do a defensive sort in case the server is feeling out of sorts.
    93  		slices.Sort(s.xipList)
    94  	} else {
    95  		s.xipList = nil
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  func (s pgSnapshot) TextValue() (pgtype.Text, error) {
   102  	return pgtype.Text{String: s.String(), Valid: true}, nil
   103  }
   104  
   105  // String uses the official postgres encoding for snapshots, which is described here:
   106  // https://www.postgresql.org/docs/current/functions-info.html#FUNCTIONS-PG-SNAPSHOT-PARTS
   107  func (s pgSnapshot) String() string {
   108  	xipStrs := make([]string, len(s.xipList))
   109  	for i, xip := range s.xipList {
   110  		xipStrs[i] = strconv.FormatUint(xip, 10)
   111  	}
   112  
   113  	components := []string{
   114  		strconv.FormatUint(s.xmin, 10),
   115  		strconv.FormatUint(s.xmax, 10),
   116  		strings.Join(xipStrs, ","),
   117  	}
   118  
   119  	return strings.Join(components, ":")
   120  }
   121  
   122  func (s pgSnapshot) Equal(rhs pgSnapshot) bool {
   123  	return s.compare(rhs) == equal
   124  }
   125  
   126  func (s pgSnapshot) GreaterThan(rhs pgSnapshot) bool {
   127  	return s.compare(rhs) == gt
   128  }
   129  
   130  func (s pgSnapshot) LessThan(rhs pgSnapshot) bool {
   131  	return s.compare(rhs) == lt
   132  }
   133  
   134  type comparisonResult uint8
   135  
   136  const (
   137  	_ comparisonResult = iota
   138  	equal
   139  	lt
   140  	gt
   141  	concurrent
   142  )
   143  
   144  // compare will return whether we can definitely assert that one snapshot was
   145  // definitively created after, before, at the same time, or was executed
   146  // concurrent with another transaction. We assess this based on whether a
   147  // transaction has more, less, or conflicting information about the resolution
   148  // of in-progress transactions. E.g. if one snapshot only sees txids 1 and 3 as
   149  // visible but another transaction sees 1-3 as visible, that transaction is
   150  // greater.
   151  func (s pgSnapshot) compare(rhs pgSnapshot) comparisonResult {
   152  	rhsHasMoreInfo := rhs.anyTXVisible(s.xmax, s.xipList)
   153  	lhsHasMoreInfo := s.anyTXVisible(rhs.xmax, rhs.xipList)
   154  
   155  	switch {
   156  	case rhsHasMoreInfo && lhsHasMoreInfo:
   157  		return concurrent
   158  	case rhsHasMoreInfo:
   159  		return lt
   160  	case lhsHasMoreInfo:
   161  		return gt
   162  	default:
   163  		return equal
   164  	}
   165  }
   166  
   167  func (s pgSnapshot) anyTXVisible(first uint64, others []uint64) bool {
   168  	if s.txVisible(first) {
   169  		return true
   170  	}
   171  	for _, txid := range others {
   172  		if s.txVisible(txid) {
   173  			return true
   174  		}
   175  	}
   176  
   177  	return false
   178  }
   179  
   180  // markComplete will create a new snapshot where the specified transaction will be marked as
   181  // complete and visible. For example, if txid was present in the xip list of this snapshot
   182  // it will be removed and the xmin and xmax will be adjusted accordingly.
   183  func (s pgSnapshot) markComplete(txid uint64) pgSnapshot {
   184  	if txid < s.xmin {
   185  		// Nothing to do
   186  		return s
   187  	}
   188  
   189  	xipListCopy := make([]uint64, len(s.xipList))
   190  	copy(xipListCopy, s.xipList)
   191  
   192  	newSnapshot := pgSnapshot{
   193  		s.xmin,
   194  		s.xmax,
   195  		xipListCopy,
   196  	}
   197  
   198  	// Adjust the xmax and running tx if necessary
   199  	if txid >= s.xmax {
   200  		for newIP := s.xmax; newIP < txid+1; newIP++ {
   201  			newSnapshot.xipList = append(newSnapshot.xipList, newIP)
   202  		}
   203  		newSnapshot.xmax = txid + 1
   204  	}
   205  
   206  	// Mark the tx complete if it's in the xipList
   207  	// Note: we only find the first if it was erroneously duplicate
   208  	pos, found := slices.BinarySearch(newSnapshot.xipList, txid)
   209  	if found {
   210  		newSnapshot.xipList = slices.Delete(newSnapshot.xipList, pos, pos+1)
   211  	}
   212  
   213  	// Adjust the xmin if necessary
   214  	if len(newSnapshot.xipList) > 0 {
   215  		newSnapshot.xmin = newSnapshot.xipList[0]
   216  	} else {
   217  		newSnapshot.xmin = newSnapshot.xmax
   218  		newSnapshot.xipList = nil
   219  	}
   220  
   221  	return newSnapshot
   222  }
   223  
   224  // markInProgress will create a new snapshot where the specified transaction will be marked as
   225  // in-progress and therefore invisible. For example, if the specified xmin falls between two
   226  // values in the xip list, it will be inserted in order.
   227  func (s pgSnapshot) markInProgress(txid uint64) pgSnapshot {
   228  	if txid >= s.xmax {
   229  		// Nothing to do
   230  		return s
   231  	}
   232  
   233  	xipListCopy := make([]uint64, len(s.xipList))
   234  	copy(xipListCopy, s.xipList)
   235  
   236  	newSnapshot := pgSnapshot{
   237  		s.xmin,
   238  		s.xmax,
   239  		xipListCopy,
   240  	}
   241  
   242  	// Adjust the xmax and running tx if necessary
   243  	if txid < s.xmin {
   244  		// Adjust the xmin and prepend the newly running tx
   245  		newSnapshot.xmin = txid
   246  		newSnapshot.xipList = append([]uint64{txid}, newSnapshot.xipList...)
   247  	} else {
   248  		// Add the newly in-progress xip to the list of in-progress transactions
   249  		pos, found := slices.BinarySearch(newSnapshot.xipList, txid)
   250  		if !found {
   251  			newSnapshot.xipList = slices.Insert(newSnapshot.xipList, pos, txid)
   252  		}
   253  	}
   254  
   255  	// Adjust the xmax if necessary
   256  	var numToDrop int
   257  	startingXipLen := len(newSnapshot.xipList)
   258  	for numToDrop = 0; numToDrop < startingXipLen; numToDrop++ {
   259  		if newSnapshot.xipList[startingXipLen-1-numToDrop] != newSnapshot.xmax-uint64(numToDrop)-1 {
   260  			break
   261  		}
   262  	}
   263  
   264  	if numToDrop > 0 {
   265  		newSnapshot.xmax = newSnapshot.xipList[startingXipLen-numToDrop]
   266  		newSnapshot.xipList = newSnapshot.xipList[:startingXipLen-numToDrop]
   267  		if len(newSnapshot.xipList) == 0 {
   268  			newSnapshot.xipList = nil
   269  		}
   270  	}
   271  
   272  	return newSnapshot
   273  }
   274  
   275  // txVisible will return whether the specified txid has a disposition (i.e. committed or rolled back)
   276  // in the specified snapshot, and is therefore txVisible to transactions using this snapshot.
   277  func (s pgSnapshot) txVisible(txid uint64) bool {
   278  	switch {
   279  	case txid < s.xmin:
   280  		return true
   281  	case txid >= s.xmax:
   282  		return false
   283  	default:
   284  		_, txInProgress := slices.BinarySearch(s.xipList, txid)
   285  		return !txInProgress
   286  	}
   287  }