github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/schema_table.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"strings"
    21  	"time"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    25  	"github.com/dolthub/vitess/go/vt/proto/query"
    26  
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    30  )
    31  
    32  const (
    33  	viewFragment    = "view"
    34  	triggerFragment = "trigger"
    35  	eventFragment   = "event"
    36  )
    37  
    38  type Extra struct {
    39  	CreatedAt int64
    40  }
    41  
    42  func mustNewColWithTypeInfo(name string, tag uint64, typeInfo typeinfo.TypeInfo, partOfPK bool, defaultVal string, autoIncrement bool, comment string, constraints ...schema.ColConstraint) schema.Column {
    43  	col, err := schema.NewColumnWithTypeInfo(name, tag, typeInfo, partOfPK, defaultVal, autoIncrement, comment, constraints...)
    44  	if err != nil {
    45  		panic(err)
    46  	}
    47  	return col
    48  }
    49  
    50  func mustCreateStringType(baseType query.Type, length int64, collation sql.CollationID) sql.StringType {
    51  	ti, err := gmstypes.CreateString(baseType, length, collation)
    52  	if err != nil {
    53  		panic(err)
    54  	}
    55  	return ti
    56  }
    57  
    58  // dolt_schemas columns
    59  var schemasTableCols = schema.NewColCollection(
    60  	mustNewColWithTypeInfo(doltdb.SchemasTablesTypeCol, schema.DoltSchemasTypeTag, typeinfo.CreateVarStringTypeFromSqlType(mustCreateStringType(query.Type_VARCHAR, 64, sql.Collation_utf8mb4_0900_ai_ci)), true, "", false, ""),
    61  	mustNewColWithTypeInfo(doltdb.SchemasTablesNameCol, schema.DoltSchemasNameTag, typeinfo.CreateVarStringTypeFromSqlType(mustCreateStringType(query.Type_VARCHAR, 64, sql.Collation_utf8mb4_0900_ai_ci)), true, "", false, ""),
    62  	mustNewColWithTypeInfo(doltdb.SchemasTablesFragmentCol, schema.DoltSchemasFragmentTag, typeinfo.CreateVarStringTypeFromSqlType(gmstypes.LongText), false, "", false, ""),
    63  	mustNewColWithTypeInfo(doltdb.SchemasTablesExtraCol, schema.DoltSchemasExtraTag, typeinfo.JSONType, false, "", false, ""),
    64  	mustNewColWithTypeInfo(doltdb.SchemasTablesSqlModeCol, schema.DoltSchemasSqlModeTag, typeinfo.CreateVarStringTypeFromSqlType(mustCreateStringType(query.Type_VARCHAR, 256, sql.Collation_utf8mb4_0900_ai_ci)), false, "", false, ""),
    65  )
    66  
    67  var schemaTableSchema = schema.MustSchemaFromCols(schemasTableCols)
    68  
    69  // getOrCreateDoltSchemasTable returns the `dolt_schemas` table in `db`, creating it if it does not already exist.
    70  // Also migrates data to the correct format if necessary.
    71  func getOrCreateDoltSchemasTable(ctx *sql.Context, db Database) (retTbl *WritableDoltTable, retErr error) {
    72  	tbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	if found {
    78  		schemasTable := tbl.(*WritableDoltTable)
    79  		// Old schemas table contains the `id` column or is missing an `extra` column.
    80  		if tbl.Schema().Contains(doltdb.SchemasTablesIdCol, doltdb.SchemasTableName) || !tbl.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.SchemasTableName) {
    81  			return migrateOldSchemasTableToNew(ctx, db, schemasTable)
    82  		} else {
    83  			return schemasTable, nil
    84  		}
    85  	}
    86  
    87  	root, err := db.GetRoot(ctx)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	// Create new empty table
    93  	err = db.createDoltTable(ctx, doltdb.SchemasTableName, doltdb.DefaultSchemaName, root, schemaTableSchema)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	tbl, found, err = db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	if !found {
   102  		return nil, sql.ErrTableNotFound.New(doltdb.SchemasTableName)
   103  	}
   104  
   105  	return tbl.(*WritableDoltTable), nil
   106  }
   107  
   108  func migrateOldSchemasTableToNew(ctx *sql.Context, db Database, schemasTable *WritableDoltTable) (newTable *WritableDoltTable, rerr error) {
   109  	// Copy all of the old data over and add an index column and an extra column
   110  	iter, err := SqlTableToRowIter(ctx, schemasTable.DoltTable, nil)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
   116  	// need to get the column indexes from the current schema
   117  	nameIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
   118  	typeIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
   119  	fragmentIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesFragmentCol)
   120  	extraIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesExtraCol)
   121  	sqlModeIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesSqlModeCol)
   122  
   123  	defer func(iter sql.RowIter, ctx *sql.Context) {
   124  		err := iter.Close(ctx)
   125  		if err != nil && rerr == nil {
   126  			rerr = err
   127  		}
   128  	}(iter, ctx)
   129  
   130  	var newRows []sql.Row
   131  	for {
   132  		sqlRow, err := iter.Next(ctx)
   133  		if err == io.EOF {
   134  			break
   135  		}
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  
   140  		newRow := make(sql.Row, schemasTableCols.Size())
   141  		newRow[0] = sqlRow[typeIdx]
   142  		newRow[1] = sqlRow[nameIdx]
   143  		newRow[2] = sqlRow[fragmentIdx]
   144  		if extraIdx >= 0 {
   145  			newRow[3] = sqlRow[extraIdx]
   146  		}
   147  		if sqlModeIdx >= 0 {
   148  			newRow[4] = sqlRow[sqlModeIdx]
   149  		}
   150  
   151  		newRows = append(newRows, newRow)
   152  	}
   153  
   154  	err = db.dropTable(ctx, doltdb.SchemasTableName)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  
   159  	root, err := db.GetRoot(ctx)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	err = db.createDoltTable(ctx, doltdb.SchemasTableName, doltdb.DefaultSchemaName, root, schemaTableSchema)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	tbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	if !found {
   174  		return nil, sql.ErrTableNotFound.New(doltdb.SchemasTableName)
   175  	}
   176  
   177  	inserter := tbl.(*WritableDoltTable).Inserter(ctx)
   178  	for _, row := range newRows {
   179  		err = inserter.Insert(ctx, row)
   180  		if err != nil {
   181  			return nil, err
   182  		}
   183  	}
   184  
   185  	err = inserter.Close(ctx)
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  
   190  	return tbl.(*WritableDoltTable), nil
   191  }
   192  
   193  // fragFromSchemasTable returns the row with the given schema fragment if it exists.
   194  func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType string, name string) (r sql.Row, found bool, rerr error) {
   195  	fragType, name = strings.ToLower(fragType), strings.ToLower(name)
   196  
   197  	// This performs a full table scan in the worst case, but it's only used when adding or dropping a trigger or view
   198  	iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
   199  	if err != nil {
   200  		return nil, false, err
   201  	}
   202  
   203  	defer func(iter sql.RowIter, ctx *sql.Context) {
   204  		err := iter.Close(ctx)
   205  		if err != nil && rerr == nil {
   206  			rerr = err
   207  		}
   208  	}(iter, ctx)
   209  
   210  	// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
   211  	// need to get the column indexes from the current schema
   212  	nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
   213  	typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
   214  
   215  	for {
   216  		sqlRow, err := iter.Next(ctx)
   217  		if err == io.EOF {
   218  			break
   219  		}
   220  		if err != nil {
   221  			return nil, false, err
   222  		}
   223  
   224  		// These columns are case insensitive, make sure to do a case-insensitive comparison
   225  		if strings.ToLower(sqlRow[typeIdx].(string)) == fragType && strings.ToLower(sqlRow[nameIdx].(string)) == name {
   226  			return sqlRow, true, nil
   227  		}
   228  	}
   229  
   230  	return nil, false, nil
   231  }
   232  
   233  type schemaFragment struct {
   234  	name     string
   235  	fragment string
   236  	created  time.Time
   237  	// sqlMode indicates the SQL_MODE that was used when this schema fragment was initially parsed. SQL_MODE settings
   238  	// such as ANSI_QUOTES control customized parsing behavior needed for some schema fragments.
   239  	sqlMode string
   240  }
   241  
   242  func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType string) (sf []schemaFragment, rerr error) {
   243  	iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
   249  	// need to get the column indexes from the current schema
   250  	nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
   251  	typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
   252  	fragmentIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesFragmentCol)
   253  	extraIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesExtraCol)
   254  	sqlModeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesSqlModeCol)
   255  
   256  	defer func(iter sql.RowIter, ctx *sql.Context) {
   257  		err := iter.Close(ctx)
   258  		if err != nil && rerr == nil {
   259  			rerr = err
   260  		}
   261  	}(iter, ctx)
   262  
   263  	var frags []schemaFragment
   264  	for {
   265  		sqlRow, err := iter.Next(ctx)
   266  		if err == io.EOF {
   267  			break
   268  		}
   269  		if err != nil {
   270  			return nil, err
   271  		}
   272  
   273  		if sqlRow[typeIdx] != fragType {
   274  			continue
   275  		}
   276  
   277  		sqlModeString := ""
   278  		if sqlModeIdx >= 0 {
   279  			if s, ok := sqlRow[sqlModeIdx].(string); ok {
   280  				sqlModeString = s
   281  			}
   282  		} else {
   283  			defaultSqlMode, err := loadDefaultSqlMode()
   284  			if err != nil {
   285  				return nil, err
   286  			}
   287  			sqlModeString = defaultSqlMode
   288  		}
   289  
   290  		// For older tables, use 1 as the trigger creation time
   291  		if extraIdx < 0 || sqlRow[extraIdx] == nil {
   292  			frags = append(frags, schemaFragment{
   293  				name:     sqlRow[nameIdx].(string),
   294  				fragment: sqlRow[fragmentIdx].(string),
   295  				created:  time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
   296  				sqlMode:  sqlModeString,
   297  			})
   298  			continue
   299  		}
   300  
   301  		// Extract Created Time from JSON column
   302  		createdTime, err := getCreatedTime(ctx, sqlRow[extraIdx].(sql.JSONWrapper))
   303  		if err != nil {
   304  			return nil, err
   305  		}
   306  
   307  		frags = append(frags, schemaFragment{
   308  			name:     sqlRow[nameIdx].(string),
   309  			fragment: sqlRow[fragmentIdx].(string),
   310  			created:  time.Unix(createdTime, 0).UTC(),
   311  			sqlMode:  sqlModeString,
   312  		})
   313  	}
   314  
   315  	return frags, nil
   316  }
   317  
   318  // loadDefaultSqlMode loads the default value for the @@SQL_MODE system variable and returns it, along
   319  // with any unexpected errors encountered while reading the default value.
   320  func loadDefaultSqlMode() (string, error) {
   321  	global, _, ok := sql.SystemVariables.GetGlobal("SQL_MODE")
   322  	if !ok {
   323  		return "", fmt.Errorf("unable to load default @@SQL_MODE")
   324  	}
   325  	s, ok := global.GetDefault().(string)
   326  	if !ok {
   327  		return "", fmt.Errorf("unexpected type for @@SQL_MODE default value: %T", global.GetDefault())
   328  	}
   329  	return s, nil
   330  }
   331  
   332  func getCreatedTime(ctx *sql.Context, extraCol sql.JSONWrapper) (int64, error) {
   333  	doc, err := extraCol.ToInterface()
   334  	if err != nil {
   335  		return 0, err
   336  	}
   337  
   338  	err = fmt.Errorf("value %v does not contain creation time", doc)
   339  
   340  	obj, ok := doc.(map[string]interface{})
   341  	if !ok {
   342  		return 0, err
   343  	}
   344  
   345  	v, ok := obj["CreatedAt"]
   346  	if !ok {
   347  		return 0, err
   348  	}
   349  
   350  	f, ok := v.(float64)
   351  	if !ok {
   352  		return 0, err
   353  	}
   354  	return int64(f), nil
   355  }