github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/inserts.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/expression" 23 "github.com/dolthub/go-mysql-server/sql/plan" 24 "github.com/dolthub/go-mysql-server/sql/transform" 25 "github.com/dolthub/go-mysql-server/sql/types" 26 ) 27 28 func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 29 if _, ok := n.(*plan.TriggerExecutor); ok { 30 return n, transform.SameTree, nil 31 } else if _, ok := n.(*plan.CreateProcedure); ok { 32 return n, transform.SameTree, nil 33 } 34 // We capture all INSERTs along the tree, such as those inside of block statements. 35 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 36 insert, ok := n.(*plan.InsertInto) 37 if !ok { 38 return n, transform.SameTree, nil 39 } 40 41 table := getResolvedTable(insert.Destination) 42 43 insertable, err := plan.GetInsertable(table) 44 if err != nil { 45 return nil, transform.SameTree, err 46 } 47 48 source := insert.Source 49 // TriggerExecutor has already been analyzed 50 if _, ok := insert.Source.(*plan.TriggerExecutor); !ok { 51 // Analyze the source of the insert independently 52 if _, ok := insert.Source.(*plan.Values); ok { 53 scope = scope.NewScope(plan.NewProject( 54 expression.SchemaToGetFields(insert.Source.Schema()[:len(insert.ColumnNames)], sql.ColSet{}), 55 plan.NewSubqueryAlias("dummy", "", insert.Source), 56 )) 57 } 58 source, _, err = a.analyzeWithSelector(ctx, insert.Source, scope, SelectAllBatches, newInsertSourceSelector(sel)) 59 if err != nil { 60 return nil, transform.SameTree, err 61 } 62 63 source = StripPassthroughNodes(source) 64 } 65 66 dstSchema := insertable.Schema() 67 68 // normalize the column name 69 columnNames := make([]string, len(insert.ColumnNames)) 70 for i, name := range insert.ColumnNames { 71 columnNames[i] = strings.ToLower(name) 72 } 73 74 // If no columns are given and value tuples are not all empty, use the full schema 75 if len(columnNames) == 0 && existsNonZeroValueCount(source) { 76 columnNames = make([]string, len(dstSchema)) 77 for i, f := range dstSchema { 78 columnNames[i] = f.Name 79 } 80 } 81 82 // The schema of the destination node and the underlying table differ subtly in terms of defaults 83 project, autoAutoIncrement, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames) 84 if err != nil { 85 return nil, transform.SameTree, err 86 } 87 88 return insert.WithSource(project).WithUnspecifiedAutoIncrement(autoAutoIncrement), transform.NewTree, nil 89 }) 90 } 91 92 // Ensures that the number of elements in each Value tuple is empty 93 func existsNonZeroValueCount(values sql.Node) bool { 94 switch node := values.(type) { 95 case *plan.Values: 96 for _, exprTuple := range node.ExpressionTuples { 97 if len(exprTuple) != 0 { 98 return true 99 } 100 } 101 default: 102 return true 103 } 104 return false 105 } 106 107 // wrapRowSource returns a projection that wraps the original row source so that its schema matches the full schema of 108 // the underlying table, in the same order. Also returns a boolean value that indicates whether this row source will 109 // result in an automatically generated value for an auto_increment column. 110 func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, bool, error) { 111 projExprs := make([]sql.Expression, len(schema)) 112 autoAutoIncrement := false 113 114 for i, f := range schema { 115 columnExplicitlySpecified := false 116 for j, col := range columnNames { 117 if strings.EqualFold(f.Name, col) { 118 projExprs[i] = expression.NewGetField(j, f.Type, f.Name, f.Nullable) 119 columnExplicitlySpecified = true 120 break 121 } 122 } 123 124 if !columnExplicitlySpecified { 125 defaultExpr := f.Default 126 if defaultExpr == nil { 127 defaultExpr = f.Generated 128 } 129 130 if !f.Nullable && defaultExpr == nil && !f.AutoIncrement { 131 return nil, false, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(f.Name) 132 } 133 var err error 134 135 colIdx := make(map[string]int) 136 for i, c := range schema { 137 colIdx[fmt.Sprintf("%s.%s", strings.ToLower(c.Source), strings.ToLower(c.Name))] = i 138 } 139 def, _, err := transform.Expr(defaultExpr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 140 switch e := e.(type) { 141 case *expression.GetField: 142 idx, ok := colIdx[strings.ToLower(e.WithTable(destTbl.Name()).String())] 143 if !ok { 144 return nil, transform.SameTree, fmt.Errorf("field not found: %s", e.String()) 145 } 146 return e.WithIndex(idx), transform.NewTree, nil 147 default: 148 return e, transform.SameTree, nil 149 } 150 }) 151 if err != nil { 152 return nil, false, err 153 } 154 projExprs[i] = def 155 } 156 157 if f.AutoIncrement { 158 ai, err := expression.NewAutoIncrement(ctx, destTbl, projExprs[i]) 159 if err != nil { 160 return nil, false, err 161 } 162 projExprs[i] = ai 163 164 if !columnExplicitlySpecified { 165 autoAutoIncrement = true 166 } 167 } 168 } 169 170 err := validateRowSource(insertSource, projExprs) 171 if err != nil { 172 return nil, false, err 173 } 174 175 return plan.NewProject(projExprs, insertSource), autoAutoIncrement, nil 176 } 177 178 // validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node. 179 // Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column 180 func validGeneratedColumnValue(idx int, source sql.Node) bool { 181 switch source := source.(type) { 182 case *plan.Values: 183 for _, tuple := range source.ExpressionTuples { 184 switch val := tuple[idx].(type) { 185 case *sql.ColumnDefaultValue: // should be wrapped, but just in case 186 return true 187 case *expression.Wrapper: 188 if _, ok := val.Unwrap().(*sql.ColumnDefaultValue); ok { 189 return true 190 } 191 return false 192 default: 193 return false 194 } 195 } 196 return false 197 default: 198 return false 199 } 200 } 201 202 func assertCompatibleSchemas(projExprs []sql.Expression, schema sql.Schema) error { 203 for _, expr := range projExprs { 204 switch e := expr.(type) { 205 case *expression.Literal, 206 *expression.AutoIncrement, 207 *sql.ColumnDefaultValue: 208 continue 209 case *expression.GetField: 210 otherCol := schema[e.Index()] 211 // special case: null field type, will get checked at execution time 212 if otherCol.Type == types.Null { 213 continue 214 } 215 exprType := expr.Type() 216 _, _, err := exprType.Convert(otherCol.Type.Zero()) 217 if err != nil { 218 // The zero value will fail when passing string values to ENUM, so we specially handle this case 219 if _, ok := exprType.(sql.EnumType); ok && types.IsText(otherCol.Type) { 220 continue 221 } 222 return plan.ErrInsertIntoIncompatibleTypes.New(otherCol.Type.String(), expr.Type().String()) 223 } 224 default: 225 return plan.ErrInsertIntoUnsupportedValues.New(expr) 226 } 227 } 228 return nil 229 } 230 231 func validateRowSource(values sql.Node, projExprs []sql.Expression) error { 232 if exchange, ok := values.(*plan.Exchange); ok { 233 values = exchange.Child 234 } 235 236 switch n := values.(type) { 237 case *plan.Values, *plan.LoadData: 238 // already verified 239 return nil 240 default: 241 // Parser assures us that this will be some form of SelectStatement, so no need to type check it 242 return assertCompatibleSchemas(projExprs, n.Schema()) 243 } 244 }