github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/inserts.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  	"github.com/dolthub/go-mysql-server/sql/plan"
    24  	"github.com/dolthub/go-mysql-server/sql/transform"
    25  	"github.com/dolthub/go-mysql-server/sql/types"
    26  )
    27  
    28  func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    29  	if _, ok := n.(*plan.TriggerExecutor); ok {
    30  		return n, transform.SameTree, nil
    31  	} else if _, ok := n.(*plan.CreateProcedure); ok {
    32  		return n, transform.SameTree, nil
    33  	}
    34  	// We capture all INSERTs along the tree, such as those inside of block statements.
    35  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    36  		insert, ok := n.(*plan.InsertInto)
    37  		if !ok {
    38  			return n, transform.SameTree, nil
    39  		}
    40  
    41  		table := getResolvedTable(insert.Destination)
    42  
    43  		insertable, err := plan.GetInsertable(table)
    44  		if err != nil {
    45  			return nil, transform.SameTree, err
    46  		}
    47  
    48  		source := insert.Source
    49  		// TriggerExecutor has already been analyzed
    50  		if _, ok := insert.Source.(*plan.TriggerExecutor); !ok {
    51  			// Analyze the source of the insert independently
    52  			if _, ok := insert.Source.(*plan.Values); ok {
    53  				scope = scope.NewScope(plan.NewProject(
    54  					expression.SchemaToGetFields(insert.Source.Schema()[:len(insert.ColumnNames)], sql.ColSet{}),
    55  					plan.NewSubqueryAlias("dummy", "", insert.Source),
    56  				))
    57  			}
    58  			source, _, err = a.analyzeWithSelector(ctx, insert.Source, scope, SelectAllBatches, newInsertSourceSelector(sel))
    59  			if err != nil {
    60  				return nil, transform.SameTree, err
    61  			}
    62  
    63  			source = StripPassthroughNodes(source)
    64  		}
    65  
    66  		dstSchema := insertable.Schema()
    67  
    68  		// normalize the column name
    69  		columnNames := make([]string, len(insert.ColumnNames))
    70  		for i, name := range insert.ColumnNames {
    71  			columnNames[i] = strings.ToLower(name)
    72  		}
    73  
    74  		// If no columns are given and value tuples are not all empty, use the full schema
    75  		if len(columnNames) == 0 && existsNonZeroValueCount(source) {
    76  			columnNames = make([]string, len(dstSchema))
    77  			for i, f := range dstSchema {
    78  				columnNames[i] = f.Name
    79  			}
    80  		}
    81  
    82  		// The schema of the destination node and the underlying table differ subtly in terms of defaults
    83  		project, autoAutoIncrement, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames)
    84  		if err != nil {
    85  			return nil, transform.SameTree, err
    86  		}
    87  
    88  		return insert.WithSource(project).WithUnspecifiedAutoIncrement(autoAutoIncrement), transform.NewTree, nil
    89  	})
    90  }
    91  
    92  // Ensures that the number of elements in each Value tuple is empty
    93  func existsNonZeroValueCount(values sql.Node) bool {
    94  	switch node := values.(type) {
    95  	case *plan.Values:
    96  		for _, exprTuple := range node.ExpressionTuples {
    97  			if len(exprTuple) != 0 {
    98  				return true
    99  			}
   100  		}
   101  	default:
   102  		return true
   103  	}
   104  	return false
   105  }
   106  
   107  // wrapRowSource returns a projection that wraps the original row source so that its schema matches the full schema of
   108  // the underlying table, in the same order. Also returns a boolean value that indicates whether this row source will
   109  // result in an automatically generated value for an auto_increment column.
   110  func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, bool, error) {
   111  	projExprs := make([]sql.Expression, len(schema))
   112  	autoAutoIncrement := false
   113  
   114  	for i, f := range schema {
   115  		columnExplicitlySpecified := false
   116  		for j, col := range columnNames {
   117  			if strings.EqualFold(f.Name, col) {
   118  				projExprs[i] = expression.NewGetField(j, f.Type, f.Name, f.Nullable)
   119  				columnExplicitlySpecified = true
   120  				break
   121  			}
   122  		}
   123  
   124  		if !columnExplicitlySpecified {
   125  			defaultExpr := f.Default
   126  			if defaultExpr == nil {
   127  				defaultExpr = f.Generated
   128  			}
   129  
   130  			if !f.Nullable && defaultExpr == nil && !f.AutoIncrement {
   131  				return nil, false, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(f.Name)
   132  			}
   133  			var err error
   134  
   135  			colIdx := make(map[string]int)
   136  			for i, c := range schema {
   137  				colIdx[fmt.Sprintf("%s.%s", strings.ToLower(c.Source), strings.ToLower(c.Name))] = i
   138  			}
   139  			def, _, err := transform.Expr(defaultExpr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   140  				switch e := e.(type) {
   141  				case *expression.GetField:
   142  					idx, ok := colIdx[strings.ToLower(e.WithTable(destTbl.Name()).String())]
   143  					if !ok {
   144  						return nil, transform.SameTree, fmt.Errorf("field not found: %s", e.String())
   145  					}
   146  					return e.WithIndex(idx), transform.NewTree, nil
   147  				default:
   148  					return e, transform.SameTree, nil
   149  				}
   150  			})
   151  			if err != nil {
   152  				return nil, false, err
   153  			}
   154  			projExprs[i] = def
   155  		}
   156  
   157  		if f.AutoIncrement {
   158  			ai, err := expression.NewAutoIncrement(ctx, destTbl, projExprs[i])
   159  			if err != nil {
   160  				return nil, false, err
   161  			}
   162  			projExprs[i] = ai
   163  
   164  			if !columnExplicitlySpecified {
   165  				autoAutoIncrement = true
   166  			}
   167  		}
   168  	}
   169  
   170  	err := validateRowSource(insertSource, projExprs)
   171  	if err != nil {
   172  		return nil, false, err
   173  	}
   174  
   175  	return plan.NewProject(projExprs, insertSource), autoAutoIncrement, nil
   176  }
   177  
   178  // validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node.
   179  // Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column
   180  func validGeneratedColumnValue(idx int, source sql.Node) bool {
   181  	switch source := source.(type) {
   182  	case *plan.Values:
   183  		for _, tuple := range source.ExpressionTuples {
   184  			switch val := tuple[idx].(type) {
   185  			case *sql.ColumnDefaultValue: // should be wrapped, but just in case
   186  				return true
   187  			case *expression.Wrapper:
   188  				if _, ok := val.Unwrap().(*sql.ColumnDefaultValue); ok {
   189  					return true
   190  				}
   191  				return false
   192  			default:
   193  				return false
   194  			}
   195  		}
   196  		return false
   197  	default:
   198  		return false
   199  	}
   200  }
   201  
   202  func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) error {
   203  	for _, expr := range projExprs {
   204  		switch e := expr.(type) {
   205  		case *expression.Literal,
   206  			*expression.AutoIncrement,
   207  			*sql.ColumnDefaultValue:
   208  			continue
   209  		case *expression.GetField:
   210  			otherCol := schema[e.Index()]
   211  			// special case: null field type, will get checked at execution time
   212  			if otherCol.Type == types.Null {
   213  				continue
   214  			}
   215  			exprType := expr.Type()
   216  			_, _, err := exprType.Convert(otherCol.Type.Zero())
   217  			if err != nil {
   218  				// The zero value will fail when passing string values to ENUM, so we specially handle this case
   219  				if _, ok := exprType.(sql.EnumType); ok && types.IsText(otherCol.Type) {
   220  					continue
   221  				}
   222  				return plan.ErrInsertIntoIncompatibleTypes.New(otherCol.Type.String(), expr.Type().String())
   223  			}
   224  		default:
   225  			return plan.ErrInsertIntoUnsupportedValues.New(expr)
   226  		}
   227  	}
   228  	return nil
   229  }
   230  
   231  func validateRowSource(values sql.Node, projExprs []sql.Expression) error {
   232  	if exchange, ok := values.(*plan.Exchange); ok {
   233  		values = exchange.Child
   234  	}
   235  
   236  	switch n := values.(type) {
   237  	case *plan.Values, *plan.LoadData:
   238  		// already verified
   239  		return nil
   240  	default:
   241  		// Parser assures us that this will be some form of SelectStatement, so no need to type check it
   242  		return assertCompatibleSchemas(projExprs, n.Schema())
   243  	}
   244  }