github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/resolve_column_defaults.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"github.com/dolthub/vitess/go/sqltypes"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	"github.com/dolthub/go-mysql-server/sql/information_schema"
    23  	"github.com/dolthub/go-mysql-server/sql/plan"
    24  	"github.com/dolthub/go-mysql-server/sql/transform"
    25  )
    26  
    27  // Resolving column defaults is a multi-phase process, with different analyzer rules for each phase.
    28  //
    29  // * parseColumnDefaults: Some integrators (dolt but not GMS) store their column defaults as strings, which we need to
    30  // 	parse into expressions before we can analyze them any further.
    31  // * resolveColumnDefaults: Once we have an expression for a default value, it may contain expressions that need
    32  // 	simplification before further phases of processing can take place.
    33  //
    34  // After this stage, expressions in column default values are handled by the normal analyzer machinery responsible for
    35  // resolving expressions, including things like columns and functions. Every node that needs to do this for its default
    36  // values implements `sql.Expressioner` to expose such expressions. There is custom logic in `resolveColumns` to help
    37  // identify the correct indexes for column references, which can vary based on the node type.
    38  //
    39  // Finally there are cleanup phases:
    40  // * validateColumnDefaults: ensures that newly created column defaults from a DDL statement are legal for the type of
    41  // 	column, various other business logic checks to match MySQL's logic.
    42  // * stripTableNamesFromDefault: column defaults headed for storage or serialization in a query result need the table
    43  // 	names in any GetField expressions stripped out so that they serialize to strings without such table names. Table
    44  // 	names in GetField expressions are expected in much of the rest of the analyzer, so we do this after the bulk of
    45  // 	analyzer work.
    46  //
    47  // The `information_schema.columns` table also needs access to the default values of every column in the database, and
    48  // because it's a table it can't implement `sql.Expressioner` like other node types. Instead it has special handling
    49  // here, as well as in the `resolve_functions` rule.
    50  
    51  func validateColumnDefaults(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    52  	span, ctx := ctx.Span("validateColumnDefaults")
    53  	defer span.End()
    54  
    55  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    56  		switch node := n.(type) {
    57  		case *plan.AlterDefaultSet:
    58  			table := getResolvedTable(node)
    59  			sch := table.Schema()
    60  			index := sch.IndexOfColName(node.ColumnName)
    61  			if index == -1 {
    62  				return nil, transform.SameTree, sql.ErrColumnNotFound.New(node.ColumnName)
    63  			}
    64  			col := sch[index]
    65  			eWrapper := expression.WrapExpression(node.Default)
    66  			err := validateColumnDefault(ctx, col, eWrapper)
    67  			if err != nil {
    68  				return node, transform.SameTree, err
    69  			}
    70  			return node, transform.SameTree, nil
    71  		case sql.SchemaTarget:
    72  			switch node.(type) {
    73  			case *plan.AlterPK, *plan.AddColumn, *plan.ModifyColumn, *plan.AlterDefaultDrop, *plan.CreateTable, *plan.DropColumn:
    74  				// DDL nodes must validate any new column defaults, continue to logic below
    75  			default:
    76  				// other node types are not altering the schema and therefore don't need validation of column defaults
    77  				return n, transform.SameTree, nil
    78  			}
    79  
    80  			// There may be multiple DDL nodes in the plan (ALTER TABLE statements can have many clauses), and for each of them
    81  			// we need to count the column indexes in the very hacky way outlined above.
    82  			i := 0
    83  			return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
    84  				eWrapper, ok := e.(*expression.Wrapper)
    85  				if !ok {
    86  					return e, transform.SameTree, nil
    87  				}
    88  
    89  				defer func() {
    90  					i++
    91  				}()
    92  
    93  				if eWrapper.Unwrap() == nil {
    94  					return e, transform.SameTree, nil
    95  				}
    96  
    97  				col, err := lookupColumnForTargetSchema(ctx, node, i)
    98  				if err != nil {
    99  					return nil, transform.SameTree, err
   100  				}
   101  
   102  				err = validateColumnDefault(ctx, col, eWrapper)
   103  				if err != nil {
   104  					return nil, transform.SameTree, err
   105  				}
   106  
   107  				return e, transform.SameTree, nil
   108  			})
   109  		default:
   110  			return node, transform.SameTree, nil
   111  		}
   112  	})
   113  }
   114  
   115  // stripTableNamesFromColumnDefaults removes the table name from any GetField expressions in column default expressions.
   116  // Default values can only reference their host table, and since we serialize the GetField expression for storage, it's
   117  // important that we remove the table name before passing it off for storage. Otherwise we end up with serialized
   118  // defaults like `tableName.field + 1` instead of just `field + 1`.
   119  func stripTableNamesFromColumnDefaults(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   120  	span, ctx := ctx.Span("stripTableNamesFromColumnDefaults")
   121  	defer span.End()
   122  
   123  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
   124  		switch node := n.(type) {
   125  		case *plan.AlterDefaultSet:
   126  			eWrapper := expression.WrapExpression(node.Default)
   127  			newExpr, same, err := stripTableNamesFromDefault(eWrapper)
   128  			if err != nil {
   129  				return node, transform.SameTree, err
   130  			}
   131  			if same {
   132  				return node, transform.SameTree, nil
   133  			}
   134  
   135  			newNode, err := node.WithDefault(newExpr)
   136  			if err != nil {
   137  				return node, transform.SameTree, err
   138  			}
   139  			return newNode, transform.NewTree, nil
   140  		case sql.SchemaTarget:
   141  			return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   142  				eWrapper, ok := e.(*expression.Wrapper)
   143  				if !ok {
   144  					return e, transform.SameTree, nil
   145  				}
   146  
   147  				return stripTableNamesFromDefault(eWrapper)
   148  			})
   149  		case *plan.ResolvedTable:
   150  			ct, ok := node.Table.(*information_schema.ColumnsTable)
   151  			if !ok {
   152  				return node, transform.SameTree, nil
   153  			}
   154  
   155  			allColumns, err := ct.AllColumns(ctx)
   156  			if err != nil {
   157  				return nil, transform.SameTree, err
   158  			}
   159  
   160  			allDefaults, same, err := transform.Exprs(transform.WrappedColumnDefaults(allColumns), func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   161  				eWrapper, ok := e.(*expression.Wrapper)
   162  				if !ok {
   163  					return e, transform.SameTree, nil
   164  				}
   165  
   166  				return stripTableNamesFromDefault(eWrapper)
   167  			})
   168  
   169  			if err != nil {
   170  				return nil, transform.SameTree, err
   171  			}
   172  
   173  			if !same {
   174  				node.Table, err = ct.WithColumnDefaults(allDefaults)
   175  				if err != nil {
   176  					return nil, transform.SameTree, err
   177  				}
   178  				return node, transform.NewTree, err
   179  			}
   180  
   181  			return node, transform.SameTree, err
   182  		default:
   183  			return node, transform.SameTree, nil
   184  		}
   185  	})
   186  }
   187  
   188  // lookupColumnForTargetSchema looks at the target schema for the specified SchemaTarget node and returns
   189  // the column based on the specified index. For most node types, this is simply indexing into the target
   190  // schema but a few types require special handling.
   191  func lookupColumnForTargetSchema(_ *sql.Context, node sql.SchemaTarget, colIndex int) (*sql.Column, error) {
   192  	schema := node.TargetSchema()
   193  
   194  	switch n := node.(type) {
   195  	case *plan.ModifyColumn:
   196  		if colIndex < len(schema) {
   197  			return schema[colIndex], nil
   198  		} else {
   199  			return n.NewColumn(), nil
   200  		}
   201  	case *plan.AddColumn:
   202  		if colIndex < len(schema) {
   203  			return schema[colIndex], nil
   204  		} else {
   205  			return n.Column(), nil
   206  		}
   207  	case *plan.AlterDefaultSet:
   208  		index := schema.IndexOfColName(n.ColumnName)
   209  		if index == -1 {
   210  			return nil, sql.ErrTableColumnNotFound.New(n.Table, n.ColumnName)
   211  		}
   212  		return schema[index], nil
   213  	default:
   214  		if colIndex < len(schema) {
   215  			return schema[colIndex], nil
   216  		} else {
   217  			// TODO: sql.ErrColumnNotFound would be a better error here, but we need to add all the different node types to
   218  			//  the switch to get it
   219  			return nil, expression.ErrIndexOutOfBounds.New(colIndex, len(schema))
   220  		}
   221  	}
   222  }
   223  
   224  // validateColumnDefault validates that the column default expression is valid for the column type and returns an error
   225  // if not
   226  func validateColumnDefault(ctx *sql.Context, col *sql.Column, e *expression.Wrapper) error {
   227  	newDefault, ok := e.Unwrap().(*sql.ColumnDefaultValue)
   228  	if !ok {
   229  		return nil
   230  	}
   231  
   232  	if newDefault == nil {
   233  		return nil
   234  	}
   235  
   236  	var err error
   237  	sql.Inspect(newDefault.Expr, func(e sql.Expression) bool {
   238  		switch e.(type) {
   239  		case sql.FunctionExpression, *expression.UnresolvedFunction:
   240  			var funcName string
   241  			switch expr := e.(type) {
   242  			case sql.FunctionExpression:
   243  				funcName = expr.FunctionName()
   244  				// TODO: We don't currently support user created functions, but when we do, we need to prevent them
   245  				//       from being used in column default value expressions, since only built-in functions are allowed.
   246  			case *expression.UnresolvedFunction:
   247  				funcName = expr.Name()
   248  			}
   249  
   250  			if !newDefault.IsParenthesized() {
   251  				if funcName == "now" || funcName == "current_timestamp" {
   252  					// now and current_timestamps are the only functions that don't have to be enclosed in
   253  					// parens when used as a column default value, but ONLY when they are used with a
   254  					// datetime or timestamp column, otherwise it's invalid.
   255  					if col.Type.Type() == sqltypes.Datetime || col.Type.Type() == sqltypes.Timestamp {
   256  						return true
   257  					} else {
   258  						err = sql.ErrColumnDefaultDatetimeOnlyFunc.New()
   259  						return false
   260  					}
   261  				}
   262  			}
   263  			return true
   264  		case *plan.Subquery:
   265  			err = sql.ErrColumnDefaultSubquery.New(col.Name)
   266  			return false
   267  		case *expression.GetField:
   268  			if newDefault.IsParenthesized() == false {
   269  				err = sql.ErrInvalidColumnDefaultValue.New(col.Name)
   270  				return false
   271  			} else {
   272  				return true
   273  			}
   274  		default:
   275  			return true
   276  		}
   277  	})
   278  
   279  	if err != nil {
   280  		return err
   281  	}
   282  
   283  	// validate type of default expression
   284  	if err = newDefault.CheckType(ctx); err != nil {
   285  		return err
   286  	}
   287  
   288  	return nil
   289  }
   290  
   291  func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) {
   292  	newDefault, ok := e.Unwrap().(*sql.ColumnDefaultValue)
   293  	if !ok {
   294  		return e, transform.SameTree, nil
   295  	}
   296  
   297  	if newDefault == nil {
   298  		return e, transform.SameTree, nil
   299  	}
   300  
   301  	newExpr, same, err := transform.Expr(newDefault.Expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   302  		if expr, ok := e.(*expression.GetField); ok {
   303  			return expr.WithTable(""), transform.NewTree, nil
   304  		}
   305  		return e, transform.SameTree, nil
   306  	})
   307  	if err != nil {
   308  		return nil, transform.SameTree, err
   309  	}
   310  
   311  	if same {
   312  		return e, transform.SameTree, nil
   313  	}
   314  
   315  	nd := *newDefault
   316  	nd.Expr = newExpr
   317  	return expression.WrapExpression(&nd), transform.NewTree, nil
   318  }
   319  
   320  func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) {
   321  	span, ctx := ctx.Span("backtickDefaultColumnValueNames")
   322  	defer span.End()
   323  
   324  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
   325  		switch node := n.(type) {
   326  		case *plan.AlterDefaultSet:
   327  			eWrapper := expression.WrapExpression(node.Default)
   328  			newExpr, same, err := backtickDefault(eWrapper)
   329  			if err != nil {
   330  				return node, transform.SameTree, err
   331  			}
   332  			if same {
   333  				return node, transform.SameTree, nil
   334  			}
   335  
   336  			newNode, err := node.WithDefault(newExpr)
   337  			if err != nil {
   338  				return node, transform.SameTree, err
   339  			}
   340  			return newNode, transform.NewTree, nil
   341  		case sql.SchemaTarget:
   342  			return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   343  				eWrapper, ok := e.(*expression.Wrapper)
   344  				if !ok {
   345  					return e, transform.SameTree, nil
   346  				}
   347  
   348  				return backtickDefault(eWrapper)
   349  			})
   350  		case *plan.ResolvedTable:
   351  			ct, ok := node.Table.(*information_schema.ColumnsTable)
   352  			if !ok {
   353  				return node, transform.SameTree, nil
   354  			}
   355  
   356  			allColumns, err := ct.AllColumns(ctx)
   357  			if err != nil {
   358  				return nil, transform.SameTree, err
   359  			}
   360  
   361  			allDefaults, same, err := transform.Exprs(transform.WrappedColumnDefaults(allColumns), func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   362  				eWrapper, ok := e.(*expression.Wrapper)
   363  				if !ok {
   364  					return e, transform.SameTree, nil
   365  				}
   366  
   367  				return backtickDefault(eWrapper)
   368  			})
   369  
   370  			if err != nil {
   371  				return nil, transform.SameTree, err
   372  			}
   373  
   374  			if !same {
   375  				node.Table, err = ct.WithColumnDefaults(allDefaults)
   376  				if err != nil {
   377  					return nil, transform.SameTree, err
   378  				}
   379  				return node, transform.NewTree, err
   380  			}
   381  
   382  			return node, transform.SameTree, err
   383  		default:
   384  			return node, transform.SameTree, nil
   385  		}
   386  	})
   387  }
   388  
   389  func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) {
   390  	newDefault, ok := wrap.Unwrap().(*sql.ColumnDefaultValue)
   391  	if !ok {
   392  		return wrap, transform.SameTree, nil
   393  	}
   394  
   395  	if newDefault == nil {
   396  		return wrap, transform.SameTree, nil
   397  	}
   398  
   399  	newExpr, same, err := transform.Expr(newDefault.Expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   400  		if e, isGf := expr.(*expression.GetField); isGf {
   401  			return e.WithBackTickNames(true), transform.NewTree, nil
   402  		}
   403  		return expr, transform.SameTree, nil
   404  	})
   405  	if err != nil {
   406  		return nil, transform.SameTree, err
   407  	}
   408  	if same {
   409  		return wrap, transform.SameTree, nil
   410  	}
   411  
   412  	nd := *newDefault
   413  	nd.Expr = newExpr
   414  	return expression.WrapExpression(&nd), transform.NewTree, nil
   415  }