github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dsess/autoincrement_tracker.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dsess 16 17 import ( 18 "context" 19 "io" 20 "math" 21 "strings" 22 "sync" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 gmstypes "github.com/dolthub/go-mysql-server/sql/types" 26 27 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 28 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" 29 "github.com/dolthub/dolt/go/libraries/doltcore/ref" 30 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 31 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess/mutexmap" 32 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate" 33 "github.com/dolthub/dolt/go/store/prolly/tree" 34 "github.com/dolthub/dolt/go/store/types" 35 ) 36 37 type LockMode int64 38 39 var ( 40 LockMode_Traditional LockMode = 0 41 LockMode_Concurret LockMode = 1 42 LockMode_Interleaved LockMode = 2 43 ) 44 45 type AutoIncrementTracker struct { 46 dbName string 47 sequences *sync.Map // map[string]uint64 48 mm *mutexmap.MutexMap 49 lockMode LockMode 50 } 51 52 var _ globalstate.AutoIncrementTracker = &AutoIncrementTracker{} 53 54 // NewAutoIncrementTracker returns a new autoincrement tracker for the roots given. All roots sets must be 55 // considered because the auto increment value for a table is tracked globally, across all branches. 56 // Roots provided should be the working sets when available, or the branches when they are not (e.g. for remote 57 // branches that don't have a local working set) 58 func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb.Rootish) (*AutoIncrementTracker, error) { 59 ait := AutoIncrementTracker{ 60 dbName: dbName, 61 sequences: &sync.Map{}, 62 mm: mutexmap.NewMutexMap(), 63 } 64 65 for _, root := range roots { 66 root, err := root.ResolveRootValue(ctx) 67 if err != nil { 68 return &AutoIncrementTracker{}, err 69 } 70 71 err = root.IterTables(ctx, func(tableName string, table *doltdb.Table, sch schema.Schema) (bool, error) { 72 ok := schema.HasAutoIncrement(sch) 73 if !ok { 74 return false, nil 75 } 76 77 tableName = strings.ToLower(tableName) 78 79 seq, err := table.GetAutoIncrementValue(ctx) 80 if err != nil { 81 return true, err 82 } 83 84 oldValue, loaded := ait.sequences.LoadOrStore(tableName, seq) 85 if loaded && seq > oldValue.(uint64) { 86 ait.sequences.Store(tableName, seq) 87 } 88 89 return false, nil 90 }) 91 92 if err != nil { 93 return &AutoIncrementTracker{}, err 94 } 95 } 96 97 return &ait, nil 98 } 99 100 func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 { 101 tableName = strings.ToLower(tableName) 102 current, hasCurrent := sequences.Load(tableName) 103 if !hasCurrent { 104 return 0 105 } 106 return current.(uint64) 107 } 108 109 // Current returns the next value to be generated in the auto increment sequence for the table named 110 func (a AutoIncrementTracker) Current(tableName string) uint64 { 111 return loadAutoIncValue(a.sequences, tableName) 112 } 113 114 // Next returns the next auto increment value for the table named using the provided value from an insert (which may 115 // be null or 0, in which case it will be generated from the sequence). 116 func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) { 117 tbl = strings.ToLower(tbl) 118 119 given, err := CoerceAutoIncrementValue(insertVal) 120 if err != nil { 121 return 0, err 122 } 123 124 if a.lockMode == LockMode_Interleaved { 125 release := a.mm.Lock(tbl) 126 defer release() 127 } 128 129 curr := loadAutoIncValue(a.sequences, tbl) 130 131 if given == 0 { 132 // |given| is 0 or NULL 133 a.sequences.Store(tbl, curr+1) 134 return curr, nil 135 } 136 137 if given >= curr { 138 a.sequences.Store(tbl, given+1) 139 return given, nil 140 } 141 142 // |given| < curr 143 return given, nil 144 } 145 146 func (a AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) { 147 return CoerceAutoIncrementValue(val) 148 } 149 150 // CoerceAutoIncrementValue converts |val| into an AUTO_INCREMENT sequence value 151 func CoerceAutoIncrementValue(val interface{}) (uint64, error) { 152 switch typ := val.(type) { 153 case float32: 154 val = math.Round(float64(typ)) 155 case float64: 156 val = math.Round(typ) 157 } 158 159 var err error 160 val, _, err = gmstypes.Uint64.Convert(val) 161 if err != nil { 162 return 0, err 163 } 164 if val == nil || val == uint64(0) { 165 return 0, nil 166 } 167 return val.(uint64), nil 168 } 169 170 // Set sets the auto increment value for the table named, if it's greater than the one already registered for this 171 // table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the 172 // maximum value for this table across all branches. 173 func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) { 174 tableName = strings.ToLower(tableName) 175 176 release := a.mm.Lock(tableName) 177 defer release() 178 179 existing := loadAutoIncValue(a.sequences, tableName) 180 if newAutoIncVal > existing { 181 a.sequences.Store(tableName, newAutoIncVal) 182 return table.SetAutoIncrementValue(ctx, newAutoIncVal) 183 } else { 184 // If the value is not greater than the current tracker, we have more work to do 185 return a.deepSet(ctx, tableName, table, ws, newAutoIncVal) 186 } 187 } 188 189 // deepSet sets the auto increment value for the table named, if it's greater than the one on any branch head for this 190 // database, ignoring the current in-memory tracker value 191 func (a AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) { 192 sess := DSessFromSess(ctx.Session) 193 db, ok := sess.Provider().BaseDatabase(ctx, a.dbName) 194 195 // just give up if we can't find this db for any reason, or it's a non-versioned DB 196 if !ok || !db.Versioned() { 197 return table, nil 198 } 199 200 // First, establish whether to update this table based on the given value and its current max value. 201 sch, err := table.GetSchema(ctx) 202 if err != nil { 203 return nil, err 204 } 205 206 aiCol, ok := schema.GetAutoIncrementColumn(sch) 207 if !ok { 208 return nil, nil 209 } 210 211 var indexData durable.Index 212 aiIndex, ok := sch.Indexes().GetIndexByColumnNames(aiCol.Name) 213 if ok { 214 indexes, err := table.GetIndexSet(ctx) 215 if err != nil { 216 return nil, err 217 } 218 219 indexData, err = indexes.GetIndex(ctx, sch, aiIndex.Name()) 220 if err != nil { 221 return nil, err 222 } 223 } else { 224 indexData, err = table.GetRowData(ctx) 225 if err != nil { 226 return nil, err 227 } 228 } 229 230 currentMax, err := getMaxIndexValue(ctx, indexData) 231 if err != nil { 232 return nil, err 233 } 234 235 // If the given value is less than the current one, the operation is a no-op, bail out early 236 if newAutoIncVal <= currentMax { 237 return table, nil 238 } 239 240 table, err = table.SetAutoIncrementValue(ctx, newAutoIncVal) 241 if err != nil { 242 return nil, err 243 } 244 245 // Now that we have established the current max for this table, reset the global max accordingly 246 maxAutoInc := newAutoIncVal 247 doltdbs := db.DoltDatabases() 248 for _, db := range doltdbs { 249 branches, err := db.GetBranches(ctx) 250 if err != nil { 251 return nil, err 252 } 253 254 remotes, err := db.GetRemoteRefs(ctx) 255 if err != nil { 256 return nil, err 257 } 258 259 rootRefs := make([]ref.DoltRef, 0, len(branches)+len(remotes)) 260 rootRefs = append(rootRefs, branches...) 261 rootRefs = append(rootRefs, remotes...) 262 263 for _, b := range rootRefs { 264 var rootish doltdb.Rootish 265 switch b.GetType() { 266 case ref.BranchRefType: 267 wsRef, err := ref.WorkingSetRefForHead(b) 268 if err != nil { 269 return nil, err 270 } 271 272 if wsRef == ws { 273 // we don't need to check the working set we're updating 274 continue 275 } 276 277 ws, err := db.ResolveWorkingSet(ctx, wsRef) 278 if err == doltdb.ErrWorkingSetNotFound { 279 // use the branch head if there isn't a working set for it 280 cm, err := db.ResolveCommitRef(ctx, b) 281 if err != nil { 282 return nil, err 283 } 284 rootish = cm 285 } else if err != nil { 286 return nil, err 287 } else { 288 rootish = ws 289 } 290 case ref.RemoteRefType: 291 cm, err := db.ResolveCommitRef(ctx, b) 292 if err != nil { 293 return nil, err 294 } 295 rootish = cm 296 } 297 298 root, err := rootish.ResolveRootValue(ctx) 299 if err != nil { 300 return nil, err 301 } 302 303 table, _, ok, err := doltdb.GetTableInsensitive(ctx, root, tableName) 304 if err != nil { 305 return nil, err 306 } 307 if !ok { 308 continue 309 } 310 311 sch, err := table.GetSchema(ctx) 312 if err != nil { 313 return nil, err 314 } 315 316 if !schema.HasAutoIncrement(sch) { 317 continue 318 } 319 320 tableName = strings.ToLower(tableName) 321 seq, err := table.GetAutoIncrementValue(ctx) 322 if err != nil { 323 return nil, err 324 } 325 326 if seq > maxAutoInc { 327 maxAutoInc = seq 328 } 329 } 330 } 331 332 a.sequences.Store(tableName, maxAutoInc) 333 return table, nil 334 } 335 336 func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, error) { 337 if types.IsFormat_DOLT(indexData.Format()) { 338 idx := durable.ProllyMapFromIndex(indexData) 339 340 iter, err := idx.IterAllReverse(ctx) 341 if err != nil { 342 return 0, err 343 } 344 345 kd, _ := idx.Descriptors() 346 k, _, err := iter.Next(ctx) 347 if err == io.EOF { 348 return 0, nil 349 } else if err != nil { 350 return 0, err 351 } 352 353 // TODO: is the auto-inc column always the first column in the index? 354 field, err := tree.GetField(ctx, kd, 0, k, idx.NodeStore()) 355 if err != nil { 356 return 0, err 357 } 358 359 maxVal, err := CoerceAutoIncrementValue(field) 360 if err != nil { 361 return 0, err 362 } 363 364 return maxVal, nil 365 } 366 367 // For an LD format table, this operation won't succeed 368 return math.MaxUint64, nil 369 } 370 371 // AddNewTable initializes a new table with an auto increment column to the tracker, as necessary 372 func (a AutoIncrementTracker) AddNewTable(tableName string) { 373 tableName = strings.ToLower(tableName) 374 // only initialize the sequence for this table if no other branch has such a table 375 a.sequences.LoadOrStore(tableName, uint64(1)) 376 } 377 378 // DropTable drops the table with the name given. 379 // To establish the new auto increment value, callers must also pass all other working sets in scope that may include 380 // a table with the same name, omitting the working set that just deleted the table named. 381 func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error { 382 tableName = strings.ToLower(tableName) 383 384 release := a.mm.Lock(tableName) 385 defer release() 386 387 newHighestValue := uint64(1) 388 389 // Get the new highest value from all tables in the working sets given 390 for _, ws := range wses { 391 table, _, exists, err := doltdb.GetTableInsensitive(ctx, ws.WorkingRoot(), tableName) 392 if err != nil { 393 return err 394 } 395 396 if !exists { 397 continue 398 } 399 400 sch, err := table.GetSchema(ctx) 401 if err != nil { 402 return err 403 } 404 405 if schema.HasAutoIncrement(sch) { 406 seq, err := table.GetAutoIncrementValue(ctx) 407 if err != nil { 408 return err 409 } 410 411 if seq > newHighestValue { 412 newHighestValue = seq 413 } 414 } 415 } 416 417 a.sequences.Store(tableName, newHighestValue) 418 419 return nil 420 } 421 422 func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName string) (func(), error) { 423 _, i, _ := sql.SystemVariables.GetGlobal("innodb_autoinc_lock_mode") 424 lockMode := LockMode(i.(int64)) 425 if lockMode == LockMode_Interleaved { 426 panic("Attempted to acquire AutoInc lock for entire insert operation, but lock mode was set to Interleaved") 427 } 428 a.lockMode = lockMode 429 return a.mm.Lock(tableName), nil 430 }