github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/dml_validate.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"strings"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	"github.com/dolthub/go-mysql-server/sql/plan"
    23  )
    24  
    25  func (b *Builder) validateInsert(ins *plan.InsertInto) {
    26  	table := getResolvedTable(ins.Destination)
    27  	if table == nil {
    28  		return
    29  	}
    30  
    31  	insertable, err := plan.GetInsertable(table)
    32  	if err != nil {
    33  		b.handleErr(err)
    34  	}
    35  
    36  	if ins.IsReplace {
    37  		var ok bool
    38  		_, ok = insertable.(sql.ReplaceableTable)
    39  		if !ok {
    40  			err := plan.ErrReplaceIntoNotSupported.New()
    41  			b.handleErr(err)
    42  		}
    43  	}
    44  
    45  	if len(ins.OnDupExprs) > 0 {
    46  		var ok bool
    47  		_, ok = insertable.(sql.UpdatableTable)
    48  		if !ok {
    49  			err := plan.ErrOnDuplicateKeyUpdateNotSupported.New()
    50  			b.handleErr(err)
    51  		}
    52  	}
    53  
    54  	// normalize the column name
    55  	dstSchema := insertable.Schema()
    56  	columnNames := make([]string, len(ins.ColumnNames))
    57  	for i, name := range ins.ColumnNames {
    58  		columnNames[i] = strings.ToLower(name)
    59  	}
    60  
    61  	// If no columns are given and value tuples are not all empty, use the full schema
    62  	if len(columnNames) == 0 && existsNonZeroValueCount(ins.Source) {
    63  		columnNames = make([]string, len(dstSchema))
    64  		for i, f := range dstSchema {
    65  			columnNames[i] = f.Name
    66  		}
    67  	}
    68  
    69  	if len(ins.ColumnNames) > 0 {
    70  		err := validateColumns(table.Name(), columnNames, dstSchema, ins.Source)
    71  		if err != nil {
    72  			b.handleErr(err)
    73  		}
    74  	}
    75  
    76  	err = validateValueCount(columnNames, ins.Source)
    77  	if err != nil {
    78  		b.handleErr(err)
    79  	}
    80  }
    81  
    82  // Ensures that the number of elements in each Value tuple is empty
    83  func existsNonZeroValueCount(values sql.Node) bool {
    84  	switch node := values.(type) {
    85  	case *plan.Values:
    86  		for _, exprTuple := range node.ExpressionTuples {
    87  			if len(exprTuple) != 0 {
    88  				return true
    89  			}
    90  		}
    91  	default:
    92  		return true
    93  	}
    94  	return false
    95  }
    96  
    97  func validateColumns(tableName string, columnNames []string, dstSchema sql.Schema, source sql.Node) error {
    98  	dstColNames := make(map[string]*sql.Column)
    99  	for _, dstCol := range dstSchema {
   100  		dstColNames[strings.ToLower(dstCol.Name)] = dstCol
   101  	}
   102  	usedNames := make(map[string]struct{})
   103  	for i, columnName := range columnNames {
   104  		dstCol, exists := dstColNames[columnName]
   105  		if !exists {
   106  			return sql.ErrUnknownColumn.New(columnName, tableName)
   107  		}
   108  		if dstCol.Generated != nil && !validGeneratedColumnValue(i, source) {
   109  			return sql.ErrGeneratedColumnValue.New(dstCol.Name, tableName)
   110  		}
   111  		if _, exists := usedNames[columnName]; !exists {
   112  			usedNames[columnName] = struct{}{}
   113  		} else {
   114  			return sql.ErrColumnSpecifiedTwice.New(columnName)
   115  		}
   116  	}
   117  	return nil
   118  }
   119  
   120  // validGeneratedColumnValue returns true if the column is a generated column and the source node is not a values node.
   121  // Explicit default values (`DEFAULT`) are the only valid values to specify for a generated column
   122  func validGeneratedColumnValue(idx int, source sql.Node) bool {
   123  	switch source := source.(type) {
   124  	case *plan.Values:
   125  		for _, tuple := range source.ExpressionTuples {
   126  			switch val := tuple[idx].(type) {
   127  			case *sql.ColumnDefaultValue: // should be wrapped, but just in case
   128  				return true
   129  			case *expression.Wrapper:
   130  				if _, ok := val.Unwrap().(*sql.ColumnDefaultValue); ok {
   131  					return true
   132  				}
   133  				return false
   134  			default:
   135  				return false
   136  			}
   137  		}
   138  		return false
   139  	default:
   140  		return false
   141  	}
   142  }
   143  
   144  func validateValueCount(columnNames []string, values sql.Node) error {
   145  	if exchange, ok := values.(*plan.Exchange); ok {
   146  		values = exchange.Child
   147  	}
   148  
   149  	switch node := values.(type) {
   150  	case *plan.Values:
   151  		for _, exprTuple := range node.ExpressionTuples {
   152  			if len(exprTuple) != len(columnNames) {
   153  				return sql.ErrInsertIntoMismatchValueCount.New()
   154  			}
   155  		}
   156  	case *plan.LoadData:
   157  		dataColLen := len(node.ColumnNames)
   158  		if dataColLen == 0 {
   159  			dataColLen = len(node.Schema())
   160  		}
   161  		if len(columnNames) != dataColLen {
   162  			return sql.ErrInsertIntoMismatchValueCount.New()
   163  		}
   164  	default:
   165  		// Parser assures us that this will be some form of SelectStatement, so no need to type check it
   166  		if len(columnNames) != len(values.Schema()) {
   167  			return sql.ErrInsertIntoMismatchValueCount.New()
   168  		}
   169  	}
   170  	return nil
   171  }