github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/dml_validate.go (about) 1 // Copyright 2024 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "strings" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/expression" 22 "github.com/dolthub/go-mysql-server/sql/plan" 23 ) 24 25 func (b *Builder) validateInsert(ins *plan.InsertInto) { 26 table := getResolvedTable(ins.Destination) 27 if table == nil { 28 return 29 } 30 31 insertable, err := plan.GetInsertable(table) 32 if err != nil { 33 b.handleErr(err) 34 } 35 36 if ins.IsReplace { 37 var ok bool 38 _, ok = insertable.(sql.ReplaceableTable) 39 if !ok { 40 err := plan.ErrReplaceIntoNotSupported.New() 41 b.handleErr(err) 42 } 43 } 44 45 if len(ins.OnDupExprs) > 0 { 46 var ok bool 47 _, ok = insertable.(sql.UpdatableTable) 48 if !ok { 49 err := plan.ErrOnDuplicateKeyUpdateNotSupported.New() 50 b.handleErr(err) 51 } 52 } 53 54 // normalize the column name 55 dstSchema := insertable.Schema() 56 columnNames := make([]string, len(ins.ColumnNames)) 57 for i, name := range ins.ColumnNames { 58 columnNames[i] = strings.ToLower(name) 59 } 60 61 // If no columns are given and value tuples are not all empty, use the full schema 62 if len(columnNames) == 0 && existsNonZeroValueCount(ins.Source) { 63 columnNames = make([]string, len(dstSchema)) 64 for i, f := range dstSchema { 65 columnNames[i] = f.Name 66 } 67 } 68 69 if len(ins.ColumnNames) > 0 { 70 err := validateColumns(table.Name(), columnNames, dstSchema, ins.Source) 71 if err != nil { 72 b.handleErr(err) 73 } 74 } 75 76 err = validateValueCount(columnNames, ins.Source) 77 if err != nil { 78 b.handleErr(err) 79 } 80 } 81 82 // Ensures that the number of elements in each Value tuple is empty 83 func existsNonZeroValueCount(values sql.Node) bool { 84 switch node := values.(type) { 85 case *plan.Values: 86 for _, exprTuple := range node.ExpressionTuples { 87 if len(exprTuple) != 0 { 88 return true 89 } 90 } 91 default: 92 return true 93 } 94 return false 95 } 96 97 func validateColumns(tableName string, columnNames []string, dstSchema sql.Schema, source sql.Node) error { 98 dstColNames := make(map[string]*sql.Column) 99 for _, dstCol := range dstSchema { 100 dstColNames[strings.ToLower(dstCol.Name)] = dstCol 101 } 102 usedNames := make(map[string]struct{}) 103 for i, columnName := range columnNames { 104 dstCol, exists := dstColNames[columnName] 105 if !exists { 106 return sql.ErrUnknownColumn.New(columnName, tableName) 107 } 108 if dstCol.Generated != nil && !validGeneratedColumnValue(i, source) { 109 return sql.ErrGeneratedColumnValue.New(dstCol.Name, tableName) 110 } 111 if _, exists := usedNames[columnName]; !exists { 112 usedNames[columnName] = struct{}{} 113 } else { 114 return sql.ErrColumnSpecifiedTwice.New(columnName) 115 } 116 } 117 return nil 118 } 119 120 // validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node. 121 // Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column 122 func validGeneratedColumnValue(idx int, source sql.Node) bool { 123 switch source := source.(type) { 124 case *plan.Values: 125 for _, tuple := range source.ExpressionTuples { 126 switch val := tuple[idx].(type) { 127 case *sql.ColumnDefaultValue: // should be wrapped, but just in case 128 return true 129 case *expression.Wrapper: 130 if _, ok := val.Unwrap().(*sql.ColumnDefaultValue); ok { 131 return true 132 } 133 return false 134 default: 135 return false 136 } 137 } 138 return false 139 default: 140 return false 141 } 142 } 143 144 func validateValueCount(columnNames []string, values sql.Node) error { 145 if exchange, ok := values.(*plan.Exchange); ok { 146 values = exchange.Child 147 } 148 149 switch node := values.(type) { 150 case *plan.Values: 151 for _, exprTuple := range node.ExpressionTuples { 152 if len(exprTuple) != len(columnNames) { 153 return sql.ErrInsertIntoMismatchValueCount.New() 154 } 155 } 156 case *plan.LoadData: 157 dataColLen := len(node.ColumnNames) 158 if dataColLen == 0 { 159 dataColLen = len(node.Schema()) 160 } 161 if len(columnNames) != dataColLen { 162 return sql.ErrInsertIntoMismatchValueCount.New() 163 } 164 default: 165 // Parser assures us that this will be some form of SelectStatement, so no need to type check it 166 if len(columnNames) != len(values.Schema()) { 167 return sql.ErrInsertIntoMismatchValueCount.New() 168 } 169 } 170 return nil 171 }