github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/snapshot.go (about) 1 package postgres 2 3 import ( 4 "fmt" 5 "slices" 6 "strconv" 7 "strings" 8 9 "github.com/jackc/pgx/v5/pgtype" 10 ) 11 12 // RegisterTypes registers pgSnapshot and xid8 with a pgtype.ConnInfo. 13 func RegisterTypes(m *pgtype.Map) { 14 m.RegisterType(&pgtype.Type{ 15 Name: "snapshot", 16 OID: 5038, 17 Codec: SnapshotCodec{}, 18 }) 19 m.RegisterType(&pgtype.Type{ 20 Name: "xid", 21 OID: 5069, 22 Codec: Uint64Codec{}, 23 }) 24 m.RegisterDefaultPgType(pgSnapshot{}, "snapshot") 25 m.RegisterDefaultPgType(xid8{}, "xid") 26 } 27 28 type SnapshotCodec struct { 29 pgtype.TextCodec 30 } 31 32 func (SnapshotCodec) DecodeValue(tm *pgtype.Map, oid uint32, format int16, src []byte) (interface{}, error) { 33 if src == nil { 34 return nil, nil 35 } 36 37 var target pgSnapshot 38 scanPlan := tm.PlanScan(oid, format, &target) 39 if scanPlan == nil { 40 return nil, fmt.Errorf("PlanScan did not find a plan") 41 } 42 43 err := scanPlan.Scan(src, &target) 44 if err != nil { 45 return nil, err 46 } 47 48 return target, nil 49 } 50 51 type pgSnapshot struct { 52 xmin, xmax uint64 53 xipList []uint64 // Must always be sorted 54 } 55 56 var ( 57 _ pgtype.TextScanner = &pgSnapshot{} 58 _ pgtype.TextValuer = &pgSnapshot{} 59 ) 60 61 func (s *pgSnapshot) ScanText(v pgtype.Text) error { 62 if !v.Valid { 63 return fmt.Errorf("cannot scan NULL into pgSnapshot") 64 } 65 66 components := strings.SplitN(v.String, ":", 3) 67 if len(components) != 3 { 68 return fmt.Errorf("wrong number of snapshot components: %s", v.String) 69 } 70 71 var err error 72 s.xmin, err = strconv.ParseUint(components[0], 10, 64) 73 if err != nil { 74 return fmt.Errorf("unable to parse xmin: %s", components[0]) 75 } 76 77 s.xmax, err = strconv.ParseUint(components[1], 10, 64) 78 if err != nil { 79 return fmt.Errorf("unable to parse xmax: %s", components[1]) 80 } 81 82 if components[2] != "" { 83 xipStrings := strings.Split(components[2], ",") 84 s.xipList = make([]uint64, len(xipStrings)) 85 for i, xipStr := range xipStrings { 86 s.xipList[i], err = strconv.ParseUint(xipStr, 10, 64) 87 if err != nil { 88 return fmt.Errorf("unable to parse xip: %s", xipStr) 89 } 90 } 91 92 // Do a defensive sort in case the server is feeling out of sorts. 93 slices.Sort(s.xipList) 94 } else { 95 s.xipList = nil 96 } 97 98 return nil 99 } 100 101 func (s pgSnapshot) TextValue() (pgtype.Text, error) { 102 return pgtype.Text{String: s.String(), Valid: true}, nil 103 } 104 105 // String uses the official postgres encoding for snapshots, which is described here: 106 // https://www.postgresql.org/docs/current/functions-info.html#FUNCTIONS-PG-SNAPSHOT-PARTS 107 func (s pgSnapshot) String() string { 108 xipStrs := make([]string, len(s.xipList)) 109 for i, xip := range s.xipList { 110 xipStrs[i] = strconv.FormatUint(xip, 10) 111 } 112 113 components := []string{ 114 strconv.FormatUint(s.xmin, 10), 115 strconv.FormatUint(s.xmax, 10), 116 strings.Join(xipStrs, ","), 117 } 118 119 return strings.Join(components, ":") 120 } 121 122 func (s pgSnapshot) Equal(rhs pgSnapshot) bool { 123 return s.compare(rhs) == equal 124 } 125 126 func (s pgSnapshot) GreaterThan(rhs pgSnapshot) bool { 127 return s.compare(rhs) == gt 128 } 129 130 func (s pgSnapshot) LessThan(rhs pgSnapshot) bool { 131 return s.compare(rhs) == lt 132 } 133 134 type comparisonResult uint8 135 136 const ( 137 _ comparisonResult = iota 138 equal 139 lt 140 gt 141 concurrent 142 ) 143 144 // compare will return whether we can definitely assert that one snapshot was 145 // definitively created after, before, at the same time, or was executed 146 // concurrent with another transaction. We assess this based on whether a 147 // transaction has more, less, or conflicting information about the resolution 148 // of in-progress transactions. E.g. if one snapshot only sees txids 1 and 3 as 149 // visible but another transaction sees 1-3 as visible, that transaction is 150 // greater. 151 func (s pgSnapshot) compare(rhs pgSnapshot) comparisonResult { 152 rhsHasMoreInfo := rhs.anyTXVisible(s.xmax, s.xipList) 153 lhsHasMoreInfo := s.anyTXVisible(rhs.xmax, rhs.xipList) 154 155 switch { 156 case rhsHasMoreInfo && lhsHasMoreInfo: 157 return concurrent 158 case rhsHasMoreInfo: 159 return lt 160 case lhsHasMoreInfo: 161 return gt 162 default: 163 return equal 164 } 165 } 166 167 func (s pgSnapshot) anyTXVisible(first uint64, others []uint64) bool { 168 if s.txVisible(first) { 169 return true 170 } 171 for _, txid := range others { 172 if s.txVisible(txid) { 173 return true 174 } 175 } 176 177 return false 178 } 179 180 // markComplete will create a new snapshot where the specified transaction will be marked as 181 // complete and visible. For example, if txid was present in the xip list of this snapshot 182 // it will be removed and the xmin and xmax will be adjusted accordingly. 183 func (s pgSnapshot) markComplete(txid uint64) pgSnapshot { 184 if txid < s.xmin { 185 // Nothing to do 186 return s 187 } 188 189 xipListCopy := make([]uint64, len(s.xipList)) 190 copy(xipListCopy, s.xipList) 191 192 newSnapshot := pgSnapshot{ 193 s.xmin, 194 s.xmax, 195 xipListCopy, 196 } 197 198 // Adjust the xmax and running tx if necessary 199 if txid >= s.xmax { 200 for newIP := s.xmax; newIP < txid+1; newIP++ { 201 newSnapshot.xipList = append(newSnapshot.xipList, newIP) 202 } 203 newSnapshot.xmax = txid + 1 204 } 205 206 // Mark the tx complete if it's in the xipList 207 // Note: we only find the first if it was erroneously duplicate 208 pos, found := slices.BinarySearch(newSnapshot.xipList, txid) 209 if found { 210 newSnapshot.xipList = slices.Delete(newSnapshot.xipList, pos, pos+1) 211 } 212 213 // Adjust the xmin if necessary 214 if len(newSnapshot.xipList) > 0 { 215 newSnapshot.xmin = newSnapshot.xipList[0] 216 } else { 217 newSnapshot.xmin = newSnapshot.xmax 218 newSnapshot.xipList = nil 219 } 220 221 return newSnapshot 222 } 223 224 // markInProgress will create a new snapshot where the specified transaction will be marked as 225 // in-progress and therefore invisible. For example, if the specified xmin falls between two 226 // values in the xip list, it will be inserted in order. 227 func (s pgSnapshot) markInProgress(txid uint64) pgSnapshot { 228 if txid >= s.xmax { 229 // Nothing to do 230 return s 231 } 232 233 xipListCopy := make([]uint64, len(s.xipList)) 234 copy(xipListCopy, s.xipList) 235 236 newSnapshot := pgSnapshot{ 237 s.xmin, 238 s.xmax, 239 xipListCopy, 240 } 241 242 // Adjust the xmax and running tx if necessary 243 if txid < s.xmin { 244 // Adjust the xmin and prepend the newly running tx 245 newSnapshot.xmin = txid 246 newSnapshot.xipList = append([]uint64{txid}, newSnapshot.xipList...) 247 } else { 248 // Add the newly in-progress xip to the list of in-progress transactions 249 pos, found := slices.BinarySearch(newSnapshot.xipList, txid) 250 if !found { 251 newSnapshot.xipList = slices.Insert(newSnapshot.xipList, pos, txid) 252 } 253 } 254 255 // Adjust the xmax if necessary 256 var numToDrop int 257 startingXipLen := len(newSnapshot.xipList) 258 for numToDrop = 0; numToDrop < startingXipLen; numToDrop++ { 259 if newSnapshot.xipList[startingXipLen-1-numToDrop] != newSnapshot.xmax-uint64(numToDrop)-1 { 260 break 261 } 262 } 263 264 if numToDrop > 0 { 265 newSnapshot.xmax = newSnapshot.xipList[startingXipLen-numToDrop] 266 newSnapshot.xipList = newSnapshot.xipList[:startingXipLen-numToDrop] 267 if len(newSnapshot.xipList) == 0 { 268 newSnapshot.xipList = nil 269 } 270 } 271 272 return newSnapshot 273 } 274 275 // txVisible will return whether the specified txid has a disposition (i.e. committed or rolled back) 276 // in the specified snapshot, and is therefore txVisible to transactions using this snapshot. 277 func (s pgSnapshot) txVisible(txid uint64) bool { 278 switch { 279 case txid < s.xmin: 280 return true 281 case txid >= s.xmax: 282 return false 283 default: 284 _, txInProgress := slices.BinarySearch(s.xipList, txid) 285 return !txInProgress 286 } 287 }