github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dsess/autoincrement_tracker.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dsess
    16  
    17  import (
    18  	"context"
    19  	"io"
    20  	"math"
    21  	"strings"
    22  	"sync"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    26  
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess/mutexmap"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
    33  	"github.com/dolthub/dolt/go/store/prolly/tree"
    34  	"github.com/dolthub/dolt/go/store/types"
    35  )
    36  
    37  type LockMode int64
    38  
    39  var (
    40  	LockMode_Traditional LockMode = 0
    41  	LockMode_Concurret   LockMode = 1
    42  	LockMode_Interleaved LockMode = 2
    43  )
    44  
    45  type AutoIncrementTracker struct {
    46  	dbName    string
    47  	sequences *sync.Map // map[string]uint64
    48  	mm        *mutexmap.MutexMap
    49  	lockMode  LockMode
    50  }
    51  
    52  var _ globalstate.AutoIncrementTracker = &AutoIncrementTracker{}
    53  
    54  // NewAutoIncrementTracker returns a new autoincrement tracker for the roots given. All roots sets must be
    55  // considered because the auto increment value for a table is tracked globally, across all branches.
    56  // Roots provided should be the working sets when available, or the branches when they are not (e.g. for remote
    57  // branches that don't have a local working set)
    58  func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb.Rootish) (*AutoIncrementTracker, error) {
    59  	ait := AutoIncrementTracker{
    60  		dbName:    dbName,
    61  		sequences: &sync.Map{},
    62  		mm:        mutexmap.NewMutexMap(),
    63  	}
    64  
    65  	for _, root := range roots {
    66  		root, err := root.ResolveRootValue(ctx)
    67  		if err != nil {
    68  			return &AutoIncrementTracker{}, err
    69  		}
    70  
    71  		err = root.IterTables(ctx, func(tableName string, table *doltdb.Table, sch schema.Schema) (bool, error) {
    72  			ok := schema.HasAutoIncrement(sch)
    73  			if !ok {
    74  				return false, nil
    75  			}
    76  
    77  			tableName = strings.ToLower(tableName)
    78  
    79  			seq, err := table.GetAutoIncrementValue(ctx)
    80  			if err != nil {
    81  				return true, err
    82  			}
    83  
    84  			oldValue, loaded := ait.sequences.LoadOrStore(tableName, seq)
    85  			if loaded && seq > oldValue.(uint64) {
    86  				ait.sequences.Store(tableName, seq)
    87  			}
    88  
    89  			return false, nil
    90  		})
    91  
    92  		if err != nil {
    93  			return &AutoIncrementTracker{}, err
    94  		}
    95  	}
    96  
    97  	return &ait, nil
    98  }
    99  
   100  func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 {
   101  	tableName = strings.ToLower(tableName)
   102  	current, hasCurrent := sequences.Load(tableName)
   103  	if !hasCurrent {
   104  		return 0
   105  	}
   106  	return current.(uint64)
   107  }
   108  
   109  // Current returns the next value to be generated in the auto increment sequence for the table named
   110  func (a AutoIncrementTracker) Current(tableName string) uint64 {
   111  	return loadAutoIncValue(a.sequences, tableName)
   112  }
   113  
   114  // Next returns the next auto increment value for the table named using the provided value from an insert (which may
   115  // be null or 0, in which case it will be generated from the sequence).
   116  func (a AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) {
   117  	tbl = strings.ToLower(tbl)
   118  
   119  	given, err := CoerceAutoIncrementValue(insertVal)
   120  	if err != nil {
   121  		return 0, err
   122  	}
   123  
   124  	if a.lockMode == LockMode_Interleaved {
   125  		release := a.mm.Lock(tbl)
   126  		defer release()
   127  	}
   128  
   129  	curr := loadAutoIncValue(a.sequences, tbl)
   130  
   131  	if given == 0 {
   132  		// |given| is 0 or NULL
   133  		a.sequences.Store(tbl, curr+1)
   134  		return curr, nil
   135  	}
   136  
   137  	if given >= curr {
   138  		a.sequences.Store(tbl, given+1)
   139  		return given, nil
   140  	}
   141  
   142  	// |given| < curr
   143  	return given, nil
   144  }
   145  
   146  func (a AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) {
   147  	return CoerceAutoIncrementValue(val)
   148  }
   149  
   150  // CoerceAutoIncrementValue converts |val| into an AUTO_INCREMENT sequence value
   151  func CoerceAutoIncrementValue(val interface{}) (uint64, error) {
   152  	switch typ := val.(type) {
   153  	case float32:
   154  		val = math.Round(float64(typ))
   155  	case float64:
   156  		val = math.Round(typ)
   157  	}
   158  
   159  	var err error
   160  	val, _, err = gmstypes.Uint64.Convert(val)
   161  	if err != nil {
   162  		return 0, err
   163  	}
   164  	if val == nil || val == uint64(0) {
   165  		return 0, nil
   166  	}
   167  	return val.(uint64), nil
   168  }
   169  
   170  // Set sets the auto increment value for the table named, if it's greater than the one already registered for this
   171  // table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the
   172  // maximum value for this table across all branches.
   173  func (a AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
   174  	tableName = strings.ToLower(tableName)
   175  
   176  	release := a.mm.Lock(tableName)
   177  	defer release()
   178  
   179  	existing := loadAutoIncValue(a.sequences, tableName)
   180  	if newAutoIncVal > existing {
   181  		a.sequences.Store(tableName, newAutoIncVal)
   182  		return table.SetAutoIncrementValue(ctx, newAutoIncVal)
   183  	} else {
   184  		// If the value is not greater than the current tracker, we have more work to do
   185  		return a.deepSet(ctx, tableName, table, ws, newAutoIncVal)
   186  	}
   187  }
   188  
   189  // deepSet sets the auto increment value for the table named, if it's greater than the one on any branch head for this
   190  // database, ignoring the current in-memory tracker value
   191  func (a AutoIncrementTracker) deepSet(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) {
   192  	sess := DSessFromSess(ctx.Session)
   193  	db, ok := sess.Provider().BaseDatabase(ctx, a.dbName)
   194  
   195  	// just give up if we can't find this db for any reason, or it's a non-versioned DB
   196  	if !ok || !db.Versioned() {
   197  		return table, nil
   198  	}
   199  
   200  	// First, establish whether to update this table based on the given value and its current max value.
   201  	sch, err := table.GetSchema(ctx)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	aiCol, ok := schema.GetAutoIncrementColumn(sch)
   207  	if !ok {
   208  		return nil, nil
   209  	}
   210  
   211  	var indexData durable.Index
   212  	aiIndex, ok := sch.Indexes().GetIndexByColumnNames(aiCol.Name)
   213  	if ok {
   214  		indexes, err := table.GetIndexSet(ctx)
   215  		if err != nil {
   216  			return nil, err
   217  		}
   218  
   219  		indexData, err = indexes.GetIndex(ctx, sch, aiIndex.Name())
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  	} else {
   224  		indexData, err = table.GetRowData(ctx)
   225  		if err != nil {
   226  			return nil, err
   227  		}
   228  	}
   229  
   230  	currentMax, err := getMaxIndexValue(ctx, indexData)
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  
   235  	// If the given value is less than the current one, the operation is a no-op, bail out early
   236  	if newAutoIncVal <= currentMax {
   237  		return table, nil
   238  	}
   239  
   240  	table, err = table.SetAutoIncrementValue(ctx, newAutoIncVal)
   241  	if err != nil {
   242  		return nil, err
   243  	}
   244  
   245  	// Now that we have established the current max for this table, reset the global max accordingly
   246  	maxAutoInc := newAutoIncVal
   247  	doltdbs := db.DoltDatabases()
   248  	for _, db := range doltdbs {
   249  		branches, err := db.GetBranches(ctx)
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  
   254  		remotes, err := db.GetRemoteRefs(ctx)
   255  		if err != nil {
   256  			return nil, err
   257  		}
   258  
   259  		rootRefs := make([]ref.DoltRef, 0, len(branches)+len(remotes))
   260  		rootRefs = append(rootRefs, branches...)
   261  		rootRefs = append(rootRefs, remotes...)
   262  
   263  		for _, b := range rootRefs {
   264  			var rootish doltdb.Rootish
   265  			switch b.GetType() {
   266  			case ref.BranchRefType:
   267  				wsRef, err := ref.WorkingSetRefForHead(b)
   268  				if err != nil {
   269  					return nil, err
   270  				}
   271  
   272  				if wsRef == ws {
   273  					// we don't need to check the working set we're updating
   274  					continue
   275  				}
   276  
   277  				ws, err := db.ResolveWorkingSet(ctx, wsRef)
   278  				if err == doltdb.ErrWorkingSetNotFound {
   279  					// use the branch head if there isn't a working set for it
   280  					cm, err := db.ResolveCommitRef(ctx, b)
   281  					if err != nil {
   282  						return nil, err
   283  					}
   284  					rootish = cm
   285  				} else if err != nil {
   286  					return nil, err
   287  				} else {
   288  					rootish = ws
   289  				}
   290  			case ref.RemoteRefType:
   291  				cm, err := db.ResolveCommitRef(ctx, b)
   292  				if err != nil {
   293  					return nil, err
   294  				}
   295  				rootish = cm
   296  			}
   297  
   298  			root, err := rootish.ResolveRootValue(ctx)
   299  			if err != nil {
   300  				return nil, err
   301  			}
   302  
   303  			table, _, ok, err := doltdb.GetTableInsensitive(ctx, root, tableName)
   304  			if err != nil {
   305  				return nil, err
   306  			}
   307  			if !ok {
   308  				continue
   309  			}
   310  
   311  			sch, err := table.GetSchema(ctx)
   312  			if err != nil {
   313  				return nil, err
   314  			}
   315  
   316  			if !schema.HasAutoIncrement(sch) {
   317  				continue
   318  			}
   319  
   320  			tableName = strings.ToLower(tableName)
   321  			seq, err := table.GetAutoIncrementValue(ctx)
   322  			if err != nil {
   323  				return nil, err
   324  			}
   325  
   326  			if seq > maxAutoInc {
   327  				maxAutoInc = seq
   328  			}
   329  		}
   330  	}
   331  
   332  	a.sequences.Store(tableName, maxAutoInc)
   333  	return table, nil
   334  }
   335  
   336  func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, error) {
   337  	if types.IsFormat_DOLT(indexData.Format()) {
   338  		idx := durable.ProllyMapFromIndex(indexData)
   339  
   340  		iter, err := idx.IterAllReverse(ctx)
   341  		if err != nil {
   342  			return 0, err
   343  		}
   344  
   345  		kd, _ := idx.Descriptors()
   346  		k, _, err := iter.Next(ctx)
   347  		if err == io.EOF {
   348  			return 0, nil
   349  		} else if err != nil {
   350  			return 0, err
   351  		}
   352  
   353  		// TODO: is the auto-inc column always the first column in the index?
   354  		field, err := tree.GetField(ctx, kd, 0, k, idx.NodeStore())
   355  		if err != nil {
   356  			return 0, err
   357  		}
   358  
   359  		maxVal, err := CoerceAutoIncrementValue(field)
   360  		if err != nil {
   361  			return 0, err
   362  		}
   363  
   364  		return maxVal, nil
   365  	}
   366  
   367  	// For an LD format table, this operation won't succeed
   368  	return math.MaxUint64, nil
   369  }
   370  
   371  // AddNewTable initializes a new table with an auto increment column to the tracker, as necessary
   372  func (a AutoIncrementTracker) AddNewTable(tableName string) {
   373  	tableName = strings.ToLower(tableName)
   374  	// only initialize the sequence for this table if no other branch has such a table
   375  	a.sequences.LoadOrStore(tableName, uint64(1))
   376  }
   377  
   378  // DropTable drops the table with the name given.
   379  // To establish the new auto increment value, callers must also pass all other working sets in scope that may include
   380  // a table with the same name, omitting the working set that just deleted the table named.
   381  func (a AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error {
   382  	tableName = strings.ToLower(tableName)
   383  
   384  	release := a.mm.Lock(tableName)
   385  	defer release()
   386  
   387  	newHighestValue := uint64(1)
   388  
   389  	// Get the new highest value from all tables in the working sets given
   390  	for _, ws := range wses {
   391  		table, _, exists, err := doltdb.GetTableInsensitive(ctx, ws.WorkingRoot(), tableName)
   392  		if err != nil {
   393  			return err
   394  		}
   395  
   396  		if !exists {
   397  			continue
   398  		}
   399  
   400  		sch, err := table.GetSchema(ctx)
   401  		if err != nil {
   402  			return err
   403  		}
   404  
   405  		if schema.HasAutoIncrement(sch) {
   406  			seq, err := table.GetAutoIncrementValue(ctx)
   407  			if err != nil {
   408  				return err
   409  			}
   410  
   411  			if seq > newHighestValue {
   412  				newHighestValue = seq
   413  			}
   414  		}
   415  	}
   416  
   417  	a.sequences.Store(tableName, newHighestValue)
   418  
   419  	return nil
   420  }
   421  
   422  func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName string) (func(), error) {
   423  	_, i, _ := sql.SystemVariables.GetGlobal("innodb_autoinc_lock_mode")
   424  	lockMode := LockMode(i.(int64))
   425  	if lockMode == LockMode_Interleaved {
   426  		panic("Attempted to acquire AutoInc lock for entire insert operation, but lock mode was set to Interleaved")
   427  	}
   428  	a.lockMode = lockMode
   429  	return a.mm.Lock(tableName), nil
   430  }