github.com/dolthub/go-mysql-server@v0.18.0/memory/table.go (about)

     1  // Copyright 2020-2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package memory
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/gob"
    20  	"fmt"
    21  	"io"
    22  	"sort"
    23  	"strconv"
    24  	"strings"
    25  	"sync"
    26  
    27  	errors "gopkg.in/src-d/go-errors.v1"
    28  
    29  	"github.com/dolthub/go-mysql-server/sql"
    30  	"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
    31  	"github.com/dolthub/go-mysql-server/sql/expression"
    32  	"github.com/dolthub/go-mysql-server/sql/fulltext"
    33  	"github.com/dolthub/go-mysql-server/sql/transform"
    34  	"github.com/dolthub/go-mysql-server/sql/types"
    35  )
    36  
    37  type MemTable interface {
    38  	sql.Table
    39  	IgnoreSessionData() bool
    40  	UnderlyingTable() *Table
    41  }
    42  
    43  // Table represents an in-memory database table.
    44  type Table struct {
    45  	name string
    46  
    47  	// Schema and related info
    48  	data *TableData
    49  	// ignoreSessionData is used to ignore session data for versioned tables (smoke tests only), unused otherwise
    50  	ignoreSessionData bool
    51  
    52  	// Projection info and settings
    53  	pkIndexesEnabled bool
    54  	projection       []string
    55  	projectedSchema  sql.Schema
    56  	columns          []int
    57  
    58  	// filters is used for primary index scans with an index lookup
    59  	filters []sql.Expression
    60  
    61  	db *BaseDatabase
    62  }
    63  
    64  var _ sql.Table = (*Table)(nil)
    65  var _ MemTable = (*Table)(nil)
    66  var _ sql.InsertableTable = (*Table)(nil)
    67  var _ sql.UpdatableTable = (*Table)(nil)
    68  var _ sql.DeletableTable = (*Table)(nil)
    69  var _ sql.CommentedTable = (*Table)(nil)
    70  var _ sql.ReplaceableTable = (*Table)(nil)
    71  var _ sql.TruncateableTable = (*Table)(nil)
    72  var _ sql.AlterableTable = (*Table)(nil)
    73  var _ sql.IndexAlterableTable = (*Table)(nil)
    74  var _ sql.CollationAlterableTable = (*Table)(nil)
    75  var _ sql.ForeignKeyTable = (*Table)(nil)
    76  var _ sql.CheckAlterableTable = (*Table)(nil)
    77  var _ sql.RewritableTable = (*Table)(nil)
    78  var _ sql.CheckTable = (*Table)(nil)
    79  var _ sql.AutoIncrementTable = (*Table)(nil)
    80  var _ sql.StatisticsTable = (*Table)(nil)
    81  var _ sql.ProjectedTable = (*Table)(nil)
    82  var _ sql.PrimaryKeyAlterableTable = (*Table)(nil)
    83  var _ sql.PrimaryKeyTable = (*Table)(nil)
    84  var _ fulltext.IndexAlterableTable = (*Table)(nil)
    85  var _ sql.IndexBuildingTable = (*Table)(nil)
    86  var _ sql.Databaseable = (*Table)(nil)
    87  
    88  // NewTable creates a new Table with the given name and schema. Assigns the default collation, therefore if a different
    89  // collation is desired, please use NewTableWithCollation.
    90  func NewTable(db MemoryDatabase, name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection) *Table {
    91  	var baseDatabase *BaseDatabase
    92  	// the dual table has no database
    93  	if db != nil {
    94  		baseDatabase = db.Database()
    95  	}
    96  	return NewPartitionedTableWithCollation(baseDatabase, name, schema, fkColl, 0, sql.Collation_Default, "")
    97  }
    98  
    99  // NewLocalTable returns a table suitable to use for transient non-memory applications
   100  func NewLocalTable(db MemoryDatabase, name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection) *Table {
   101  	var baseDatabase *BaseDatabase
   102  	// the dual table has no database
   103  	if db != nil {
   104  		baseDatabase = db.Database()
   105  	}
   106  	tbl := NewPartitionedTableWithCollation(baseDatabase, name, schema, fkColl, 0, sql.Collation_Default, "")
   107  	tbl.ignoreSessionData = true
   108  	return tbl
   109  }
   110  
   111  // NewTableWithCollation creates a new Table with the given name, schema, and collation.
   112  func NewTableWithCollation(db *BaseDatabase, name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection, collation sql.CollationID) *Table {
   113  	return NewPartitionedTableWithCollation(db, name, schema, fkColl, 0, collation, "")
   114  }
   115  
   116  // NewPartitionedTable creates a new Table with the given name, schema and number of partitions. Assigns the default
   117  // collation, therefore if a different collation is desired, please use NewPartitionedTableWithCollation.
   118  func NewPartitionedTable(db *BaseDatabase, name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection, numPartitions int) *Table {
   119  	return NewPartitionedTableWithCollation(db, name, schema, fkColl, numPartitions, sql.Collation_Default, "")
   120  }
   121  
   122  // NewPartitionedTable creates a new Table with the given name, schema and number of partitions. Assigns the default
   123  // collation, therefore if a different collation is desired, please use NewPartitionedTableWithCollation.
   124  func NewPartitionedTableRevision(db *BaseDatabase, name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection, numPartitions int) *TableRevision {
   125  	tbl := NewPartitionedTableWithCollation(db, name, schema, fkColl, numPartitions, sql.Collation_Default, "")
   126  	tbl.ignoreSessionData = true
   127  	return &TableRevision{tbl}
   128  }
   129  
   130  func stripTblNames(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   131  	switch e := e.(type) {
   132  	case *expression.GetField:
   133  		// strip table names
   134  		ne := expression.NewGetField(e.Index(), e.Type(), e.Name(), e.IsNullable())
   135  		ne = ne.WithBackTickNames(e.IsBackTickNames())
   136  		return ne, transform.NewTree, nil
   137  	default:
   138  	}
   139  	return e, transform.SameTree, nil
   140  }
   141  
   142  // NewPartitionedTableWithCollation creates a new Table with the given name, schema, number of partitions, collation,
   143  // and comment.
   144  func NewPartitionedTableWithCollation(db *BaseDatabase, name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection, numPartitions int, collation sql.CollationID, comment string) *Table {
   145  	var keys [][]byte
   146  	var partitions = map[string][]sql.Row{}
   147  
   148  	if numPartitions < 1 {
   149  		numPartitions = 1
   150  	}
   151  
   152  	for i := 0; i < numPartitions; i++ {
   153  		key := strconv.Itoa(i)
   154  		keys = append(keys, []byte(key))
   155  		partitions[key] = []sql.Row{}
   156  	}
   157  
   158  	var autoIncVal uint64
   159  	autoIncIdx := -1
   160  	for i, c := range schema.Schema {
   161  		if c.AutoIncrement {
   162  			autoIncVal = uint64(1)
   163  			autoIncIdx = i
   164  			break
   165  		}
   166  	}
   167  
   168  	newSchema := make(sql.Schema, len(schema.Schema))
   169  	for i, c := range schema.Schema {
   170  		cCopy := c.Copy()
   171  		if cCopy.Default != nil {
   172  			newDef, _, _ := transform.Expr(cCopy.Default, stripTblNames)
   173  			defStr := newDef.String()
   174  			unrDef := sql.NewUnresolvedColumnDefaultValue(defStr)
   175  			cCopy.Default = unrDef
   176  		}
   177  		if cCopy.Generated != nil {
   178  			newDef, _, _ := transform.Expr(cCopy.Generated, stripTblNames)
   179  			defStr := newDef.String()
   180  			unrDef := sql.NewUnresolvedColumnDefaultValue(defStr)
   181  			cCopy.Generated = unrDef
   182  		}
   183  		if cCopy.OnUpdate != nil {
   184  			newDef, _, _ := transform.Expr(cCopy.OnUpdate, stripTblNames)
   185  			defStr := newDef.String()
   186  			unrDef := sql.NewUnresolvedColumnDefaultValue(defStr)
   187  			cCopy.OnUpdate = unrDef
   188  		}
   189  		newSchema[i] = cCopy
   190  	}
   191  
   192  	schema.Schema = newSchema
   193  
   194  	// The dual table has a nil database
   195  	dbName := ""
   196  	if db != nil {
   197  		dbName = db.Name()
   198  	}
   199  	return &Table{
   200  		name: name,
   201  		data: &TableData{
   202  			dbName:                dbName,
   203  			tableName:             name,
   204  			comment:               comment,
   205  			schema:                schema,
   206  			fkColl:                fkColl,
   207  			collation:             collation,
   208  			partitions:            partitions,
   209  			partitionKeys:         keys,
   210  			autoIncVal:            autoIncVal,
   211  			autoColIdx:            autoIncIdx,
   212  			secondaryIndexStorage: make(map[indexName][]sql.Row),
   213  		},
   214  		db: db,
   215  	}
   216  }
   217  
   218  // Name implements the sql.Table interface.
   219  func (t *Table) Name() string {
   220  	return t.name
   221  }
   222  
   223  func (t *Table) Database() string {
   224  	return t.dbName()
   225  }
   226  
   227  // Schema implements the sql.Table interface.
   228  func (t *Table) Schema() sql.Schema {
   229  	if t.projectedSchema != nil {
   230  		return t.projectedSchema
   231  	}
   232  	return t.data.schema.Schema
   233  }
   234  
   235  // Collation implements the sql.Table interface.
   236  func (t *Table) Collation() sql.CollationID {
   237  	return t.data.collation
   238  }
   239  
   240  // Comment implements the sql.CommentedTable interface.
   241  func (t *Table) Comment() string {
   242  	return t.data.comment
   243  }
   244  
   245  func (t *Table) IgnoreSessionData() bool {
   246  	return t.ignoreSessionData
   247  }
   248  
   249  func (t *Table) UnderlyingTable() *Table {
   250  	return t
   251  }
   252  
   253  func (t *Table) GetPartition(key string) []sql.Row {
   254  	rows, ok := t.data.partitions[key]
   255  	if ok {
   256  		return rows
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  // Partitions implements the sql.Table interface.
   263  func (t *Table) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
   264  	data := t.sessionTableData(ctx)
   265  
   266  	var keys [][]byte
   267  	for _, k := range data.partitionKeys {
   268  		if rows, ok := data.partitions[string(k)]; ok && len(rows) > 0 {
   269  			keys = append(keys, k)
   270  		}
   271  	}
   272  	return &partitionIter{keys: keys}, nil
   273  }
   274  
   275  // rangePartitionIter returns a partition that has range and table data access
   276  type rangePartitionIter struct {
   277  	child  *partitionIter
   278  	ranges sql.Expression
   279  }
   280  
   281  var _ sql.PartitionIter = (*rangePartitionIter)(nil)
   282  
   283  func (i rangePartitionIter) Close(ctx *sql.Context) error {
   284  	return i.child.Close(ctx)
   285  }
   286  
   287  func (i rangePartitionIter) Next(ctx *sql.Context) (sql.Partition, error) {
   288  	part, err := i.child.Next(ctx)
   289  	if err != nil {
   290  		return nil, err
   291  	}
   292  	return &rangePartition{
   293  		Partition: part.(*Partition),
   294  		rang:      i.ranges,
   295  	}, nil
   296  }
   297  
   298  // indexScanPartitionIter is a partition iterator that returns a single partition for an index scan
   299  type indexScanPartitionIter struct {
   300  	once   sync.Once
   301  	index  *Index
   302  	ranges sql.Expression
   303  	lookup sql.IndexLookup
   304  }
   305  
   306  type indexScanPartition struct {
   307  	index  *Index
   308  	lookup sql.IndexLookup
   309  	ranges sql.Expression
   310  }
   311  
   312  func (i indexScanPartition) Key() []byte {
   313  	return []byte("indexScanPartition")
   314  }
   315  
   316  var _ sql.PartitionIter = (*indexScanPartitionIter)(nil)
   317  
   318  func (i *indexScanPartitionIter) Close(ctx *sql.Context) error {
   319  	return nil
   320  }
   321  
   322  func (i *indexScanPartitionIter) Next(ctx *sql.Context) (sql.Partition, error) {
   323  	part, err := indexScanPartition{}, io.EOF
   324  
   325  	i.once.Do(func() {
   326  		part, err = indexScanPartition{
   327  			index:  i.index,
   328  			lookup: i.lookup,
   329  			ranges: i.ranges,
   330  		}, nil
   331  	})
   332  
   333  	return part, err
   334  }
   335  
   336  type rangePartition struct {
   337  	*Partition
   338  	rang sql.Expression
   339  }
   340  
   341  // spatialRangePartitionIter returns a partition that has range and table data access
   342  type spatialRangePartitionIter struct {
   343  	child                  *partitionIter
   344  	ord                    int
   345  	minX, minY, maxX, maxY float64
   346  }
   347  
   348  var _ sql.PartitionIter = (*spatialRangePartitionIter)(nil)
   349  
   350  func (i spatialRangePartitionIter) Close(ctx *sql.Context) error {
   351  	return i.child.Close(ctx)
   352  }
   353  
   354  func (i spatialRangePartitionIter) Next(ctx *sql.Context) (sql.Partition, error) {
   355  	part, err := i.child.Next(ctx)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  	return &spatialRangePartition{
   360  		Partition: part.(*Partition),
   361  		ord:       i.ord,
   362  		minX:      i.minX,
   363  		minY:      i.minY,
   364  		maxX:      i.maxX,
   365  		maxY:      i.maxY,
   366  	}, nil
   367  }
   368  
   369  type spatialRangePartition struct {
   370  	*Partition
   371  	ord                    int
   372  	minX, minY, maxX, maxY float64
   373  }
   374  
   375  // PartitionCount implements the sql.PartitionCounter interface.
   376  func (t *Table) PartitionCount(ctx *sql.Context) (int64, error) {
   377  	data := t.sessionTableData(ctx)
   378  
   379  	return int64(len(data.partitions)), nil
   380  }
   381  
   382  type indexScanRowIter struct {
   383  	i             int
   384  	incrementFunc func()
   385  	index         *Index
   386  	lookup        sql.IndexLookup
   387  	ranges        sql.Expression
   388  	primaryRows   map[string][]sql.Row
   389  	indexRows     []sql.Row
   390  
   391  	columns     []int
   392  	numColumns  int
   393  	virtualCols []int
   394  }
   395  
   396  func newIndexScanRowIter(
   397  	index *Index,
   398  	lookup sql.IndexLookup,
   399  	ranges sql.Expression,
   400  	primaryRows map[string][]sql.Row,
   401  	indexRows []sql.Row,
   402  	columns []int,
   403  	numColumns int,
   404  	virtualCols []int,
   405  ) *indexScanRowIter {
   406  
   407  	i := 0
   408  	if lookup.IsReverse {
   409  		i = len(indexRows) - 1
   410  	}
   411  
   412  	iter := &indexScanRowIter{
   413  		i:           i,
   414  		index:       index,
   415  		lookup:      lookup,
   416  		ranges:      ranges,
   417  		primaryRows: primaryRows,
   418  		indexRows:   indexRows,
   419  		columns:     columns,
   420  		numColumns:  numColumns,
   421  		virtualCols: virtualCols,
   422  	}
   423  
   424  	if lookup.IsReverse {
   425  		iter.incrementFunc = func() {
   426  			iter.i--
   427  		}
   428  	} else {
   429  		iter.incrementFunc = func() {
   430  			iter.i++
   431  		}
   432  	}
   433  
   434  	return iter
   435  }
   436  
   437  func (i *indexScanRowIter) Next(ctx *sql.Context) (sql.Row, error) {
   438  	if i.i >= len(i.indexRows) || i.i < 0 {
   439  		return nil, io.EOF
   440  	}
   441  
   442  	var row sql.Row
   443  	for ; i.i < len(i.indexRows) && i.i >= 0; i.incrementFunc() {
   444  		idxRow := i.indexRows[i.i]
   445  		rowLoc := idxRow[len(idxRow)-1].(primaryRowLocation)
   446  		// this is a bit of a hack: during self-referential foreign key delete cascades, the index storage rows don't get
   447  		// updated at the same time the primary table storage does, since we update the slices directly in the case of
   448  		// the primary index but update the map entries for the secondary index storage.
   449  		// TODO: revisit this once we have b-tree storage in place
   450  		if len(i.primaryRows[rowLoc.partition]) <= rowLoc.idx {
   451  			continue
   452  		}
   453  
   454  		candidate := i.primaryRows[rowLoc.partition][rowLoc.idx]
   455  
   456  		matches, err := indexRowMatches(i.ranges, idxRow[:len(idxRow)-1])
   457  		if err != nil {
   458  			return nil, err
   459  		}
   460  
   461  		if matches {
   462  			row = candidate
   463  			i.incrementFunc()
   464  			break
   465  		}
   466  	}
   467  
   468  	if row == nil {
   469  		return nil, io.EOF
   470  	}
   471  
   472  	row = normalizeRowForRead(row, i.numColumns, i.virtualCols)
   473  
   474  	return projectRow(i.columns, row), nil
   475  }
   476  
   477  func indexRowMatches(ranges sql.Expression, candidate sql.Row) (bool, error) {
   478  	result, err := ranges.Eval(nil, candidate)
   479  	if err != nil {
   480  		return false, err
   481  	}
   482  
   483  	return sql.IsTrue(result), nil
   484  }
   485  
   486  func (i *indexScanRowIter) Close(context *sql.Context) error {
   487  	return nil
   488  }
   489  
   490  // PartitionRows implements the sql.PartitionRows interface.
   491  func (t *Table) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
   492  	data := t.sessionTableData(ctx)
   493  
   494  	if isp, ok := partition.(indexScanPartition); ok {
   495  		numColumns := len(data.schema.Schema)
   496  		if len(t.columns) > 0 {
   497  			numColumns = len(t.columns)
   498  		}
   499  
   500  		return newIndexScanRowIter(
   501  			isp.index,
   502  			isp.lookup,
   503  			isp.ranges,
   504  			data.partitions,
   505  			data.secondaryIndexStorage[indexName(isp.index.Name)],
   506  			t.columns,
   507  			numColumns,
   508  			data.virtualColIndexes(),
   509  		), nil
   510  	}
   511  
   512  	filters := t.filters
   513  	if r, ok := partition.(*rangePartition); ok && r.rang != nil {
   514  		// index lookup is currently a single filter applied to a full table scan
   515  		filters = append(t.filters, r.rang)
   516  	}
   517  
   518  	rows, ok := data.partitions[string(partition.Key())]
   519  	if !ok {
   520  		return nil, sql.ErrPartitionNotFound.New(partition.Key())
   521  	}
   522  	// The slice could be altered by other operations taking place during iteration (such as deletion or insertion), so
   523  	// make a copy of the values as they exist when execution begins.
   524  	rowsCopy := make([]sql.Row, len(rows))
   525  	copy(rowsCopy, rows)
   526  
   527  	numColumns := len(data.schema.Schema)
   528  	if len(t.columns) > 0 {
   529  		numColumns = len(t.columns)
   530  	}
   531  
   532  	if r, ok := partition.(*spatialRangePartition); ok {
   533  		// TODO: virtual column support
   534  		return &spatialTableIter{
   535  			columns: t.columns,
   536  			ord:     r.ord,
   537  			minX:    r.minX,
   538  			minY:    r.minY,
   539  			maxX:    r.maxX,
   540  			maxY:    r.maxY,
   541  			rows:    rowsCopy,
   542  		}, nil
   543  	}
   544  
   545  	return &tableIter{
   546  		rows:        rowsCopy,
   547  		columns:     t.columns,
   548  		numColumns:  numColumns,
   549  		virtualCols: data.virtualColIndexes(),
   550  		filters:     filters,
   551  	}, nil
   552  }
   553  
   554  func (t *Table) DataLength(ctx *sql.Context) (uint64, error) {
   555  	data := t.sessionTableData(ctx)
   556  
   557  	var numBytesPerRow uint64
   558  	for _, col := range data.schema.Schema {
   559  		switch n := col.Type.(type) {
   560  		case sql.NumberType:
   561  			numBytesPerRow += 8
   562  		case sql.StringType:
   563  			numBytesPerRow += uint64(n.MaxByteLength())
   564  		case types.BitType:
   565  			numBytesPerRow += 1
   566  		case sql.DatetimeType:
   567  			numBytesPerRow += 8
   568  		case sql.DecimalType:
   569  			numBytesPerRow += uint64(n.MaximumScale())
   570  		case sql.EnumType:
   571  			numBytesPerRow += 2
   572  		case types.JsonType:
   573  			numBytesPerRow += 20
   574  		case sql.NullType:
   575  			numBytesPerRow += 1
   576  		case types.TimeType:
   577  			numBytesPerRow += 16
   578  		case sql.YearType:
   579  			numBytesPerRow += 8
   580  		default:
   581  			numBytesPerRow += 0
   582  		}
   583  	}
   584  
   585  	numRows, err := data.numRows(ctx)
   586  	if err != nil {
   587  		return 0, err
   588  	}
   589  
   590  	return numBytesPerRow * numRows, nil
   591  }
   592  
   593  func (t *Table) RowCount(ctx *sql.Context) (uint64, bool, error) {
   594  	data := t.sessionTableData(ctx)
   595  	rows, err := data.numRows(ctx)
   596  	return rows, true, err
   597  }
   598  
   599  func NewPartition(key []byte) *Partition {
   600  	return &Partition{key: key}
   601  }
   602  
   603  type Partition struct {
   604  	key []byte
   605  }
   606  
   607  func (p *Partition) Key() []byte { return p.key }
   608  
   609  type partitionIter struct {
   610  	keys [][]byte
   611  	pos  int
   612  }
   613  
   614  func (p *partitionIter) Next(*sql.Context) (sql.Partition, error) {
   615  	if p.pos >= len(p.keys) {
   616  		return nil, io.EOF
   617  	}
   618  
   619  	key := p.keys[p.pos]
   620  	p.pos++
   621  	return &Partition{key}, nil
   622  }
   623  
   624  func (p *partitionIter) Close(*sql.Context) error { return nil }
   625  
   626  type tableIter struct {
   627  	columns     []int
   628  	virtualCols []int
   629  	numColumns  int
   630  
   631  	rows        []sql.Row
   632  	filters     []sql.Expression
   633  	indexValues sql.IndexValueIter
   634  	pos         int
   635  }
   636  
   637  var _ sql.RowIter = (*tableIter)(nil)
   638  
   639  func (i *tableIter) Next(ctx *sql.Context) (sql.Row, error) {
   640  	row, err := i.getRow(ctx)
   641  	if err != nil {
   642  		return nil, err
   643  	}
   644  
   645  	row = normalizeRowForRead(row, i.numColumns, i.virtualCols)
   646  
   647  	for _, f := range i.filters {
   648  		result, err := f.Eval(ctx, row)
   649  		if err != nil {
   650  			return nil, err
   651  		}
   652  		result, _ = sql.ConvertToBool(ctx, result)
   653  		if result != true {
   654  			return i.Next(ctx)
   655  		}
   656  	}
   657  
   658  	return projectRow(i.columns, row), nil
   659  }
   660  
   661  func projectRow(columns []int, row sql.Row) sql.Row {
   662  	if columns != nil {
   663  		resultRow := make(sql.Row, len(columns))
   664  		for i, j := range columns {
   665  			resultRow[i] = row[j]
   666  		}
   667  		return resultRow
   668  	}
   669  	return row
   670  }
   671  
   672  // normalizeRowForRead returns a copy of the row with nil values inserted for any virtual columns
   673  func normalizeRowForRead(row sql.Row, numColumns int, virtualCols []int) sql.Row {
   674  	if len(virtualCols) == 0 {
   675  		return row
   676  	}
   677  
   678  	var virtualRow sql.Row
   679  
   680  	// Columns are the indexes of projected columns, which we don't always have. In either case, we are filling the row
   681  	// with nil values for virtual columns. The simple iteration below only works when the column and virtual column
   682  	// indexes are in ascending order, which is true for the time being.
   683  	var j int
   684  	virtualRow = make(sql.Row, numColumns)
   685  	for i := 0; i < numColumns; i++ {
   686  		if j < len(virtualCols) && i == virtualCols[j] {
   687  			j++
   688  		} else {
   689  			virtualRow[i] = row[i-j]
   690  		}
   691  	}
   692  
   693  	return virtualRow
   694  }
   695  
   696  func (i *tableIter) Close(ctx *sql.Context) error {
   697  	if i.indexValues == nil {
   698  		return nil
   699  	}
   700  
   701  	return i.indexValues.Close(ctx)
   702  }
   703  
   704  func (i *tableIter) getRow(ctx *sql.Context) (sql.Row, error) {
   705  	if i.indexValues != nil {
   706  		return i.getFromIndex(ctx)
   707  	}
   708  
   709  	if i.pos >= len(i.rows) {
   710  		return nil, io.EOF
   711  	}
   712  
   713  	row := i.rows[i.pos]
   714  	i.pos++
   715  	return row, nil
   716  }
   717  
   718  func projectOnRow(columns []int, row sql.Row) sql.Row {
   719  	if len(columns) < 1 {
   720  		return row
   721  	}
   722  
   723  	projected := make([]interface{}, len(columns))
   724  	for i, selected := range columns {
   725  		projected[i] = row[selected]
   726  	}
   727  
   728  	return projected
   729  }
   730  
   731  func (i *tableIter) getFromIndex(ctx *sql.Context) (sql.Row, error) {
   732  	data, err := i.indexValues.Next(ctx)
   733  	if err != nil {
   734  		return nil, err
   735  	}
   736  
   737  	value, err := DecodeIndexValue(data)
   738  	if err != nil {
   739  		return nil, err
   740  	}
   741  
   742  	return i.rows[value.Pos], nil
   743  }
   744  
   745  type spatialTableIter struct {
   746  	columns                []int
   747  	rows                   []sql.Row
   748  	pos                    int
   749  	ord                    int
   750  	minX, minY, maxX, maxY float64
   751  }
   752  
   753  var _ sql.RowIter = (*spatialTableIter)(nil)
   754  
   755  func (i *spatialTableIter) Next(ctx *sql.Context) (sql.Row, error) {
   756  	row, err := i.getRow(ctx)
   757  	if err != nil {
   758  		return nil, err
   759  	}
   760  
   761  	if len(i.columns) == 0 {
   762  		return row, nil
   763  	}
   764  
   765  	// check if bounding boxes of geometry and range intersect
   766  	// if the range [i.minX, i.maxX] and [gMinX, gMaxX] overlap and
   767  	// if the range [i.minY, i.maxY] and [gMinY, gMaxY] overlap
   768  	// then, the bounding boxes intersect
   769  	g, ok := row[i.ord].(types.GeometryValue)
   770  	if !ok {
   771  		return nil, fmt.Errorf("spatial index over non-geometry column")
   772  	}
   773  	gMinX, gMinY, gMaxX, gMaxY := g.BBox()
   774  	xInt := (gMinX <= i.minX && i.minX <= gMaxX) ||
   775  		(gMinX <= i.maxX && i.maxX <= gMaxX) ||
   776  		(i.minX <= gMinX && gMinX <= i.maxX) ||
   777  		(i.minX <= gMaxX && gMaxX <= i.maxX)
   778  	yInt := (gMinY <= i.minY && i.minY <= gMaxY) ||
   779  		(gMinY <= i.maxY && i.maxY <= gMaxY) ||
   780  		(i.minY <= gMinY && gMinY <= i.maxY) ||
   781  		(i.minY <= gMaxY && gMaxY <= i.maxY)
   782  	if !(xInt && yInt) {
   783  		return i.Next(ctx)
   784  	}
   785  
   786  	resultRow := make(sql.Row, len(i.columns))
   787  	for i, j := range i.columns {
   788  		resultRow[i] = row[j]
   789  	}
   790  	return resultRow, nil
   791  }
   792  
   793  func (i *spatialTableIter) Close(ctx *sql.Context) error {
   794  	return nil
   795  }
   796  
   797  func (i *spatialTableIter) getRow(ctx *sql.Context) (sql.Row, error) {
   798  	if i.pos >= len(i.rows) {
   799  		return nil, io.EOF
   800  	}
   801  
   802  	row := i.rows[i.pos]
   803  	i.pos++
   804  	return row, nil
   805  }
   806  
   807  type IndexValue struct {
   808  	Key string
   809  	Pos int
   810  }
   811  
   812  func DecodeIndexValue(data []byte) (*IndexValue, error) {
   813  	dec := gob.NewDecoder(bytes.NewReader(data))
   814  	var value IndexValue
   815  	if err := dec.Decode(&value); err != nil {
   816  		return nil, err
   817  	}
   818  
   819  	return &value, nil
   820  }
   821  
   822  func EncodeIndexValue(value *IndexValue) ([]byte, error) {
   823  	var buf bytes.Buffer
   824  	enc := gob.NewEncoder(&buf)
   825  	if err := enc.Encode(value); err != nil {
   826  		return nil, err
   827  	}
   828  
   829  	return buf.Bytes(), nil
   830  }
   831  
   832  func (t *Table) Inserter(ctx *sql.Context) sql.RowInserter {
   833  	return t.getTableEditor(ctx)
   834  }
   835  
   836  func (t *Table) Updater(ctx *sql.Context) sql.RowUpdater {
   837  	return t.getTableEditor(ctx)
   838  }
   839  
   840  func (t *Table) Replacer(ctx *sql.Context) sql.RowReplacer {
   841  	return t.getTableEditor(ctx)
   842  }
   843  
   844  func (t *Table) Deleter(ctx *sql.Context) sql.RowDeleter {
   845  	return t.getTableEditor(ctx)
   846  }
   847  
   848  func (t *Table) AutoIncrementSetter(ctx *sql.Context) sql.AutoIncrementSetter {
   849  	return t.getTableEditor(ctx).(sql.AutoIncrementSetter)
   850  }
   851  
   852  func (t *Table) getTableEditor(ctx *sql.Context) sql.TableEditor {
   853  	editor, err := t.newTableEditor(ctx)
   854  	if err != nil {
   855  		panic(err)
   856  	}
   857  
   858  	tableSets, err := t.getFulltextTableSets(ctx)
   859  	if err != nil {
   860  		panic(err)
   861  	}
   862  
   863  	if len(tableSets) > 0 {
   864  		editor = t.newFulltextTableEditor(ctx, editor, tableSets)
   865  	}
   866  
   867  	return editor
   868  }
   869  
   870  func (t *Table) getRewriteTableEditor(ctx *sql.Context, oldSchema, newSchema sql.PrimaryKeySchema) sql.TableEditor {
   871  	editor, err := t.tableEditorForRewrite(ctx, oldSchema, newSchema)
   872  	if err != nil {
   873  		panic(err)
   874  	}
   875  
   876  	tableUnderEdit := editor.(*tableEditor).editedTable
   877  	err = tableUnderEdit.modifyFulltextIndexesForRewrite(ctx, tableUnderEdit.data, oldSchema)
   878  	if err != nil {
   879  		panic(err)
   880  	}
   881  
   882  	tableSets, err := fulltextTableSets(ctx, tableUnderEdit.data, tableUnderEdit.db)
   883  	if err != nil {
   884  		panic(err)
   885  	}
   886  
   887  	if len(tableSets) > 0 {
   888  		_, insertCols, err := fulltext.GetKeyColumns(ctx, tableUnderEdit)
   889  		if err != nil {
   890  			panic(err)
   891  		}
   892  
   893  		// table editors used for rewrite need to truncate the fulltext tables as well as the primary table (which happens
   894  		// in the RewriteInserter method for all tables)
   895  		newTableSets := make([]fulltext.TableSet, len(tableSets))
   896  		for i := range tableSets {
   897  			ts := *(&tableSets[i])
   898  
   899  			positionSch, err := fulltext.NewSchema(fulltext.SchemaPosition, insertCols, ts.Position.Name(), tableUnderEdit.Collation())
   900  			if err != nil {
   901  				panic(err)
   902  			}
   903  
   904  			docCountSch, err := fulltext.NewSchema(fulltext.SchemaDocCount, insertCols, ts.DocCount.Name(), tableUnderEdit.Collation())
   905  			if err != nil {
   906  				panic(err)
   907  			}
   908  
   909  			globalCountSch, err := fulltext.NewSchema(fulltext.SchemaGlobalCount, nil, ts.GlobalCount.Name(), tableUnderEdit.Collation())
   910  			if err != nil {
   911  				panic(err)
   912  			}
   913  
   914  			rowCountSch, err := fulltext.NewSchema(fulltext.SchemaRowCount, nil, ts.RowCount.Name(), tableUnderEdit.Collation())
   915  			if err != nil {
   916  				panic(err)
   917  			}
   918  
   919  			ts.RowCount.(*Table).data = ts.RowCount.(*Table).data.copy().truncate(sql.NewPrimaryKeySchema(rowCountSch))
   920  			ts.DocCount.(*Table).data = ts.DocCount.(*Table).data.copy().truncate(sql.NewPrimaryKeySchema(docCountSch))
   921  			ts.GlobalCount.(*Table).data = ts.GlobalCount.(*Table).data.copy().truncate(sql.NewPrimaryKeySchema(globalCountSch))
   922  			ts.Position.(*Table).data = ts.Position.(*Table).data.copy().truncate(sql.NewPrimaryKeySchema(positionSch))
   923  			newTableSets[i] = ts
   924  
   925  			// When we get a rowcount editor below, we are going to use the session data for each of these tables. Since we
   926  			// are rewriting them anyway, update their session data with the new empty data and new schema
   927  			sess := SessionFromContext(ctx)
   928  			sess.putTable(ts.RowCount.(*Table).data)
   929  			sess.putTable(ts.DocCount.(*Table).data)
   930  			sess.putTable(ts.GlobalCount.(*Table).data)
   931  			sess.putTable(ts.Position.(*Table).data)
   932  		}
   933  
   934  		editor = tableUnderEdit.newFulltextTableEditor(ctx, editor, newTableSets)
   935  	}
   936  
   937  	return editor
   938  }
   939  
   940  func (t *Table) newTableEditor(ctx *sql.Context) (sql.TableEditor, error) {
   941  	var ea tableEditAccumulator
   942  	var data *TableData
   943  	if t.ignoreSessionData {
   944  		ea = newTableEditAccumulator(t.data)
   945  		data = t.data
   946  	} else {
   947  		sess := SessionFromContext(ctx)
   948  		ea = sess.editAccumulator(t)
   949  		data = sess.tableData(t)
   950  	}
   951  
   952  	tableUnderEdit := t.copy()
   953  	tableUnderEdit.data = data
   954  
   955  	uniqIdxCols, prefixLengths := t.data.indexColsForTableEditor()
   956  	var editor sql.TableEditor = &tableEditor{
   957  		editedTable:   tableUnderEdit,
   958  		initialTable:  t.copy(),
   959  		ea:            ea,
   960  		uniqueIdxCols: uniqIdxCols,
   961  		prefixLengths: prefixLengths,
   962  	}
   963  	return editor, nil
   964  }
   965  
   966  func (t *Table) tableEditorForRewrite(ctx *sql.Context, oldSchema, newSchema sql.PrimaryKeySchema) (sql.TableEditor, error) {
   967  	// Make a copy of the table under edit with the new schema and no data
   968  	// sess := SessionFromContext(ctx)
   969  	tableUnderEdit := t.copy()
   970  	// tableUnderEdit.data = sess.tableData(t).copy()
   971  	tableData := tableUnderEdit.data.truncate(normalizeSchemaForRewrite(newSchema))
   972  	tableUnderEdit.data = tableData
   973  
   974  	uniqIdxCols, prefixLengths := tableData.indexColsForTableEditor()
   975  	var editor sql.TableEditor = &tableEditor{
   976  		editedTable:   tableUnderEdit,
   977  		initialTable:  t.copy(),
   978  		ea:            newTableEditAccumulator(tableData),
   979  		uniqueIdxCols: uniqIdxCols,
   980  		prefixLengths: prefixLengths,
   981  	}
   982  	return editor, nil
   983  }
   984  
   985  func (t *Table) newFulltextTableEditor(ctx *sql.Context, parentEditor sql.TableEditor, tableSets []fulltext.TableSet) sql.TableEditor {
   986  	configTbl, ok, err := t.db.GetTableInsensitive(ctx, t.data.fullTextConfigTableName)
   987  	if err != nil {
   988  		panic(err)
   989  	}
   990  	if !ok { // This should never happen
   991  		panic(fmt.Sprintf("table `%s` declares the table `%s` as a FULLTEXT config table, but it could not be found", t.name, configTbl))
   992  	}
   993  	ftEditor, err := fulltext.CreateEditor(ctx, t, configTbl.(fulltext.EditableTable), tableSets...)
   994  	if err != nil {
   995  		panic(err)
   996  	}
   997  	parentEditor, err = fulltext.CreateMultiTableEditor(ctx, parentEditor, ftEditor)
   998  	if err != nil {
   999  		panic(err)
  1000  	}
  1001  	return parentEditor
  1002  }
  1003  
  1004  func (t *Table) getFulltextTableSets(ctx *sql.Context) ([]fulltext.TableSet, error) {
  1005  	data := t.sessionTableData(ctx)
  1006  	db := t.db
  1007  
  1008  	return fulltextTableSets(ctx, data, db)
  1009  }
  1010  
  1011  func fulltextTableSets(ctx *sql.Context, data *TableData, db *BaseDatabase) ([]fulltext.TableSet, error) {
  1012  	var tableSets []fulltext.TableSet
  1013  	for _, idx := range data.indexes {
  1014  		if !idx.IsFullText() {
  1015  			continue
  1016  		}
  1017  		if db == nil { // Rewrite your test if you run into this
  1018  			panic("database is nil, which can only happen when adding a table outside of the SQL path, such as during harness creation")
  1019  		}
  1020  		ftIdx, ok := idx.(fulltext.Index)
  1021  		if !ok { // This should never happen
  1022  			panic("index returns true for FULLTEXT, but does not implement interface")
  1023  		}
  1024  		ftTableNames, err := ftIdx.FullTextTableNames(ctx)
  1025  		if err != nil { // This should never happen
  1026  			panic(err.Error())
  1027  		}
  1028  
  1029  		positionTbl, ok, err := db.GetTableInsensitive(ctx, ftTableNames.Position)
  1030  		if err != nil {
  1031  			panic(err)
  1032  		}
  1033  		if !ok { // This should never happen
  1034  			panic(fmt.Sprintf("index `%s` declares the table `%s` as a FULLTEXT position table, but it could not be found", idx.ID(), ftTableNames.Position))
  1035  		}
  1036  		docCountTbl, ok, err := db.GetTableInsensitive(ctx, ftTableNames.DocCount)
  1037  		if err != nil {
  1038  			panic(err)
  1039  		}
  1040  		if !ok { // This should never happen
  1041  			panic(fmt.Sprintf("index `%s` declares the table `%s` as a FULLTEXT doc count table, but it could not be found", idx.ID(), ftTableNames.DocCount))
  1042  		}
  1043  		globalCountTbl, ok, err := db.GetTableInsensitive(ctx, ftTableNames.GlobalCount)
  1044  		if err != nil {
  1045  			panic(err)
  1046  		}
  1047  		if !ok { // This should never happen
  1048  			panic(fmt.Sprintf("index `%s` declares the table `%s` as a FULLTEXT global count table, but it could not be found", idx.ID(), ftTableNames.GlobalCount))
  1049  		}
  1050  		rowCountTbl, ok, err := db.GetTableInsensitive(ctx, ftTableNames.RowCount)
  1051  		if err != nil {
  1052  			panic(err)
  1053  		}
  1054  		if !ok { // This should never happen
  1055  			panic(fmt.Sprintf("index `%s` declares the table `%s` as a FULLTEXT row count table, but it could not be found", idx.ID(), ftTableNames.RowCount))
  1056  		}
  1057  
  1058  		tableSets = append(tableSets, fulltext.TableSet{
  1059  			Index:       ftIdx,
  1060  			Position:    positionTbl.(fulltext.EditableTable),
  1061  			DocCount:    docCountTbl.(fulltext.EditableTable),
  1062  			GlobalCount: globalCountTbl.(fulltext.EditableTable),
  1063  			RowCount:    rowCountTbl.(fulltext.EditableTable),
  1064  		})
  1065  	}
  1066  
  1067  	return tableSets, nil
  1068  }
  1069  
  1070  func (t *Table) Truncate(ctx *sql.Context) (int, error) {
  1071  	data := t.sessionTableData(ctx)
  1072  
  1073  	count := 0
  1074  	for key := range data.partitions {
  1075  		count += len(data.partitions[key])
  1076  	}
  1077  
  1078  	data.truncate(data.schema)
  1079  	return count, nil
  1080  }
  1081  
  1082  // Insert is a convenience method to avoid having to create an inserter in test setup
  1083  func (t *Table) Insert(ctx *sql.Context, row sql.Row) error {
  1084  	inserter := t.Inserter(ctx)
  1085  	if err := inserter.Insert(ctx, row); err != nil {
  1086  		return err
  1087  	}
  1088  	return inserter.Close(ctx)
  1089  }
  1090  
  1091  // PeekNextAutoIncrementValue peeks at the next AUTO_INCREMENT value
  1092  func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
  1093  	data := t.sessionTableData(ctx)
  1094  
  1095  	return data.autoIncVal, nil
  1096  }
  1097  
  1098  // GetNextAutoIncrementValue gets the next auto increment value for the memory table the increment.
  1099  func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
  1100  	data := t.sessionTableData(ctx)
  1101  
  1102  	cmp, err := types.Uint64.Compare(insertVal, data.autoIncVal)
  1103  	if err != nil {
  1104  		return 0, err
  1105  	}
  1106  
  1107  	if cmp > 0 && insertVal != nil {
  1108  		v, _, err := types.Uint64.Convert(insertVal)
  1109  		if err != nil {
  1110  			return 0, err
  1111  		}
  1112  		data.autoIncVal = v.(uint64)
  1113  	}
  1114  
  1115  	return data.autoIncVal, nil
  1116  }
  1117  
  1118  func (t *Table) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error {
  1119  	sess := SessionFromContext(ctx)
  1120  	data := sess.tableData(t)
  1121  
  1122  	newColIdx, data, err := addColumnToSchema(ctx, data, column, order)
  1123  	if err != nil {
  1124  		return err
  1125  	}
  1126  
  1127  	err = insertValueInRows(ctx, data, newColIdx, column.Default)
  1128  	if err != nil {
  1129  		return err
  1130  	}
  1131  
  1132  	sess.putTable(data)
  1133  	return nil
  1134  }
  1135  
  1136  // addColumnToSchema adds the given column to the schema and returns the new index
  1137  func addColumnToSchema(ctx *sql.Context, data *TableData, newCol *sql.Column, order *sql.ColumnOrder) (int, *TableData, error) {
  1138  	// TODO: might have wrong case
  1139  	newCol.Source = data.tableName
  1140  	newSch := make(sql.Schema, len(data.schema.Schema)+1)
  1141  
  1142  	// TODO: need to fix this in the engine itself
  1143  	if newCol.PrimaryKey {
  1144  		newCol.Nullable = false
  1145  	}
  1146  
  1147  	newColIdx := 0
  1148  	var i int
  1149  	if order != nil && order.First {
  1150  		newSch[i] = newCol
  1151  		i++
  1152  	}
  1153  
  1154  	numPrecedingVirtuals := 0
  1155  	for _, col := range data.schema.Schema {
  1156  		newSch[i] = col
  1157  		if col.Virtual {
  1158  			numPrecedingVirtuals++
  1159  		}
  1160  		i++
  1161  		if (order != nil && order.AfterColumn == col.Name) || (order == nil && i == len(data.schema.Schema)) {
  1162  			newSch[i] = newCol
  1163  			newColIdx = i - numPrecedingVirtuals
  1164  			i++
  1165  		}
  1166  	}
  1167  
  1168  	for _, newSchCol := range newSch {
  1169  		newDefault, _, _ := transform.Expr(newSchCol.Default, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
  1170  			if expr, ok := expr.(*expression.GetField); ok {
  1171  				return expr.WithIndex(newSch.IndexOf(expr.Name(), data.tableName)), transform.NewTree, nil
  1172  			}
  1173  			return expr, transform.SameTree, nil
  1174  		})
  1175  		newSchCol.Default = newDefault.(*sql.ColumnDefaultValue)
  1176  	}
  1177  
  1178  	if newCol.AutoIncrement {
  1179  		data.autoColIdx = newColIdx
  1180  		data.autoIncVal = 0
  1181  
  1182  		if newColIdx < len(data.schema.Schema) {
  1183  			for _, p := range data.partitions {
  1184  				for _, row := range p {
  1185  					if row[newColIdx] == nil {
  1186  						continue
  1187  					}
  1188  
  1189  					cmp, err := newCol.Type.Compare(row[newColIdx], data.autoIncVal)
  1190  					if err != nil {
  1191  						panic(err)
  1192  					}
  1193  
  1194  					if cmp > 0 {
  1195  						var val interface{}
  1196  						val, _, err = types.Uint64.Convert(row[newColIdx])
  1197  						if err != nil {
  1198  							panic(err)
  1199  						}
  1200  						data.autoIncVal = val.(uint64)
  1201  					}
  1202  				}
  1203  			}
  1204  		} else {
  1205  			data.autoIncVal = 0
  1206  		}
  1207  
  1208  		data.autoIncVal++
  1209  	}
  1210  
  1211  	newPkOrds := data.schema.PkOrdinals
  1212  	for i := 0; i < len(newPkOrds); i++ {
  1213  		// added column shifts the index of every column after
  1214  		// all ordinals above addIdx will be bumped
  1215  		if newColIdx <= newPkOrds[i] {
  1216  			newPkOrds[i]++
  1217  		}
  1218  	}
  1219  
  1220  	data.schema = sql.NewPrimaryKeySchema(newSch, newPkOrds...)
  1221  
  1222  	return newColIdx, data, nil
  1223  }
  1224  
  1225  func validateMaxRowLength(sch sql.Schema) error {
  1226  	if rowLen := maxRowStorageSize(sch); rowLen > types.MaxRowLength {
  1227  		return analyzererrors.ErrInvalidRowLength.New(types.MaxRowLength, rowLen)
  1228  	}
  1229  	return nil
  1230  }
  1231  
  1232  // maxRowStorageSize simulates InnoDB's storage limitations,
  1233  // which are different than Dolt's.
  1234  func maxRowStorageSize(schema sql.Schema) int64 {
  1235  	var numBytesPerRow int64 = 0
  1236  	for _, col := range schema {
  1237  		switch n := col.Type.(type) {
  1238  		case sql.NumberType:
  1239  			numBytesPerRow += 8
  1240  		case sql.StringType:
  1241  			if types.IsTextBlob(n) {
  1242  				numBytesPerRow += 16
  1243  			} else {
  1244  				numBytesPerRow += n.MaxByteLength()
  1245  			}
  1246  		case types.BitType:
  1247  			numBytesPerRow += 8
  1248  		case sql.DatetimeType:
  1249  			numBytesPerRow += 8
  1250  		case sql.DecimalType:
  1251  			numBytesPerRow += int64(n.MaximumScale())
  1252  		case sql.EnumType:
  1253  			numBytesPerRow += 2
  1254  		case types.JsonType:
  1255  			numBytesPerRow += 20
  1256  		case sql.NullType:
  1257  			numBytesPerRow += 1
  1258  		case types.TimeType:
  1259  			numBytesPerRow += 16
  1260  		case sql.YearType:
  1261  			numBytesPerRow += 8
  1262  		default:
  1263  			panic(fmt.Sprintf("unknown type in create table: %s", n.String()))
  1264  		}
  1265  	}
  1266  	return numBytesPerRow
  1267  }
  1268  
  1269  func (t *Table) DropColumn(ctx *sql.Context, columnName string) error {
  1270  	sess := SessionFromContext(ctx)
  1271  	data := sess.tableData(t)
  1272  
  1273  	droppedCol, data := dropColumnFromSchema(ctx, data, columnName)
  1274  	for k, p := range data.partitions {
  1275  		newP := make([]sql.Row, len(p))
  1276  		for i, row := range p {
  1277  			var newRow sql.Row
  1278  			newRow = append(newRow, row[:droppedCol]...)
  1279  			newRow = append(newRow, row[droppedCol+1:]...)
  1280  			newP[i] = newRow
  1281  		}
  1282  		data.partitions[k] = newP
  1283  	}
  1284  
  1285  	sess.putTable(data)
  1286  
  1287  	return nil
  1288  }
  1289  
  1290  // dropColumnFromSchema drops the given column name from the schema and returns its old index.
  1291  func dropColumnFromSchema(ctx *sql.Context, data *TableData, columnName string) (int, *TableData) {
  1292  	newSch := make(sql.Schema, len(data.schema.Schema)-1)
  1293  	var i int
  1294  	droppedCol := -1
  1295  	for _, col := range data.schema.Schema {
  1296  		if col.Name != columnName {
  1297  			newSch[i] = col
  1298  			i++
  1299  		} else {
  1300  			droppedCol = i
  1301  		}
  1302  	}
  1303  
  1304  	newPkOrds := data.schema.PkOrdinals
  1305  	for i := 0; i < len(newPkOrds); i++ {
  1306  		// deleting a column will shift subsequent column indices left
  1307  		// PK ordinals after dropIdx bumped down
  1308  		if droppedCol <= newPkOrds[i] {
  1309  			newPkOrds[i]--
  1310  		}
  1311  	}
  1312  
  1313  	data.schema = sql.NewPrimaryKeySchema(newSch, newPkOrds...)
  1314  	return droppedCol, data
  1315  }
  1316  
  1317  func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Column, order *sql.ColumnOrder) error {
  1318  	sess := SessionFromContext(ctx)
  1319  	data := sess.tableData(t)
  1320  
  1321  	oldIdx := -1
  1322  	newIdx := 0
  1323  	for i, col := range data.schema.Schema {
  1324  		if col.Name == columnName {
  1325  			oldIdx = i
  1326  			column.PrimaryKey = col.PrimaryKey
  1327  			if column.PrimaryKey {
  1328  				column.Nullable = false
  1329  			}
  1330  			// We've removed auto increment through this modification so we need to do some bookkeeping
  1331  			if col.AutoIncrement && !column.AutoIncrement {
  1332  				data.autoColIdx = -1
  1333  				data.autoIncVal = 0
  1334  			}
  1335  			break
  1336  		}
  1337  	}
  1338  
  1339  	if order == nil {
  1340  		newIdx = oldIdx
  1341  		if newIdx == 0 {
  1342  			order = &sql.ColumnOrder{First: true}
  1343  		} else {
  1344  			order = &sql.ColumnOrder{AfterColumn: data.schema.Schema[newIdx-1].Name}
  1345  		}
  1346  	} else if !order.First {
  1347  		var oldSchemaWithoutCol sql.Schema
  1348  		oldSchemaWithoutCol = append(oldSchemaWithoutCol, data.schema.Schema[:oldIdx]...)
  1349  		oldSchemaWithoutCol = append(oldSchemaWithoutCol, data.schema.Schema[oldIdx+1:]...)
  1350  		for i, col := range oldSchemaWithoutCol {
  1351  			if col.Name == order.AfterColumn {
  1352  				newIdx = i + 1
  1353  				break
  1354  			}
  1355  		}
  1356  	}
  1357  
  1358  	for k, p := range data.partitions {
  1359  		newP := make([]sql.Row, len(p))
  1360  		for i, row := range p {
  1361  			var oldRowWithoutVal sql.Row
  1362  			oldRowWithoutVal = append(oldRowWithoutVal, row[:oldIdx]...)
  1363  			oldRowWithoutVal = append(oldRowWithoutVal, row[oldIdx+1:]...)
  1364  			newVal, inRange, err := column.Type.Convert(row[oldIdx])
  1365  			if err != nil {
  1366  				if sql.ErrNotMatchingSRID.Is(err) {
  1367  					err = sql.ErrNotMatchingSRIDWithColName.New(columnName, err)
  1368  				}
  1369  				return err
  1370  			}
  1371  			if !inRange {
  1372  				return sql.ErrValueOutOfRange.New(row[oldIdx], column.Type)
  1373  			}
  1374  			var newRow sql.Row
  1375  			newRow = append(newRow, oldRowWithoutVal[:newIdx]...)
  1376  			newRow = append(newRow, newVal)
  1377  			newRow = append(newRow, oldRowWithoutVal[newIdx:]...)
  1378  			newP[i] = newRow
  1379  		}
  1380  		data.partitions[k] = newP
  1381  	}
  1382  
  1383  	pkNameToOrdIdx := make(map[string]int)
  1384  	for i, ord := range data.schema.PkOrdinals {
  1385  		pkNameToOrdIdx[data.schema.Schema[ord].Name] = i
  1386  	}
  1387  
  1388  	oldSch := data.schema
  1389  	_, _ = dropColumnFromSchema(ctx, data, columnName)
  1390  	_, _, err := addColumnToSchema(ctx, data, column, order)
  1391  	if err != nil {
  1392  		data.schema = oldSch
  1393  		return err
  1394  	}
  1395  
  1396  	newPkOrds := make([]int, len(data.schema.PkOrdinals))
  1397  	for ord, col := range data.schema.Schema {
  1398  		if col.PrimaryKey {
  1399  			i := pkNameToOrdIdx[col.Name]
  1400  			newPkOrds[i] = ord
  1401  		}
  1402  	}
  1403  
  1404  	data.schema.PkOrdinals = newPkOrds
  1405  
  1406  	for _, index := range data.indexes {
  1407  		memIndex := index.(*Index)
  1408  		nameLowercase := strings.ToLower(columnName)
  1409  		for i, expr := range memIndex.Exprs {
  1410  			getField := expr.(*expression.GetField)
  1411  			if strings.ToLower(getField.Name()) == nameLowercase {
  1412  				memIndex.Exprs[i] = expression.NewGetFieldWithTable(newIdx, int(getField.TableId()), column.Type, getField.Database(), getField.Table(), column.Name, column.Nullable)
  1413  			}
  1414  		}
  1415  	}
  1416  
  1417  	sess.putTable(data)
  1418  
  1419  	return nil
  1420  }
  1421  
  1422  // PrimaryKeySchema implements sql.PrimaryKeyAlterableTable
  1423  func (t *Table) PrimaryKeySchema() sql.PrimaryKeySchema {
  1424  	return t.data.schema
  1425  }
  1426  
  1427  // String implements the sql.Table interface.
  1428  func (t *Table) String() string {
  1429  	return t.name
  1430  }
  1431  
  1432  var debugDataPrint = false
  1433  
  1434  func (t *Table) DebugString() string {
  1435  	if debugDataPrint {
  1436  		p := t.data.partitions["0"]
  1437  		s := ""
  1438  		for i, row := range p {
  1439  			if i > 0 {
  1440  				s += ", "
  1441  			}
  1442  			s += fmt.Sprintf("%v", row)
  1443  		}
  1444  		return s
  1445  	}
  1446  
  1447  	p := sql.NewTreePrinter()
  1448  
  1449  	children := []string{fmt.Sprintf("name: %s", t.name)}
  1450  
  1451  	if len(t.columns) > 0 {
  1452  		var projections []string
  1453  		for _, column := range t.columns {
  1454  			projections = append(projections, fmt.Sprintf("%d", column))
  1455  		}
  1456  		children = append(children, fmt.Sprintf("projections: %s", projections))
  1457  
  1458  	}
  1459  
  1460  	if len(t.filters) > 0 {
  1461  		var filters []string
  1462  		for _, filter := range t.filters {
  1463  			filters = append(filters, fmt.Sprintf("%s", sql.DebugString(filter)))
  1464  		}
  1465  		children = append(children, fmt.Sprintf("filters: %s", filters))
  1466  	}
  1467  	_ = p.WriteNode("Table")
  1468  	p.WriteChildren(children...)
  1469  	return p.String()
  1470  }
  1471  
  1472  // HandledFilters implements the sql.FilteredTable interface.
  1473  func (t *Table) HandledFilters(filters []sql.Expression) []sql.Expression {
  1474  	var handled []sql.Expression
  1475  	for _, f := range filters {
  1476  		var hasOtherFields bool
  1477  		sql.Inspect(f, func(e sql.Expression) bool {
  1478  			if e, ok := e.(*expression.GetField); ok {
  1479  				if e.Table() != t.name || !t.data.schema.Contains(e.Name(), t.name) {
  1480  					hasOtherFields = true
  1481  					return false
  1482  				}
  1483  			}
  1484  			return true
  1485  		})
  1486  
  1487  		if !hasOtherFields {
  1488  			handled = append(handled, f)
  1489  		}
  1490  	}
  1491  
  1492  	return handled
  1493  }
  1494  
  1495  // FilteredTable functionality in the Table type was disabled for a long period of time, and has developed major
  1496  // issues with the current analyzer logic. It's only used in the pushdown unit tests, and sql.FilteredTable should be
  1497  // considered unstable until this situation is fixed.
  1498  type FilteredTable struct {
  1499  	*Table
  1500  }
  1501  
  1502  var _ sql.FilteredTable = (*FilteredTable)(nil)
  1503  
  1504  func NewFilteredTable(db MemoryDatabase, name string, schema sql.PrimaryKeySchema, fkColl *ForeignKeyCollection) *FilteredTable {
  1505  	return &FilteredTable{
  1506  		Table: NewTable(db, name, schema, fkColl),
  1507  	}
  1508  }
  1509  
  1510  // WithFilters implements the sql.FilteredTable interface.
  1511  func (t *FilteredTable) WithFilters(ctx *sql.Context, filters []sql.Expression) sql.Table {
  1512  	if len(filters) == 0 {
  1513  		return t
  1514  	}
  1515  
  1516  	nt := *t
  1517  	nt.filters = filters
  1518  	return &nt
  1519  }
  1520  
  1521  // WithProjections implements sql.ProjectedTable
  1522  func (t *FilteredTable) WithProjections(schema []string) sql.Table {
  1523  	table := t.Table.WithProjections(schema)
  1524  
  1525  	nt := *t
  1526  	nt.Table = table.(*Table)
  1527  	return &nt
  1528  }
  1529  
  1530  // Projections implements sql.ProjectedTable
  1531  func (t *FilteredTable) Projections() []string {
  1532  	return t.projection
  1533  }
  1534  
  1535  // IndexedTable is a table that expects to return one or more partitions
  1536  // for range lookups.
  1537  type IndexedTable struct {
  1538  	*Table
  1539  	Lookup sql.IndexLookup
  1540  }
  1541  
  1542  func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
  1543  	memIdx := lookup.Index.(*Index)
  1544  	filter, err := memIdx.rangeFilterExpr(ctx, lookup.Ranges...)
  1545  	if err != nil {
  1546  		return nil, err
  1547  	}
  1548  
  1549  	if lookup.Index.IsSpatial() {
  1550  		child, err := t.Table.Partitions(ctx)
  1551  		if err != nil {
  1552  			return nil, err
  1553  		}
  1554  
  1555  		lower := sql.GetRangeCutKey(lookup.Ranges[0][0].LowerBound)
  1556  		upper := sql.GetRangeCutKey(lookup.Ranges[0][0].UpperBound)
  1557  		minPoint, ok := lower.(types.Point)
  1558  		if !ok {
  1559  			return nil, sql.ErrInvalidGISData.New()
  1560  		}
  1561  		maxPoint, ok := upper.(types.Point)
  1562  		if !ok {
  1563  			return nil, sql.ErrInvalidGISData.New()
  1564  		}
  1565  
  1566  		ord := memIdx.Exprs[0].(*expression.GetField).Index()
  1567  		return spatialRangePartitionIter{
  1568  			child: child.(*partitionIter),
  1569  			ord:   ord,
  1570  			minX:  minPoint.X,
  1571  			minY:  minPoint.Y,
  1572  			maxX:  maxPoint.X,
  1573  			maxY:  maxPoint.Y,
  1574  		}, nil
  1575  	}
  1576  
  1577  	if lookup.Index.ID() == "PRIMARY" {
  1578  		child, err := t.Table.Partitions(ctx)
  1579  		if err != nil {
  1580  			return nil, err
  1581  		}
  1582  
  1583  		return rangePartitionIter{
  1584  			child:  child.(*partitionIter),
  1585  			ranges: filter,
  1586  		}, nil
  1587  	}
  1588  
  1589  	indexFilter := adjustRangeScanFilterForIndexLookup(filter, memIdx)
  1590  
  1591  	return &indexScanPartitionIter{
  1592  		index:  memIdx,
  1593  		lookup: lookup,
  1594  		ranges: indexFilter,
  1595  	}, nil
  1596  }
  1597  
  1598  func adjustRangeScanFilterForIndexLookup(filter sql.Expression, index *Index) sql.Expression {
  1599  	exprs := index.ExtendedExprs()
  1600  
  1601  	indexStorageSchema := make(sql.Schema, len(exprs))
  1602  	for i, e := range exprs {
  1603  		indexStorageSchema[i] = &sql.Column{
  1604  			Name: e.(*expression.GetField).Name(),
  1605  		}
  1606  	}
  1607  
  1608  	filter, _, err := transform.Expr(filter, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
  1609  		if gf, ok := e.(*expression.GetField); ok {
  1610  			idxIdx := indexStorageSchema.IndexOfColName(gf.Name())
  1611  			return gf.WithIndex(idxIdx), transform.NewTree, nil
  1612  		}
  1613  		return e, transform.SameTree, nil
  1614  	})
  1615  
  1616  	if err != nil {
  1617  		panic(err)
  1618  	}
  1619  
  1620  	return filter
  1621  }
  1622  
  1623  // PartitionRows implements the sql.PartitionRows interface.
  1624  func (t *IndexedTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
  1625  	iter, err := t.Table.PartitionRows(ctx, partition)
  1626  	if err != nil {
  1627  		return nil, err
  1628  	}
  1629  
  1630  	// Sorting code below is only for spatial indexes, which use a different partition iterator
  1631  	if _, ok := partition.(indexScanPartition); ok {
  1632  		return iter, nil
  1633  	}
  1634  
  1635  	if t.Lookup.Index != nil {
  1636  		idx := t.Lookup.Index.(*Index)
  1637  		sf := make(sql.SortFields, len(idx.Exprs))
  1638  		for i, e := range idx.Exprs {
  1639  			sf[i] = sql.SortField{Column: e}
  1640  			if t.Lookup.IsReverse {
  1641  				sf[i].Order = sql.Descending
  1642  				// TODO: null ordering?
  1643  			}
  1644  		}
  1645  		var sorter *expression.Sorter
  1646  		if i, ok := iter.(*tableIter); ok {
  1647  			sorter = &expression.Sorter{
  1648  				SortFields: sf,
  1649  				Rows:       i.rows,
  1650  				LastError:  nil,
  1651  				Ctx:        ctx,
  1652  			}
  1653  		} else if i, ok := iter.(*spatialTableIter); ok {
  1654  			sorter = &expression.Sorter{
  1655  				SortFields: sf,
  1656  				Rows:       i.rows,
  1657  				LastError:  nil,
  1658  				Ctx:        ctx,
  1659  			}
  1660  		}
  1661  
  1662  		sort.Stable(sorter)
  1663  	}
  1664  
  1665  	return iter, nil
  1666  }
  1667  
  1668  func (t *Table) IndexedAccess(lookup sql.IndexLookup) sql.IndexedTable {
  1669  	return &IndexedTable{Table: t, Lookup: lookup}
  1670  }
  1671  
  1672  func (t *Table) PreciseMatch() bool {
  1673  	return true
  1674  }
  1675  
  1676  // WithProjections implements sql.ProjectedTable
  1677  func (t *Table) WithProjections(cols []string) sql.Table {
  1678  	nt := *t
  1679  	if cols == nil {
  1680  		nt.projectedSchema = nil
  1681  		nt.projection = nil
  1682  		nt.columns = nil
  1683  		return &nt
  1684  	}
  1685  	columns, err := nt.data.columnIndexes(cols)
  1686  	if err != nil {
  1687  		panic(err)
  1688  	}
  1689  
  1690  	nt.columns = columns
  1691  
  1692  	projectedSchema := make(sql.Schema, len(columns))
  1693  	for i, j := range columns {
  1694  		projectedSchema[i] = nt.data.schema.Schema[j]
  1695  	}
  1696  	nt.projectedSchema = projectedSchema
  1697  	nt.projection = cols
  1698  
  1699  	return &nt
  1700  }
  1701  
  1702  // Projections implements sql.ProjectedTable
  1703  func (t *Table) Projections() []string {
  1704  	return t.projection
  1705  }
  1706  
  1707  // EnablePrimaryKeyIndexes enables the use of primary key indexes on this table.
  1708  func (t *Table) EnablePrimaryKeyIndexes() {
  1709  	t.pkIndexesEnabled = true
  1710  	t.data.primaryKeyIndexes = true
  1711  }
  1712  
  1713  func (t *Table) dbName() string {
  1714  	if t.db != nil {
  1715  		return t.db.Name()
  1716  	}
  1717  	return ""
  1718  }
  1719  
  1720  // GetIndexes implements sql.IndexedTable
  1721  func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
  1722  	data := t.sessionTableData(ctx)
  1723  
  1724  	indexes := make([]sql.Index, 0)
  1725  
  1726  	if data.primaryKeyIndexes {
  1727  		if len(data.schema.PkOrdinals) > 0 {
  1728  			exprs := make([]sql.Expression, len(data.schema.PkOrdinals))
  1729  			for i, ord := range data.schema.PkOrdinals {
  1730  				column := data.schema.Schema[ord]
  1731  				idx, field := data.getColumnOrdinal(column.Name)
  1732  				exprs[i] = expression.NewGetFieldWithTable(idx, 0, field.Type, t.dbName(), t.name, field.Name, field.Nullable)
  1733  			}
  1734  			indexes = append(indexes, &Index{
  1735  				DB:         t.dbName(),
  1736  				DriverName: "",
  1737  				Tbl:        t,
  1738  				TableName:  t.name,
  1739  				Exprs:      exprs,
  1740  				Name:       "PRIMARY",
  1741  				Unique:     true,
  1742  			})
  1743  		}
  1744  	}
  1745  
  1746  	nonPrimaryIndexes := make([]sql.Index, len(data.indexes))
  1747  	var i int
  1748  	for _, index := range data.indexes {
  1749  		nonPrimaryIndexes[i] = index
  1750  		i++
  1751  	}
  1752  	sort.Slice(nonPrimaryIndexes, func(i, j int) bool {
  1753  		return nonPrimaryIndexes[i].ID() < nonPrimaryIndexes[j].ID()
  1754  	})
  1755  
  1756  	return append(indexes, nonPrimaryIndexes...), nil
  1757  }
  1758  
  1759  // GetDeclaredForeignKeys implements the interface sql.ForeignKeyTable.
  1760  func (t *Table) GetDeclaredForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
  1761  	data := t.sessionTableData(ctx)
  1762  
  1763  	//TODO: may not be the best location, need to handle db as well
  1764  	var fks []sql.ForeignKeyConstraint
  1765  	lowerName := strings.ToLower(t.name)
  1766  	for _, fk := range data.fkColl.Keys() {
  1767  		if strings.ToLower(fk.Table) == lowerName {
  1768  			fks = append(fks, fk)
  1769  		}
  1770  	}
  1771  	return fks, nil
  1772  }
  1773  
  1774  // GetReferencedForeignKeys implements the interface sql.ForeignKeyTable.
  1775  func (t *Table) GetReferencedForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
  1776  	data := t.sessionTableData(ctx)
  1777  
  1778  	// TODO: may not be the best location, need to handle db as well
  1779  	var fks []sql.ForeignKeyConstraint
  1780  	lowerName := strings.ToLower(t.name)
  1781  	for _, fk := range data.fkColl.Keys() {
  1782  		if strings.ToLower(fk.ParentTable) == lowerName {
  1783  			fks = append(fks, fk)
  1784  		}
  1785  	}
  1786  	return fks, nil
  1787  }
  1788  
  1789  // AddForeignKey implements sql.ForeignKeyTable. Foreign partitionKeys are not enforced on update / delete.
  1790  func (t *Table) AddForeignKey(ctx *sql.Context, fk sql.ForeignKeyConstraint) error {
  1791  	sess := SessionFromContext(ctx)
  1792  	data := sess.tableData(t)
  1793  
  1794  	lowerName := strings.ToLower(fk.Name)
  1795  	for _, key := range data.fkColl.Keys() {
  1796  		if strings.ToLower(key.Name) == lowerName {
  1797  			return fmt.Errorf("Constraint %s already exists", fk.Name)
  1798  		}
  1799  	}
  1800  	data.fkColl.AddFK(fk)
  1801  
  1802  	return nil
  1803  }
  1804  
  1805  // DropForeignKey implements sql.ForeignKeyTable.
  1806  func (t *Table) DropForeignKey(ctx *sql.Context, fkName string) error {
  1807  	sess := SessionFromContext(ctx)
  1808  	data := sess.tableData(t)
  1809  
  1810  	if data.fkColl.DropFK(fkName) {
  1811  		return nil
  1812  	}
  1813  
  1814  	return sql.ErrForeignKeyNotFound.New(fkName, t.name)
  1815  }
  1816  
  1817  // UpdateForeignKey implements sql.ForeignKeyTable.
  1818  func (t *Table) UpdateForeignKey(ctx *sql.Context, fkName string, fk sql.ForeignKeyConstraint) error {
  1819  	sess := SessionFromContext(ctx)
  1820  	data := sess.tableData(t)
  1821  
  1822  	data.fkColl.DropFK(fkName)
  1823  	lowerName := strings.ToLower(fk.Name)
  1824  	for _, key := range data.fkColl.Keys() {
  1825  		if strings.ToLower(key.Name) == lowerName {
  1826  			return fmt.Errorf("Constraint %s already exists", fk.Name)
  1827  		}
  1828  	}
  1829  	data.fkColl.AddFK(fk)
  1830  
  1831  	return nil
  1832  }
  1833  
  1834  // CreateIndexForForeignKey implements sql.ForeignKeyTable.
  1835  func (t *Table) CreateIndexForForeignKey(ctx *sql.Context, idx sql.IndexDef) error {
  1836  	return t.CreateIndex(ctx, idx)
  1837  }
  1838  
  1839  // SetForeignKeyResolved implements sql.ForeignKeyTable.
  1840  func (t *Table) SetForeignKeyResolved(ctx *sql.Context, fkName string) error {
  1841  	data := t.sessionTableData(ctx)
  1842  
  1843  	if !data.fkColl.SetResolved(fkName) {
  1844  		return sql.ErrForeignKeyNotFound.New(fkName, t.name)
  1845  	}
  1846  	return nil
  1847  }
  1848  
  1849  // GetForeignKeyEditor implements sql.ForeignKeyTable.
  1850  func (t *Table) GetForeignKeyEditor(ctx *sql.Context) sql.ForeignKeyEditor {
  1851  	return t.getTableEditor(ctx).(sql.ForeignKeyEditor)
  1852  }
  1853  
  1854  // GetChecks implements sql.CheckTable
  1855  func (t *Table) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) {
  1856  	data := t.sessionTableData(ctx)
  1857  	return data.checks, nil
  1858  }
  1859  
  1860  func (t *Table) sessionTableData(ctx *sql.Context) *TableData {
  1861  	if t.ignoreSessionData {
  1862  		return t.data
  1863  	}
  1864  	sess := SessionFromContext(ctx)
  1865  	return sess.tableData(t)
  1866  }
  1867  
  1868  // CreateCheck implements sql.CheckAlterableTable
  1869  func (t *Table) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
  1870  	data := t.sessionTableData(ctx)
  1871  
  1872  	toInsert := *check
  1873  	if toInsert.Name == "" {
  1874  		toInsert.Name = data.generateCheckName()
  1875  	}
  1876  
  1877  	for _, key := range data.checks {
  1878  		if key.Name == toInsert.Name {
  1879  			return fmt.Errorf("constraint %s already exists", toInsert.Name)
  1880  		}
  1881  	}
  1882  
  1883  	data.checks = append(data.checks, toInsert)
  1884  	return nil
  1885  }
  1886  
  1887  // DropCheck implements sql.CheckAlterableTable.
  1888  func (t *Table) DropCheck(ctx *sql.Context, chName string) error {
  1889  	data := t.sessionTableData(ctx)
  1890  
  1891  	lowerName := strings.ToLower(chName)
  1892  	for i, key := range data.checks {
  1893  		if strings.ToLower(key.Name) == lowerName {
  1894  			data.checks = append(data.checks[:i], data.checks[i+1:]...)
  1895  			return nil
  1896  		}
  1897  	}
  1898  	//TODO: add SQL error
  1899  	return fmt.Errorf("check '%s' was not found on the table", chName)
  1900  }
  1901  
  1902  func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexColumn, constraint sql.IndexConstraint, comment string) (sql.Index, error) {
  1903  	if name == "" {
  1904  		for _, column := range columns {
  1905  			name += column.Name + "_"
  1906  		}
  1907  	}
  1908  	if data.indexes[name] != nil {
  1909  		// TODO: extract a standard error type for this
  1910  		return nil, fmt.Errorf("Error: index already exists")
  1911  	}
  1912  
  1913  	exprs := make([]sql.Expression, len(columns))
  1914  	colNames := make([]string, len(columns))
  1915  	for i, column := range columns {
  1916  		idx, field := data.getColumnOrdinal(column.Name)
  1917  		exprs[i] = expression.NewGetFieldWithTable(idx, 0, field.Type, t.dbName(), t.name, field.Name, field.Nullable)
  1918  		colNames[i] = column.Name
  1919  	}
  1920  
  1921  	var hasNonZeroLengthColumn bool
  1922  	for _, column := range columns {
  1923  		if column.Length > 0 {
  1924  			hasNonZeroLengthColumn = true
  1925  			break
  1926  		}
  1927  	}
  1928  	var prefixLengths []uint16
  1929  	if hasNonZeroLengthColumn {
  1930  		prefixLengths = make([]uint16, len(columns))
  1931  		for i, column := range columns {
  1932  			prefixLengths[i] = uint16(column.Length)
  1933  		}
  1934  	}
  1935  
  1936  	if constraint == sql.IndexConstraint_Unique {
  1937  		err := data.errIfDuplicateEntryExist(colNames, name)
  1938  		if err != nil {
  1939  			return nil, err
  1940  		}
  1941  	}
  1942  
  1943  	return &Index{
  1944  		DB:         t.dbName(),
  1945  		DriverName: "",
  1946  		Tbl:        t,
  1947  		TableName:  t.name,
  1948  		Exprs:      exprs,
  1949  		Name:       name,
  1950  		Unique:     constraint == sql.IndexConstraint_Unique,
  1951  		Spatial:    constraint == sql.IndexConstraint_Spatial,
  1952  		Fulltext:   constraint == sql.IndexConstraint_Fulltext,
  1953  		CommentStr: comment,
  1954  		PrefixLens: prefixLengths,
  1955  	}, nil
  1956  }
  1957  
  1958  // CreateIndex implements sql.IndexAlterableTable
  1959  func (t *Table) CreateIndex(ctx *sql.Context, idx sql.IndexDef) error {
  1960  	sess := SessionFromContext(ctx)
  1961  	data := sess.tableData(t)
  1962  
  1963  	if data.indexes == nil {
  1964  		data.indexes = make(map[string]sql.Index)
  1965  	}
  1966  
  1967  	index, err := t.createIndex(data, idx.Name, idx.Columns, idx.Constraint, idx.Comment)
  1968  	if err != nil {
  1969  		return err
  1970  	}
  1971  
  1972  	// Store the computed index name in the case of an empty index name being passed in
  1973  	data.indexes[index.ID()] = index
  1974  	sess.putTable(data)
  1975  
  1976  	return nil
  1977  }
  1978  
  1979  // DropIndex implements sql.IndexAlterableTable
  1980  func (t *Table) DropIndex(ctx *sql.Context, name string) error {
  1981  	if strings.ToLower(name) == "primary" {
  1982  		return t.DropPrimaryKey(ctx)
  1983  	}
  1984  
  1985  	data := t.sessionTableData(ctx)
  1986  
  1987  	for idxName := range data.indexes {
  1988  		if strings.ToLower(idxName) == strings.ToLower(name) {
  1989  			delete(data.indexes, idxName)
  1990  			delete(data.secondaryIndexStorage, indexName(idxName))
  1991  			return nil
  1992  		}
  1993  	}
  1994  
  1995  	return sql.ErrIndexNotFound.New(name)
  1996  }
  1997  
  1998  // RenameIndex implements sql.IndexAlterableTable
  1999  func (t *Table) RenameIndex(ctx *sql.Context, fromIndexName string, toIndexName string) error {
  2000  	data := t.sessionTableData(ctx)
  2001  
  2002  	if fromIndexName == toIndexName {
  2003  		return nil
  2004  	}
  2005  	if idx, ok := data.indexes[fromIndexName]; ok {
  2006  		delete(data.indexes, fromIndexName)
  2007  		data.indexes[toIndexName] = idx
  2008  		idx.(*Index).Name = toIndexName
  2009  	}
  2010  	return nil
  2011  }
  2012  
  2013  // CreateFulltextIndex implements fulltext.IndexAlterableTable
  2014  func (t *Table) CreateFulltextIndex(ctx *sql.Context, indexDef sql.IndexDef, keyCols fulltext.KeyColumns, tableNames fulltext.IndexTableNames) error {
  2015  	sess := SessionFromContext(ctx)
  2016  	data := sess.tableData(t)
  2017  
  2018  	if len(data.fullTextConfigTableName) > 0 {
  2019  		if data.fullTextConfigTableName != tableNames.Config {
  2020  			return fmt.Errorf("Full-Text config table name has been changed from `%s` to `%s`", data.fullTextConfigTableName, tableNames.Config)
  2021  		}
  2022  	} else {
  2023  		data.fullTextConfigTableName = tableNames.Config
  2024  	}
  2025  
  2026  	if data.indexes == nil {
  2027  		data.indexes = make(map[string]sql.Index)
  2028  	}
  2029  
  2030  	index, err := t.createIndex(data, indexDef.Name, indexDef.Columns, indexDef.Constraint, indexDef.Comment)
  2031  	if err != nil {
  2032  		return err
  2033  	}
  2034  	index.(*Index).fulltextInfo = fulltextInfo{
  2035  		PositionTableName:    tableNames.Position,
  2036  		DocCountTableName:    tableNames.DocCount,
  2037  		GlobalCountTableName: tableNames.GlobalCount,
  2038  		RowCountTableName:    tableNames.RowCount,
  2039  		KeyColumns:           keyCols,
  2040  	}
  2041  
  2042  	data.indexes[index.ID()] = index // We should store the computed index name in the case of an empty index name being passed in
  2043  	sess.putTable(data)
  2044  
  2045  	return nil
  2046  }
  2047  
  2048  // ModifyStoredCollation implements sql.CollationAlterableTable
  2049  func (t *Table) ModifyStoredCollation(ctx *sql.Context, collation sql.CollationID) error {
  2050  	return fmt.Errorf("converting the collations of columns is not yet supported")
  2051  }
  2052  
  2053  // ModifyDefaultCollation implements sql.CollationAlterableTable
  2054  func (t *Table) ModifyDefaultCollation(ctx *sql.Context, collation sql.CollationID) error {
  2055  	data := t.sessionTableData(ctx)
  2056  
  2057  	data.collation = collation
  2058  	return nil
  2059  }
  2060  
  2061  // Filters implements the sql.FilteredTable interface.
  2062  func (t *Table) Filters() []sql.Expression {
  2063  	return t.filters
  2064  }
  2065  
  2066  // CreatePrimaryKey implements the PrimaryKeyAlterableTable
  2067  func (t *Table) CreatePrimaryKey(ctx *sql.Context, columns []sql.IndexColumn) error {
  2068  	data := t.sessionTableData(ctx)
  2069  
  2070  	// TODO: create alternate table implementation that doesn't implement rewriter to test this
  2071  	// First check that a primary key already exists
  2072  	for _, col := range data.schema.Schema {
  2073  		if col.PrimaryKey {
  2074  			return sql.ErrMultiplePrimaryKeysDefined.New()
  2075  		}
  2076  	}
  2077  
  2078  	potentialSchema := data.schema.Schema.Copy()
  2079  
  2080  	pkOrdinals := make([]int, len(columns))
  2081  	for i, newCol := range columns {
  2082  		found := false
  2083  		for j, currCol := range potentialSchema {
  2084  			if strings.ToLower(currCol.Name) == strings.ToLower(newCol.Name) {
  2085  				if types.IsText(currCol.Type) && newCol.Length > 0 {
  2086  					return sql.ErrUnsupportedIndexPrefix.New(currCol.Name)
  2087  				}
  2088  				currCol.PrimaryKey = true
  2089  				currCol.Nullable = false
  2090  				found = true
  2091  				pkOrdinals[i] = j
  2092  				break
  2093  			}
  2094  		}
  2095  
  2096  		if !found {
  2097  			return sql.ErrKeyColumnDoesNotExist.New(newCol.Name)
  2098  		}
  2099  	}
  2100  
  2101  	return nil
  2102  }
  2103  
  2104  type pkfield struct {
  2105  	i int
  2106  	c *sql.Column
  2107  }
  2108  
  2109  type partitionRow struct {
  2110  	partitionName string
  2111  	rowIdx        int
  2112  }
  2113  
  2114  type partitionssort struct {
  2115  	pk      []pkfield
  2116  	ps      map[string][]sql.Row
  2117  	allRows []partitionRow
  2118  	indexes map[indexName][]sql.Row
  2119  }
  2120  
  2121  func (ps partitionssort) Len() int {
  2122  	return len(ps.allRows)
  2123  }
  2124  
  2125  func (ps partitionssort) Less(i, j int) bool {
  2126  	lidx := ps.allRows[i]
  2127  	ridx := ps.allRows[j]
  2128  	lr := ps.ps[lidx.partitionName][lidx.rowIdx]
  2129  	rr := ps.ps[ridx.partitionName][ridx.rowIdx]
  2130  	return ps.pkLess(lr, rr)
  2131  }
  2132  
  2133  func (ps partitionssort) pkLess(l, r sql.Row) bool {
  2134  	for _, f := range ps.pk {
  2135  		r, err := f.c.Type.Compare(l[f.i], r[f.i])
  2136  		if err != nil {
  2137  			panic(err)
  2138  		}
  2139  		if r != 0 {
  2140  			return r < 0
  2141  		}
  2142  	}
  2143  	return false
  2144  }
  2145  
  2146  func (ps partitionssort) Swap(i, j int) {
  2147  	lidx := ps.allRows[i]
  2148  	ridx := ps.allRows[j]
  2149  	ps.ps[lidx.partitionName][lidx.rowIdx], ps.ps[ridx.partitionName][ridx.rowIdx] = ps.ps[ridx.partitionName][ridx.rowIdx], ps.ps[lidx.partitionName][lidx.rowIdx]
  2150  
  2151  	// Now update the index storage locations for the swap we just performed as well. This is frankly awful performance
  2152  	// that turns the sort operation into worse than cubic. Doing better requires doing something more intelligent than
  2153  	// sorted slices for rows and indexes, some sort of sorted collection.
  2154  	for _, indexRows := range ps.indexes {
  2155  		for _, idxRow := range indexRows {
  2156  			rowLoc := idxRow[len(idxRow)-1].(primaryRowLocation)
  2157  			if rowLoc.partition == lidx.partitionName && rowLoc.idx == lidx.rowIdx {
  2158  				idxRow[len(idxRow)-1] = primaryRowLocation{
  2159  					partition: ridx.partitionName,
  2160  					idx:       ridx.rowIdx,
  2161  				}
  2162  			} else if rowLoc.partition == ridx.partitionName && rowLoc.idx == ridx.rowIdx {
  2163  				idxRow[len(idxRow)-1] = primaryRowLocation{
  2164  					partition: lidx.partitionName,
  2165  					idx:       lidx.rowIdx,
  2166  				}
  2167  			}
  2168  		}
  2169  	}
  2170  }
  2171  
  2172  func (t Table) copy() *Table {
  2173  	t.data = t.data.copy()
  2174  
  2175  	if t.projection != nil {
  2176  		projection := make([]string, len(t.projection))
  2177  		copy(projection, t.projection)
  2178  		t.projection = projection
  2179  	}
  2180  
  2181  	if t.columns != nil {
  2182  		columns := make([]int, len(t.columns))
  2183  		copy(columns, t.columns)
  2184  		t.columns = columns
  2185  	}
  2186  
  2187  	return &t
  2188  }
  2189  
  2190  // replaceData replaces the data in this table with the one in the source
  2191  func (t *Table) replaceData(src *TableData) {
  2192  	t.data = src.copy()
  2193  }
  2194  
  2195  // normalizeSchemaForRewrite returns a copy of the schema provided suitable for rewriting. This is necessary because
  2196  // the engine doesn't currently enforce that primary key columns are not nullable, rather taking the definition
  2197  // directly from the user.
  2198  func normalizeSchemaForRewrite(newSch sql.PrimaryKeySchema) sql.PrimaryKeySchema {
  2199  	schema := newSch.Schema.Copy()
  2200  	for _, col := range schema {
  2201  		if col.PrimaryKey {
  2202  			col.Nullable = false
  2203  		}
  2204  	}
  2205  
  2206  	return sql.NewPrimaryKeySchema(schema, newSch.PkOrdinals...)
  2207  }
  2208  
  2209  // DropPrimaryKey implements the PrimaryKeyAlterableTable
  2210  // TODO: get rid of this / make it error?
  2211  func (t *Table) DropPrimaryKey(ctx *sql.Context) error {
  2212  	data := t.sessionTableData(ctx)
  2213  
  2214  	err := sql.ValidatePrimaryKeyDrop(ctx, t, t.PrimaryKeySchema())
  2215  	if err != nil {
  2216  		return err
  2217  	}
  2218  
  2219  	pks := make([]*sql.Column, 0)
  2220  	for _, col := range data.schema.Schema {
  2221  		if col.PrimaryKey {
  2222  			pks = append(pks, col)
  2223  		}
  2224  	}
  2225  
  2226  	if len(pks) == 0 {
  2227  		return sql.ErrCantDropFieldOrKey.New("PRIMARY")
  2228  	}
  2229  
  2230  	// Check for foreign key relationships
  2231  	for _, pk := range pks {
  2232  		if fkName, ok := columnInFkRelationship(pk.Name, data.fkColl.Keys()); ok {
  2233  			return sql.ErrCantDropIndex.New("PRIMARY", fkName)
  2234  		}
  2235  	}
  2236  
  2237  	for _, c := range pks {
  2238  		c.PrimaryKey = false
  2239  	}
  2240  
  2241  	delete(data.indexes, "PRIMARY")
  2242  	data.schema.PkOrdinals = []int{}
  2243  
  2244  	return nil
  2245  }
  2246  
  2247  func columnInFkRelationship(col string, fkc []sql.ForeignKeyConstraint) (string, bool) {
  2248  	colsInFks := make(map[string]string)
  2249  	for _, fk := range fkc {
  2250  		allCols := append(fk.Columns, fk.ParentColumns...)
  2251  		for _, ac := range allCols {
  2252  			colsInFks[ac] = fk.Name
  2253  		}
  2254  	}
  2255  
  2256  	fkName, ok := colsInFks[col]
  2257  	return fkName, ok
  2258  }
  2259  
  2260  var errColumnNotFound = errors.NewKind("could not find column %s")
  2261  
  2262  func (t *Table) ShouldRewriteTable(ctx *sql.Context, oldSchema, newSchema sql.PrimaryKeySchema, oldColumn, newColumn *sql.Column) bool {
  2263  	return orderChanged(oldSchema, newSchema, oldColumn, newColumn) ||
  2264  		isColumnDrop(oldSchema, newSchema) ||
  2265  		isPrimaryKeyChange(oldSchema, newSchema)
  2266  }
  2267  
  2268  func orderChanged(oldSchema, newSchema sql.PrimaryKeySchema, oldColumn, newColumn *sql.Column) bool {
  2269  	if oldColumn == nil || newColumn == nil {
  2270  		return false
  2271  	}
  2272  
  2273  	return oldSchema.Schema.IndexOfColName(oldColumn.Name) != newSchema.Schema.IndexOfColName(newColumn.Name)
  2274  }
  2275  
  2276  func isPrimaryKeyChange(oldSchema sql.PrimaryKeySchema,
  2277  	newSchema sql.PrimaryKeySchema) bool {
  2278  	return len(newSchema.PkOrdinals) != len(oldSchema.PkOrdinals)
  2279  }
  2280  
  2281  func isColumnDrop(oldSchema sql.PrimaryKeySchema, newSchema sql.PrimaryKeySchema) bool {
  2282  	return len(oldSchema.Schema) > len(newSchema.Schema)
  2283  }
  2284  
  2285  func (t *Table) RewriteInserter(ctx *sql.Context, oldSchema, newSchema sql.PrimaryKeySchema, _, _ *sql.Column, idxCols []sql.IndexColumn) (sql.RowInserter, error) {
  2286  	// TODO: this is insufficient: we need prevent dropping any index that is used by a primary key (or the engine does)
  2287  	if isPrimaryKeyDrop(oldSchema, newSchema) {
  2288  		err := sql.ValidatePrimaryKeyDrop(ctx, t, oldSchema)
  2289  		if err != nil {
  2290  			return nil, err
  2291  		}
  2292  	}
  2293  
  2294  	if isPrimaryKeyChange(oldSchema, newSchema) {
  2295  		err := validatePrimaryKeyChange(ctx, oldSchema, newSchema, idxCols)
  2296  		if err != nil {
  2297  			return nil, err
  2298  		}
  2299  	}
  2300  
  2301  	return t.getRewriteTableEditor(ctx, oldSchema, newSchema), nil
  2302  }
  2303  
  2304  func validatePrimaryKeyChange(ctx *sql.Context, oldSchema sql.PrimaryKeySchema, newSchema sql.PrimaryKeySchema, idxCols []sql.IndexColumn) error {
  2305  	for _, idxCol := range idxCols {
  2306  		idx := newSchema.Schema.IndexOfColName(idxCol.Name)
  2307  		if idx < 0 {
  2308  			return sql.ErrColumnNotFound.New(idxCol.Name)
  2309  		}
  2310  		col := newSchema.Schema[idx]
  2311  		if col.PrimaryKey && idxCol.Length > 0 && types.IsText(col.Type) {
  2312  			return sql.ErrUnsupportedIndexPrefix.New(col.Name)
  2313  		}
  2314  	}
  2315  
  2316  	return nil
  2317  }
  2318  
  2319  func isPrimaryKeyDrop(oldSchema sql.PrimaryKeySchema, newSchema sql.PrimaryKeySchema) bool {
  2320  	return len(oldSchema.PkOrdinals) > 0 && len(newSchema.PkOrdinals) == 0
  2321  }
  2322  
  2323  // modifyFulltextIndexesForRewrite will modify the fulltext indexes of a table to correspond to a new schema before a rewrite.
  2324  func (t *Table) modifyFulltextIndexesForRewrite(
  2325  	ctx *sql.Context,
  2326  	data *TableData,
  2327  	oldSchema sql.PrimaryKeySchema,
  2328  ) error {
  2329  	keyCols, _, err := fulltext.GetKeyColumns(ctx, data.Table(t.db))
  2330  	if err != nil {
  2331  		return err
  2332  	}
  2333  
  2334  	newIndexes := make(map[string]sql.Index)
  2335  	for name, idx := range data.indexes {
  2336  		if !idx.IsFullText() {
  2337  			newIndexes[name] = idx
  2338  			continue
  2339  		}
  2340  
  2341  		if t.db == nil { // Rewrite your test if you run into this
  2342  			return fmt.Errorf("database is nil, which can only happen when adding a table outside of the SQL path, such as during harness creation")
  2343  		}
  2344  
  2345  		memIdx, ok := idx.(*Index)
  2346  		if !ok { // This should never happen
  2347  			return fmt.Errorf("index returns true for FULLTEXT, but does not implement interface")
  2348  		}
  2349  
  2350  		newExprs := removeDroppedColumns(data.schema, memIdx)
  2351  		if len(newExprs) == 0 {
  2352  			// omit this index, no columns in it left in new schema
  2353  			continue
  2354  		}
  2355  
  2356  		newIdx := memIdx.copy()
  2357  		newIdx.fulltextInfo.KeyColumns = keyCols
  2358  		newIdx.Exprs = newExprs
  2359  
  2360  		newIndexes[name] = newIdx
  2361  	}
  2362  
  2363  	data.indexes = newIndexes
  2364  
  2365  	return nil
  2366  }
  2367  
  2368  func removeDroppedColumns(schema sql.PrimaryKeySchema, idx *Index) []sql.Expression {
  2369  	var newExprs []sql.Expression
  2370  	for _, expr := range idx.Exprs {
  2371  		if gf, ok := expr.(*expression.GetField); ok {
  2372  			idx := schema.Schema.IndexOfColName(gf.Name())
  2373  			if idx < 0 {
  2374  				continue
  2375  			}
  2376  		}
  2377  		newExprs = append(newExprs, expr)
  2378  	}
  2379  	return newExprs
  2380  }
  2381  
  2382  func hasNullForAnyCols(row sql.Row, cols []int) bool {
  2383  	for _, idx := range cols {
  2384  		if row[idx] == nil {
  2385  			return true
  2386  		}
  2387  	}
  2388  	return false
  2389  }
  2390  
  2391  func (t *Table) ShouldBuildIndex(ctx *sql.Context, indexDef sql.IndexDef) (bool, error) {
  2392  	// We always want help building new indexes
  2393  	return true, nil
  2394  }
  2395  
  2396  func (t *Table) BuildIndex(ctx *sql.Context, indexDef sql.IndexDef) (sql.RowInserter, error) {
  2397  	data := t.sessionTableData(ctx)
  2398  	_, ok := data.indexes[indexDef.Name]
  2399  	if !ok {
  2400  		return nil, sql.ErrIndexNotFound.New(indexDef.Name)
  2401  	}
  2402  
  2403  	// For now we're just rewriting the entire table, but we could also just rewrite the index with a little work
  2404  	return t.getRewriteTableEditor(ctx, data.schema, data.schema), nil
  2405  }
  2406  
  2407  // TableRevision is a container for memory tables to run basic smoke tests for versioned queries. It overrides only
  2408  // enough of the Table interface required to pass those tests. Memory tables have a flag to force them to ignore
  2409  // session data and use embedded data, which is required for the versioned table tests to pass.
  2410  type TableRevision struct {
  2411  	*Table
  2412  }
  2413  
  2414  var _ MemTable = (*TableRevision)(nil)
  2415  
  2416  func (t *TableRevision) Inserter(ctx *sql.Context) sql.RowInserter {
  2417  	ea := newTableEditAccumulator(t.Table.data)
  2418  
  2419  	uniqIdxCols, prefixLengths := t.data.indexColsForTableEditor()
  2420  	return &tableEditor{
  2421  		editedTable:   t.Table,
  2422  		initialTable:  t.copy(),
  2423  		ea:            ea,
  2424  		uniqueIdxCols: uniqIdxCols,
  2425  		prefixLengths: prefixLengths,
  2426  	}
  2427  }
  2428  
  2429  func (t *TableRevision) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error {
  2430  	newColIdx, data, err := addColumnToSchema(ctx, t.data, column, order)
  2431  	if err != nil {
  2432  		return err
  2433  	}
  2434  
  2435  	err = insertValueInRows(ctx, data, newColIdx, column.Default)
  2436  	if err != nil {
  2437  		return err
  2438  	}
  2439  
  2440  	t.data = data
  2441  	return nil
  2442  }
  2443  
  2444  func (t *TableRevision) IgnoreSessionData() bool {
  2445  	return true
  2446  }