
     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package memory
    17  import (
    18  	"fmt"
    19  	"sort"
    20  	"strconv"
    22  	""
    24  	""
    25  	""
    26  	""
    27  	""
    28  )
    30  // TableData encapsulates all schema and data for a table's schema and rows. Other aspects of a table can change
    31  // freely as needed for different views on a table (column projections, index lookups, filters, etc.) but the
    32  // storage of underlying data lives here.
    33  type TableData struct {
    34  	dbName    string
    35  	tableName string
    36  	comment   string
    38  	// Schema / config data
    39  	schema                  sql.PrimaryKeySchema
    40  	indexes                 map[string]sql.Index
    41  	fkColl                  *ForeignKeyCollection
    42  	checks                  []sql.CheckDefinition
    43  	collation               sql.CollationID
    44  	autoColIdx              int
    45  	primaryKeyIndexes       bool
    46  	fullTextConfigTableName string
    48  	// Data storage
    49  	partitions    map[string][]sql.Row
    50  	partitionKeys [][]byte
    51  	autoIncVal    uint64
    53  	// Indexes are implemented as an unordered slice of rows. The first N elements in the row are the values of the
    54  	// indexed columns, and the final value is the location of the row in the primary storage.
    55  	// TODO: we could make these much more performant by using a tree or other ordered collection
    56  	secondaryIndexStorage map[indexName][]sql.Row
    57  }
    59  type indexName string
    61  // primaryRowLocation is a special marker element in index storage rows containing the partition and index of the row
    62  // in the primary storage.
    63  type primaryRowLocation struct {
    64  	partition string
    65  	idx       int
    66  }
    68  // Table returns a table with this data
    69  func (td TableData) Table(database *BaseDatabase) *Table {
    70  	return &Table{
    71  		db:               database,
    72  		name:             td.tableName,
    73  		data:             &td,
    74  		pkIndexesEnabled: td.primaryKeyIndexes,
    75  	}
    76  }
    78  func (td TableData) copy() *TableData {
    79  	sch := td.schema.Schema.Copy()
    80  	pkSch := sql.NewPrimaryKeySchema(sch, td.schema.PkOrdinals...)
    81  	td.schema = pkSch
    83  	parts := make(map[string][]sql.Row, len(td.partitions))
    84  	for k, v := range td.partitions {
    85  		data := make([]sql.Row, len(v))
    86  		copy(data, v)
    87  		parts[k] = data
    88  	}
    90  	keys := make([][]byte, len(td.partitionKeys))
    91  	for i := range td.partitionKeys {
    92  		keys[i] = make([]byte, len(td.partitionKeys[i]))
    93  		copy(keys[i], td.partitionKeys[i])
    94  	}
    96  	idxStorage := make(map[indexName][]sql.Row, len(td.secondaryIndexStorage))
    97  	for k, v := range td.secondaryIndexStorage {
    98  		data := make([]sql.Row, len(v))
    99  		copy(data, v)
   100  		idxStorage[k] = data
   101  	}
   102  	td.secondaryIndexStorage = idxStorage
   104  	td.partitionKeys, td.partitions = keys, parts
   106  	if td.checks != nil {
   107  		checks := make([]sql.CheckDefinition, len(td.checks))
   108  		copy(checks, td.checks)
   109  		td.checks = checks
   110  	}
   112  	return &td
   113  }
   115  // partition returns the partition for the row given. Uses the primary key columns if they exist, or all columns
   116  // otherwise
   117  func (td TableData) partition(row sql.Row) (int, error) {
   118  	var keyColumns []int
   119  	if len(td.schema.PkOrdinals) > 0 {
   120  		keyColumns = td.schema.PkOrdinals
   121  	} else {
   122  		keyColumns = make([]int, len(td.schema.Schema))
   123  		for i := range keyColumns {
   124  			keyColumns[i] = i
   125  		}
   126  	}
   128  	hash := xxhash.New()
   129  	var err error
   130  	for i := range keyColumns {
   131  		v := row[keyColumns[i]]
   132  		if i > 0 {
   133  			// separate each column with a null byte
   134  			if _, err = hash.Write([]byte{0}); err != nil {
   135  				return 0, err
   136  			}
   137  		}
   139  		t, isStringType := td.schema.Schema[i].Type.(sql.StringType)
   140  		if isStringType && v != nil {
   141  			v, err = types.ConvertToString(v, t)
   142  			if err == nil {
   143  				err = t.Collation().WriteWeightString(hash, v.(string))
   144  			}
   145  		} else {
   146  			_, err = fmt.Fprintf(hash, "%v", v)
   147  		}
   148  		if err != nil {
   149  			return 0, err
   150  		}
   151  	}
   153  	sum64 := hash.Sum64()
   154  	return int(sum64 % uint64(len(td.partitionKeys))), nil
   155  }
   157  func (td *TableData) truncate(schema sql.PrimaryKeySchema) *TableData {
   158  	var keys [][]byte
   159  	var partitions = map[string][]sql.Row{}
   160  	numParts := len(td.partitionKeys)
   162  	for i := 0; i < numParts; i++ {
   163  		key := strconv.Itoa(i)
   164  		keys = append(keys, []byte(key))
   165  		partitions[key] = []sql.Row{}
   166  	}
   168  	td.partitionKeys = keys
   169  	td.partitions = partitions
   170  	td.schema = schema
   172  	td.indexes = rewriteIndexes(td.indexes, schema)
   173  	td.secondaryIndexStorage = make(map[indexName][]sql.Row)
   175  	td.autoIncVal = 0
   176  	if schema.HasAutoIncrement() {
   177  		td.autoIncVal = 1
   178  	}
   180  	return td
   181  }
   183  // rewriteIndexes returns a new set of indexes appropriate for the new schema provided. Index expressions are adjusted
   184  // as necessary, and any indexes for columns that no longer exist are removed from the set.
   185  func rewriteIndexes(indexes map[string]sql.Index, schema sql.PrimaryKeySchema) map[string]sql.Index {
   186  	newIdxes := make(map[string]sql.Index)
   187  	for name, idx := range indexes {
   188  		newIdx := rewriteIndex(idx.(*Index), schema)
   189  		if newIdx != nil {
   190  			newIdxes[name] = newIdx
   191  		}
   192  	}
   193  	return newIdxes
   194  }
   196  // rewriteIndex returns a new index appropriate for the new schema provided, or nil if no columns remain to be indexed
   197  // in the schema
   198  func rewriteIndex(idx *Index, schema sql.PrimaryKeySchema) *Index {
   199  	var newExprs []sql.Expression
   200  	for _, expr := range idx.Exprs {
   201  		newE, _, _ := transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   202  			if gf, ok := e.(*expression.GetField); ok {
   203  				newIdx := schema.IndexOfColName(gf.Name())
   204  				if newIdx < 0 {
   205  					return nil, transform.SameTree, nil
   206  				}
   207  				return gf.WithIndex(newIdx), transform.NewTree, nil
   208  			}
   210  			return e, transform.SameTree, nil
   211  		})
   212  		if newE != nil {
   213  			newExprs = append(newExprs, newE)
   214  		}
   215  	}
   217  	if len(newExprs) == 0 {
   218  		return nil
   219  	}
   221  	newIdx := *idx
   222  	newIdx.Exprs = newExprs
   223  	return &newIdx
   224  }
   226  func (td *TableData) columnIndexes(colNames []string) ([]int, error) {
   227  	columns := make([]int, 0, len(colNames))
   229  	for _, name := range colNames {
   230  		i := td.schema.IndexOf(name, td.tableName)
   231  		if i == -1 {
   232  			return nil, errColumnNotFound.New(name)
   233  		}
   235  		columns = append(columns, i)
   236  	}
   238  	return columns, nil
   239  }
   241  // toStorageRow returns the given row normalized for storage, omitting virtual columns
   242  func (td *TableData) toStorageRow(row sql.Row) sql.Row {
   243  	if !td.schema.HasVirtualColumns() {
   244  		return row
   245  	}
   247  	storageRow := make(sql.Row, len(td.schema.Schema))
   248  	storageRowIdx := 0
   249  	for i, col := range td.schema.Schema {
   250  		if col.Virtual {
   251  			continue
   252  		}
   253  		storageRow[storageRowIdx] = row[i]
   254  		storageRowIdx++
   255  	}
   257  	return storageRow[:storageRowIdx]
   258  }
   260  func (td *TableData) numRows(ctx *sql.Context) (uint64, error) {
   261  	var count uint64
   262  	for _, rows := range td.partitions {
   263  		count += uint64(len(rows))
   264  	}
   266  	return count, nil
   267  }
   269  // throws an error if any two or more rows share the same |cols| values.
   270  func (td *TableData) errIfDuplicateEntryExist(cols []string, idxName string) error {
   271  	columnMapping, err := td.columnIndexes(cols)
   272  	if err != nil {
   273  		return err
   274  	}
   275  	unique := make(map[uint64]struct{})
   276  	for _, partition := range td.partitions {
   277  		for _, row := range partition {
   278  			idxPrefixKey := projectOnRow(columnMapping, row)
   279  			if hasNulls(idxPrefixKey) {
   280  				continue
   281  			}
   282  			h, err := sql.HashOf(idxPrefixKey)
   283  			if err != nil {
   284  				return err
   285  			}
   286  			if _, ok := unique[h]; ok {
   287  				return sql.NewUniqueKeyErr(formatRow(row, columnMapping), false, nil)
   288  			}
   289  			unique[h] = struct{}{}
   290  		}
   291  	}
   292  	return nil
   293  }
   295  func hasNulls(row sql.Row) bool {
   296  	for _, v := range row {
   297  		if v == nil {
   298  			return true
   299  		}
   300  	}
   301  	return false
   302  }
   304  // getColumnOrdinal returns the index in the schema and column with the name given, if it exists, or -1, nil otherwise.
   305  func (td *TableData) getColumnOrdinal(col string) (int, *sql.Column) {
   306  	i := td.schema.IndexOf(col, td.tableName)
   307  	if i == -1 {
   308  		return -1, nil
   309  	}
   311  	return i, td.schema.Schema[i]
   312  }
   314  func (td *TableData) generateCheckName() string {
   315  	i := 1
   316  Top:
   317  	for {
   318  		name := fmt.Sprintf("%s_chk_%d", td.tableName, i)
   319  		for _, check := range td.checks {
   320  			if check.Name == name {
   321  				i++
   322  				continue Top
   323  			}
   324  		}
   325  		return name
   326  	}
   327  }
   329  func (td *TableData) indexColsForTableEditor() ([][]int, [][]uint16) {
   330  	var uniqIdxCols [][]int
   331  	var prefixLengths [][]uint16
   332  	for _, idx := range td.indexes {
   333  		if !idx.IsUnique() {
   334  			continue
   335  		}
   336  		var colNames []string
   337  		expressions := idx.(*Index).Exprs
   338  		for _, exp := range expressions {
   339  			colNames = append(colNames, exp.(*expression.GetField).Name())
   340  		}
   341  		colIdxs, err := td.columnIndexes(colNames)
   342  		if err != nil {
   343  			// this means that the column names in this index aren't in the schema, which can happen in the case of a
   344  			// table rewrite
   345  			continue
   346  		}
   347  		uniqIdxCols = append(uniqIdxCols, colIdxs)
   348  		prefixLengths = append(prefixLengths, idx.PrefixLengths())
   349  	}
   350  	return uniqIdxCols, prefixLengths
   351  }
   353  // Sorts the rows in the partitions of the table to be in primary key order.
   354  func (td *TableData) sortRows() {
   355  	var pk []pkfield
   356  	for _, column := range td.schema.Schema {
   357  		if column.PrimaryKey {
   358  			idx, col := td.getColumnOrdinal(column.Name)
   359  			pk = append(pk, pkfield{idx, col})
   360  		}
   361  	}
   363  	var flattenedRows []partitionRow
   364  	for _, k := range td.partitionKeys {
   365  		p := td.partitions[string(k)]
   366  		for i := 0; i < len(p); i++ {
   367  			flattenedRows = append(flattenedRows, partitionRow{string(k), i})
   368  		}
   369  	}
   371  	sort.Sort(partitionssort{
   372  		pk:      pk,
   373  		ps:      td.partitions,
   374  		allRows: flattenedRows,
   375  		indexes: td.secondaryIndexStorage,
   376  	})
   378  	td.sortSecondaryIndexes()
   379  }
   381  func (td *TableData) sortSecondaryIndexes() {
   382  	for idxName, idxStorage := range td.secondaryIndexStorage {
   383  		idx := td.indexes[string(idxName)].(*Index)
   384  		fieldIndexes := idx.columnIndexes(td.schema.Schema)
   385  		types := make([]sql.Type, len(fieldIndexes))
   386  		for i, idx := range fieldIndexes {
   387  			types[i] = td.schema.Schema[idx].Type
   388  		}
   389  		sort.Slice(idxStorage, func(i, j int) bool {
   390  			for t, typ := range types {
   391  				left := idxStorage[i][t]
   392  				right := idxStorage[j][t]
   394  				// Compare doesn't handle nil values, so we need to handle that case. Nils sort before other values
   395  				if left == nil {
   396  					if right == nil {
   397  						continue
   398  					} else {
   399  						return true
   400  					}
   401  				} else if right == nil {
   402  					return false
   403  				}
   405  				compare, err := typ.Compare(left, right)
   406  				if err != nil {
   407  					panic(err)
   408  				}
   409  				if compare != 0 {
   410  					return compare < 0
   411  				}
   412  			}
   413  			return false
   414  		})
   415  	}
   416  }
   418  func (td TableData) virtualColIndexes() []int {
   419  	var indexes []int
   420  	for i, col := range td.schema.Schema {
   421  		if col.Virtual {
   422  			indexes = append(indexes, i)
   423  		}
   424  	}
   425  	return indexes
   426  }
   428  func insertValueInRows(ctx *sql.Context, data *TableData, colIdx int, colDefault *sql.ColumnDefaultValue) error {
   429  	for k, p := range data.partitions {
   430  		newP := make([]sql.Row, len(p))
   431  		for i, row := range p {
   432  			var newRow sql.Row
   433  			newRow = append(newRow, row[:colIdx]...)
   434  			newRow = append(newRow, nil)
   435  			newRow = append(newRow, row[colIdx:]...)
   436  			var err error
   437  			if !data.schema.Schema[colIdx].Nullable && colDefault == nil {
   438  				newRow[colIdx] = data.schema.Schema[colIdx].Type.Zero()
   439  			} else {
   440  				newRow[colIdx], err = colDefault.Eval(ctx, newRow)
   441  				if err != nil {
   442  					return err
   443  				}
   444  			}
   445  			newP[i] = newRow
   446  		}
   447  		data.partitions[k] = newP
   448  	}
   449  	return nil
   450  }