github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/temp_table.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math/rand"
    21  	"strconv"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor/creation"
    34  	"github.com/dolthub/dolt/go/store/types"
    35  )
    36  
    37  type TempTable struct {
    38  	tableName string
    39  	dbName    string
    40  	pkSch     sql.PrimaryKeySchema
    41  
    42  	table *doltdb.Table
    43  	sch   schema.Schema
    44  
    45  	lookup sql.IndexLookup
    46  
    47  	ed   writer.TableWriter
    48  	opts editor.Options
    49  }
    50  
    51  var _ sql.TemporaryTable = &TempTable{}
    52  var _ sql.Table = &TempTable{}
    53  var _ sql.PrimaryKeyTable = &TempTable{}
    54  var _ sql.IndexedTable = &TempTable{}
    55  var _ sql.IndexAlterableTable = &TempTable{}
    56  var _ sql.ForeignKeyTable = &TempTable{}
    57  var _ sql.CheckTable = &TempTable{}
    58  var _ sql.CheckAlterableTable = &TempTable{}
    59  var _ sql.StatisticsTable = &TempTable{}
    60  
    61  func NewTempTable(
    62  	ctx *sql.Context,
    63  	ddb *doltdb.DoltDB,
    64  	pkSch sql.PrimaryKeySchema,
    65  	name, db string,
    66  	opts editor.Options,
    67  	collation sql.CollationID,
    68  ) (*TempTable, error) {
    69  	sess := dsess.DSessFromSess(ctx.Session)
    70  
    71  	dbState, ok, err := sess.LookupDbState(ctx, db)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	if !ok {
    77  		return nil, fmt.Errorf("database %s not found in session", db)
    78  	}
    79  
    80  	ws := dbState.WorkingSet()
    81  	if ws == nil {
    82  		return nil, doltdb.ErrOperationNotSupportedInDetachedHead
    83  	}
    84  
    85  	sch, err := temporaryDoltSchema(ctx, pkSch, collation)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	vrw := ddb.ValueReadWriter()
    90  	ns := ddb.NodeStore()
    91  
    92  	idx, err := durable.NewEmptyIndex(ctx, vrw, ns, sch)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	set, err := durable.NewIndexSet(ctx, vrw, ns)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	tbl, err := doltdb.NewTable(ctx, vrw, ns, sch, idx, set, nil)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	newRoot, err := ws.WorkingRoot().PutTable(ctx, doltdb.TableName{Name: name}, tbl)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	newWs := ws.WithWorkingRoot(newRoot)
   112  
   113  	ait, err := dsess.NewAutoIncrementTracker(ctx, db, newWs)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	writeSession := writer.NewWriteSession(tbl.Format(), newWs, ait, opts)
   119  
   120  	tempTable := &TempTable{
   121  		tableName: name,
   122  		dbName:    db,
   123  		pkSch:     pkSch,
   124  		table:     tbl,
   125  		sch:       sch,
   126  		opts:      opts,
   127  	}
   128  
   129  	tempTable.ed, err = writeSession.GetTableWriter(ctx, doltdb.TableName{Name: name}, db, setTempTableRoot(tempTable))
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	return tempTable, nil
   135  }
   136  
   137  func setTempTableRoot(t *TempTable) func(ctx *sql.Context, dbName string, newRoot doltdb.RootValue) error {
   138  	return func(ctx *sql.Context, dbName string, newRoot doltdb.RootValue) error {
   139  		newTable, _, err := newRoot.GetTable(ctx, doltdb.TableName{Name: t.tableName})
   140  		if err != nil {
   141  			return err
   142  		}
   143  
   144  		t.table = newTable
   145  
   146  		sess := dsess.DSessFromSess(ctx.Session)
   147  
   148  		dbState, ok, err := sess.LookupDbState(ctx, t.dbName)
   149  		if err != nil {
   150  			return err
   151  		}
   152  
   153  		if !ok {
   154  			return fmt.Errorf("database %s not found in session", t.dbName)
   155  		}
   156  
   157  		ws := dbState.WorkingSet()
   158  		if ws == nil {
   159  			return doltdb.ErrOperationNotSupportedInDetachedHead
   160  		}
   161  		newWs := ws.WithWorkingRoot(newRoot)
   162  
   163  		ait, err := dsess.NewAutoIncrementTracker(ctx, t.dbName, newWs)
   164  		if err != nil {
   165  			return err
   166  		}
   167  
   168  		writeSession := writer.NewWriteSession(newTable.Format(), newWs, ait, t.opts)
   169  		t.ed, err = writeSession.GetTableWriter(ctx, doltdb.TableName{Name: t.tableName}, t.dbName, setTempTableRoot(t))
   170  		if err != nil {
   171  			return err
   172  		}
   173  
   174  		return nil
   175  	}
   176  }
   177  
   178  func (t *TempTable) RowCount(ctx *sql.Context) (uint64, bool, error) {
   179  	rows, err := t.table.GetRowData(ctx)
   180  	if err != nil {
   181  		return 0, false, err
   182  	}
   183  	cnt, err := rows.Count()
   184  	return cnt, true, err
   185  }
   186  
   187  func (t *TempTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
   188  	return index.DoltIndexesFromTable(ctx, t.dbName, t.tableName, t.table)
   189  }
   190  
   191  func (t *TempTable) PreciseMatch() bool {
   192  	return true
   193  }
   194  
   195  func (t *TempTable) Name() string {
   196  	return t.tableName
   197  }
   198  
   199  func (t *TempTable) String() string {
   200  	return t.tableName
   201  }
   202  
   203  func (t *TempTable) Format() *types.NomsBinFormat {
   204  	return t.table.Format()
   205  }
   206  
   207  func (t *TempTable) Schema() sql.Schema {
   208  	return t.pkSch.Schema
   209  }
   210  
   211  func (t *TempTable) Collation() sql.CollationID {
   212  	return sql.CollationID(t.sch.GetCollation())
   213  }
   214  
   215  func (t *TempTable) sqlSchema() sql.PrimaryKeySchema {
   216  	return t.pkSch
   217  }
   218  
   219  func (t *TempTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
   220  	rows, err := t.table.GetRowData(ctx)
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  	parts, err := partitionsFromRows(ctx, rows)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	return newDoltTablePartitionIter(rows, parts...), nil
   229  }
   230  
   231  func (t *TempTable) IsTemporary() bool {
   232  	return true
   233  }
   234  
   235  // DataLength implements the sql.StatisticsTable interface.
   236  func (t *TempTable) DataLength(ctx *sql.Context) (uint64, error) {
   237  	idx, err := t.table.GetRowData(ctx)
   238  	if err != nil {
   239  		return 0, err
   240  	}
   241  	return idx.Count()
   242  }
   243  
   244  func (t *TempTable) DoltTable(ctx *sql.Context) (*doltdb.Table, error) {
   245  	return t.table, nil
   246  }
   247  
   248  func (t *TempTable) DataCacheKey(ctx *sql.Context) (doltdb.DataCacheKey, bool, error) {
   249  	return doltdb.DataCacheKey{}, false, nil
   250  }
   251  
   252  func (t *TempTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
   253  	t.lookup = lookup
   254  	return t.Partitions(ctx)
   255  }
   256  
   257  func (t *TempTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
   258  	if !t.lookup.IsEmpty() {
   259  		return index.RowIterForIndexLookup(ctx, t, t.lookup, t.pkSch, nil)
   260  	} else {
   261  		return partitionRows(ctx, t.table, nil, partition)
   262  	}
   263  }
   264  
   265  func (t *TempTable) IndexedAccess(lookup sql.IndexLookup) sql.IndexedTable {
   266  	return t
   267  }
   268  
   269  func (t *TempTable) CreateIndex(ctx *sql.Context, idx sql.IndexDef) error {
   270  	if idx.Constraint != sql.IndexConstraint_None && idx.Constraint != sql.IndexConstraint_Unique && idx.Constraint != sql.IndexConstraint_Spatial {
   271  		return fmt.Errorf("only the following types of index constraints are supported: none, unique, spatial")
   272  	}
   273  	cols := make([]string, len(idx.Columns))
   274  	for i, c := range idx.Columns {
   275  		cols[i] = c.Name
   276  	}
   277  
   278  	ret, err := creation.CreateIndex(ctx, t.table, t.Name(), idx.Name, cols, allocatePrefixLengths(idx.Columns), schema.IndexProperties{
   279  		IsUnique:      idx.Constraint == sql.IndexConstraint_Unique,
   280  		IsSpatial:     idx.Constraint == sql.IndexConstraint_Spatial,
   281  		IsFullText:    idx.Constraint == sql.IndexConstraint_Fulltext,
   282  		IsUserDefined: true,
   283  		Comment:       idx.Comment,
   284  	}, t.opts)
   285  	if err != nil {
   286  		return err
   287  	}
   288  
   289  	t.table = ret.NewTable
   290  	return nil
   291  }
   292  
   293  func (t *TempTable) DropIndex(ctx *sql.Context, indexName string) error {
   294  	_, err := t.sch.Indexes().RemoveIndex(indexName)
   295  	if err != nil {
   296  		return err
   297  	}
   298  
   299  	newTable, err := t.table.UpdateSchema(ctx, t.sch)
   300  	if err != nil {
   301  		return err
   302  	}
   303  	newTable, err = newTable.DeleteIndexRowData(ctx, indexName)
   304  	if err != nil {
   305  		return err
   306  	}
   307  	t.table = newTable
   308  
   309  	return nil
   310  }
   311  
   312  func (t *TempTable) RenameIndex(ctx *sql.Context, fromIndexName string, toIndexName string) error {
   313  	_, err := t.sch.Indexes().RenameIndex(fromIndexName, toIndexName)
   314  	if err != nil {
   315  		return err
   316  	}
   317  
   318  	newTable, err := t.table.UpdateSchema(ctx, t.sch)
   319  	if err != nil {
   320  		return err
   321  	}
   322  	newTable, err = newTable.RenameIndexRowData(ctx, fromIndexName, toIndexName)
   323  	if err != nil {
   324  		return err
   325  	}
   326  	t.table = newTable
   327  
   328  	return nil
   329  }
   330  
   331  func (t *TempTable) GetDeclaredForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
   332  	return nil, nil
   333  }
   334  
   335  func (t *TempTable) GetReferencedForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
   336  	return nil, nil
   337  }
   338  
   339  func (t *TempTable) CreateIndexForForeignKey(ctx *sql.Context, idx sql.IndexDef) error {
   340  	return sql.ErrTemporaryTablesForeignKeySupport.New()
   341  }
   342  
   343  func (t *TempTable) AddForeignKey(ctx *sql.Context, fk sql.ForeignKeyConstraint) error {
   344  	return sql.ErrTemporaryTablesForeignKeySupport.New()
   345  }
   346  
   347  func (t *TempTable) UpdateForeignKey(ctx *sql.Context, fkName string, fk sql.ForeignKeyConstraint) error {
   348  	return sql.ErrTemporaryTablesForeignKeySupport.New()
   349  }
   350  
   351  func (t *TempTable) DropForeignKey(ctx *sql.Context, fkName string) error {
   352  	return sql.ErrTemporaryTablesForeignKeySupport.New()
   353  }
   354  
   355  func (t *TempTable) GetForeignKeyEditor(ctx *sql.Context) sql.ForeignKeyEditor {
   356  	return nil
   357  }
   358  
   359  func (t *TempTable) Inserter(*sql.Context) sql.RowInserter {
   360  	return t
   361  }
   362  
   363  func (t *TempTable) Deleter(*sql.Context) sql.RowDeleter {
   364  	return t
   365  }
   366  
   367  func (t *TempTable) Replacer(*sql.Context) sql.RowReplacer {
   368  	return t
   369  }
   370  
   371  func (t *TempTable) Updater(*sql.Context) sql.RowUpdater {
   372  	return t
   373  }
   374  
   375  func (t *TempTable) GetChecks(*sql.Context) ([]sql.CheckDefinition, error) {
   376  	return checksInSchema(t.sch), nil
   377  }
   378  
   379  func (t *TempTable) PrimaryKeySchema() sql.PrimaryKeySchema {
   380  	return t.pkSch
   381  }
   382  
   383  func (t *TempTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
   384  	sch, err := t.table.GetSchema(ctx)
   385  	if err != nil {
   386  		return err
   387  	}
   388  
   389  	check = &(*check)
   390  	if check.Name == "" {
   391  		check.Name = strconv.Itoa(rand.Int())
   392  	}
   393  
   394  	_, err = sch.Checks().AddCheck(check.Name, check.CheckExpression, check.Enforced)
   395  	if err != nil {
   396  		return err
   397  	}
   398  	t.table, err = t.table.UpdateSchema(ctx, sch)
   399  
   400  	return err
   401  }
   402  
   403  func (t *TempTable) DropCheck(ctx *sql.Context, chName string) error {
   404  	err := t.sch.Checks().DropCheck(chName)
   405  	if err != nil {
   406  		return err
   407  	}
   408  	t.table, err = t.table.UpdateSchema(ctx, t.sch)
   409  
   410  	return err
   411  }
   412  
   413  func (t *TempTable) Insert(ctx *sql.Context, sqlRow sql.Row) error {
   414  	return t.ed.Insert(ctx, sqlRow)
   415  }
   416  
   417  func (t *TempTable) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
   418  	return t.ed.Update(ctx, oldRow, newRow)
   419  }
   420  
   421  func (t *TempTable) Delete(ctx *sql.Context, sqlRow sql.Row) error {
   422  	return t.ed.Delete(ctx, sqlRow)
   423  }
   424  
   425  func (t *TempTable) StatementBegin(ctx *sql.Context) {
   426  	return
   427  }
   428  
   429  func (t *TempTable) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
   430  	t.lookup = sql.IndexLookup{}
   431  	return nil
   432  }
   433  
   434  func (t *TempTable) StatementComplete(ctx *sql.Context) error {
   435  	t.lookup = sql.IndexLookup{}
   436  	return nil
   437  }
   438  
   439  func (t *TempTable) Close(ctx *sql.Context) error {
   440  	err := t.ed.Close(ctx)
   441  
   442  	t.lookup = sql.IndexLookup{}
   443  	return err
   444  }
   445  
   446  func temporaryDoltSchema(ctx context.Context, pkSch sql.PrimaryKeySchema, collation sql.CollationID) (sch schema.Schema, err error) {
   447  	cols := make([]schema.Column, len(pkSch.Schema))
   448  	for i, col := range pkSch.Schema {
   449  		tag := uint64(i)
   450  		cols[i], err = sqlutil.ToDoltCol(tag, col)
   451  		if err != nil {
   452  			return nil, err
   453  		}
   454  	}
   455  
   456  	sch, err = schema.SchemaFromCols(schema.NewColCollection(cols...))
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  
   461  	err = sch.SetPkOrdinals(pkSch.PkOrdinals)
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  	sch.SetCollation(schema.Collation(collation))
   466  
   467  	return sch, nil
   468  }