github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/gql/math.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package gql
    18  
    19  import (
    20  	"bytes"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"github.com/dgraph-io/dgraph/lex"
    25  	"github.com/dgraph-io/dgraph/types"
    26  	"github.com/dgraph-io/dgraph/x"
    27  	"github.com/pkg/errors"
    28  )
    29  
    30  type mathTreeStack struct{ a []*MathTree }
    31  
    32  func (s *mathTreeStack) empty() bool      { return len(s.a) == 0 }
    33  func (s *mathTreeStack) size() int        { return len(s.a) }
    34  func (s *mathTreeStack) push(t *MathTree) { s.a = append(s.a, t) }
    35  
    36  func (s *mathTreeStack) popAssert() *MathTree {
    37  	x.AssertTruef(!s.empty(), "Expected a non-empty stack")
    38  	last := s.a[len(s.a)-1]
    39  	s.a = s.a[:len(s.a)-1]
    40  	return last
    41  }
    42  
    43  func (s *mathTreeStack) pop() (*MathTree, error) {
    44  	if s.empty() {
    45  		return nil, errors.Errorf("Empty stack")
    46  	}
    47  	last := s.a[len(s.a)-1]
    48  	s.a = s.a[:len(s.a)-1]
    49  	return last, nil
    50  }
    51  
    52  func (s *mathTreeStack) peek() *MathTree {
    53  	x.AssertTruef(!s.empty(), "Trying to peek empty stack")
    54  	return s.a[len(s.a)-1]
    55  }
    56  
    57  // MathTree represents math operations in tree form for evaluation.
    58  type MathTree struct {
    59  	Fn    string
    60  	Var   string
    61  	Const types.Val // This will always be parsed as a float value
    62  	Val   map[uint64]types.Val
    63  	Child []*MathTree
    64  }
    65  
    66  func isUnary(f string) bool {
    67  	return f == "exp" || f == "ln" || f == "u-" || f == "sqrt" ||
    68  		f == "floor" || f == "ceil" || f == "since"
    69  }
    70  
    71  func isBinaryMath(f string) bool {
    72  	return f == "*" || f == "+" || f == "-" || f == "/" || f == "%"
    73  }
    74  
    75  func isTernary(f string) bool {
    76  	return f == "cond"
    77  }
    78  
    79  func isZero(f string, rval types.Val) bool {
    80  	if rval.Tid != types.FloatID {
    81  		return false
    82  	}
    83  	g, ok := rval.Value.(float64)
    84  	if !ok {
    85  		return false
    86  	}
    87  	switch f {
    88  	case "floor":
    89  		return g >= 0 && g < 1.0
    90  	case "/", "%", "ceil", "sqrt", "u-":
    91  		return g == 0
    92  	case "ln":
    93  		return g == 1
    94  	}
    95  	return false
    96  }
    97  
    98  func evalMathStack(opStack, valueStack *mathTreeStack) error {
    99  	topOp, err := opStack.pop()
   100  	if err != nil {
   101  		return errors.Errorf("Invalid Math expression")
   102  	}
   103  	if isUnary(topOp.Fn) {
   104  		// Since "not" is a unary operator, just pop one value.
   105  		topVal, err := valueStack.pop()
   106  		if err != nil {
   107  			return errors.Errorf("Invalid math statement. Expected 1 operands")
   108  		}
   109  		if opStack.size() > 1 {
   110  			peek := opStack.peek().Fn
   111  			if (peek == "/" || peek == "%") && isZero(topOp.Fn, topVal.Const) {
   112  				return errors.Errorf("Division by zero")
   113  			}
   114  		}
   115  		topOp.Child = []*MathTree{topVal}
   116  
   117  	} else if isTernary(topOp.Fn) {
   118  		if valueStack.size() < 3 {
   119  			return errors.Errorf("Invalid Math expression. Expected 3 operands")
   120  		}
   121  		topVal1 := valueStack.popAssert()
   122  		topVal2 := valueStack.popAssert()
   123  		topVal3 := valueStack.popAssert()
   124  		topOp.Child = []*MathTree{topVal3, topVal2, topVal1}
   125  
   126  	} else {
   127  		if valueStack.size() < 2 {
   128  			return errors.Errorf("Invalid Math expression. Expected 2 operands")
   129  		}
   130  		if isZero(topOp.Fn, valueStack.peek().Const) {
   131  			return errors.Errorf("Division by zero.")
   132  		}
   133  		topVal1 := valueStack.popAssert()
   134  		topVal2 := valueStack.popAssert()
   135  		topOp.Child = []*MathTree{topVal2, topVal1}
   136  
   137  	}
   138  	// Push the new value (tree) into the valueStack.
   139  	valueStack.push(topOp)
   140  	return nil
   141  }
   142  
   143  func isMathFunc(f string) bool {
   144  	// While adding an op, also add it to the corresponding function type.
   145  	return f == "*" || f == "%" || f == "+" || f == "-" || f == "/" ||
   146  		f == "exp" || f == "ln" || f == "cond" ||
   147  		f == "<" || f == ">" || f == ">=" || f == "<=" ||
   148  		f == "==" || f == "!=" ||
   149  		f == "min" || f == "max" || f == "sqrt" ||
   150  		f == "pow" || f == "logbase" || f == "floor" || f == "ceil" ||
   151  		f == "since"
   152  }
   153  
   154  func parseMathFunc(it *lex.ItemIterator, again bool) (*MathTree, bool, error) {
   155  	if !again {
   156  		it.Next()
   157  		item := it.Item()
   158  		if item.Typ != itemLeftRound {
   159  			return nil, false, errors.Errorf("Expected ( after math")
   160  		}
   161  	}
   162  
   163  	// opStack is used to collect the operators in right order.
   164  	opStack := new(mathTreeStack)
   165  	opStack.push(&MathTree{Fn: "("}) // Push ( onto operator stack.
   166  	// valueStack is used to collect the values.
   167  	valueStack := new(mathTreeStack)
   168  
   169  	for it.Next() {
   170  		item := it.Item()
   171  		lval := strings.ToLower(item.Val)
   172  		if isMathFunc(lval) {
   173  			op := lval
   174  			it.Prev()
   175  			lastItem := it.Item()
   176  			it.Next()
   177  			if op == "-" &&
   178  				(lastItem.Val == "(" || lastItem.Val == "," || isBinaryMath(lastItem.Val)) {
   179  				op = "u-" // This is a unary -
   180  			}
   181  			opPred := mathOpPrecedence[op]
   182  			x.AssertTruef(opPred > 0, "Expected opPred > 0 for %v: %d", op, opPred)
   183  			// Evaluate the stack until we see an operator with strictly lower pred.
   184  			for !opStack.empty() {
   185  				topOp := opStack.peek()
   186  				if mathOpPrecedence[topOp.Fn] < opPred {
   187  					break
   188  				}
   189  				err := evalMathStack(opStack, valueStack)
   190  				if err != nil {
   191  					return nil, false, err
   192  				}
   193  			}
   194  			opStack.push(&MathTree{Fn: op}) // Push current operator.
   195  			peekIt, err := it.Peek(1)
   196  			if err != nil {
   197  				return nil, false, err
   198  			}
   199  			if peekIt[0].Typ == itemLeftRound {
   200  				again := false
   201  				var child *MathTree
   202  				for {
   203  					child, again, err = parseMathFunc(it, again)
   204  					if err != nil {
   205  						return nil, false, err
   206  					}
   207  					valueStack.push(child)
   208  					if !again {
   209  						break
   210  					}
   211  				}
   212  			}
   213  		} else if item.Typ == itemName { // Value.
   214  			peekIt, err := it.Peek(1)
   215  			if err != nil {
   216  				return nil, false, err
   217  			}
   218  			if peekIt[0].Typ == itemLeftRound {
   219  				again := false
   220  				if !isMathFunc(item.Val) {
   221  					return nil, false, errors.Errorf("Unknown math function: %v", item.Val)
   222  				}
   223  				var child *MathTree
   224  				for {
   225  					child, again, err = parseMathFunc(it, again)
   226  					if err != nil {
   227  						return nil, false, err
   228  					}
   229  					valueStack.push(child)
   230  					if !again {
   231  						break
   232  					}
   233  				}
   234  				continue
   235  			}
   236  			// Try to parse it as a constant.
   237  			child := &MathTree{}
   238  			v, err := strconv.ParseFloat(item.Val, 64)
   239  			if err != nil {
   240  				child.Var = item.Val
   241  			} else {
   242  				child.Const = types.Val{
   243  					Tid:   types.FloatID,
   244  					Value: v,
   245  				}
   246  			}
   247  			valueStack.push(child)
   248  		} else if item.Typ == itemLeftRound { // Just push to op stack.
   249  			opStack.push(&MathTree{Fn: "("})
   250  
   251  		} else if item.Typ == itemComma {
   252  			for !opStack.empty() {
   253  				topOp := opStack.peek()
   254  				if topOp.Fn == "(" {
   255  					break
   256  				}
   257  				err := evalMathStack(opStack, valueStack)
   258  				if err != nil {
   259  					return nil, false, err
   260  				}
   261  			}
   262  			_, err := opStack.pop() // Pop away the (.
   263  			if err != nil {
   264  				return nil, false, errors.Errorf("Invalid Math expression")
   265  			}
   266  			if !opStack.empty() {
   267  				return nil, false, errors.Errorf("Invalid math expression.")
   268  			}
   269  			if valueStack.size() != 1 {
   270  				return nil, false, errors.Errorf("Expected one item in value stack, but got %d",
   271  					valueStack.size())
   272  			}
   273  			res, err := valueStack.pop()
   274  			if err != nil {
   275  				return nil, false, err
   276  			}
   277  			return res, true, nil
   278  		} else if item.Typ == itemRightRound { // Pop op stack until we see a (.
   279  			for !opStack.empty() {
   280  				topOp := opStack.peek()
   281  				if topOp.Fn == "(" {
   282  					break
   283  				}
   284  				err := evalMathStack(opStack, valueStack)
   285  				if err != nil {
   286  					return nil, false, err
   287  				}
   288  			}
   289  			_, err := opStack.pop() // Pop away the (.
   290  			if err != nil {
   291  				return nil, false, errors.Errorf("Invalid Math expression")
   292  			}
   293  			if opStack.empty() {
   294  				// The parentheses are balanced out. Let's break.
   295  				break
   296  			}
   297  		} else {
   298  			return nil, false, errors.Errorf("Unexpected item while parsing math expression: %v", item)
   299  		}
   300  	}
   301  
   302  	// For math Expressions, we start with ( and end with ). We expect to break out of loop
   303  	// when the parentheses balance off, and at that point, opStack should be empty.
   304  	// For other applications, typically after all items are
   305  	// consumed, we will run a loop like "while opStack is nonempty, evalStack".
   306  	// This is not needed here.
   307  	x.AssertTruef(opStack.empty(), "Op stack should be empty when we exit")
   308  
   309  	if valueStack.empty() {
   310  		// This happens when we have math(). We can either return an error or
   311  		// ignore. Currently, let's just ignore and pretend there is no expression.
   312  		return nil, false, errors.Errorf("Empty () not allowed in math block.")
   313  	}
   314  
   315  	if valueStack.size() != 1 {
   316  		return nil, false, errors.Errorf("Expected one item in value stack, but got %d",
   317  			valueStack.size())
   318  	}
   319  	res, err := valueStack.pop()
   320  	return res, false, err
   321  }
   322  
   323  // debugString converts mathTree to a string. Good for testing, debugging.
   324  // nolint: unused
   325  func (t *MathTree) debugString() string {
   326  	buf := bytes.NewBuffer(make([]byte, 0, 20))
   327  	t.stringHelper(buf)
   328  	return buf.String()
   329  }
   330  
   331  // stringHelper does simple DFS to convert MathTree to string.
   332  // nolint: unused
   333  func (t *MathTree) stringHelper(buf *bytes.Buffer) {
   334  	x.AssertTruef(t != nil, "Nil Math tree")
   335  	if t.Var != "" {
   336  		// Leaf node.
   337  		buf.WriteString(t.Var)
   338  		return
   339  	}
   340  	if t.Const.Value != nil {
   341  		// Leaf node.
   342  		buf.WriteString(strconv.FormatFloat(t.Const.Value.(float64), 'E', -1, 64))
   343  		return
   344  	}
   345  	// Non-leaf node.
   346  	buf.WriteRune('(')
   347  	switch t.Fn {
   348  	case "+", "-", "/", "*", "%", "exp", "ln", "cond", "min",
   349  		"sqrt", "max", "<", ">", "<=", ">=", "==", "!=", "u-",
   350  		"logbase", "pow":
   351  		buf.WriteString(t.Fn)
   352  	default:
   353  		x.Fatalf("Unknown operator: %q", t.Fn)
   354  	}
   355  
   356  	for _, c := range t.Child {
   357  		buf.WriteRune(' ')
   358  		c.stringHelper(buf)
   359  	}
   360  	buf.WriteRune(')')
   361  }