github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/resolve_column_defaults.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "github.com/dolthub/vitess/go/sqltypes" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/expression" 22 "github.com/dolthub/go-mysql-server/sql/information_schema" 23 "github.com/dolthub/go-mysql-server/sql/plan" 24 "github.com/dolthub/go-mysql-server/sql/transform" 25 ) 26 27 // Resolving column defaults is a multi-phase process, with different analyzer rules for each phase. 28 // 29 // * parseColumnDefaults: Some integrators (dolt but not GMS) store their column defaults as strings, which we need to 30 // parse into expressions before we can analyze them any further. 31 // * resolveColumnDefaults: Once we have an expression for a default value, it may contain expressions that need 32 // simplification before further phases of processing can take place. 33 // 34 // After this stage, expressions in column default values are handled by the normal analyzer machinery responsible for 35 // resolving expressions, including things like columns and functions. Every node that needs to do this for its default 36 // values implements `sql.Expressioner` to expose such expressions. There is custom logic in `resolveColumns` to help 37 // identify the correct indexes for column references, which can vary based on the node type. 38 // 39 // Finally there are cleanup phases: 40 // * validateColumnDefaults: ensures that newly created column defaults from a DDL statement are legal for the type of 41 // column, various other business logic checks to match MySQL's logic. 42 // * stripTableNamesFromDefault: column defaults headed for storage or serialization in a query result need the table 43 // names in any GetField expressions stripped out so that they serialize to strings without such table names. Table 44 // names in GetField expressions are expected in much of the rest of the analyzer, so we do this after the bulk of 45 // analyzer work. 46 // 47 // The `information_schema.columns` table also needs access to the default values of every column in the database, and 48 // because it's a table it can't implement `sql.Expressioner` like other node types. Instead it has special handling 49 // here, as well as in the `resolve_functions` rule. 50 51 func validateColumnDefaults(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) { 52 span, ctx := ctx.Span("validateColumnDefaults") 53 defer span.End() 54 55 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 56 switch node := n.(type) { 57 case *plan.AlterDefaultSet: 58 table := getResolvedTable(node) 59 sch := table.Schema() 60 index := sch.IndexOfColName(node.ColumnName) 61 if index == -1 { 62 return nil, transform.SameTree, sql.ErrColumnNotFound.New(node.ColumnName) 63 } 64 col := sch[index] 65 eWrapper := expression.WrapExpression(node.Default) 66 err := validateColumnDefault(ctx, col, eWrapper) 67 if err != nil { 68 return node, transform.SameTree, err 69 } 70 return node, transform.SameTree, nil 71 case sql.SchemaTarget: 72 switch node.(type) { 73 case *plan.AlterPK, *plan.AddColumn, *plan.ModifyColumn, *plan.AlterDefaultDrop, *plan.CreateTable, *plan.DropColumn: 74 // DDL nodes must validate any new column defaults, continue to logic below 75 default: 76 // other node types are not altering the schema and therefore don't need validation of column defaults 77 return n, transform.SameTree, nil 78 } 79 80 // There may be multiple DDL nodes in the plan (ALTER TABLE statements can have many clauses), and for each of them 81 // we need to count the column indexes in the very hacky way outlined above. 82 i := 0 83 return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 84 eWrapper, ok := e.(*expression.Wrapper) 85 if !ok { 86 return e, transform.SameTree, nil 87 } 88 89 defer func() { 90 i++ 91 }() 92 93 if eWrapper.Unwrap() == nil { 94 return e, transform.SameTree, nil 95 } 96 97 col, err := lookupColumnForTargetSchema(ctx, node, i) 98 if err != nil { 99 return nil, transform.SameTree, err 100 } 101 102 err = validateColumnDefault(ctx, col, eWrapper) 103 if err != nil { 104 return nil, transform.SameTree, err 105 } 106 107 return e, transform.SameTree, nil 108 }) 109 default: 110 return node, transform.SameTree, nil 111 } 112 }) 113 } 114 115 // stripTableNamesFromColumnDefaults removes the table name from any GetField expressions in column default expressions. 116 // Default values can only reference their host table, and since we serialize the GetField expression for storage, it's 117 // important that we remove the table name before passing it off for storage. Otherwise we end up with serialized 118 // defaults like `tableName.field + 1` instead of just `field + 1`. 119 func stripTableNamesFromColumnDefaults(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) { 120 span, ctx := ctx.Span("stripTableNamesFromColumnDefaults") 121 defer span.End() 122 123 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 124 switch node := n.(type) { 125 case *plan.AlterDefaultSet: 126 eWrapper := expression.WrapExpression(node.Default) 127 newExpr, same, err := stripTableNamesFromDefault(eWrapper) 128 if err != nil { 129 return node, transform.SameTree, err 130 } 131 if same { 132 return node, transform.SameTree, nil 133 } 134 135 newNode, err := node.WithDefault(newExpr) 136 if err != nil { 137 return node, transform.SameTree, err 138 } 139 return newNode, transform.NewTree, nil 140 case sql.SchemaTarget: 141 return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 142 eWrapper, ok := e.(*expression.Wrapper) 143 if !ok { 144 return e, transform.SameTree, nil 145 } 146 147 return stripTableNamesFromDefault(eWrapper) 148 }) 149 case *plan.ResolvedTable: 150 ct, ok := node.Table.(*information_schema.ColumnsTable) 151 if !ok { 152 return node, transform.SameTree, nil 153 } 154 155 allColumns, err := ct.AllColumns(ctx) 156 if err != nil { 157 return nil, transform.SameTree, err 158 } 159 160 allDefaults, same, err := transform.Exprs(transform.WrappedColumnDefaults(allColumns), func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 161 eWrapper, ok := e.(*expression.Wrapper) 162 if !ok { 163 return e, transform.SameTree, nil 164 } 165 166 return stripTableNamesFromDefault(eWrapper) 167 }) 168 169 if err != nil { 170 return nil, transform.SameTree, err 171 } 172 173 if !same { 174 node.Table, err = ct.WithColumnDefaults(allDefaults) 175 if err != nil { 176 return nil, transform.SameTree, err 177 } 178 return node, transform.NewTree, err 179 } 180 181 return node, transform.SameTree, err 182 default: 183 return node, transform.SameTree, nil 184 } 185 }) 186 } 187 188 // lookupColumnForTargetSchema looks at the target schema for the specified SchemaTarget node and returns 189 // the column based on the specified index. For most node types, this is simply indexing into the target 190 // schema but a few types require special handling. 191 func lookupColumnForTargetSchema(_ *sql.Context, node sql.SchemaTarget, colIndex int) (*sql.Column, error) { 192 schema := node.TargetSchema() 193 194 switch n := node.(type) { 195 case *plan.ModifyColumn: 196 if colIndex < len(schema) { 197 return schema[colIndex], nil 198 } else { 199 return n.NewColumn(), nil 200 } 201 case *plan.AddColumn: 202 if colIndex < len(schema) { 203 return schema[colIndex], nil 204 } else { 205 return n.Column(), nil 206 } 207 case *plan.AlterDefaultSet: 208 index := schema.IndexOfColName(n.ColumnName) 209 if index == -1 { 210 return nil, sql.ErrTableColumnNotFound.New(n.Table, n.ColumnName) 211 } 212 return schema[index], nil 213 default: 214 if colIndex < len(schema) { 215 return schema[colIndex], nil 216 } else { 217 // TODO: sql.ErrColumnNotFound would be a better error here, but we need to add all the different node types to 218 // the switch to get it 219 return nil, expression.ErrIndexOutOfBounds.New(colIndex, len(schema)) 220 } 221 } 222 } 223 224 // validateColumnDefault validates that the column default expression is valid for the column type and returns an error 225 // if not 226 func validateColumnDefault(ctx *sql.Context, col *sql.Column, e *expression.Wrapper) error { 227 newDefault, ok := e.Unwrap().(*sql.ColumnDefaultValue) 228 if !ok { 229 return nil 230 } 231 232 if newDefault == nil { 233 return nil 234 } 235 236 var err error 237 sql.Inspect(newDefault.Expr, func(e sql.Expression) bool { 238 switch e.(type) { 239 case sql.FunctionExpression, *expression.UnresolvedFunction: 240 var funcName string 241 switch expr := e.(type) { 242 case sql.FunctionExpression: 243 funcName = expr.FunctionName() 244 // TODO: We don't currently support user created functions, but when we do, we need to prevent them 245 // from being used in column default value expressions, since only built-in functions are allowed. 246 case *expression.UnresolvedFunction: 247 funcName = expr.Name() 248 } 249 250 if !newDefault.IsParenthesized() { 251 if funcName == "now" || funcName == "current_timestamp" { 252 // now and current_timestamps are the only functions that don't have to be enclosed in 253 // parens when used as a column default value, but ONLY when they are used with a 254 // datetime or timestamp column, otherwise it's invalid. 255 if col.Type.Type() == sqltypes.Datetime || col.Type.Type() == sqltypes.Timestamp { 256 return true 257 } else { 258 err = sql.ErrColumnDefaultDatetimeOnlyFunc.New() 259 return false 260 } 261 } 262 } 263 return true 264 case *plan.Subquery: 265 err = sql.ErrColumnDefaultSubquery.New(col.Name) 266 return false 267 case *expression.GetField: 268 if newDefault.IsParenthesized() == false { 269 err = sql.ErrInvalidColumnDefaultValue.New(col.Name) 270 return false 271 } else { 272 return true 273 } 274 default: 275 return true 276 } 277 }) 278 279 if err != nil { 280 return err 281 } 282 283 // validate type of default expression 284 if err = newDefault.CheckType(ctx); err != nil { 285 return err 286 } 287 288 return nil 289 } 290 291 func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) { 292 newDefault, ok := e.Unwrap().(*sql.ColumnDefaultValue) 293 if !ok { 294 return e, transform.SameTree, nil 295 } 296 297 if newDefault == nil { 298 return e, transform.SameTree, nil 299 } 300 301 newExpr, same, err := transform.Expr(newDefault.Expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 302 if expr, ok := e.(*expression.GetField); ok { 303 return expr.WithTable(""), transform.NewTree, nil 304 } 305 return e, transform.SameTree, nil 306 }) 307 if err != nil { 308 return nil, transform.SameTree, err 309 } 310 311 if same { 312 return e, transform.SameTree, nil 313 } 314 315 nd := *newDefault 316 nd.Expr = newExpr 317 return expression.WrapExpression(&nd), transform.NewTree, nil 318 } 319 320 func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector) (sql.Node, transform.TreeIdentity, error) { 321 span, ctx := ctx.Span("backtickDefaultColumnValueNames") 322 defer span.End() 323 324 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 325 switch node := n.(type) { 326 case *plan.AlterDefaultSet: 327 eWrapper := expression.WrapExpression(node.Default) 328 newExpr, same, err := backtickDefault(eWrapper) 329 if err != nil { 330 return node, transform.SameTree, err 331 } 332 if same { 333 return node, transform.SameTree, nil 334 } 335 336 newNode, err := node.WithDefault(newExpr) 337 if err != nil { 338 return node, transform.SameTree, err 339 } 340 return newNode, transform.NewTree, nil 341 case sql.SchemaTarget: 342 return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 343 eWrapper, ok := e.(*expression.Wrapper) 344 if !ok { 345 return e, transform.SameTree, nil 346 } 347 348 return backtickDefault(eWrapper) 349 }) 350 case *plan.ResolvedTable: 351 ct, ok := node.Table.(*information_schema.ColumnsTable) 352 if !ok { 353 return node, transform.SameTree, nil 354 } 355 356 allColumns, err := ct.AllColumns(ctx) 357 if err != nil { 358 return nil, transform.SameTree, err 359 } 360 361 allDefaults, same, err := transform.Exprs(transform.WrappedColumnDefaults(allColumns), func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 362 eWrapper, ok := e.(*expression.Wrapper) 363 if !ok { 364 return e, transform.SameTree, nil 365 } 366 367 return backtickDefault(eWrapper) 368 }) 369 370 if err != nil { 371 return nil, transform.SameTree, err 372 } 373 374 if !same { 375 node.Table, err = ct.WithColumnDefaults(allDefaults) 376 if err != nil { 377 return nil, transform.SameTree, err 378 } 379 return node, transform.NewTree, err 380 } 381 382 return node, transform.SameTree, err 383 default: 384 return node, transform.SameTree, nil 385 } 386 }) 387 } 388 389 func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) { 390 newDefault, ok := wrap.Unwrap().(*sql.ColumnDefaultValue) 391 if !ok { 392 return wrap, transform.SameTree, nil 393 } 394 395 if newDefault == nil { 396 return wrap, transform.SameTree, nil 397 } 398 399 newExpr, same, err := transform.Expr(newDefault.Expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 400 if e, isGf := expr.(*expression.GetField); isGf { 401 return e.WithBackTickNames(true), transform.NewTree, nil 402 } 403 return expr, transform.SameTree, nil 404 }) 405 if err != nil { 406 return nil, transform.SameTree, err 407 } 408 if same { 409 return wrap, transform.SameTree, nil 410 } 411 412 nd := *newDefault 413 nd.Expr = newExpr 414 return expression.WrapExpression(&nd), transform.NewTree, nil 415 }