github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/database.go (about)

     1  // Copyright 2019-2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/parse"
    26  	"github.com/dolthub/go-mysql-server/sql/plan"
    27  	"github.com/dolthub/vitess/go/vt/proto/query"
    28  	"gopkg.in/src-d/go-errors.v1"
    29  
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions/commitwalk"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/ref"
    34  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    35  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    36  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/alterschema"
    37  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables"
    38  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt"
    39  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    40  	"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
    41  	"github.com/dolthub/dolt/go/store/types"
    42  )
    43  
    44  var ErrInvalidTableName = errors.NewKind("Invalid table name %s. Table names must match the regular expression " + doltdb.TableNameRegexStr)
    45  var ErrReservedTableName = errors.NewKind("Invalid table name %s. Table names beginning with `dolt_` are reserved for internal use")
    46  var ErrSystemTableAlter = errors.NewKind("Cannot alter table %s: system tables cannot be dropped or altered")
    47  
    48  type SqlDatabase interface {
    49  	sql.Database
    50  	GetRoot(*sql.Context) (*doltdb.RootValue, error)
    51  	GetTemporaryTablesRoot(*sql.Context) (*doltdb.RootValue, bool)
    52  }
    53  
    54  // Database implements sql.Database for a dolt DB.
    55  type Database struct {
    56  	name string
    57  	ddb  *doltdb.DoltDB
    58  	rsr  env.RepoStateReader
    59  	rsw  env.RepoStateWriter
    60  	drw  env.DocsReadWriter
    61  }
    62  
    63  var _ sql.Database = (*Database)(nil)
    64  var _ sql.TableCreator = (*Database)(nil)
    65  var _ sql.TemporaryTableCreator = (*Database)(nil)
    66  var _ sql.TemporaryTableDatabase = (*Database)(nil)
    67  
    68  // DisabledTransaction is a no-op transaction type that lets us feature-gate transaction logic changes
    69  type DisabledTransaction struct{}
    70  
    71  func (d DisabledTransaction) String() string {
    72  	return "Disabled transaction"
    73  }
    74  
    75  func (db Database) StartTransaction(ctx *sql.Context) (sql.Transaction, error) {
    76  	if !TransactionsEnabled(ctx) {
    77  		return DisabledTransaction{}, nil
    78  	}
    79  
    80  	dsession := DSessFromSess(ctx.Session)
    81  
    82  	// When we begin the transaction, we must synchronize the state of this session with the global state for the
    83  	// current head ref. Any pending transaction has already been committed before this happens.
    84  	wsRef := dsession.workingSets[ctx.GetCurrentDatabase()]
    85  
    86  	ws, err := db.ddb.ResolveWorkingSet(ctx, wsRef)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	err = db.SetRoot(ctx, ws.RootValue())
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	root, err := db.GetRoot(ctx)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	err = db.setHeadHash(ctx, wsRef)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	return NewDoltTransaction(root, wsRef, db.DbData()), nil
   107  }
   108  
   109  func (db Database) setHeadHash(ctx *sql.Context, ref ref.WorkingSetRef) error {
   110  	// TODO: use the session HEAD ref here instead of the repo state one
   111  	// headRef, err := ref.ToHeadRef()
   112  	// if err != nil {
   113  	// 	return err
   114  	// }
   115  
   116  	headCommit, err := db.ddb.Resolve(ctx, db.rsr.CWBHeadSpec(), db.rsr.CWBHeadRef())
   117  	if err != nil {
   118  		return err
   119  	}
   120  	headHash, err := headCommit.HashOf()
   121  	if err != nil {
   122  		return err
   123  	}
   124  	if doltSession, ok := ctx.Session.(*DoltSession); ok {
   125  		return doltSession.SetSessionVarDirectly(ctx, HeadKey(db.name), headHash.String())
   126  	} else {
   127  		return ctx.SetSessionVariable(ctx, HeadKey(db.name), headHash.String())
   128  	}
   129  }
   130  
   131  func (db Database) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error {
   132  	dsession := DSessFromSess(ctx.Session)
   133  	return dsession.CommitTransaction(ctx, db.name, tx)
   134  }
   135  
   136  func (db Database) Rollback(ctx *sql.Context, tx sql.Transaction) error {
   137  	dsession := DSessFromSess(ctx.Session)
   138  	return dsession.RollbackTransaction(ctx, db.name, tx)
   139  }
   140  
   141  func (db Database) CreateSavepoint(ctx *sql.Context, tx sql.Transaction, name string) error {
   142  	dsession := DSessFromSess(ctx.Session)
   143  	return dsession.CreateSavepoint(ctx, name, db.name, tx)
   144  }
   145  
   146  func (db Database) RollbackToSavepoint(ctx *sql.Context, tx sql.Transaction, name string) error {
   147  	dsession := DSessFromSess(ctx.Session)
   148  	return dsession.RollbackToSavepoint(ctx, name, db.name, tx)
   149  }
   150  
   151  func (db Database) ReleaseSavepoint(ctx *sql.Context, tx sql.Transaction, name string) error {
   152  	dsession := DSessFromSess(ctx.Session)
   153  	return dsession.ReleaseSavepoint(ctx, name, db.name, tx)
   154  }
   155  
   156  var _ SqlDatabase = Database{}
   157  var _ sql.VersionedDatabase = Database{}
   158  var _ sql.TableDropper = Database{}
   159  var _ sql.TableCreator = Database{}
   160  var _ sql.TemporaryTableCreator = Database{}
   161  var _ sql.TableRenamer = Database{}
   162  var _ sql.TriggerDatabase = Database{}
   163  var _ sql.StoredProcedureDatabase = Database{}
   164  var _ sql.TransactionDatabase = Database{}
   165  
   166  // NewDatabase returns a new dolt database to use in queries.
   167  func NewDatabase(name string, dbData env.DbData) Database {
   168  	return Database{
   169  		name: name,
   170  		ddb:  dbData.Ddb,
   171  		rsr:  dbData.Rsr,
   172  		rsw:  dbData.Rsw,
   173  		drw:  dbData.Drw,
   174  	}
   175  }
   176  
   177  // Name returns the name of this database, set at creation time.
   178  func (db Database) Name() string {
   179  	return db.name
   180  }
   181  
   182  // GetDoltDB gets the underlying DoltDB of the Database
   183  func (db Database) GetDoltDB() *doltdb.DoltDB {
   184  	return db.ddb
   185  }
   186  
   187  // GetStateReader gets the RepoStateReader for a Database
   188  func (db Database) GetStateReader() env.RepoStateReader {
   189  	return db.rsr
   190  }
   191  
   192  // GetStateWriter gets the RepoStateWriter for a Database
   193  func (db Database) GetStateWriter() env.RepoStateWriter {
   194  	return db.rsw
   195  }
   196  
   197  func (db Database) GetDocsReadWriter() env.DocsReadWriter {
   198  	return db.drw
   199  }
   200  
   201  func (db Database) DbData() env.DbData {
   202  	return env.DbData{
   203  		Ddb: db.ddb,
   204  		Rsw: db.rsw,
   205  		Rsr: db.rsr,
   206  		Drw: db.drw,
   207  	}
   208  }
   209  
   210  // GetTableInsensitive is used when resolving tables in queries. It returns a best-effort case-insensitive match for
   211  // the table name given.
   212  func (db Database) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) {
   213  	// We start by first checking whether the input table is a temporary table. Temporary tables with name `x` take
   214  	// priority over persisted tables of name `x`.
   215  	tempTableRootValue, tempRootExists := db.GetTemporaryTablesRoot(ctx)
   216  	if tempRootExists {
   217  		tbl, tempTableFound, err := db.getTable(ctx, tempTableRootValue, tblName, true)
   218  		if err != nil {
   219  			return nil, false, err
   220  		}
   221  
   222  		if tempTableFound {
   223  			return tbl, true, nil
   224  		}
   225  	}
   226  
   227  	root, err := db.GetRoot(ctx)
   228  
   229  	if err != nil {
   230  		return nil, false, err
   231  	}
   232  
   233  	return db.GetTableInsensitiveWithRoot(ctx, root, tblName)
   234  }
   235  
   236  func (db Database) GetTableInsensitiveWithRoot(ctx *sql.Context, root *doltdb.RootValue, tblName string) (dt sql.Table, found bool, err error) {
   237  	lwrName := strings.ToLower(tblName)
   238  
   239  	head, _, err := DSessFromSess(ctx.Session).GetHeadCommit(ctx, db.name)
   240  	if err != nil {
   241  		return nil, false, err
   242  	}
   243  
   244  	// NOTE: system tables are not suitable for caching
   245  	switch {
   246  	case strings.HasPrefix(lwrName, doltdb.DoltDiffTablePrefix):
   247  		suffix := tblName[len(doltdb.DoltDiffTablePrefix):]
   248  		found = true
   249  		dt, err = dtables.NewDiffTable(ctx, suffix, db.ddb, root, head)
   250  	case strings.HasPrefix(lwrName, doltdb.DoltCommitDiffTablePrefix):
   251  		suffix := tblName[len(doltdb.DoltCommitDiffTablePrefix):]
   252  		found = true
   253  		dt, err = dtables.NewCommitDiffTable(ctx, suffix, db.ddb, root)
   254  	case strings.HasPrefix(lwrName, doltdb.DoltHistoryTablePrefix):
   255  		suffix := tblName[len(doltdb.DoltHistoryTablePrefix):]
   256  		found = true
   257  		dt, err = dtables.NewHistoryTable(ctx, suffix, db.ddb, root, head)
   258  	case strings.HasPrefix(lwrName, doltdb.DoltConfTablePrefix):
   259  		suffix := tblName[len(doltdb.DoltConfTablePrefix):]
   260  		found = true
   261  		dt, err = dtables.NewConflictsTable(ctx, suffix, root, dtables.RootSetter(db))
   262  	}
   263  	if err != nil {
   264  		return nil, false, err
   265  	}
   266  	if found {
   267  		return dt, found, nil
   268  	}
   269  
   270  	// NOTE: system tables are not suitable for caching
   271  	switch lwrName {
   272  	case doltdb.LogTableName:
   273  		dt, found = dtables.NewLogTable(ctx, db.ddb, head), true
   274  	case doltdb.TableOfTablesInConflictName:
   275  		dt, found = dtables.NewTableOfTablesInConflict(ctx, db.ddb, root), true
   276  	case doltdb.BranchesTableName:
   277  		dt, found = dtables.NewBranchesTable(ctx, db.ddb), true
   278  	case doltdb.CommitsTableName:
   279  		dt, found = dtables.NewCommitsTable(ctx, db.ddb), true
   280  	case doltdb.CommitAncestorsTableName:
   281  		dt, found = dtables.NewCommitAncestorsTable(ctx, db.ddb), true
   282  	case doltdb.StatusTableName:
   283  		dt, found = dtables.NewStatusTable(ctx, db.ddb, db.rsr, db.drw), true
   284  	}
   285  	if found {
   286  		return dt, found, nil
   287  	}
   288  
   289  	return db.getTable(ctx, root, tblName, false)
   290  }
   291  
   292  // GetTableInsensitiveAsOf implements sql.VersionedDatabase
   293  func (db Database) GetTableInsensitiveAsOf(ctx *sql.Context, tableName string, asOf interface{}) (sql.Table, bool, error) {
   294  	root, err := db.rootAsOf(ctx, asOf)
   295  
   296  	if err != nil {
   297  		return nil, false, err
   298  	} else if root == nil {
   299  		return nil, false, nil
   300  	}
   301  
   302  	table, ok, err := db.getTable(ctx, root, tableName, false)
   303  	if err != nil {
   304  		return nil, false, err
   305  	}
   306  	if !ok {
   307  		return nil, false, nil
   308  	}
   309  
   310  	switch table := table.(type) {
   311  	case *DoltTable:
   312  		return table.LockedToRoot(root), true, nil
   313  	case *AlterableDoltTable:
   314  		return table.LockedToRoot(root), true, nil
   315  	case *WritableDoltTable:
   316  		return table.LockedToRoot(root), true, nil
   317  	default:
   318  		panic(fmt.Sprintf("unexpected table type %T", table))
   319  	}
   320  }
   321  
   322  // rootAsOf returns the root of the DB as of the expression given, which may be nil in the case that it refers to an
   323  // expression before the first commit.
   324  func (db Database) rootAsOf(ctx *sql.Context, asOf interface{}) (*doltdb.RootValue, error) {
   325  	switch x := asOf.(type) {
   326  	case string:
   327  		return db.getRootForCommitRef(ctx, x)
   328  	case time.Time:
   329  		return db.getRootForTime(ctx, x)
   330  	default:
   331  		panic(fmt.Sprintf("unsupported AS OF type %T", asOf))
   332  	}
   333  }
   334  
   335  func (db Database) getRootForTime(ctx *sql.Context, asOf time.Time) (*doltdb.RootValue, error) {
   336  	cs, err := doltdb.NewCommitSpec("HEAD")
   337  	if err != nil {
   338  		return nil, err
   339  	}
   340  
   341  	cm, err := db.ddb.Resolve(ctx, cs, db.rsr.CWBHeadRef())
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	hash, err := cm.HashOf()
   347  	if err != nil {
   348  		return nil, err
   349  	}
   350  
   351  	cmItr, err := commitwalk.GetTopologicalOrderIterator(ctx, db.ddb, hash)
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  
   356  	for {
   357  		_, curr, err := cmItr.Next(ctx)
   358  		if err == io.EOF {
   359  			break
   360  		} else if err != nil {
   361  			return nil, err
   362  		}
   363  
   364  		meta, err := curr.GetCommitMeta()
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  
   369  		if meta.Time().Equal(asOf) || meta.Time().Before(asOf) {
   370  			return curr.GetRootValue()
   371  		}
   372  	}
   373  
   374  	return nil, nil
   375  }
   376  
   377  func (db Database) getRootForCommitRef(ctx *sql.Context, commitRef string) (*doltdb.RootValue, error) {
   378  	cs, err := doltdb.NewCommitSpec(commitRef)
   379  	if err != nil {
   380  		return nil, err
   381  	}
   382  
   383  	cm, err := db.ddb.Resolve(ctx, cs, db.rsr.CWBHeadRef())
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  
   388  	root, err := cm.GetRootValue()
   389  	if err != nil {
   390  		return nil, err
   391  	}
   392  
   393  	return root, nil
   394  }
   395  
   396  // GetTableNamesAsOf implements sql.VersionedDatabase
   397  func (db Database) GetTableNamesAsOf(ctx *sql.Context, time interface{}) ([]string, error) {
   398  	root, err := db.rootAsOf(ctx, time)
   399  	if err != nil {
   400  		return nil, err
   401  	} else if root == nil {
   402  		return nil, nil
   403  	}
   404  
   405  	tblNames, err := root.GetTableNames(ctx)
   406  	if err != nil {
   407  		return nil, err
   408  	}
   409  	return filterDoltInternalTables(tblNames), nil
   410  }
   411  
   412  // getTable gets the table with the exact name given at the root value given. The database caches tables for all root
   413  // values to avoid doing schema lookups on every table lookup, which are expensive.
   414  func (db Database) getTable(ctx *sql.Context, root *doltdb.RootValue, tableName string, temporary bool) (sql.Table, bool, error) {
   415  	tableNames, err := getAllTableNames(ctx, root)
   416  	if err != nil {
   417  		return nil, true, err
   418  	}
   419  
   420  	tableName, ok := sql.GetTableNameInsensitive(tableName, tableNames)
   421  	if !ok {
   422  		return nil, false, nil
   423  	}
   424  
   425  	tbl, ok, err := root.GetTable(ctx, tableName)
   426  	if err != nil {
   427  		return nil, false, err
   428  	} else if !ok {
   429  		// Should be impossible
   430  		return nil, false, doltdb.ErrTableNotFound
   431  	}
   432  
   433  	sch, err := tbl.GetSchema(ctx)
   434  	if err != nil {
   435  		return nil, false, err
   436  	}
   437  
   438  	var table sql.Table
   439  
   440  	readonlyTable := NewDoltTable(tableName, sch, tbl, db, temporary)
   441  	if doltdb.IsReadOnlySystemTable(tableName) {
   442  		table = readonlyTable
   443  	} else if doltdb.HasDoltPrefix(tableName) {
   444  		table = &WritableDoltTable{DoltTable: readonlyTable, db: db}
   445  	} else {
   446  		table = &AlterableDoltTable{WritableDoltTable{DoltTable: readonlyTable, db: db}}
   447  	}
   448  
   449  	return table, true, nil
   450  }
   451  
   452  // GetTableNames returns the names of all user tables. System tables in user space (e.g. dolt_docs, dolt_query_catalog)
   453  // are filtered out. This method is used for queries that examine the schema of the database, e.g. show tables. Table
   454  // name resolution in queries is handled by GetTableInsensitive. Use GetAllTableNames for an unfiltered list of all
   455  // tables in user space.
   456  func (db Database) GetTableNames(ctx *sql.Context) ([]string, error) {
   457  	tblNames, err := db.GetAllTableNames(ctx)
   458  	if err != nil {
   459  		return nil, err
   460  	}
   461  	return filterDoltInternalTables(tblNames), nil
   462  }
   463  
   464  // GetAllTableNames returns all user-space tables, including system tables in user space
   465  // (e.g. dolt_docs, dolt_query_catalog).
   466  func (db Database) GetAllTableNames(ctx *sql.Context) ([]string, error) {
   467  	root, err := db.GetRoot(ctx)
   468  
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  
   473  	return getAllTableNames(ctx, root)
   474  }
   475  
   476  func getAllTableNames(ctx context.Context, root *doltdb.RootValue) ([]string, error) {
   477  	return root.GetTableNames(ctx)
   478  }
   479  
   480  func filterDoltInternalTables(tblNames []string) []string {
   481  	result := []string{}
   482  	for _, tbl := range tblNames {
   483  		if !doltdb.HasDoltPrefix(tbl) {
   484  			result = append(result, tbl)
   485  		}
   486  	}
   487  	return result
   488  }
   489  
   490  func HeadKey(dbName string) string {
   491  	return dbName + HeadKeySuffix
   492  }
   493  
   494  func HeadRefKey(dbName string) string {
   495  	return dbName + HeadRefKeySuffix
   496  }
   497  
   498  func WorkingKey(dbName string) string {
   499  	return dbName + WorkingKeySuffix
   500  }
   501  
   502  var hashType = sql.MustCreateString(query.Type_TEXT, 32, sql.Collation_ascii_bin)
   503  
   504  // GetRoot returns the root value for this database session
   505  func (db Database) GetRoot(ctx *sql.Context) (*doltdb.RootValue, error) {
   506  	dsess := DSessFromSess(ctx.Session)
   507  	currRoot, dbRootOk := dsess.roots[db.name]
   508  
   509  	if !dbRootOk {
   510  		return nil, fmt.Errorf("no root value found in session")
   511  	}
   512  
   513  	return currRoot.root, nil
   514  }
   515  
   516  func (db Database) GetTemporaryTablesRoot(ctx *sql.Context) (*doltdb.RootValue, bool) {
   517  	dsess := DSessFromSess(ctx.Session)
   518  	return dsess.GetTempTableRootValue(ctx, db.Name())
   519  }
   520  
   521  // SetRoot should typically be called on the Session, which is where this state lives. But it's available here as a
   522  // convenience.
   523  func (db Database) SetRoot(ctx *sql.Context, newRoot *doltdb.RootValue) error {
   524  	dsess := DSessFromSess(ctx.Session)
   525  	return dsess.SetRoot(ctx, db.name, newRoot)
   526  }
   527  
   528  // SetTemporaryRoot sets the root value holding temporary tables not persisted to the repo state after the session.
   529  func (db Database) SetTemporaryRoot(ctx *sql.Context, newRoot *doltdb.RootValue) error {
   530  	dsess := DSessFromSess(ctx.Session)
   531  	return dsess.SetTempTableRoot(ctx, db.name, newRoot)
   532  }
   533  
   534  // LoadRootFromRepoState loads the root value from the repo state's working hash, then calls SetRoot with the loaded
   535  // root value.
   536  func (db Database) LoadRootFromRepoState(ctx *sql.Context) error {
   537  	workingHash := db.rsr.WorkingHash()
   538  	root, err := db.ddb.ReadRootValue(ctx, workingHash)
   539  	if err != nil {
   540  		return err
   541  	}
   542  
   543  	return db.SetRoot(ctx, root)
   544  }
   545  
   546  func (db Database) GetHeadRoot(ctx *sql.Context) (*doltdb.RootValue, error) {
   547  	dsess := DSessFromSess(ctx.Session)
   548  	head, _, err := dsess.GetHeadCommit(ctx, db.name)
   549  	if err != nil {
   550  		return nil, err
   551  	}
   552  	return head.GetRootValue()
   553  }
   554  
   555  // DropTable drops the table with the name given.
   556  // The planner returns the correct case sensitive name in tableName
   557  func (db Database) DropTable(ctx *sql.Context, tableName string) error {
   558  	if doltdb.IsReadOnlySystemTable(tableName) {
   559  		return ErrSystemTableAlter.New(tableName)
   560  	}
   561  
   562  	// Temporary Tables Get Precedence over schema tables
   563  	tempTableRoot, tempRootExists := db.GetTemporaryTablesRoot(ctx)
   564  	if tempRootExists {
   565  		tempTableExists, err := tempTableRoot.HasTable(ctx, tableName)
   566  		if err != nil {
   567  			return err
   568  		}
   569  
   570  		if tempTableExists {
   571  			newRoot, err := tempTableRoot.RemoveTables(ctx, tableName)
   572  			if err != nil {
   573  				return err
   574  			}
   575  
   576  			return db.SetTemporaryRoot(ctx, newRoot)
   577  		}
   578  	}
   579  
   580  	root, err := db.GetRoot(ctx)
   581  	if err != nil {
   582  		return err
   583  	}
   584  
   585  	tableExists, err := root.HasTable(ctx, tableName)
   586  	if err != nil {
   587  		return err
   588  	}
   589  
   590  	if !tableExists {
   591  		return sql.ErrTableNotFound.New(tableName)
   592  	}
   593  
   594  	newRoot, err := root.RemoveTables(ctx, tableName)
   595  	if err != nil {
   596  		return err
   597  	}
   598  
   599  	return db.SetRoot(ctx, newRoot)
   600  }
   601  
   602  // CreateTable creates a table with the name and schema given.
   603  func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Schema) error {
   604  	if doltdb.HasDoltPrefix(tableName) {
   605  		return ErrReservedTableName.New(tableName)
   606  	}
   607  
   608  	if !doltdb.IsValidTableName(tableName) {
   609  		return ErrInvalidTableName.New(tableName)
   610  	}
   611  
   612  	return db.createSqlTable(ctx, tableName, sch)
   613  }
   614  
   615  // Unlike the exported version CreateTable, createSqlTable doesn't enforce any table name checks.
   616  func (db Database) createSqlTable(ctx *sql.Context, tableName string, sch sql.Schema) error {
   617  	root, err := db.GetRoot(ctx)
   618  	if err != nil {
   619  		return err
   620  	}
   621  
   622  	if exists, err := root.HasTable(ctx, tableName); err != nil {
   623  		return err
   624  	} else if exists {
   625  		return sql.ErrTableAlreadyExists.New(tableName)
   626  	}
   627  
   628  	headRoot, err := db.GetHeadRoot(ctx)
   629  	if err != nil {
   630  		return err
   631  	}
   632  
   633  	doltSch, err := sqlutil.ToDoltSchema(ctx, root, tableName, sch, headRoot)
   634  	if err != nil {
   635  		return err
   636  	}
   637  
   638  	return db.createDoltTable(ctx, tableName, root, doltSch)
   639  }
   640  
   641  // createDoltTable creates a table on the database using the given dolt schema while not enforcing table name checks.
   642  func (db Database) createDoltTable(ctx *sql.Context, tableName string, root *doltdb.RootValue, doltSch schema.Schema) error {
   643  	if exists, err := root.HasTable(ctx, tableName); err != nil {
   644  		return err
   645  	} else if exists {
   646  		return sql.ErrTableAlreadyExists.New(tableName)
   647  	}
   648  
   649  	var conflictingTbls []string
   650  	_ = doltSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   651  		_, tbl, exists, err := root.GetTableByColTag(ctx, tag)
   652  		if err != nil {
   653  			return true, err
   654  		}
   655  		if exists && tbl != tableName {
   656  			errStr := schema.ErrTagPrevUsed(tag, col.Name, tbl).Error()
   657  			conflictingTbls = append(conflictingTbls, errStr)
   658  		}
   659  		return false, nil
   660  	})
   661  
   662  	if len(conflictingTbls) > 0 {
   663  		return fmt.Errorf(strings.Join(conflictingTbls, "\n"))
   664  	}
   665  
   666  	newRoot, err := root.CreateEmptyTable(ctx, tableName, doltSch)
   667  	if err != nil {
   668  		return err
   669  	}
   670  
   671  	return db.SetRoot(ctx, newRoot)
   672  }
   673  
   674  // CreateTemporaryTable creates a table that only exists the length of a session.
   675  func (db Database) CreateTemporaryTable(ctx *sql.Context, tableName string, sch sql.Schema) error {
   676  	if doltdb.HasDoltPrefix(tableName) {
   677  		return ErrReservedTableName.New(tableName)
   678  	}
   679  
   680  	if !doltdb.IsValidTableName(tableName) {
   681  		return ErrInvalidTableName.New(tableName)
   682  	}
   683  
   684  	return db.createTempSQLTable(ctx, tableName, sch)
   685  }
   686  
   687  func (db Database) createTempSQLTable(ctx *sql.Context, tableName string, sch sql.Schema) error {
   688  	// Get temporary root value
   689  	dsess := DSessFromSess(ctx.Session)
   690  	tempTableRootValue, exists := db.GetTemporaryTablesRoot(ctx)
   691  
   692  	// create the root value only when needed.
   693  	if !exists {
   694  		err := dsess.CreateTemporaryTablesRoot(ctx, db.Name(), db.GetDoltDB())
   695  		if err != nil {
   696  			return err
   697  		}
   698  
   699  		tempTableRootValue, _ = db.GetTemporaryTablesRoot(ctx)
   700  	}
   701  
   702  	doltSch, err := sqlutil.ToDoltSchema(ctx, tempTableRootValue, tableName, sch, nil)
   703  	if err != nil {
   704  		return err
   705  	}
   706  
   707  	return db.createTempDoltTable(ctx, tableName, tempTableRootValue, doltSch, dsess)
   708  }
   709  
   710  func (db Database) createTempDoltTable(ctx *sql.Context, tableName string, root *doltdb.RootValue, doltSch schema.Schema, dsess *DoltSession) error {
   711  	if exists, err := root.HasTable(ctx, tableName); err != nil {
   712  		return err
   713  	} else if exists {
   714  		return sql.ErrTableAlreadyExists.New(tableName)
   715  	}
   716  
   717  	_ = doltSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   718  		_, tbl, exists, err := root.GetTableByColTag(ctx, tag)
   719  		if err != nil {
   720  			return true, err
   721  		}
   722  		if exists && tbl != tableName {
   723  			panic("Table's tags are associated with a different table name")
   724  		}
   725  		return false, nil
   726  	})
   727  
   728  	newRoot, err := root.CreateEmptyTable(ctx, tableName, doltSch)
   729  	if err != nil {
   730  		return err
   731  	}
   732  
   733  	return dsess.SetTempTableRoot(ctx, db.Name(), newRoot)
   734  }
   735  
   736  // RenameTable implements sql.TableRenamer
   737  func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error {
   738  	root, err := db.GetRoot(ctx)
   739  
   740  	if err != nil {
   741  		return err
   742  	}
   743  
   744  	if doltdb.IsReadOnlySystemTable(oldName) {
   745  		return ErrSystemTableAlter.New(oldName)
   746  	}
   747  
   748  	if doltdb.HasDoltPrefix(newName) {
   749  		return ErrReservedTableName.New(newName)
   750  	}
   751  
   752  	if !doltdb.IsValidTableName(newName) {
   753  		return ErrInvalidTableName.New(newName)
   754  	}
   755  
   756  	if _, ok, _ := db.GetTableInsensitive(ctx, newName); ok {
   757  		return sql.ErrTableAlreadyExists.New(newName)
   758  	}
   759  
   760  	newRoot, err := alterschema.RenameTable(ctx, root, oldName, newName)
   761  
   762  	if err != nil {
   763  		return err
   764  	}
   765  
   766  	return db.SetRoot(ctx, newRoot)
   767  }
   768  
   769  // Flush flushes the current batch of outstanding changes and returns any errors.
   770  func (db Database) Flush(ctx *sql.Context) error {
   771  	dsess := DSessFromSess(ctx.Session)
   772  	editSession := dsess.editSessions[db.name]
   773  
   774  	newRoot, err := editSession.Flush(ctx)
   775  	if err != nil {
   776  		return err
   777  	}
   778  
   779  	err = db.SetRoot(ctx, newRoot)
   780  	if err != nil {
   781  		return nil
   782  	}
   783  
   784  	// Flush any changes made to temporary tables
   785  	// TODO: Shouldn't always be updating both roots. Needs to update either both roots or neither of them, atomically
   786  	tempTableEditSession, sessionExists := dsess.tempTableEditSessions[db.Name()]
   787  	if sessionExists {
   788  		newTempTableRoot, err := tempTableEditSession.Flush(ctx)
   789  		if err != nil {
   790  			return nil
   791  		}
   792  
   793  		return dsess.SetTempTableRoot(ctx, db.Name(), newTempTableRoot)
   794  	}
   795  
   796  	return nil
   797  }
   798  
   799  // CreateView implements sql.ViewCreator. Persists the view in the dolt database, so
   800  // it can exist in a sql session later. Returns sql.ErrExistingView if a view
   801  // with that name already exists.
   802  func (db Database) CreateView(ctx *sql.Context, name string, definition string) error {
   803  	return db.addFragToSchemasTable(ctx, "view", name, definition, sql.ErrExistingView.New(name))
   804  }
   805  
   806  // DropView implements sql.ViewDropper. Removes a view from persistence in the
   807  // dolt database. Returns sql.ErrNonExistingView if the view did not
   808  // exist.
   809  func (db Database) DropView(ctx *sql.Context, name string) error {
   810  	return db.dropFragFromSchemasTable(ctx, "view", name, sql.ErrNonExistingView.New(name))
   811  }
   812  
   813  // GetTriggers implements sql.TriggerDatabase.
   814  func (db Database) GetTriggers(ctx *sql.Context) ([]sql.TriggerDefinition, error) {
   815  	root, err := db.GetRoot(ctx)
   816  	if err != nil {
   817  		return nil, err
   818  	}
   819  
   820  	tbl, ok, err := root.GetTable(ctx, doltdb.SchemasTableName)
   821  	if err != nil {
   822  		return nil, err
   823  	}
   824  	if !ok {
   825  		return nil, nil
   826  	}
   827  
   828  	sch, err := tbl.GetSchema(ctx)
   829  	if err != nil {
   830  		return nil, err
   831  	}
   832  
   833  	typeCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesTypeCol)
   834  	if !ok {
   835  		return nil, fmt.Errorf("`%s` schema in unexpected format", doltdb.SchemasTableName)
   836  	}
   837  	nameCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesNameCol)
   838  	if !ok {
   839  		return nil, fmt.Errorf("`%s` schema in unexpected format", doltdb.SchemasTableName)
   840  	}
   841  	fragCol, ok := sch.GetAllCols().GetByName(doltdb.SchemasTablesFragmentCol)
   842  	if !ok {
   843  		return nil, fmt.Errorf("`%s` schema in unexpected format", doltdb.SchemasTableName)
   844  	}
   845  
   846  	rowData, err := tbl.GetRowData(ctx)
   847  	if err != nil {
   848  		return nil, err
   849  	}
   850  	var triggers []sql.TriggerDefinition
   851  	err = rowData.Iter(ctx, func(key, val types.Value) (stop bool, err error) {
   852  		dRow, err := row.FromNoms(sch, key.(types.Tuple), val.(types.Tuple))
   853  		if err != nil {
   854  			return true, err
   855  		}
   856  		if typeColVal, ok := dRow.GetColVal(typeCol.Tag); ok && typeColVal.Equals(types.String("trigger")) {
   857  			name, ok := dRow.GetColVal(nameCol.Tag)
   858  			if !ok {
   859  				taggedVals, _ := dRow.TaggedValues()
   860  				return true, fmt.Errorf("missing `%s` value for trigger row: (%s)", doltdb.SchemasTablesNameCol, taggedVals)
   861  			}
   862  			createStmt, ok := dRow.GetColVal(fragCol.Tag)
   863  			if !ok {
   864  				taggedVals, _ := dRow.TaggedValues()
   865  				return true, fmt.Errorf("missing `%s` value for trigger row: (%s)", doltdb.SchemasTablesFragmentCol, taggedVals)
   866  			}
   867  			triggers = append(triggers, sql.TriggerDefinition{
   868  				Name:            string(name.(types.String)),
   869  				CreateStatement: string(createStmt.(types.String)),
   870  			})
   871  		}
   872  		return false, nil
   873  	})
   874  	if err != nil {
   875  		return nil, err
   876  	}
   877  	return triggers, nil
   878  }
   879  
   880  // CreateTrigger implements sql.TriggerDatabase.
   881  func (db Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinition) error {
   882  	return db.addFragToSchemasTable(ctx,
   883  		"trigger",
   884  		definition.Name,
   885  		definition.CreateStatement,
   886  		fmt.Errorf("triggers `%s` already exists", definition.Name), //TODO: add a sql error and return that instead
   887  	)
   888  }
   889  
   890  // DropTrigger implements sql.TriggerDatabase.
   891  func (db Database) DropTrigger(ctx *sql.Context, name string) error {
   892  	//TODO: add a sql error and use that as the param error instead
   893  	return db.dropFragFromSchemasTable(ctx, "trigger", name, sql.ErrTriggerDoesNotExist.New(name))
   894  }
   895  
   896  // GetStoredProcedures implements sql.StoredProcedureDatabase.
   897  func (db Database) GetStoredProcedures(ctx *sql.Context) ([]sql.StoredProcedureDetails, error) {
   898  	missingValue := errors.NewKind("missing `%s` value for procedure row: (%s)")
   899  
   900  	root, err := db.GetRoot(ctx)
   901  	if err != nil {
   902  		return nil, err
   903  	}
   904  
   905  	table, ok, err := root.GetTable(ctx, doltdb.ProceduresTableName)
   906  	if err != nil {
   907  		return nil, err
   908  	}
   909  	if !ok {
   910  		return nil, nil
   911  	}
   912  
   913  	rowData, err := table.GetRowData(ctx)
   914  	if err != nil {
   915  		return nil, err
   916  	}
   917  
   918  	sch, err := table.GetSchema(ctx)
   919  	if err != nil {
   920  		return nil, err
   921  	}
   922  
   923  	var spds []sql.StoredProcedureDetails
   924  	err = rowData.Iter(ctx, func(key, val types.Value) (stop bool, err error) {
   925  		dRow, err := row.FromNoms(sch, key.(types.Tuple), val.(types.Tuple))
   926  		if err != nil {
   927  			return true, err
   928  		}
   929  		taggedVals, err := dRow.TaggedValues()
   930  		if err != nil {
   931  			return true, err
   932  		}
   933  
   934  		name, ok := dRow.GetColVal(schema.DoltProceduresNameTag)
   935  		if !ok {
   936  			return true, missingValue.New(doltdb.ProceduresTableNameCol, taggedVals)
   937  		}
   938  		createStmt, ok := dRow.GetColVal(schema.DoltProceduresCreateStmtTag)
   939  		if !ok {
   940  			return true, missingValue.New(doltdb.ProceduresTableCreateStmtCol, taggedVals)
   941  		}
   942  		createdAt, ok := dRow.GetColVal(schema.DoltProceduresCreatedAtTag)
   943  		if !ok {
   944  			return true, missingValue.New(doltdb.ProceduresTableCreatedAtCol, taggedVals)
   945  		}
   946  		modifiedAt, ok := dRow.GetColVal(schema.DoltProceduresModifiedAtTag)
   947  		if !ok {
   948  			return true, missingValue.New(doltdb.ProceduresTableModifiedAtCol, taggedVals)
   949  		}
   950  		spds = append(spds, sql.StoredProcedureDetails{
   951  			Name:            string(name.(types.String)),
   952  			CreateStatement: string(createStmt.(types.String)),
   953  			CreatedAt:       time.Time(createdAt.(types.Timestamp)),
   954  			ModifiedAt:      time.Time(modifiedAt.(types.Timestamp)),
   955  		})
   956  		return false, nil
   957  	})
   958  	if err != nil {
   959  		return nil, err
   960  	}
   961  	return spds, nil
   962  }
   963  
   964  // SaveStoredProcedure implements sql.StoredProcedureDatabase.
   965  func (db Database) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedureDetails) error {
   966  	return DoltProceduresAddProcedure(ctx, db, spd)
   967  }
   968  
   969  // DropStoredProcedure implements sql.StoredProcedureDatabase.
   970  func (db Database) DropStoredProcedure(ctx *sql.Context, name string) error {
   971  	return DoltProceduresDropProcedure(ctx, db, name)
   972  }
   973  
   974  func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, definition string, existingErr error) (retErr error) {
   975  	tbl, err := GetOrCreateDoltSchemasTable(ctx, db)
   976  	if err != nil {
   977  		return err
   978  	}
   979  
   980  	_, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
   981  	if err != nil {
   982  		return err
   983  	}
   984  	if exists {
   985  		return existingErr
   986  	}
   987  
   988  	// If rows exist, then grab the highest id and add 1 to get the new id
   989  	indexToUse := int64(1)
   990  	te, err := db.TableEditSession(ctx, tbl.IsTemporary()).GetTableEditor(ctx, doltdb.SchemasTableName, tbl.sch)
   991  	if err != nil {
   992  		return err
   993  	}
   994  	dTable, err := te.Table(ctx)
   995  	if err != nil {
   996  		return err
   997  	}
   998  	rowData, err := dTable.GetRowData(ctx)
   999  	if err != nil {
  1000  		return err
  1001  	}
  1002  	if rowData.Len() > 0 {
  1003  		keyTpl, _, err := rowData.Last(ctx)
  1004  		if err != nil {
  1005  			return err
  1006  		}
  1007  		if keyTpl != nil {
  1008  			key, err := keyTpl.(types.Tuple).Get(1)
  1009  			if err != nil {
  1010  				return err
  1011  			}
  1012  			indexToUse = int64(key.(types.Int)) + 1
  1013  		}
  1014  	}
  1015  
  1016  	// Insert the new row into the db
  1017  	inserter := tbl.Inserter(ctx)
  1018  	defer func() {
  1019  		err := inserter.Close(ctx)
  1020  		if retErr == nil {
  1021  			retErr = err
  1022  		}
  1023  	}()
  1024  	return inserter.Insert(ctx, sql.Row{fragType, name, definition, indexToUse})
  1025  }
  1026  
  1027  func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error {
  1028  	stbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
  1029  	if err != nil {
  1030  		return err
  1031  	}
  1032  	if !found {
  1033  		return missingErr
  1034  	}
  1035  
  1036  	tbl := stbl.(*WritableDoltTable)
  1037  	row, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
  1038  	if err != nil {
  1039  		return err
  1040  	}
  1041  	if !exists {
  1042  		return missingErr
  1043  	}
  1044  	deleter := tbl.Deleter(ctx)
  1045  	err = deleter.Delete(ctx, row)
  1046  	if err != nil {
  1047  		return err
  1048  	}
  1049  
  1050  	return deleter.Close(ctx)
  1051  }
  1052  
  1053  // TableEditSession returns the TableEditSession for this database from the given context.
  1054  func (db Database) TableEditSession(ctx *sql.Context, isTemporary bool) *editor.TableEditSession {
  1055  	if isTemporary {
  1056  		return DSessFromSess(ctx.Session).tempTableEditSessions[db.name]
  1057  	}
  1058  	return DSessFromSess(ctx.Session).editSessions[db.name]
  1059  }
  1060  
  1061  // GetAllTemporaryTables returns all temporary tables
  1062  func (db Database) GetAllTemporaryTables(ctx *sql.Context) ([]sql.Table, error) {
  1063  	dsess := DSessFromSess(ctx.Session)
  1064  
  1065  	tables := make([]sql.Table, 0)
  1066  
  1067  	for _, root := range dsess.tempTableRoots {
  1068  		tNames, err := root.GetTableNames(ctx)
  1069  		if err != nil {
  1070  			return nil, err
  1071  		}
  1072  
  1073  		for _, tName := range tNames {
  1074  			tbl, ok, err := db.GetTableInsensitive(ctx, tName)
  1075  			if err != nil {
  1076  				return nil, err
  1077  			}
  1078  
  1079  			if ok {
  1080  				tables = append(tables, tbl)
  1081  			}
  1082  		}
  1083  	}
  1084  
  1085  	return tables, nil
  1086  }
  1087  
  1088  // RegisterSchemaFragments register SQL schema fragments that are persisted in the given
  1089  // `Database` with the provided `sql.ViewRegistry`. Returns an error if
  1090  // there are I/O issues, but currently silently fails to register some
  1091  // schema fragments if they don't parse, or if registries within the
  1092  // `catalog` return errors.
  1093  func RegisterSchemaFragments(ctx *sql.Context, db Database, root *doltdb.RootValue) error {
  1094  	root, err := db.GetRoot(ctx)
  1095  	if err != nil {
  1096  		return err
  1097  	}
  1098  
  1099  	tbl, found, err := root.GetTable(ctx, doltdb.SchemasTableName)
  1100  	if err != nil {
  1101  		return err
  1102  	}
  1103  	if !found {
  1104  		return nil
  1105  	}
  1106  
  1107  	rowData, err := tbl.GetRowData(ctx)
  1108  
  1109  	if err != nil {
  1110  		return err
  1111  	}
  1112  
  1113  	iter, err := newRowIterator(ctx, tbl, nil, &doltTablePartition{rowData: rowData, end: NoUpperBound})
  1114  	if err != nil {
  1115  		return err
  1116  	}
  1117  	defer iter.Close(ctx)
  1118  
  1119  	var parseErrors []error
  1120  
  1121  	r, err := iter.Next()
  1122  	for err == nil {
  1123  		if r[0] == "view" {
  1124  			name := r[1].(string)
  1125  			definition := r[2].(string)
  1126  			cv, err := parse.Parse(ctx, fmt.Sprintf("create view %s as %s", sqlfmt.QuoteIdentifier(name), definition))
  1127  			if err != nil {
  1128  				parseErrors = append(parseErrors, err)
  1129  			} else {
  1130  				err = ctx.Register(db.Name(), cv.(*plan.CreateView).Definition.AsView())
  1131  				if err != nil {
  1132  					return err
  1133  				}
  1134  			}
  1135  		}
  1136  		r, err = iter.Next()
  1137  	}
  1138  	if err != io.EOF {
  1139  		return err
  1140  	}
  1141  
  1142  	if len(parseErrors) > 0 {
  1143  		// TODO: Warning for uncreated views...
  1144  	}
  1145  
  1146  	return nil
  1147  }