github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/expreval/expression_evaluator.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expreval
    16  
    17  import (
    18  	"context"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	"gopkg.in/src-d/go-errors.v1"
    23  
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    25  	"github.com/dolthub/dolt/go/store/types"
    26  )
    27  
    28  var errUnsupportedComparisonType = errors.NewKind("Unsupported Comparison Type.")
    29  var errUnknownColumn = errors.NewKind("Column %s not found.")
    30  var errInvalidConversion = errors.NewKind("Could not convert %s from %s to %s.")
    31  var errNotImplemented = errors.NewKind("Not Implemented: %s")
    32  
    33  // ExpressionFunc is a function that takes a map of tag to value and returns whether some set of criteria are true for
    34  // the set of values
    35  type ExpressionFunc func(ctx context.Context, vals map[uint64]types.Value) (bool, error)
    36  
    37  // ExpressionFuncFromSQLExpressions returns an ExpressionFunc which represents the slice of sql.Expressions passed in
    38  func ExpressionFuncFromSQLExpressions(nbf *types.NomsBinFormat, sch schema.Schema, expressions []sql.Expression) (ExpressionFunc, error) {
    39  	var root ExpressionFunc
    40  	for _, exp := range expressions {
    41  		expFunc, err := getExpFunc(nbf, sch, exp)
    42  
    43  		if err != nil {
    44  			return nil, err
    45  		}
    46  
    47  		if root == nil {
    48  			root = expFunc
    49  		} else {
    50  			root = newAndFunc(root, expFunc)
    51  		}
    52  	}
    53  
    54  	if root == nil {
    55  		root = func(ctx context.Context, vals map[uint64]types.Value) (bool, error) {
    56  			return true, nil
    57  		}
    58  	}
    59  
    60  	return root, nil
    61  }
    62  
    63  func getExpFunc(nbf *types.NomsBinFormat, sch schema.Schema, exp sql.Expression) (ExpressionFunc, error) {
    64  	switch typedExpr := exp.(type) {
    65  	case *expression.Equals:
    66  		return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
    67  	case *expression.GreaterThan:
    68  		return newComparisonFunc(GreaterOp{nbf}, typedExpr.BinaryExpression, sch)
    69  	case *expression.GreaterThanOrEqual:
    70  		return newComparisonFunc(GreaterEqualOp{nbf}, typedExpr.BinaryExpression, sch)
    71  	case *expression.LessThan:
    72  		return newComparisonFunc(LessOp{nbf}, typedExpr.BinaryExpression, sch)
    73  	case *expression.LessThanOrEqual:
    74  		return newComparisonFunc(LessEqualOp{nbf}, typedExpr.BinaryExpression, sch)
    75  	case *expression.Or:
    76  		leftFunc, err := getExpFunc(nbf, sch, typedExpr.Left)
    77  
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  
    82  		rightFunc, err := getExpFunc(nbf, sch, typedExpr.Right)
    83  
    84  		if err != nil {
    85  			return nil, err
    86  		}
    87  
    88  		return newOrFunc(leftFunc, rightFunc), nil
    89  	case *expression.And:
    90  		leftFunc, err := getExpFunc(nbf, sch, typedExpr.Left)
    91  
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  
    96  		rightFunc, err := getExpFunc(nbf, sch, typedExpr.Right)
    97  
    98  		if err != nil {
    99  			return nil, err
   100  		}
   101  
   102  		return newAndFunc(leftFunc, rightFunc), nil
   103  	case *expression.InTuple:
   104  		return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
   105  	}
   106  
   107  	return nil, errNotImplemented.New(exp.Type().String())
   108  }
   109  
   110  func newOrFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc {
   111  	return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   112  		lRes, err := left(ctx, vals)
   113  
   114  		if err != nil {
   115  			return false, err
   116  		}
   117  
   118  		if lRes {
   119  			return true, nil
   120  		}
   121  
   122  		return right(ctx, vals)
   123  	}
   124  }
   125  
   126  func newAndFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc {
   127  	return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   128  		lRes, err := left(ctx, vals)
   129  
   130  		if err != nil {
   131  			return false, err
   132  		}
   133  
   134  		if !lRes {
   135  			return false, nil
   136  		}
   137  
   138  		return right(ctx, vals)
   139  	}
   140  }
   141  
   142  type ComparisonType int
   143  
   144  const (
   145  	InvalidCompare ComparisonType = iota
   146  	VariableConstCompare
   147  	VariableVariableCompare
   148  	VariableInLiteralList
   149  	ConstConstCompare
   150  )
   151  
   152  // GetComparisonType looks at a go-mysql-server BinaryExpression classifies the left and right arguments
   153  // as variables or constants.
   154  func GetComparisonType(be expression.BinaryExpression) ([]*expression.GetField, []*expression.Literal, ComparisonType, error) {
   155  	var variables []*expression.GetField
   156  	var consts []*expression.Literal
   157  
   158  	for _, curr := range []sql.Expression{be.Left, be.Right} {
   159  		// need to remove this and handle properly
   160  		if conv, ok := curr.(*expression.Convert); ok {
   161  			curr = conv.Child
   162  		}
   163  
   164  		switch v := curr.(type) {
   165  		case *expression.GetField:
   166  			variables = append(variables, v)
   167  		case *expression.Literal:
   168  			consts = append(consts, v)
   169  		case expression.Tuple:
   170  			children := v.Children()
   171  			for _, currChild := range children {
   172  				lit, ok := currChild.(*expression.Literal)
   173  				if !ok {
   174  					return nil, nil, InvalidCompare, errUnsupportedComparisonType.New()
   175  				}
   176  				consts = append(consts, lit)
   177  			}
   178  		default:
   179  			return nil, nil, InvalidCompare, errUnsupportedComparisonType.New()
   180  		}
   181  	}
   182  
   183  	var compType ComparisonType
   184  	if len(variables) == 2 {
   185  		compType = VariableVariableCompare
   186  	} else if len(variables) == 1 {
   187  		if len(consts) == 1 {
   188  			compType = VariableConstCompare
   189  		} else if len(consts) > 1 {
   190  			compType = VariableInLiteralList
   191  		}
   192  	} else if len(consts) == 2 {
   193  		compType = ConstConstCompare
   194  	}
   195  
   196  	return variables, consts, compType, nil
   197  }
   198  
   199  var trueFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return true, nil }
   200  var falseFunc = func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) { return false, nil }
   201  
   202  func newComparisonFunc(op CompareOp, exp expression.BinaryExpression, sch schema.Schema) (ExpressionFunc, error) {
   203  	vars, consts, compType, err := GetComparisonType(exp)
   204  
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	if compType == ConstConstCompare {
   210  		res, err := op.CompareLiterals(consts[0], consts[1])
   211  
   212  		if err != nil {
   213  			return nil, err
   214  		}
   215  
   216  		if res {
   217  			return trueFunc, nil
   218  		} else {
   219  			return falseFunc, nil
   220  		}
   221  	} else if compType == VariableConstCompare {
   222  		colName := vars[0].Name()
   223  		col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName)
   224  
   225  		if !ok {
   226  			return nil, errUnknownColumn.New(colName)
   227  		}
   228  
   229  		tag := col.Tag
   230  		nomsVal, err := LiteralToNomsValue(col.Kind, consts[0])
   231  
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  
   236  		compareNomsValues := op.CompareNomsValues
   237  		compareToNil := op.CompareToNil
   238  
   239  		return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   240  			colVal, ok := vals[tag]
   241  
   242  			if ok && !types.IsNull(colVal) {
   243  				return compareNomsValues(colVal, nomsVal)
   244  			} else {
   245  				return compareToNil(nomsVal)
   246  			}
   247  		}, nil
   248  	} else if compType == VariableVariableCompare {
   249  		col1Name := vars[0].Name()
   250  		col1, ok := sch.GetAllCols().GetByNameCaseInsensitive(col1Name)
   251  
   252  		if !ok {
   253  			return nil, errUnknownColumn.New(col1Name)
   254  		}
   255  
   256  		col2Name := vars[1].Name()
   257  		col2, ok := sch.GetAllCols().GetByNameCaseInsensitive(col2Name)
   258  
   259  		if !ok {
   260  			return nil, errUnknownColumn.New(col2Name)
   261  		}
   262  
   263  		compareNomsValues := op.CompareNomsValues
   264  		compareToNull := op.CompareToNil
   265  
   266  		tag1, tag2 := col1.Tag, col2.Tag
   267  		return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   268  			v1 := vals[tag1]
   269  			v2 := vals[tag2]
   270  
   271  			if types.IsNull(v1) {
   272  				return compareToNull(v2)
   273  			} else {
   274  				return compareNomsValues(v1, v2)
   275  			}
   276  		}, nil
   277  	} else if compType == VariableInLiteralList {
   278  		colName := vars[0].Name()
   279  		col, ok := sch.GetAllCols().GetByNameCaseInsensitive(colName)
   280  
   281  		if !ok {
   282  			return nil, errUnknownColumn.New(colName)
   283  		}
   284  
   285  		tag := col.Tag
   286  
   287  		// Get all the noms values
   288  		nomsVals := make([]types.Value, len(consts))
   289  		for i, c := range consts {
   290  			nomsVal, err := LiteralToNomsValue(col.Kind, c)
   291  			if err != nil {
   292  				return nil, err
   293  			}
   294  			nomsVals[i] = nomsVal
   295  		}
   296  
   297  		compareNomsValues := op.CompareNomsValues
   298  		compareToNil := op.CompareToNil
   299  
   300  		return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
   301  			colVal, ok := vals[tag]
   302  
   303  			for _, nv := range nomsVals {
   304  				var lb bool
   305  				if ok && !types.IsNull(colVal) {
   306  					lb, err = compareNomsValues(colVal, nv)
   307  				} else {
   308  					lb, err = compareToNil(nv)
   309  				}
   310  
   311  				if err != nil {
   312  					return false, err
   313  				}
   314  				if lb {
   315  					return true, nil
   316  				}
   317  			}
   318  
   319  			return false, nil
   320  		}, nil
   321  	} else {
   322  		return nil, errUnsupportedComparisonType.New()
   323  	}
   324  }