github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/expreval/expression_evaluator.go (about) 1 // Copyright 2020 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expreval 16 17 import ( 18 "context" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/expression" 22 gmstypes "github.com/dolthub/go-mysql-server/sql/types" 23 "gopkg.in/src-d/go-errors.v1" 24 25 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 26 "github.com/dolthub/dolt/go/store/types" 27 ) 28 29 var errUnsupportedComparisonType = errors.NewKind("Unsupported Comparison Type.") 30 var errUnknownColumn = errors.NewKind("Column %s not found.") 31 var errInvalidConversion = errors.NewKind("Could not convert %s from %s to %s.") 32 var errNotImplemented = errors.NewKind("Not Implemented: %s") 33 34 // ExpressionFunc is a function that takes a map of tag to value and returns whether some set of criteria are true for 35 // the set of values 36 type ExpressionFunc func(ctx context.Context, vals map[uint64]types.Value) (bool, error) 37 38 // ExpressionFuncFromSQLExpressions returns an ExpressionFunc which represents the slice of sql.Expressions passed in 39 func ExpressionFuncFromSQLExpressions(vr types.ValueReader, sch schema.Schema, expressions []sql.Expression) (ExpressionFunc, error) { 40 var root ExpressionFunc 41 for _, exp := range expressions { 42 expFunc, err := getExpFunc(vr, sch, exp) 43 44 if err != nil { 45 return nil, err 46 } 47 48 if root == nil { 49 root = expFunc 50 } else { 51 root = newAndFunc(root, expFunc) 52 } 53 } 54 55 if root == nil { 56 root = func(ctx context.Context, vals map[uint64]types.Value) (bool, error) { 57 return true, nil 58 } 59 } 60 61 return root, nil 62 } 63 64 func getExpFunc(vr types.ValueReader, sch schema.Schema, exp sql.Expression) (ExpressionFunc, error) { 65 switch typedExpr := exp.(type) { 66 case *expression.Equals: 67 return newComparisonFunc(EqualsOp{}, typedExpr, sch) 68 case *expression.GreaterThan: 69 return newComparisonFunc(GreaterOp{vr}, typedExpr, sch) 70 case *expression.GreaterThanOrEqual: 71 return newComparisonFunc(GreaterEqualOp{vr}, typedExpr, sch) 72 case *expression.LessThan: 73 return newComparisonFunc(LessOp{vr}, typedExpr, sch) 74 case *expression.LessThanOrEqual: 75 return newComparisonFunc(LessEqualOp{vr}, typedExpr, sch) 76 case *expression.Or: 77 leftFunc, err := getExpFunc(vr, sch, typedExpr.Left()) 78 79 if err != nil { 80 return nil, err 81 } 82 83 rightFunc, err := getExpFunc(vr, sch, typedExpr.Right()) 84 85 if err != nil { 86 return nil, err 87 } 88 89 return newOrFunc(leftFunc, rightFunc), nil 90 case *expression.And: 91 leftFunc, err := getExpFunc(vr, sch, typedExpr.Left()) 92 93 if err != nil { 94 return nil, err 95 } 96 97 rightFunc, err := getExpFunc(vr, sch, typedExpr.Right()) 98 99 if err != nil { 100 return nil, err 101 } 102 103 return newAndFunc(leftFunc, rightFunc), nil 104 case *expression.InTuple: 105 return newComparisonFunc(EqualsOp{}, typedExpr, sch) 106 case *expression.Not: 107 expFunc, err := getExpFunc(vr, sch, typedExpr.Child) 108 if err != nil { 109 return nil, err 110 } 111 return newNotFunc(expFunc), nil 112 case *expression.IsNull: 113 return newComparisonFunc(EqualsOp{}, expression.NewNullSafeEquals(typedExpr.Child, expression.NewLiteral(nil, gmstypes.Null)), sch) 114 } 115 116 return nil, errNotImplemented.New(exp.Type().String()) 117 } 118 119 func newOrFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc { 120 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 121 lRes, err := left(ctx, vals) 122 123 if err != nil { 124 return false, err 125 } 126 127 if lRes { 128 return true, nil 129 } 130 131 return right(ctx, vals) 132 } 133 } 134 135 func newAndFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc { 136 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 137 lRes, err := left(ctx, vals) 138 139 if err != nil { 140 return false, err 141 } 142 143 if !lRes { 144 return false, nil 145 } 146 147 return right(ctx, vals) 148 } 149 } 150 151 func newNotFunc(exp ExpressionFunc) ExpressionFunc { 152 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 153 res, err := exp(ctx, vals) 154 if err != nil { 155 return false, err 156 } 157 158 return !res, nil 159 } 160 } 161 162 type ComparisonType int 163 164 const ( 165 InvalidCompare ComparisonType = iota 166 VariableConstCompare 167 VariableVariableCompare 168 VariableInLiteralList 169 ConstConstCompare 170 ) 171 172 // GetComparisonType looks at a go-mysql-server BinaryExpression classifies the left and right arguments 173 // as variables or constants. 174 func GetComparisonType(be expression.BinaryExpression) ([]*expression.GetField, []*expression.Literal, ComparisonType, error) { 175 var variables []*expression.GetField 176 var consts []*expression.Literal 177 178 for _, curr := range []sql.Expression{be.Left(), be.Right()} { 179 // need to remove this and handle properly 180 if conv, ok := curr.(*expression.Convert); ok { 181 curr = conv.Child 182 } 183 184 switch v := curr.(type) { 185 case *expression.GetField: 186 variables = append(variables, v) 187 case *expression.Literal: 188 consts = append(consts, v) 189 case expression.Tuple: 190 children := v.Children() 191 for _, currChild := range children { 192 lit, ok := currChild.(*expression.Literal) 193 if !ok { 194 return nil, nil, InvalidCompare, errUnsupportedComparisonType.New() 195 } 196 consts = append(consts, lit) 197 } 198 default: 199 return nil, nil, InvalidCompare, errUnsupportedComparisonType.New() 200 } 201 } 202 203 var compType ComparisonType 204 if len(variables) == 2 { 205 compType = VariableVariableCompare 206 } else if len(variables) == 1 { 207 if len(consts) == 1 { 208 compType = VariableConstCompare 209 } else if len(consts) > 1 { 210 compType = VariableInLiteralList 211 } 212 } else if len(consts) == 2 { 213 compType = ConstConstCompare 214 } 215 216 return variables, consts, compType, nil 217 } 218 219 var trueFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return true, nil } 220 var falseFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return false, nil } 221 222 func newComparisonFunc(op CompareOp, exp expression.BinaryExpression, sch schema.Schema) (ExpressionFunc, error) { 223 vars, consts, compType, err := GetComparisonType(exp) 224 225 if err != nil { 226 return nil, err 227 } 228 229 if compType == ConstConstCompare { 230 res, err := op.CompareLiterals(consts[0], consts[1]) 231 232 if err != nil { 233 return nil, err 234 } 235 236 if res { 237 return trueFunc, nil 238 } else { 239 return falseFunc, nil 240 } 241 } else if compType == VariableConstCompare { 242 colName := vars[0].Name() 243 col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName) 244 245 if !ok { 246 return nil, errUnknownColumn.New(colName) 247 } 248 249 tag := col.Tag 250 nomsVal, err := LiteralToNomsValue(col.Kind, consts[0]) 251 252 if err != nil { 253 return nil, err 254 } 255 256 compareNomsValues := op.CompareNomsValues 257 compareToNil := op.CompareToNil 258 259 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 260 colVal, ok := vals[tag] 261 262 if ok && !types.IsNull(colVal) { 263 return compareNomsValues(ctx, colVal, nomsVal) 264 } else { 265 return compareToNil(nomsVal) 266 } 267 }, nil 268 } else if compType == VariableVariableCompare { 269 col1Name := vars[0].Name() 270 col1, ok := sch.GetAllCols().GetByNameCaseInsensitive(col1Name) 271 272 if !ok { 273 return nil, errUnknownColumn.New(col1Name) 274 } 275 276 col2Name := vars[1].Name() 277 col2, ok := sch.GetAllCols().GetByNameCaseInsensitive(col2Name) 278 279 if !ok { 280 return nil, errUnknownColumn.New(col2Name) 281 } 282 283 compareNomsValues := op.CompareNomsValues 284 compareToNull := op.CompareToNil 285 286 tag1, tag2 := col1.Tag, col2.Tag 287 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 288 v1 := vals[tag1] 289 v2 := vals[tag2] 290 291 if types.IsNull(v1) { 292 return compareToNull(v2) 293 } else { 294 return compareNomsValues(ctx, v1, v2) 295 } 296 }, nil 297 } else if compType == VariableInLiteralList { 298 colName := vars[0].Name() 299 col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName) 300 301 if !ok { 302 return nil, errUnknownColumn.New(colName) 303 } 304 305 tag := col.Tag 306 307 // Get all the noms values 308 nomsVals := make([]types.Value, len(consts)) 309 for i, c := range consts { 310 nomsVal, err := LiteralToNomsValue(col.Kind, c) 311 if err != nil { 312 return nil, err 313 } 314 nomsVals[i] = nomsVal 315 } 316 317 compareNomsValues := op.CompareNomsValues 318 compareToNil := op.CompareToNil 319 320 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 321 colVal, ok := vals[tag] 322 323 for _, nv := range nomsVals { 324 var lb bool 325 if ok && !types.IsNull(colVal) { 326 lb, err = compareNomsValues(ctx, colVal, nv) 327 } else { 328 lb, err = compareToNil(nv) 329 } 330 331 if err != nil { 332 return false, err 333 } 334 if lb { 335 return true, nil 336 } 337 } 338 339 return false, nil 340 }, nil 341 } else { 342 return nil, errUnsupportedComparisonType.New() 343 } 344 }