github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/in.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "fmt" 19 "strconv" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/types" 23 ) 24 25 // InTuple is an expression that checks an expression is inside a list of expressions. 26 type InTuple struct { 27 BinaryExpressionStub 28 } 29 30 // We implement Comparer because we have a Left() and a Right(), but we can't be Compare()d 31 var _ Comparer = (*InTuple)(nil) 32 var _ sql.CollationCoercible = (*InTuple)(nil) 33 34 func (in *InTuple) Compare(ctx *sql.Context, row sql.Row) (int, error) { 35 panic("Compare not implemented for InTuple") 36 } 37 38 func (in *InTuple) Type() sql.Type { 39 return types.Boolean 40 } 41 42 // CollationCoercibility implements the interface sql.CollationCoercible. 43 func (*InTuple) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 44 return sql.Collation_binary, 5 45 } 46 47 func (in *InTuple) Left() sql.Expression { 48 return in.BinaryExpressionStub.LeftChild 49 } 50 51 func (in *InTuple) Right() sql.Expression { 52 return in.BinaryExpressionStub.RightChild 53 } 54 55 // NewInTuple creates an InTuple expression. 56 func NewInTuple(left sql.Expression, right sql.Expression) *InTuple { 57 disableRounding(left) 58 disableRounding(right) 59 return &InTuple{BinaryExpressionStub{left, right}} 60 } 61 62 // Eval implements the Expression interface. 63 func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 64 typ := in.Left().Type().Promote() 65 leftElems := types.NumColumns(typ) 66 originalLeft, err := in.Left().Eval(ctx, row) 67 if err != nil { 68 return nil, err 69 } 70 71 if originalLeft == nil { 72 return nil, nil 73 } 74 75 // The NULL handling for IN expressions is tricky. According to 76 // https://dev.mysql.com/doc/refman/8.0/en/comparison-operators.html#operator_in: 77 // To comply with the SQL standard, IN() returns NULL not only if the expression on the left hand side is NULL, but 78 // also if no match is found in the list and one of the expressions in the list is NULL. 79 rightNull := false 80 81 left, _, err := typ.Convert(originalLeft) 82 if err != nil { 83 return nil, err 84 } 85 86 switch right := in.Right().(type) { 87 case Tuple: 88 for _, el := range right { 89 if types.NumColumns(el.Type()) != leftElems { 90 return nil, sql.ErrInvalidOperandColumns.New(leftElems, types.NumColumns(el.Type())) 91 } 92 } 93 94 for _, el := range right { 95 originalRight, err := el.Eval(ctx, row) 96 if err != nil { 97 return nil, err 98 } 99 100 if !rightNull && originalRight == nil { 101 rightNull = true 102 continue 103 } 104 105 var cmp int 106 elType := el.Type() 107 if types.IsDecimal(elType) || types.IsFloat(elType) { 108 rtyp := el.Type().Promote() 109 left, err := convertOrTruncate(ctx, left, rtyp) 110 if err != nil { 111 return nil, err 112 } 113 right, err := convertOrTruncate(ctx, originalRight, rtyp) 114 if err != nil { 115 return nil, err 116 } 117 cmp, err = rtyp.Compare(left, right) 118 if err != nil { 119 return nil, err 120 } 121 } else { 122 right, err := convertOrTruncate(ctx, originalRight, typ) 123 if err != nil { 124 return nil, err 125 } 126 cmp, err = typ.Compare(left, right) 127 if err != nil { 128 return nil, err 129 } 130 } 131 132 if cmp == 0 { 133 return true, nil 134 } 135 } 136 137 if rightNull { 138 return nil, nil 139 } 140 141 return false, nil 142 default: 143 return nil, ErrUnsupportedInOperand.New(right) 144 } 145 } 146 147 // WithChildren implements the Expression interface. 148 func (in *InTuple) WithChildren(children ...sql.Expression) (sql.Expression, error) { 149 if len(children) != 2 { 150 return nil, sql.ErrInvalidChildrenNumber.New(in, len(children), 2) 151 } 152 return NewInTuple(children[0], children[1]), nil 153 } 154 155 func (in *InTuple) String() string { 156 // scalar expression must round-trip 157 return fmt.Sprintf("(%s IN %s)", in.Left(), in.Right()) 158 } 159 160 func (in *InTuple) DebugString() string { 161 pr := sql.NewTreePrinter() 162 _ = pr.WriteNode("IN") 163 children := []string{fmt.Sprintf("left: %s", sql.DebugString(in.Left())), fmt.Sprintf("right: %s", sql.DebugString(in.Right()))} 164 _ = pr.WriteChildren(children...) 165 return pr.String() 166 } 167 168 // Children implements the Expression interface. 169 func (in *InTuple) Children() []sql.Expression { 170 return []sql.Expression{in.Left(), in.Right()} 171 } 172 173 // NewNotInTuple creates a new NotInTuple expression. 174 func NewNotInTuple(left sql.Expression, right sql.Expression) sql.Expression { 175 return NewNot(NewInTuple(left, right)) 176 } 177 178 // HashInTuple is an expression that checks an expression is inside a list of expressions using a hashmap. 179 type HashInTuple struct { 180 in *InTuple 181 cmp map[uint64]sql.Expression 182 hasNull bool 183 } 184 185 var _ Comparer = (*HashInTuple)(nil) 186 var _ sql.CollationCoercible = (*HashInTuple)(nil) 187 var _ sql.Expression = (*HashInTuple)(nil) 188 189 // NewHashInTuple creates an InTuple expression. 190 func NewHashInTuple(ctx *sql.Context, left, right sql.Expression) (*HashInTuple, error) { 191 rightTup, ok := right.(Tuple) 192 if !ok { 193 return nil, ErrUnsupportedInOperand.New(right) 194 } 195 196 cmp, hasNull, err := newInMap(ctx, rightTup, left.Type()) 197 if err != nil { 198 return nil, err 199 } 200 201 return &HashInTuple{in: NewInTuple(left, right), cmp: cmp, hasNull: hasNull}, nil 202 } 203 204 // newInMap hashes static expressions in the right child Tuple of a InTuple node 205 func newInMap(ctx *sql.Context, right Tuple, lType sql.Type) (map[uint64]sql.Expression, bool, error) { 206 if lType == types.Null { 207 return nil, true, nil 208 } 209 210 elements := make(map[uint64]sql.Expression) 211 hasNull := false 212 lColumnCount := types.NumColumns(lType) 213 214 for _, el := range right { 215 rType := el.Type().Promote() 216 rColumnCount := types.NumColumns(rType) 217 if rColumnCount != lColumnCount { 218 return nil, false, sql.ErrInvalidOperandColumns.New(lColumnCount, rColumnCount) 219 } 220 221 if rType == types.Null { 222 hasNull = true 223 continue 224 } 225 i, err := el.Eval(ctx, sql.Row{}) 226 if err != nil { 227 return nil, hasNull, err 228 } 229 if i == nil { 230 hasNull = true 231 continue 232 } 233 234 var key uint64 235 if types.IsDecimal(rType) || types.IsFloat(rType) { 236 key, err = hashOfSimple(ctx, i, rType) 237 } else { 238 key, err = hashOfSimple(ctx, i, lType) 239 } 240 if err != nil { 241 return nil, false, err 242 } 243 elements[key] = el 244 } 245 246 return elements, hasNull, nil 247 } 248 249 func hashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) { 250 if i == nil { 251 return 0, nil 252 } 253 254 var str string 255 coll := sql.Collation_Default 256 if types.IsTextOnly(t) { 257 coll = t.(sql.StringType).Collation() 258 if s, ok := i.(string); ok { 259 str = s 260 } else { 261 converted, err := convertOrTruncate(ctx, i, t) 262 if err != nil { 263 return 0, err 264 } 265 str = converted.(string) 266 } 267 } else { 268 x, err := convertOrTruncate(ctx, i, t.Promote()) 269 if err != nil { 270 return 0, err 271 } 272 273 // Remove trailing 0s from floats 274 switch v := x.(type) { 275 case float32: 276 str = strconv.FormatFloat(float64(v), 'f', -1, 32) 277 if str == "-0" { 278 str = "0" 279 } 280 case float64: 281 str = strconv.FormatFloat(v, 'f', -1, 64) 282 if str == "-0" { 283 str = "0" 284 } 285 default: 286 str = fmt.Sprintf("%v", v) 287 } 288 } 289 290 // Collated strings that are equivalent may have different runes, so we must make them hash to the same value 291 return coll.HashToUint(str) 292 } 293 294 // Eval implements the Expression interface. 295 func (hit *HashInTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 296 leftElems := types.NumColumns(hit.in.Left().Type().Promote()) 297 298 leftVal, err := hit.in.Left().Eval(ctx, row) 299 if err != nil { 300 return nil, err 301 } 302 303 if leftVal == nil { 304 return nil, nil 305 } 306 307 key, err := hashOfSimple(ctx, leftVal, hit.in.Left().Type()) 308 if err != nil { 309 return nil, err 310 } 311 312 right, ok := hit.cmp[key] 313 if !ok { 314 if hit.hasNull { 315 return nil, nil 316 } 317 return false, nil 318 } 319 320 if types.NumColumns(right.Type().Promote()) != leftElems { 321 return nil, sql.ErrInvalidOperandColumns.New(leftElems, types.NumColumns(right.Type().Promote())) 322 } 323 324 return true, nil 325 } 326 327 // convertOrTruncate converts the value |i| to type |t| and returns the converted value; if the value does not convert 328 // cleanly and the type is automatically coerced (i.e. string and numeric types), then a warning is logged and the 329 // value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically 330 // coerced, then an error is returned. 331 func convertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) { 332 converted, _, err := t.Convert(i) 333 if err == nil { 334 return converted, nil 335 } 336 337 // If a value can't be converted to an enum or set type, truncate it to a value that is guaranteed 338 // to not match any enum value. 339 if types.IsEnum(t) || types.IsSet(t) { 340 return nil, nil 341 } 342 343 // Values for numeric and string types are automatically coerced. For all other types, if they 344 // don't convert cleanly, it's an error. 345 if err != nil && !(types.IsNumber(t) || types.IsTextOnly(t)) { 346 return nil, err 347 } 348 349 // For numeric and string types, if the value can't be cleanly converted, truncate to the zero value for 350 // the type and log a warning in the session. 351 warning := sql.Warning{ 352 Level: "Warning", 353 Message: fmt.Sprintf("Truncated incorrect %s value: %v", t.String(), i), 354 Code: 1292, 355 } 356 357 if ctx != nil && ctx.Session != nil { 358 ctx.Session.Warn(&warning) 359 } 360 361 return t.Zero(), nil 362 } 363 364 func (hit *HashInTuple) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 365 return hit.in.CollationCoercibility(ctx) 366 } 367 368 func (hit *HashInTuple) Resolved() bool { 369 return hit.in.Resolved() 370 } 371 372 func (hit *HashInTuple) Type() sql.Type { 373 return hit.in.Type() 374 } 375 376 func (hit *HashInTuple) IsNullable() bool { 377 return hit.in.IsNullable() 378 } 379 380 func (hit *HashInTuple) Children() []sql.Expression { 381 return hit.in.Children() 382 } 383 384 func (hit *HashInTuple) WithChildren(children ...sql.Expression) (sql.Expression, error) { 385 if len(children) != 2 { 386 return nil, sql.ErrInvalidChildrenNumber.New(hit, len(children), 2) 387 } 388 ret := *hit 389 newIn, err := ret.in.WithChildren(children...) 390 ret.in = newIn.(*InTuple) 391 return &ret, err 392 } 393 394 func (hit *HashInTuple) Compare(ctx *sql.Context, row sql.Row) (int, error) { 395 return hit.in.Compare(ctx, row) 396 } 397 398 func (hit *HashInTuple) Left() sql.Expression { 399 return hit.in.Left() 400 } 401 402 func (hit *HashInTuple) Right() sql.Expression { 403 return hit.in.Right() 404 } 405 406 func (hit *HashInTuple) String() string { 407 return fmt.Sprintf("(%s HASH IN %s)", hit.in.Left(), hit.in.Right()) 408 } 409 410 func (hit *HashInTuple) DebugString() string { 411 pr := sql.NewTreePrinter() 412 _ = pr.WriteNode("HashIn") 413 children := []string{sql.DebugString(hit.in.Left()), sql.DebugString(hit.in.Right())} 414 _ = pr.WriteChildren(children...) 415 return pr.String() 416 }