github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/procedures_table.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"strings"
    21  	"time"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    30  	"github.com/dolthub/dolt/go/store/types"
    31  )
    32  
    33  const (
    34  	// ProceduresTableName is the name of the dolt stored procedures table.
    35  	ProceduresTableName = "dolt_procedures"
    36  	// ProceduresTableNameCol is the name of the stored procedure. Using CREATE PROCEDURE, will always be lowercase.
    37  	ProceduresTableNameCol = "name"
    38  	// ProceduresTableCreateStmtCol is the CREATE PROCEDURE statement for this stored procedure.
    39  	ProceduresTableCreateStmtCol = "create_stmt"
    40  	// ProceduresTableCreatedAtCol is the time that the stored procedure was created at, in UTC.
    41  	ProceduresTableCreatedAtCol = "created_at"
    42  	// ProceduresTableModifiedAtCol is the time that the stored procedure was last modified, in UTC.
    43  	ProceduresTableModifiedAtCol = "modified_at"
    44  )
    45  
    46  // The fixed SQL schema for the `dolt_procedures` table.
    47  func ProceduresTableSqlSchema() sql.PrimaryKeySchema {
    48  	sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.ProceduresTableName, ProceduresTableSchema())
    49  	if err != nil {
    50  		panic(err) // should never happen
    51  	}
    52  	return sqlSchema
    53  }
    54  
    55  // The fixed dolt schema for the `dolt_procedures` table.
    56  func ProceduresTableSchema() schema.Schema {
    57  	colColl := schema.NewColCollection(
    58  		schema.NewColumn(doltdb.ProceduresTableNameCol, schema.DoltProceduresNameTag, types.StringKind, true, schema.NotNullConstraint{}),
    59  		schema.NewColumn(doltdb.ProceduresTableCreateStmtCol, schema.DoltProceduresCreateStmtTag, types.StringKind, false),
    60  		schema.NewColumn(doltdb.ProceduresTableCreatedAtCol, schema.DoltProceduresCreatedAtTag, types.TimestampKind, false),
    61  		schema.NewColumn(doltdb.ProceduresTableModifiedAtCol, schema.DoltProceduresModifiedAtTag, types.TimestampKind, false),
    62  		schema.NewColumn(doltdb.ProceduresTableSqlModeCol, schema.DoltProceduresSqlModeTag, types.StringKind, false),
    63  	)
    64  	return schema.MustSchemaFromCols(colColl)
    65  }
    66  
    67  // DoltProceduresGetOrCreateTable returns the `dolt_procedures` table from the given db, creating it in the db's
    68  // current root if it doesn't exist
    69  func DoltProceduresGetOrCreateTable(ctx *sql.Context, db Database) (*WritableDoltTable, error) {
    70  	tbl, found, err := db.GetTableInsensitive(ctx, doltdb.ProceduresTableName)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	if found {
    75  		// Make sure the schema is up to date
    76  		writableDoltTable := tbl.(*WritableDoltTable)
    77  		return migrateDoltProceduresSchema(ctx, db, writableDoltTable)
    78  	}
    79  
    80  	root, err := db.GetRoot(ctx)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	err = db.createDoltTable(ctx, doltdb.ProceduresTableName, doltdb.DefaultSchemaName, root, ProceduresTableSchema())
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	tbl, found, err = db.GetTableInsensitive(ctx, doltdb.ProceduresTableName)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	// Verify it was created successfully
    95  	if !found {
    96  		return nil, sql.ErrTableNotFound.New(ProceduresTableName)
    97  	}
    98  	return tbl.(*WritableDoltTable), nil
    99  }
   100  
   101  // migrateDoltProceduresSchema migrates the dolt_procedures system table from a previous schema version to the current
   102  // schema version by adding any columns that do not exist.
   103  func migrateDoltProceduresSchema(ctx *sql.Context, db Database, oldTable *WritableDoltTable) (newTable *WritableDoltTable, rerr error) {
   104  	// Check whether the table needs to be migrated
   105  	targetSchema := ProceduresTableSqlSchema().Schema
   106  	if len(oldTable.Schema()) == len(targetSchema) {
   107  		return oldTable, nil
   108  	}
   109  
   110  	// Copy all the old data
   111  	iter, err := SqlTableToRowIter(ctx, oldTable.DoltTable, nil)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	nameIdx := oldTable.sqlSchema().IndexOfColName(doltdb.ProceduresTableNameCol)
   117  	createStatementIdx := oldTable.sqlSchema().IndexOfColName(doltdb.ProceduresTableCreateStmtCol)
   118  	createdAtIdx := oldTable.sqlSchema().IndexOfColName(doltdb.ProceduresTableCreatedAtCol)
   119  	modifiedAtIdx := oldTable.sqlSchema().IndexOfColName(doltdb.ProceduresTableModifiedAtCol)
   120  	sqlModeIdx := oldTable.sqlSchema().IndexOfColName(doltdb.ProceduresTableSqlModeCol)
   121  
   122  	defer func(iter sql.RowIter, ctx *sql.Context) {
   123  		err := iter.Close(ctx)
   124  		if err != nil && rerr == nil {
   125  			rerr = err
   126  		}
   127  	}(iter, ctx)
   128  
   129  	var newRows []sql.Row
   130  	for {
   131  		sqlRow, err := iter.Next(ctx)
   132  		if err == io.EOF {
   133  			break
   134  		}
   135  		if err != nil {
   136  			return nil, err
   137  		}
   138  
   139  		newRow := make(sql.Row, ProceduresTableSchema().GetAllCols().Size())
   140  		newRow[0] = sqlRow[nameIdx]
   141  		newRow[1] = sqlRow[createStatementIdx]
   142  		newRow[2] = sqlRow[createdAtIdx]
   143  		newRow[3] = sqlRow[modifiedAtIdx]
   144  		if sqlModeIdx >= 0 {
   145  			newRow[4] = sqlRow[sqlModeIdx]
   146  		}
   147  		newRows = append(newRows, newRow)
   148  	}
   149  
   150  	err = db.dropTable(ctx, doltdb.ProceduresTableName)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	root, err := db.GetRoot(ctx)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	err = db.createDoltTable(ctx, doltdb.ProceduresTableName, doltdb.DefaultSchemaName, root, ProceduresTableSchema())
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	tbl, found, err := db.GetTableInsensitive(ctx, doltdb.ProceduresTableName)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	if !found {
   170  		return nil, sql.ErrTableNotFound.New(doltdb.ProceduresTableName)
   171  	}
   172  
   173  	inserter := tbl.(*WritableDoltTable).Inserter(ctx)
   174  	for _, row := range newRows {
   175  		err = inserter.Insert(ctx, row)
   176  		if err != nil {
   177  			return nil, err
   178  		}
   179  	}
   180  
   181  	err = inserter.Close(ctx)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	return tbl.(*WritableDoltTable), nil
   187  }
   188  
   189  // DoltProceduresGetTable returns the `dolt_procedures` table from the given db, or nil if the table doesn't exist
   190  func DoltProceduresGetTable(ctx *sql.Context, db Database) (*WritableDoltTable, error) {
   191  	tbl, found, err := db.GetTableInsensitive(ctx, doltdb.ProceduresTableName)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	if found {
   196  		// Make sure the schema is up to date
   197  		writableDoltTable := tbl.(*WritableDoltTable)
   198  		return migrateDoltProceduresSchema(ctx, db, writableDoltTable)
   199  	} else {
   200  		return nil, nil
   201  	}
   202  }
   203  
   204  // DoltProceduresGetAll returns all stored procedures for the database if the procedureName is blank (and empty string),
   205  // or it returns only the procedure with the matching name if one is given. The name is not case-sensitive.
   206  func DoltProceduresGetAll(ctx *sql.Context, db Database, procedureName string) ([]sql.StoredProcedureDetails, error) {
   207  	tbl, err := DoltProceduresGetTable(ctx, db)
   208  	if err != nil {
   209  		return nil, err
   210  	} else if tbl == nil {
   211  		return nil, nil
   212  	}
   213  
   214  	indexes, err := tbl.GetIndexes(ctx)
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  	if len(indexes) == 0 {
   219  		return nil, fmt.Errorf("missing index for stored procedures")
   220  	}
   221  	idx := indexes[0]
   222  
   223  	if len(idx.Expressions()) == 0 {
   224  		return nil, fmt.Errorf("missing index expression for stored procedures")
   225  	}
   226  	nameExpr := idx.Expressions()[0]
   227  
   228  	var lookup sql.IndexLookup
   229  	if procedureName == "" {
   230  		lookup, err = sql.NewIndexBuilder(idx).IsNotNull(ctx, nameExpr).Build(ctx)
   231  	} else {
   232  		lookup, err = sql.NewIndexBuilder(idx).Equals(ctx, nameExpr, procedureName).Build(ctx)
   233  	}
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	iter, err := index.RowIterForIndexLookup(ctx, tbl.DoltTable, lookup, tbl.sqlSch, nil)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  	defer func() {
   243  		if cerr := iter.Close(ctx); cerr != nil {
   244  			err = cerr
   245  		}
   246  	}()
   247  
   248  	var sqlRow sql.Row
   249  	var details []sql.StoredProcedureDetails
   250  	missingValue := errors.NewKind("missing `%s` value for procedure row: (%s)")
   251  
   252  	for {
   253  		sqlRow, err = iter.Next(ctx)
   254  		if err == io.EOF {
   255  			break
   256  		}
   257  		if err != nil {
   258  			return nil, err
   259  		}
   260  
   261  		var d sql.StoredProcedureDetails
   262  		var ok bool
   263  
   264  		if d.Name, ok = sqlRow[0].(string); !ok {
   265  			return nil, missingValue.New(doltdb.ProceduresTableNameCol, sqlRow)
   266  		}
   267  		if d.CreateStatement, ok = sqlRow[1].(string); !ok {
   268  			return nil, missingValue.New(doltdb.ProceduresTableCreateStmtCol, sqlRow)
   269  		}
   270  		if d.CreatedAt, ok = sqlRow[2].(time.Time); !ok {
   271  			return nil, missingValue.New(doltdb.ProceduresTableCreatedAtCol, sqlRow)
   272  		}
   273  		if d.ModifiedAt, ok = sqlRow[3].(time.Time); !ok {
   274  			return nil, missingValue.New(doltdb.ProceduresTableModifiedAtCol, sqlRow)
   275  		}
   276  		if s, ok := sqlRow[4].(string); ok {
   277  			d.SqlMode = s
   278  		} else {
   279  			defaultSqlMode, err := loadDefaultSqlMode()
   280  			if err != nil {
   281  				return nil, err
   282  			}
   283  			d.SqlMode = defaultSqlMode
   284  		}
   285  		details = append(details, d)
   286  	}
   287  	return details, nil
   288  }
   289  
   290  // DoltProceduresAddProcedure adds the stored procedure to the `dolt_procedures` table in the given db, creating it if
   291  // it does not exist.
   292  func DoltProceduresAddProcedure(ctx *sql.Context, db Database, spd sql.StoredProcedureDetails) (retErr error) {
   293  	tbl, err := DoltProceduresGetOrCreateTable(ctx, db)
   294  	if err != nil {
   295  		return err
   296  	}
   297  	_, ok, err := DoltProceduresGetDetails(ctx, tbl, spd.Name)
   298  	if err != nil {
   299  		return err
   300  	}
   301  	if ok {
   302  		return sql.ErrStoredProcedureAlreadyExists.New(spd.Name)
   303  	}
   304  	inserter := tbl.Inserter(ctx)
   305  	defer func() {
   306  		err := inserter.Close(ctx)
   307  		if retErr == nil {
   308  			retErr = err
   309  		}
   310  	}()
   311  	return inserter.Insert(ctx, sql.Row{
   312  		strings.ToLower(spd.Name),
   313  		spd.CreateStatement,
   314  		spd.CreatedAt.UTC(),
   315  		spd.ModifiedAt.UTC(),
   316  		spd.SqlMode,
   317  	})
   318  }
   319  
   320  // DoltProceduresDropProcedure removes the stored procedure from the `dolt_procedures` table. The procedure named must
   321  // exist.
   322  func DoltProceduresDropProcedure(ctx *sql.Context, db Database, name string) (retErr error) {
   323  	name = strings.ToLower(name)
   324  	tbl, err := DoltProceduresGetTable(ctx, db)
   325  	if err != nil {
   326  		return err
   327  	} else if tbl == nil {
   328  		return sql.ErrStoredProcedureDoesNotExist.New(name)
   329  	}
   330  
   331  	_, ok, err := DoltProceduresGetDetails(ctx, tbl, name)
   332  	if err != nil {
   333  		return err
   334  	}
   335  	if !ok {
   336  		return sql.ErrStoredProcedureDoesNotExist.New(name)
   337  	}
   338  	deleter := tbl.Deleter(ctx)
   339  	defer func() {
   340  		err := deleter.Close(ctx)
   341  		if retErr == nil {
   342  			retErr = err
   343  		}
   344  	}()
   345  	return deleter.Delete(ctx, sql.Row{name})
   346  }
   347  
   348  // DoltProceduresGetDetails returns the stored procedure with the given name from `dolt_procedures` if it exists.
   349  func DoltProceduresGetDetails(ctx *sql.Context, tbl *WritableDoltTable, name string) (sql.StoredProcedureDetails, bool, error) {
   350  	name = strings.ToLower(name)
   351  	indexes, err := tbl.GetIndexes(ctx)
   352  	if err != nil {
   353  		return sql.StoredProcedureDetails{}, false, err
   354  	}
   355  	var fragNameIndex sql.Index
   356  	for _, idx := range indexes {
   357  		if idx.ID() == "PRIMARY" {
   358  			fragNameIndex = idx
   359  			break
   360  		}
   361  	}
   362  	if fragNameIndex == nil {
   363  		return sql.StoredProcedureDetails{}, false, fmt.Errorf("could not find primary key index on system table `%s`", doltdb.ProceduresTableName)
   364  	}
   365  
   366  	indexLookup, err := sql.NewIndexBuilder(fragNameIndex).Equals(ctx, fragNameIndex.Expressions()[0], name).Build(ctx)
   367  	if err != nil {
   368  		return sql.StoredProcedureDetails{}, false, err
   369  	}
   370  
   371  	rowIter, err := index.RowIterForIndexLookup(ctx, tbl.DoltTable, indexLookup, tbl.sqlSch, nil)
   372  	if err != nil {
   373  		return sql.StoredProcedureDetails{}, false, err
   374  	}
   375  	defer func() {
   376  		if cerr := rowIter.Close(ctx); cerr != nil {
   377  			err = cerr
   378  		}
   379  	}()
   380  
   381  	sqlRow, err := rowIter.Next(ctx)
   382  	if err == nil {
   383  		if len(sqlRow) != 5 {
   384  			return sql.StoredProcedureDetails{}, false, fmt.Errorf("unexpected row in dolt_procedures:\n%v", sqlRow)
   385  		}
   386  		return sql.StoredProcedureDetails{
   387  			Name:            sqlRow[0].(string),
   388  			CreateStatement: sqlRow[1].(string),
   389  			CreatedAt:       sqlRow[2].(time.Time),
   390  			ModifiedAt:      sqlRow[3].(time.Time),
   391  		}, true, nil
   392  	} else if err == io.EOF {
   393  		return sql.StoredProcedureDetails{}, false, nil
   394  	} else {
   395  		return sql.StoredProcedureDetails{}, false, err
   396  	}
   397  }