github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/expreval/expression_evaluator.go (about) 1 // Copyright 2020 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expreval 16 17 import ( 18 "context" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/expression" 22 "gopkg.in/src-d/go-errors.v1" 23 24 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 25 "github.com/dolthub/dolt/go/store/types" 26 ) 27 28 var errUnsupportedComparisonType = errors.NewKind("Unsupported Comparison Type.") 29 var errUnknownColumn = errors.NewKind("Column %s not found.") 30 var errInvalidConversion = errors.NewKind("Could not convert %s from %s to %s.") 31 var errNotImplemented = errors.NewKind("Not Implemented: %s") 32 33 // ExpressionFunc is a function that takes a map of tag to value and returns whether some set of criteria are true for 34 // the set of values 35 type ExpressionFunc func(ctx context.Context, vals map[uint64]types.Value) (bool, error) 36 37 // ExpressionFuncFromSQLExpressions returns an ExpressionFunc which represents the slice of sql.Expressions passed in 38 func ExpressionFuncFromSQLExpressions(nbf *types.NomsBinFormat, sch schema.Schema, expressions []sql.Expression) (ExpressionFunc, error) { 39 var root ExpressionFunc 40 for _, exp := range expressions { 41 expFunc, err := getExpFunc(nbf, sch, exp) 42 43 if err != nil { 44 return nil, err 45 } 46 47 if root == nil { 48 root = expFunc 49 } else { 50 root = newAndFunc(root, expFunc) 51 } 52 } 53 54 if root == nil { 55 root = func(ctx context.Context, vals map[uint64]types.Value) (bool, error) { 56 return true, nil 57 } 58 } 59 60 return root, nil 61 } 62 63 func getExpFunc(nbf *types.NomsBinFormat, sch schema.Schema, exp sql.Expression) (ExpressionFunc, error) { 64 switch typedExpr := exp.(type) { 65 case *expression.Equals: 66 return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch) 67 case *expression.GreaterThan: 68 return newComparisonFunc(GreaterOp{nbf}, typedExpr.BinaryExpression, sch) 69 case *expression.GreaterThanOrEqual: 70 return newComparisonFunc(GreaterEqualOp{nbf}, typedExpr.BinaryExpression, sch) 71 case *expression.LessThan: 72 return newComparisonFunc(LessOp{nbf}, typedExpr.BinaryExpression, sch) 73 case *expression.LessThanOrEqual: 74 return newComparisonFunc(LessEqualOp{nbf}, typedExpr.BinaryExpression, sch) 75 case *expression.Or: 76 leftFunc, err := getExpFunc(nbf, sch, typedExpr.Left) 77 78 if err != nil { 79 return nil, err 80 } 81 82 rightFunc, err := getExpFunc(nbf, sch, typedExpr.Right) 83 84 if err != nil { 85 return nil, err 86 } 87 88 return newOrFunc(leftFunc, rightFunc), nil 89 case *expression.And: 90 leftFunc, err := getExpFunc(nbf, sch, typedExpr.Left) 91 92 if err != nil { 93 return nil, err 94 } 95 96 rightFunc, err := getExpFunc(nbf, sch, typedExpr.Right) 97 98 if err != nil { 99 return nil, err 100 } 101 102 return newAndFunc(leftFunc, rightFunc), nil 103 case *expression.InTuple: 104 return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch) 105 } 106 107 return nil, errNotImplemented.New(exp.Type().String()) 108 } 109 110 func newOrFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc { 111 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 112 lRes, err := left(ctx, vals) 113 114 if err != nil { 115 return false, err 116 } 117 118 if lRes { 119 return true, nil 120 } 121 122 return right(ctx, vals) 123 } 124 } 125 126 func newAndFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc { 127 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 128 lRes, err := left(ctx, vals) 129 130 if err != nil { 131 return false, err 132 } 133 134 if !lRes { 135 return false, nil 136 } 137 138 return right(ctx, vals) 139 } 140 } 141 142 type ComparisonType int 143 144 const ( 145 InvalidCompare ComparisonType = iota 146 VariableConstCompare 147 VariableVariableCompare 148 VariableInLiteralList 149 ConstConstCompare 150 ) 151 152 // GetComparisonType looks at a go-mysql-server BinaryExpression classifies the left and right arguments 153 // as variables or constants. 154 func GetComparisonType(be expression.BinaryExpression) ([]*expression.GetField, []*expression.Literal, ComparisonType, error) { 155 var variables []*expression.GetField 156 var consts []*expression.Literal 157 158 for _, curr := range []sql.Expression{be.Left, be.Right} { 159 // need to remove this and handle properly 160 if conv, ok := curr.(*expression.Convert); ok { 161 curr = conv.Child 162 } 163 164 switch v := curr.(type) { 165 case *expression.GetField: 166 variables = append(variables, v) 167 case *expression.Literal: 168 consts = append(consts, v) 169 case expression.Tuple: 170 children := v.Children() 171 for _, currChild := range children { 172 lit, ok := currChild.(*expression.Literal) 173 if !ok { 174 return nil, nil, InvalidCompare, errUnsupportedComparisonType.New() 175 } 176 consts = append(consts, lit) 177 } 178 default: 179 return nil, nil, InvalidCompare, errUnsupportedComparisonType.New() 180 } 181 } 182 183 var compType ComparisonType 184 if len(variables) == 2 { 185 compType = VariableVariableCompare 186 } else if len(variables) == 1 { 187 if len(consts) == 1 { 188 compType = VariableConstCompare 189 } else if len(consts) > 1 { 190 compType = VariableInLiteralList 191 } 192 } else if len(consts) == 2 { 193 compType = ConstConstCompare 194 } 195 196 return variables, consts, compType, nil 197 } 198 199 var trueFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return true, nil } 200 var falseFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return false, nil } 201 202 func newComparisonFunc(op CompareOp, exp expression.BinaryExpression, sch schema.Schema) (ExpressionFunc, error) { 203 vars, consts, compType, err := GetComparisonType(exp) 204 205 if err != nil { 206 return nil, err 207 } 208 209 if compType == ConstConstCompare { 210 res, err := op.CompareLiterals(consts[0], consts[1]) 211 212 if err != nil { 213 return nil, err 214 } 215 216 if res { 217 return trueFunc, nil 218 } else { 219 return falseFunc, nil 220 } 221 } else if compType == VariableConstCompare { 222 colName := vars[0].Name() 223 col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName) 224 225 if !ok { 226 return nil, errUnknownColumn.New(colName) 227 } 228 229 tag := col.Tag 230 nomsVal, err := LiteralToNomsValue(col.Kind, consts[0]) 231 232 if err != nil { 233 return nil, err 234 } 235 236 compareNomsValues := op.CompareNomsValues 237 compareToNil := op.CompareToNil 238 239 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 240 colVal, ok := vals[tag] 241 242 if ok && !types.IsNull(colVal) { 243 return compareNomsValues(colVal, nomsVal) 244 } else { 245 return compareToNil(nomsVal) 246 } 247 }, nil 248 } else if compType == VariableVariableCompare { 249 col1Name := vars[0].Name() 250 col1, ok := sch.GetAllCols().GetByNameCaseInsensitive(col1Name) 251 252 if !ok { 253 return nil, errUnknownColumn.New(col1Name) 254 } 255 256 col2Name := vars[1].Name() 257 col2, ok := sch.GetAllCols().GetByNameCaseInsensitive(col2Name) 258 259 if !ok { 260 return nil, errUnknownColumn.New(col2Name) 261 } 262 263 compareNomsValues := op.CompareNomsValues 264 compareToNull := op.CompareToNil 265 266 tag1, tag2 := col1.Tag, col2.Tag 267 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 268 v1 := vals[tag1] 269 v2 := vals[tag2] 270 271 if types.IsNull(v1) { 272 return compareToNull(v2) 273 } else { 274 return compareNomsValues(v1, v2) 275 } 276 }, nil 277 } else if compType == VariableInLiteralList { 278 colName := vars[0].Name() 279 col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName) 280 281 if !ok { 282 return nil, errUnknownColumn.New(colName) 283 } 284 285 tag := col.Tag 286 287 // Get all the noms values 288 nomsVals := make([]types.Value, len(consts)) 289 for i, c := range consts { 290 nomsVal, err := LiteralToNomsValue(col.Kind, c) 291 if err != nil { 292 return nil, err 293 } 294 nomsVals[i] = nomsVal 295 } 296 297 compareNomsValues := op.CompareNomsValues 298 compareToNil := op.CompareToNil 299 300 return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { 301 colVal, ok := vals[tag] 302 303 for _, nv := range nomsVals { 304 var lb bool 305 if ok && !types.IsNull(colVal) { 306 lb, err = compareNomsValues(colVal, nv) 307 } else { 308 lb, err = compareToNil(nv) 309 } 310 311 if err != nil { 312 return false, err 313 } 314 if lb { 315 return true, nil 316 } 317 } 318 319 return false, nil 320 }, nil 321 } else { 322 return nil, errUnsupportedComparisonType.New() 323 } 324 }