github.com/rminnich/u-root@v7.0.0+incompatible/pkg/pogosh/arithmetic.go (about)

     1  // Copyright 2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package pogosh
     6  
     7  import (
     8  	"fmt"
     9  	"math/big"
    10  )
    11  
    12  // After variable substitution has occured, this
    13  // evaluates the arithmetic expression.
    14  
    15  // These rules come from ISO C standard, section 6.5 Expression and section
    16  // 6.4.4.1 Integer Constants.
    17  // LL(3) grammar
    18  
    19  // Params:
    20  //	 map[string]Variable
    21  // Returns:
    22  //   big.Int: value
    23  //   map[string]big.Int: assignments
    24  // Panics: on parse error
    25  
    26  // Arithmetic computes value in $((...)) expression.
    27  // TODO: why public?
    28  type Arithmetic struct {
    29  	getVar func(string) *big.Int
    30  	setVar func(string, *big.Int)
    31  
    32  	// Initial string
    33  	input string
    34  
    35  	// Remaining unparsed string
    36  	rem string
    37  }
    38  
    39  // if x == 0, 0
    40  // otherwise, 1
    41  func bigBool(x *big.Int) *big.Int {
    42  	if x.BitLen() == 0 {
    43  		return big.NewInt(0)
    44  	}
    45  	return big.NewInt(1)
    46  }
    47  
    48  func asBig(b bool) *big.Int {
    49  	if b {
    50  		return big.NewInt(1)
    51  	}
    52  	return big.NewInt(0)
    53  }
    54  
    55  // Wrapper which should be used outside this file.
    56  func (a *Arithmetic) evalExpression() *big.Int {
    57  	// TODO: this might not be standard and might be inefficient
    58  	// Augment for LL(3)
    59  	a.rem = a.input + "\000\000\000"
    60  	val := a.evalAssignmentExpression()
    61  	a.evalSpaces()
    62  	if a.rem != "\000\000\000" {
    63  		panic("Expected EOF at " + a.rem)
    64  	}
    65  	return val
    66  }
    67  
    68  func (a *Arithmetic) evalSpaces() {
    69  	for a.rem[0] == ' ' || a.rem[0] == '\t' || a.rem[0] == '\n' { // TODO: more space characters
    70  		a.rem = a.rem[1:]
    71  	}
    72  }
    73  
    74  // [_0-9a-zA-Z]
    75  func isIdentifierChar(char byte) bool {
    76  	return char == '_' || ('a' <= char && char <= 'z') || ('A' <= char && char <= 'Z') || ('0' <= char && char <= '9')
    77  }
    78  
    79  func isDecimal(char byte) bool {
    80  	return '0' <= char && char <= '9'
    81  }
    82  
    83  func isHex(char byte) bool {
    84  	return ('0' <= char && char <= '9') || ('a' <= char && char <= 'f') || ('A' <= char && char <= 'F')
    85  }
    86  
    87  func isOctal(char byte) bool {
    88  	return ('0' <= char && char <= '7')
    89  }
    90  
    91  // Identifier ::= [_a-zA-Z][_0-9a-zA-Z]*
    92  func (a *Arithmetic) evalIdentifier() *big.Int {
    93  	// This does not check the first character is not numeric because it has
    94  	// already been done as part of the callee's FIRST set.
    95  	i := 0
    96  	for isIdentifierChar(a.rem[i]) {
    97  		i++
    98  	}
    99  
   100  	identifier := a.rem[:i]
   101  	a.rem = a.rem[i:]
   102  	return a.getVar(identifier)
   103  }
   104  
   105  // Constant ::= DecimalConstant | OctalConstant | HexadecimalConstant
   106  // DecimalConstant ::= [1-9][0-9]*
   107  // OctalConstant ::= 0[0-9]*
   108  // HexadecimalConstant ::= 0x[0-9]* | 0X[0-9]*
   109  func (a *Arithmetic) evalConstant() *big.Int {
   110  	// Get length of constant.
   111  	i := 0
   112  	if a.rem[i] == '0' {
   113  		i++
   114  		if a.rem[i] == 'x' || a.rem[i] == 'X' {
   115  			i++
   116  			for isHex(a.rem[i]) {
   117  				i++
   118  			}
   119  		} else {
   120  			for isOctal(a.rem[i]) {
   121  				i++
   122  			}
   123  		}
   124  	} else {
   125  		for isDecimal(a.rem[i]) {
   126  			i++
   127  		}
   128  	}
   129  
   130  	var val big.Int
   131  	_, ok := val.SetString(a.rem[:i], 0)
   132  	if !ok {
   133  		panic("Not a valid constant")
   134  	}
   135  	a.rem = a.rem[i:]
   136  	return &val
   137  }
   138  
   139  // PrimaryExpression ::= Identifier | Constant | '(' AssignmentExpression ')'
   140  func (a *Arithmetic) evalPrimaryExpression() *big.Int {
   141  	a.evalSpaces()
   142  	char := a.rem[0]
   143  	switch {
   144  	case '0' <= char && char <= '9':
   145  		return a.evalConstant()
   146  	case isIdentifierChar(char):
   147  		return a.evalIdentifier()
   148  	case char == '(':
   149  		a.rem = a.rem[1:]
   150  		val := a.evalAssignmentExpression()
   151  		a.evalSpaces()
   152  		if a.rem[0] != ')' {
   153  			panic("No matching closing parenthesis")
   154  		}
   155  		a.rem = a.rem[1:]
   156  		return val
   157  	default:
   158  		panic(fmt.Sprintf("Expected identifier or constant at %c", char))
   159  	}
   160  }
   161  
   162  // UnaryExpression ::= PrimaryExpression
   163  // UnaryExpression ::= UnaryOperator UnaryExpression
   164  // UnaryOperator ::= '+' | '-' | '~' | '!'
   165  func (a *Arithmetic) evalUnaryExpression() *big.Int {
   166  	val := big.NewInt(0)
   167  	a.evalSpaces()
   168  	switch a.rem[0] {
   169  	case '+':
   170  		a.rem = a.rem[1:]
   171  		val = a.evalUnaryExpression()
   172  	case '-':
   173  		a.rem = a.rem[1:]
   174  		val.Neg(a.evalUnaryExpression())
   175  	case '~':
   176  		a.rem = a.rem[1:]
   177  		val.Not(a.evalUnaryExpression())
   178  	case '!':
   179  		a.rem = a.rem[1:]
   180  		val.Xor(big.NewInt(1), bigBool(a.evalUnaryExpression()))
   181  	default:
   182  		val = a.evalPrimaryExpression()
   183  	}
   184  	return val
   185  }
   186  
   187  // MultiplicativeExpression ::= UnaryExpression MultiplicativeExpression2
   188  // MultiplicativeExpression2 ::= MultiplicativeOperator MultiplicativeExpression |
   189  // MultiplicativeOperator ::= '*' | '/' | '%'
   190  func (a *Arithmetic) evalMultiplicativeExpression() *big.Int {
   191  	val := a.evalUnaryExpression()
   192  	for {
   193  		a.evalSpaces()
   194  		switch a.rem[0] {
   195  		case '*':
   196  			a.rem = a.rem[1:]
   197  			val.Mul(val, a.evalUnaryExpression())
   198  		case '/':
   199  			a.rem = a.rem[1:]
   200  			val.Div(val, a.evalUnaryExpression())
   201  		case '%':
   202  			a.rem = a.rem[1:]
   203  			val.Mod(val, a.evalUnaryExpression())
   204  		default:
   205  			return val
   206  		}
   207  	}
   208  }
   209  
   210  // AdditiveExpression ::= MultiplicativeExpression AdditiveExpression2
   211  // AdditiveExpression2 ::= AdditiveOperator AdditiveExpression |
   212  // AdditiveOperator ::= '+' | '-'
   213  func (a *Arithmetic) evalAdditiveExpression() *big.Int {
   214  	val := a.evalMultiplicativeExpression()
   215  	for {
   216  		a.evalSpaces()
   217  		switch a.rem[0] {
   218  		case '+':
   219  			a.rem = a.rem[1:]
   220  			val.Add(val, a.evalMultiplicativeExpression())
   221  		case '-':
   222  			a.rem = a.rem[1:]
   223  			val.Sub(val, a.evalMultiplicativeExpression())
   224  		default:
   225  			return val
   226  		}
   227  	}
   228  }
   229  
   230  // ShiftExpression ::= AdditiveExpression ShiftExpression2
   231  // ShiftExpression2 ::= ShiftOperator ShiftExpression |
   232  // ShiftOperator ::= '<<' | '>>'
   233  func (a *Arithmetic) evalShiftExpression() *big.Int {
   234  	val := a.evalAdditiveExpression()
   235  	for {
   236  		a.evalSpaces()
   237  		switch a.rem[:2] {
   238  		case "<<":
   239  			a.rem = a.rem[2:]
   240  			// TODO: might be undefined if > UINT64_MAX
   241  			val.Lsh(val, uint(a.evalAdditiveExpression().Uint64()))
   242  		case ">>":
   243  			a.rem = a.rem[2:]
   244  			// TODO: might be undefined if > UINT64_MAX
   245  			val.Rsh(val, uint(a.evalAdditiveExpression().Uint64()))
   246  		default:
   247  			return val
   248  		}
   249  	}
   250  }
   251  
   252  // RelationalExpression ::= ShiftExpression RelationalExpression2
   253  // RelationalExpression2 ::= RelationalOperator RelationalExpression |
   254  // RelationalOperator ::= '<' | '>' | '<=' | '>='
   255  func (a *Arithmetic) evalRelationalExpression() *big.Int {
   256  	val := a.evalShiftExpression()
   257  	for {
   258  		a.evalSpaces()
   259  		switch {
   260  		case a.rem[:2] == "<=":
   261  			a.rem = a.rem[2:]
   262  			val = asBig(val.Cmp(a.evalShiftExpression()) <= 0)
   263  		case a.rem[:2] == ">=":
   264  			a.rem = a.rem[2:]
   265  			val = asBig(val.Cmp(a.evalShiftExpression()) >= 0)
   266  		case a.rem[0] == '<' && a.rem[1] != '<':
   267  			a.rem = a.rem[2:]
   268  			val = asBig(val.Cmp(a.evalShiftExpression()) < 0)
   269  		case a.rem[0] == '>' && a.rem[1] != '>':
   270  			a.rem = a.rem[2:]
   271  			val = asBig(val.Cmp(a.evalShiftExpression()) > 0)
   272  		default:
   273  			return val
   274  		}
   275  	}
   276  }
   277  
   278  // EqualityExpression ::= RelationalExpression EqualityExpression2
   279  // EqualityExpression2 ::= EqualityOperator EqualityExpression
   280  // EqualityOperator ::= '==' | '!='
   281  func (a *Arithmetic) evalEqualityExpression() *big.Int {
   282  	val := a.evalRelationalExpression()
   283  	for {
   284  		a.evalSpaces()
   285  		switch a.rem[:2] {
   286  		case "==":
   287  			a.rem = a.rem[2:]
   288  			val = asBig(val.Cmp(a.evalRelationalExpression()) == 0)
   289  		case "!=":
   290  			a.rem = a.rem[2:]
   291  			val = asBig(val.Cmp(a.evalRelationalExpression()) != 0)
   292  		default:
   293  			return val
   294  		}
   295  	}
   296  }
   297  
   298  // ANDExpression ::= EqualityExpression AndExpression2
   299  // ANDExpression2 ::= '&' ANDExpression |
   300  func (a *Arithmetic) evalANDExpression() *big.Int {
   301  	val := a.evalEqualityExpression()
   302  	for {
   303  		a.evalSpaces()
   304  		switch {
   305  		case a.rem[0] == '&' && a.rem[1] != '&':
   306  			a.rem = a.rem[2:]
   307  			val.And(val, bigBool(a.evalEqualityExpression()))
   308  		default:
   309  			return val
   310  		}
   311  	}
   312  }
   313  
   314  // ExclusiveORExpression ::= ANDExpression ExclusiveORExpression2
   315  // ExclusiveORExpression2 ::= '^' ExclusiveORExpression |
   316  func (a *Arithmetic) evalExclusiveORExpression() *big.Int {
   317  	val := a.evalANDExpression()
   318  	for {
   319  		a.evalSpaces()
   320  		switch a.rem[0] {
   321  		case '^':
   322  			a.rem = a.rem[1:]
   323  			val.Xor(val, bigBool(a.evalANDExpression()))
   324  		default:
   325  			return val
   326  		}
   327  	}
   328  }
   329  
   330  // InclusiveORExpression ::= ExclusiveORExpression InclusiveORExpression2
   331  // InclusiveORExpression2 ::= '|' InclusiveORExpression |
   332  func (a *Arithmetic) evalInclusiveORExpression() *big.Int {
   333  	val := a.evalExclusiveORExpression()
   334  	for {
   335  		a.evalSpaces()
   336  		switch {
   337  		case a.rem[0] == '|' && a.rem[1] != '|':
   338  			a.rem = a.rem[2:]
   339  			val.Or(val, bigBool(a.evalExclusiveORExpression()))
   340  		default:
   341  			return val
   342  		}
   343  	}
   344  }
   345  
   346  // LogicalANDExpression ::= InclusiveORExpression LogicalANDExpression2
   347  // LogicalANDExpression2 ::= '&&' InclusiveORExpression |
   348  func (a *Arithmetic) evalLogicalANDExpression() *big.Int {
   349  	val := a.evalInclusiveORExpression()
   350  	for {
   351  		a.evalSpaces()
   352  		switch a.rem[:2] {
   353  		case "&&":
   354  			a.rem = a.rem[2:]
   355  			val.And(bigBool(val), bigBool(a.evalInclusiveORExpression()))
   356  		default:
   357  			return val
   358  		}
   359  	}
   360  }
   361  
   362  // LogicalORExpression ::= LogicalANDExpression LogicalORExpression2
   363  // LogicalORExpression2 ::= '||' LogicalORExpression |
   364  func (a *Arithmetic) evalLogicalORExpression() *big.Int {
   365  	val := a.evalLogicalANDExpression()
   366  	for {
   367  		a.evalSpaces()
   368  		switch a.rem[:2] {
   369  		case "||":
   370  			a.rem = a.rem[2:]
   371  			val.Or(bigBool(val), bigBool(a.evalLogicalANDExpression()))
   372  		default:
   373  			return val
   374  		}
   375  	}
   376  }
   377  
   378  // ConditionalExpression ::= LogicalORExpression ConditionalExpression2
   379  // ConditionalExpression2 ::= '?' AssignmentExpression ':' ConditionalExpression |
   380  func (a *Arithmetic) evalConditionalExpression() *big.Int {
   381  	val := a.evalLogicalORExpression()
   382  
   383  	a.evalSpaces()
   384  	if a.rem[0] != '?' {
   385  		return val
   386  	}
   387  	a.rem = a.rem[1:]
   388  	trueVal := a.evalAssignmentExpression()
   389  
   390  	if a.rem[0] != ':' {
   391  		panic("Bad conditional expression")
   392  	}
   393  	a.rem = a.rem[1:]
   394  	falseVal := a.evalConditionalExpression()
   395  
   396  	if val.BitLen() == 0 {
   397  		return falseVal
   398  	}
   399  	return trueVal
   400  }
   401  
   402  // AssignmentExpression ::= ConditionalExpression
   403  // AssignmentExpression ::= Identifier AssignmentOperator AssignmentExpression
   404  // AssignmentOperator ::= '=' | '*=' | '/=' | '%=' | '+=' | '-=' | '<<='
   405  //                      | '>>=' | '&=' | '^=' | '|='
   406  func (a *Arithmetic) evalAssignmentExpression() *big.Int {
   407  	val := a.evalConditionalExpression()
   408  	// TODO: assignment
   409  	// TODO: some other follow sets need to be updated
   410  	return val
   411  }