github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/in.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/types"
    23  )
    24  
    25  // InTuple is an expression that checks an expression is inside a list of expressions.
    26  type InTuple struct {
    27  	BinaryExpressionStub
    28  }
    29  
    30  // We implement Comparer because we have a Left() and a Right(), but we can't be Compare()d
    31  var _ Comparer = (*InTuple)(nil)
    32  var _ sql.CollationCoercible = (*InTuple)(nil)
    33  
    34  func (in *InTuple) Compare(ctx *sql.Context, row sql.Row) (int, error) {
    35  	panic("Compare not implemented for InTuple")
    36  }
    37  
    38  func (in *InTuple) Type() sql.Type {
    39  	return types.Boolean
    40  }
    41  
    42  // CollationCoercibility implements the interface sql.CollationCoercible.
    43  func (*InTuple) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    44  	return sql.Collation_binary, 5
    45  }
    46  
    47  func (in *InTuple) Left() sql.Expression {
    48  	return in.BinaryExpressionStub.LeftChild
    49  }
    50  
    51  func (in *InTuple) Right() sql.Expression {
    52  	return in.BinaryExpressionStub.RightChild
    53  }
    54  
    55  // NewInTuple creates an InTuple expression.
    56  func NewInTuple(left sql.Expression, right sql.Expression) *InTuple {
    57  	disableRounding(left)
    58  	disableRounding(right)
    59  	return &InTuple{BinaryExpressionStub{left, right}}
    60  }
    61  
    62  // Eval implements the Expression interface.
    63  func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
    64  	typ := in.Left().Type().Promote()
    65  	leftElems := types.NumColumns(typ)
    66  	originalLeft, err := in.Left().Eval(ctx, row)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	if originalLeft == nil {
    72  		return nil, nil
    73  	}
    74  
    75  	// The NULL handling for IN expressions is tricky. According to
    76  	// https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#operator_in:
    77  	// To comply with the SQL standard, IN() returns NULL not only if the expression on the left hand side is NULL, but
    78  	// also if no match is found in the list and one of the expressions in the list is NULL.
    79  	rightNull := false
    80  
    81  	left, _, err := typ.Convert(originalLeft)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	switch right := in.Right().(type) {
    87  	case Tuple:
    88  		for _, el := range right {
    89  			if types.NumColumns(el.Type()) != leftElems {
    90  				return nil, sql.ErrInvalidOperandColumns.New(leftElems, types.NumColumns(el.Type()))
    91  			}
    92  		}
    93  
    94  		for _, el := range right {
    95  			originalRight, err := el.Eval(ctx, row)
    96  			if err != nil {
    97  				return nil, err
    98  			}
    99  
   100  			if !rightNull && originalRight == nil {
   101  				rightNull = true
   102  				continue
   103  			}
   104  
   105  			var cmp int
   106  			elType := el.Type()
   107  			if types.IsDecimal(elType) || types.IsFloat(elType) {
   108  				rtyp := el.Type().Promote()
   109  				left, err := convertOrTruncate(ctx, left, rtyp)
   110  				if err != nil {
   111  					return nil, err
   112  				}
   113  				right, err := convertOrTruncate(ctx, originalRight, rtyp)
   114  				if err != nil {
   115  					return nil, err
   116  				}
   117  				cmp, err = rtyp.Compare(left, right)
   118  				if err != nil {
   119  					return nil, err
   120  				}
   121  			} else {
   122  				right, err := convertOrTruncate(ctx, originalRight, typ)
   123  				if err != nil {
   124  					return nil, err
   125  				}
   126  				cmp, err = typ.Compare(left, right)
   127  				if err != nil {
   128  					return nil, err
   129  				}
   130  			}
   131  
   132  			if cmp == 0 {
   133  				return true, nil
   134  			}
   135  		}
   136  
   137  		if rightNull {
   138  			return nil, nil
   139  		}
   140  
   141  		return false, nil
   142  	default:
   143  		return nil, ErrUnsupportedInOperand.New(right)
   144  	}
   145  }
   146  
   147  // WithChildren implements the Expression interface.
   148  func (in *InTuple) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   149  	if len(children) != 2 {
   150  		return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2)
   151  	}
   152  	return NewInTuple(children[0], children[1]), nil
   153  }
   154  
   155  func (in *InTuple) String() string {
   156  	// scalar expression must round-trip
   157  	return fmt.Sprintf("(%s IN %s)", in.Left(), in.Right())
   158  }
   159  
   160  func (in *InTuple) DebugString() string {
   161  	pr := sql.NewTreePrinter()
   162  	_ = pr.WriteNode("IN")
   163  	children := []string{fmt.Sprintf("left: %s", sql.DebugString(in.Left())), fmt.Sprintf("right: %s", sql.DebugString(in.Right()))}
   164  	_ = pr.WriteChildren(children...)
   165  	return pr.String()
   166  }
   167  
   168  // Children implements the Expression interface.
   169  func (in *InTuple) Children() []sql.Expression {
   170  	return []sql.Expression{in.Left(), in.Right()}
   171  }
   172  
   173  // NewNotInTuple creates a new NotInTuple expression.
   174  func NewNotInTuple(left sql.Expression, right sql.Expression) sql.Expression {
   175  	return NewNot(NewInTuple(left, right))
   176  }
   177  
   178  // HashInTuple is an expression that checks an expression is inside a list of expressions using a hashmap.
   179  type HashInTuple struct {
   180  	in      *InTuple
   181  	cmp     map[uint64]sql.Expression
   182  	hasNull bool
   183  }
   184  
   185  var _ Comparer = (*HashInTuple)(nil)
   186  var _ sql.CollationCoercible = (*HashInTuple)(nil)
   187  var _ sql.Expression = (*HashInTuple)(nil)
   188  
   189  // NewHashInTuple creates an InTuple expression.
   190  func NewHashInTuple(ctx *sql.Context, left, right sql.Expression) (*HashInTuple, error) {
   191  	rightTup, ok := right.(Tuple)
   192  	if !ok {
   193  		return nil, ErrUnsupportedInOperand.New(right)
   194  	}
   195  
   196  	cmp, hasNull, err := newInMap(ctx, rightTup, left.Type())
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	return &HashInTuple{in: NewInTuple(left, right), cmp: cmp, hasNull: hasNull}, nil
   202  }
   203  
   204  // newInMap hashes static expressions in the right child Tuple of a InTuple node
   205  func newInMap(ctx *sql.Context, right Tuple, lType sql.Type) (map[uint64]sql.Expression, bool, error) {
   206  	if lType == types.Null {
   207  		return nil, true, nil
   208  	}
   209  
   210  	elements := make(map[uint64]sql.Expression)
   211  	hasNull := false
   212  	lColumnCount := types.NumColumns(lType)
   213  
   214  	for _, el := range right {
   215  		rType := el.Type().Promote()
   216  		rColumnCount := types.NumColumns(rType)
   217  		if rColumnCount != lColumnCount {
   218  			return nil, false, sql.ErrInvalidOperandColumns.New(lColumnCount, rColumnCount)
   219  		}
   220  
   221  		if rType == types.Null {
   222  			hasNull = true
   223  			continue
   224  		}
   225  		i, err := el.Eval(ctx, sql.Row{})
   226  		if err != nil {
   227  			return nil, hasNull, err
   228  		}
   229  		if i == nil {
   230  			hasNull = true
   231  			continue
   232  		}
   233  
   234  		var key uint64
   235  		if types.IsDecimal(rType) || types.IsFloat(rType) {
   236  			key, err = hashOfSimple(ctx, i, rType)
   237  		} else {
   238  			key, err = hashOfSimple(ctx, i, lType)
   239  		}
   240  		if err != nil {
   241  			return nil, false, err
   242  		}
   243  		elements[key] = el
   244  	}
   245  
   246  	return elements, hasNull, nil
   247  }
   248  
   249  func hashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) {
   250  	if i == nil {
   251  		return 0, nil
   252  	}
   253  
   254  	var str string
   255  	coll := sql.Collation_Default
   256  	if types.IsTextOnly(t) {
   257  		coll = t.(sql.StringType).Collation()
   258  		if s, ok := i.(string); ok {
   259  			str = s
   260  		} else {
   261  			converted, err := convertOrTruncate(ctx, i, t)
   262  			if err != nil {
   263  				return 0, err
   264  			}
   265  			str = converted.(string)
   266  		}
   267  	} else {
   268  		x, err := convertOrTruncate(ctx, i, t.Promote())
   269  		if err != nil {
   270  			return 0, err
   271  		}
   272  
   273  		// Remove trailing 0s from floats
   274  		switch v := x.(type) {
   275  		case float32:
   276  			str = strconv.FormatFloat(float64(v), 'f', -1, 32)
   277  			if str == "-0" {
   278  				str = "0"
   279  			}
   280  		case float64:
   281  			str = strconv.FormatFloat(v, 'f', -1, 64)
   282  			if str == "-0" {
   283  				str = "0"
   284  			}
   285  		default:
   286  			str = fmt.Sprintf("%v", v)
   287  		}
   288  	}
   289  
   290  	// Collated strings that are equivalent may have different runes, so we must make them hash to the same value
   291  	return coll.HashToUint(str)
   292  }
   293  
   294  // Eval implements the Expression interface.
   295  func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   296  	leftElems := types.NumColumns(hit.in.Left().Type().Promote())
   297  
   298  	leftVal, err := hit.in.Left().Eval(ctx, row)
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  
   303  	if leftVal == nil {
   304  		return nil, nil
   305  	}
   306  
   307  	key, err := hashOfSimple(ctx, leftVal, hit.in.Left().Type())
   308  	if err != nil {
   309  		return nil, err
   310  	}
   311  
   312  	right, ok := hit.cmp[key]
   313  	if !ok {
   314  		if hit.hasNull {
   315  			return nil, nil
   316  		}
   317  		return false, nil
   318  	}
   319  
   320  	if types.NumColumns(right.Type().Promote()) != leftElems {
   321  		return nil, sql.ErrInvalidOperandColumns.New(leftElems, types.NumColumns(right.Type().Promote()))
   322  	}
   323  
   324  	return true, nil
   325  }
   326  
   327  // convertOrTruncate converts the value |i| to type |t| and returns the converted value; if the value does not convert
   328  // cleanly and the type is automatically coerced (i.e. string and numeric types), then a warning is logged and the
   329  // value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically
   330  // coerced, then an error is returned.
   331  func convertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) {
   332  	converted, _, err := t.Convert(i)
   333  	if err == nil {
   334  		return converted, nil
   335  	}
   336  
   337  	// If a value can't be converted to an enum or set type, truncate it to a value that is guaranteed
   338  	// to not match any enum value.
   339  	if types.IsEnum(t) || types.IsSet(t) {
   340  		return nil, nil
   341  	}
   342  
   343  	// Values for numeric and string types are automatically coerced. For all other types, if they
   344  	// don't convert cleanly, it's an error.
   345  	if err != nil && !(types.IsNumber(t) || types.IsTextOnly(t)) {
   346  		return nil, err
   347  	}
   348  
   349  	// For numeric and string types, if the value can't be cleanly converted, truncate to the zero value for
   350  	// the type and log a warning in the session.
   351  	warning := sql.Warning{
   352  		Level:   "Warning",
   353  		Message: fmt.Sprintf("Truncated incorrect %s value: %v", t.String(), i),
   354  		Code:    1292,
   355  	}
   356  
   357  	if ctx != nil && ctx.Session != nil {
   358  		ctx.Session.Warn(&warning)
   359  	}
   360  
   361  	return t.Zero(), nil
   362  }
   363  
   364  func (hit *HashInTuple) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   365  	return hit.in.CollationCoercibility(ctx)
   366  }
   367  
   368  func (hit *HashInTuple) Resolved() bool {
   369  	return hit.in.Resolved()
   370  }
   371  
   372  func (hit *HashInTuple) Type() sql.Type {
   373  	return hit.in.Type()
   374  }
   375  
   376  func (hit *HashInTuple) IsNullable() bool {
   377  	return hit.in.IsNullable()
   378  }
   379  
   380  func (hit *HashInTuple) Children() []sql.Expression {
   381  	return hit.in.Children()
   382  }
   383  
   384  func (hit *HashInTuple) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   385  	if len(children) != 2 {
   386  		return nil, sql.ErrInvalidChildrenNumber.New(hit, len(children), 2)
   387  	}
   388  	ret := *hit
   389  	newIn, err := ret.in.WithChildren(children...)
   390  	ret.in = newIn.(*InTuple)
   391  	return &ret, err
   392  }
   393  
   394  func (hit *HashInTuple) Compare(ctx *sql.Context, row sql.Row) (int, error) {
   395  	return hit.in.Compare(ctx, row)
   396  }
   397  
   398  func (hit *HashInTuple) Left() sql.Expression {
   399  	return hit.in.Left()
   400  }
   401  
   402  func (hit *HashInTuple) Right() sql.Expression {
   403  	return hit.in.Right()
   404  }
   405  
   406  func (hit *HashInTuple) String() string {
   407  	return fmt.Sprintf("(%s HASH IN %s)", hit.in.Left(), hit.in.Right())
   408  }
   409  
   410  func (hit *HashInTuple) DebugString() string {
   411  	pr := sql.NewTreePrinter()
   412  	_ = pr.WriteNode("HashIn")
   413  	children := []string{sql.DebugString(hit.in.Left()), sql.DebugString(hit.in.Right())}
   414  	_ = pr.WriteChildren(children...)
   415  	return pr.String()
   416  }