github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/expreval/expression_evaluator.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expreval
    16  
    17  import (
    18  	"context"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    23  	"gopkg.in/src-d/go-errors.v1"
    24  
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    26  	"github.com/dolthub/dolt/go/store/types"
    27  )
    28  
    29  var errUnsupportedComparisonType = errors.NewKind("Unsupported Comparison Type.")
    30  var errUnknownColumn = errors.NewKind("Column %s not found.")
    31  var errInvalidConversion = errors.NewKind("Could not convert %s from %s to %s.")
    32  var errNotImplemented = errors.NewKind("Not Implemented: %s")
    33  
    34  // ExpressionFunc is a function that takes a map of tag to value and returns whether some set of criteria are true for
    35  // the set of values
    36  type ExpressionFunc func(ctx context.Context, vals map[uint64]types.Value) (bool, error)
    37  
    38  // ExpressionFuncFromSQLExpressions returns an ExpressionFunc which represents the slice of sql.Expressions passed in
    39  func ExpressionFuncFromSQLExpressions(vr types.ValueReader, sch schema.Schema, expressions []sql.Expression) (ExpressionFunc, error) {
    40  	var root ExpressionFunc
    41  	for _, exp := range expressions {
    42  		expFunc, err := getExpFunc(vr, sch, exp)
    43  
    44  		if err != nil {
    45  			return nil, err
    46  		}
    47  
    48  		if root == nil {
    49  			root = expFunc
    50  		} else {
    51  			root = newAndFunc(root, expFunc)
    52  		}
    53  	}
    54  
    55  	if root == nil {
    56  		root = func(ctx context.Context, vals map[uint64]types.Value) (bool, error) {
    57  			return true, nil
    58  		}
    59  	}
    60  
    61  	return root, nil
    62  }
    63  
    64  func getExpFunc(vr types.ValueReader, sch schema.Schema, exp sql.Expression) (ExpressionFunc, error) {
    65  	switch typedExpr := exp.(type) {
    66  	case *expression.Equals:
    67  		return newComparisonFunc(EqualsOp{}, typedExpr, sch)
    68  	case *expression.GreaterThan:
    69  		return newComparisonFunc(GreaterOp{vr}, typedExpr, sch)
    70  	case *expression.GreaterThanOrEqual:
    71  		return newComparisonFunc(GreaterEqualOp{vr}, typedExpr, sch)
    72  	case *expression.LessThan:
    73  		return newComparisonFunc(LessOp{vr}, typedExpr, sch)
    74  	case *expression.LessThanOrEqual:
    75  		return newComparisonFunc(LessEqualOp{vr}, typedExpr, sch)
    76  	case *expression.Or:
    77  		leftFunc, err := getExpFunc(vr, sch, typedExpr.Left())
    78  
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  
    83  		rightFunc, err := getExpFunc(vr, sch, typedExpr.Right())
    84  
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  
    89  		return newOrFunc(leftFunc, rightFunc), nil
    90  	case *expression.And:
    91  		leftFunc, err := getExpFunc(vr, sch, typedExpr.Left())
    92  
    93  		if err != nil {
    94  			return nil, err
    95  		}
    96  
    97  		rightFunc, err := getExpFunc(vr, sch, typedExpr.Right())
    98  
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  
   103  		return newAndFunc(leftFunc, rightFunc), nil
   104  	case *expression.InTuple:
   105  		return newComparisonFunc(EqualsOp{}, typedExpr, sch)
   106  	case *expression.Not:
   107  		expFunc, err := getExpFunc(vr, sch, typedExpr.Child)
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  		return newNotFunc(expFunc), nil
   112  	case *expression.IsNull:
   113  		return newComparisonFunc(EqualsOp{}, expression.NewNullSafeEquals(typedExpr.Child, expression.NewLiteral(nil, gmstypes.Null)), sch)
   114  	}
   115  
   116  	return nil, errNotImplemented.New(exp.Type().String())
   117  }
   118  
   119  func newOrFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc {
   120  	return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   121  		lRes, err := left(ctx, vals)
   122  
   123  		if err != nil {
   124  			return false, err
   125  		}
   126  
   127  		if lRes {
   128  			return true, nil
   129  		}
   130  
   131  		return right(ctx, vals)
   132  	}
   133  }
   134  
   135  func newAndFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc {
   136  	return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   137  		lRes, err := left(ctx, vals)
   138  
   139  		if err != nil {
   140  			return false, err
   141  		}
   142  
   143  		if !lRes {
   144  			return false, nil
   145  		}
   146  
   147  		return right(ctx, vals)
   148  	}
   149  }
   150  
   151  func newNotFunc(exp ExpressionFunc) ExpressionFunc {
   152  	return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   153  		res, err := exp(ctx, vals)
   154  		if err != nil {
   155  			return false, err
   156  		}
   157  
   158  		return !res, nil
   159  	}
   160  }
   161  
   162  type ComparisonType int
   163  
   164  const (
   165  	InvalidCompare ComparisonType = iota
   166  	VariableConstCompare
   167  	VariableVariableCompare
   168  	VariableInLiteralList
   169  	ConstConstCompare
   170  )
   171  
   172  // GetComparisonType looks at a go-mysql-server BinaryExpression classifies the left and right arguments
   173  // as variables or constants.
   174  func GetComparisonType(be expression.BinaryExpression) ([]*expression.GetField, []*expression.Literal, ComparisonType, error) {
   175  	var variables []*expression.GetField
   176  	var consts []*expression.Literal
   177  
   178  	for _, curr := range []sql.Expression{be.Left(), be.Right()} {
   179  		// need to remove this and handle properly
   180  		if conv, ok := curr.(*expression.Convert); ok {
   181  			curr = conv.Child
   182  		}
   183  
   184  		switch v := curr.(type) {
   185  		case *expression.GetField:
   186  			variables = append(variables, v)
   187  		case *expression.Literal:
   188  			consts = append(consts, v)
   189  		case expression.Tuple:
   190  			children := v.Children()
   191  			for _, currChild := range children {
   192  				lit, ok := currChild.(*expression.Literal)
   193  				if !ok {
   194  					return nil, nil, InvalidCompare, errUnsupportedComparisonType.New()
   195  				}
   196  				consts = append(consts, lit)
   197  			}
   198  		default:
   199  			return nil, nil, InvalidCompare, errUnsupportedComparisonType.New()
   200  		}
   201  	}
   202  
   203  	var compType ComparisonType
   204  	if len(variables) == 2 {
   205  		compType = VariableVariableCompare
   206  	} else if len(variables) == 1 {
   207  		if len(consts) == 1 {
   208  			compType = VariableConstCompare
   209  		} else if len(consts) > 1 {
   210  			compType = VariableInLiteralList
   211  		}
   212  	} else if len(consts) == 2 {
   213  		compType = ConstConstCompare
   214  	}
   215  
   216  	return variables, consts, compType, nil
   217  }
   218  
   219  var trueFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return true, nil }
   220  var falseFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return false, nil }
   221  
   222  func newComparisonFunc(op CompareOp, exp expression.BinaryExpression, sch schema.Schema) (ExpressionFunc, error) {
   223  	vars, consts, compType, err := GetComparisonType(exp)
   224  
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	if compType == ConstConstCompare {
   230  		res, err := op.CompareLiterals(consts[0], consts[1])
   231  
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  
   236  		if res {
   237  			return trueFunc, nil
   238  		} else {
   239  			return falseFunc, nil
   240  		}
   241  	} else if compType == VariableConstCompare {
   242  		colName := vars[0].Name()
   243  		col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName)
   244  
   245  		if !ok {
   246  			return nil, errUnknownColumn.New(colName)
   247  		}
   248  
   249  		tag := col.Tag
   250  		nomsVal, err := LiteralToNomsValue(col.Kind, consts[0])
   251  
   252  		if err != nil {
   253  			return nil, err
   254  		}
   255  
   256  		compareNomsValues := op.CompareNomsValues
   257  		compareToNil := op.CompareToNil
   258  
   259  		return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   260  			colVal, ok := vals[tag]
   261  
   262  			if ok && !types.IsNull(colVal) {
   263  				return compareNomsValues(ctx, colVal, nomsVal)
   264  			} else {
   265  				return compareToNil(nomsVal)
   266  			}
   267  		}, nil
   268  	} else if compType == VariableVariableCompare {
   269  		col1Name := vars[0].Name()
   270  		col1, ok := sch.GetAllCols().GetByNameCaseInsensitive(col1Name)
   271  
   272  		if !ok {
   273  			return nil, errUnknownColumn.New(col1Name)
   274  		}
   275  
   276  		col2Name := vars[1].Name()
   277  		col2, ok := sch.GetAllCols().GetByNameCaseInsensitive(col2Name)
   278  
   279  		if !ok {
   280  			return nil, errUnknownColumn.New(col2Name)
   281  		}
   282  
   283  		compareNomsValues := op.CompareNomsValues
   284  		compareToNull := op.CompareToNil
   285  
   286  		tag1, tag2 := col1.Tag, col2.Tag
   287  		return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   288  			v1 := vals[tag1]
   289  			v2 := vals[tag2]
   290  
   291  			if types.IsNull(v1) {
   292  				return compareToNull(v2)
   293  			} else {
   294  				return compareNomsValues(ctx, v1, v2)
   295  			}
   296  		}, nil
   297  	} else if compType == VariableInLiteralList {
   298  		colName := vars[0].Name()
   299  		col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName)
   300  
   301  		if !ok {
   302  			return nil, errUnknownColumn.New(colName)
   303  		}
   304  
   305  		tag := col.Tag
   306  
   307  		// Get all the noms values
   308  		nomsVals := make([]types.Value, len(consts))
   309  		for i, c := range consts {
   310  			nomsVal, err := LiteralToNomsValue(col.Kind, c)
   311  			if err != nil {
   312  				return nil, err
   313  			}
   314  			nomsVals[i] = nomsVal
   315  		}
   316  
   317  		compareNomsValues := op.CompareNomsValues
   318  		compareToNil := op.CompareToNil
   319  
   320  		return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   321  			colVal, ok := vals[tag]
   322  
   323  			for _, nv := range nomsVals {
   324  				var lb bool
   325  				if ok && !types.IsNull(colVal) {
   326  					lb, err = compareNomsValues(ctx, colVal, nv)
   327  				} else {
   328  					lb, err = compareToNil(nv)
   329  				}
   330  
   331  				if err != nil {
   332  					return false, err
   333  				}
   334  				if lb {
   335  					return true, nil
   336  				}
   337  			}
   338  
   339  			return false, nil
   340  		}, nil
   341  	} else {
   342  		return nil, errUnsupportedComparisonType.New()
   343  	}
   344  }