src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/eval/builtin_fn_num.go (about)

     1  package eval
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"math/big"
     7  	"math/rand"
     8  	"strconv"
     9  	"sync"
    10  	"time"
    11  
    12  	"src.elv.sh/pkg/eval/errs"
    13  	"src.elv.sh/pkg/eval/vals"
    14  )
    15  
    16  // Numerical operations.
    17  
    18  func init() {
    19  	addBuiltinFns(map[string]any{
    20  		// Constructor
    21  		"num":         num,
    22  		"exact-num":   exactNum,
    23  		"inexact-num": inexactNum,
    24  
    25  		// Comparison
    26  		"<":  lt,
    27  		"<=": le,
    28  		"==": eqNum,
    29  		"!=": ne,
    30  		">":  gt,
    31  		">=": ge,
    32  
    33  		// Arithmetic
    34  		"+": add,
    35  		"-": sub,
    36  		"*": mul,
    37  		// Also handles cd /
    38  		"/": slash,
    39  		"%": rem,
    40  
    41  		// Random
    42  		"rand":      randFn,
    43  		"randint":   randint,
    44  		"-randseed": randseed,
    45  
    46  		"range": rangeFn,
    47  	})
    48  
    49  }
    50  
    51  func num(n vals.Num) vals.Num {
    52  	// Conversion is actually handled in vals/conversion.go.
    53  	return n
    54  }
    55  
    56  func exactNum(n vals.Num) (vals.Num, error) {
    57  	if f, ok := n.(float64); ok {
    58  		r := new(big.Rat).SetFloat64(f)
    59  		if r == nil {
    60  			return nil, errs.BadValue{What: "argument here",
    61  				Valid: "finite float", Actual: vals.ToString(f)}
    62  		}
    63  		return r, nil
    64  	}
    65  	return n, nil
    66  }
    67  
    68  func inexactNum(f float64) float64 {
    69  	return f
    70  }
    71  
    72  func lt(nums ...vals.Num) bool {
    73  	return chainCompareNums(nums,
    74  		func(a, b int) bool { return a < b },
    75  		func(a, b *big.Int) bool { return a.Cmp(b) < 0 },
    76  		func(a, b *big.Rat) bool { return a.Cmp(b) < 0 },
    77  		func(a, b float64) bool { return a < b })
    78  
    79  }
    80  
    81  func le(nums ...vals.Num) bool {
    82  	return chainCompareNums(nums,
    83  		func(a, b int) bool { return a <= b },
    84  		func(a, b *big.Int) bool { return a.Cmp(b) <= 0 },
    85  		func(a, b *big.Rat) bool { return a.Cmp(b) <= 0 },
    86  		func(a, b float64) bool { return a <= b })
    87  }
    88  
    89  func eqNum(nums ...vals.Num) bool {
    90  	return chainCompareNums(nums,
    91  		func(a, b int) bool { return a == b },
    92  		func(a, b *big.Int) bool { return a.Cmp(b) == 0 },
    93  		func(a, b *big.Rat) bool { return a.Cmp(b) == 0 },
    94  		func(a, b float64) bool { return a == b })
    95  }
    96  
    97  func ne(a, b vals.Num) bool {
    98  	return unifyNums2And(a, b,
    99  		func(a, b int) bool { return a != b },
   100  		func(a, b *big.Int) bool { return a.Cmp(b) != 0 },
   101  		func(a, b *big.Rat) bool { return a.Cmp(b) != 0 },
   102  		func(a, b float64) bool { return a != b })
   103  }
   104  
   105  func gt(nums ...vals.Num) bool {
   106  	return chainCompareNums(nums,
   107  		func(a, b int) bool { return a > b },
   108  		func(a, b *big.Int) bool { return a.Cmp(b) > 0 },
   109  		func(a, b *big.Rat) bool { return a.Cmp(b) > 0 },
   110  		func(a, b float64) bool { return a > b })
   111  }
   112  
   113  func ge(nums ...vals.Num) bool {
   114  	return chainCompareNums(nums,
   115  		func(a, b int) bool { return a >= b },
   116  		func(a, b *big.Int) bool { return a.Cmp(b) >= 0 },
   117  		func(a, b *big.Rat) bool { return a.Cmp(b) >= 0 },
   118  		func(a, b float64) bool { return a >= b })
   119  }
   120  
   121  func chainCompareNums(nums []vals.Num,
   122  	pInt func(a, b int) bool, pBigInt func(a, b *big.Int) bool,
   123  	pBigRat func(a, b *big.Rat) bool, pFloat64 func(a, b float64) bool) bool {
   124  
   125  	for i := 0; i < len(nums)-1; i++ {
   126  		r := unifyNums2And(nums[i], nums[i+1], pInt, pBigInt, pBigRat, pFloat64)
   127  		if !r {
   128  			return false
   129  		}
   130  	}
   131  	return true
   132  }
   133  
   134  func unifyNums2And[T any](a, b vals.Num,
   135  	fInt func(a, b int) T, fBigInt func(a, b *big.Int) T,
   136  	fBigRat func(a, b *big.Rat) T, fFloat64 func(a, b float64) T) T {
   137  
   138  	a, b = vals.UnifyNums2(a, b, 0)
   139  	switch a := a.(type) {
   140  	case int:
   141  		return fInt(a, b.(int))
   142  	case *big.Int:
   143  		return fBigInt(a, b.(*big.Int))
   144  	case *big.Rat:
   145  		return fBigRat(a, b.(*big.Rat))
   146  	case float64:
   147  		return fFloat64(a, b.(float64))
   148  	default:
   149  		panic("unreachable")
   150  	}
   151  }
   152  
   153  func add(rawNums ...vals.Num) vals.Num {
   154  	nums := vals.UnifyNums(rawNums, vals.BigInt)
   155  	switch nums := nums.(type) {
   156  	case []*big.Int:
   157  		acc := big.NewInt(0)
   158  		for _, num := range nums {
   159  			acc.Add(acc, num)
   160  		}
   161  		return vals.NormalizeBigInt(acc)
   162  	case []*big.Rat:
   163  		acc := big.NewRat(0, 1)
   164  		for _, num := range nums {
   165  			acc.Add(acc, num)
   166  		}
   167  		return vals.NormalizeBigRat(acc)
   168  	case []float64:
   169  		acc := float64(0)
   170  		for _, num := range nums {
   171  			acc += num
   172  		}
   173  		return acc
   174  	default:
   175  		panic("unreachable")
   176  	}
   177  }
   178  
   179  func sub(rawNums ...vals.Num) (vals.Num, error) {
   180  	if len(rawNums) == 0 {
   181  		return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
   182  	}
   183  
   184  	nums := vals.UnifyNums(rawNums, vals.BigInt)
   185  	switch nums := nums.(type) {
   186  	case []*big.Int:
   187  		acc := &big.Int{}
   188  		if len(nums) == 1 {
   189  			acc.Neg(nums[0])
   190  			return acc, nil
   191  		}
   192  		acc.Set(nums[0])
   193  		for _, num := range nums[1:] {
   194  			acc.Sub(acc, num)
   195  		}
   196  		return acc, nil
   197  	case []*big.Rat:
   198  		acc := &big.Rat{}
   199  		if len(nums) == 1 {
   200  			acc.Neg(nums[0])
   201  			return acc, nil
   202  		}
   203  		acc.Set(nums[0])
   204  		for _, num := range nums[1:] {
   205  			acc.Sub(acc, num)
   206  		}
   207  		return acc, nil
   208  	case []float64:
   209  		if len(nums) == 1 {
   210  			return -nums[0], nil
   211  		}
   212  		acc := nums[0]
   213  		for _, num := range nums[1:] {
   214  			acc -= num
   215  		}
   216  		return acc, nil
   217  	default:
   218  		panic("unreachable")
   219  	}
   220  }
   221  
   222  func mul(rawNums ...vals.Num) vals.Num {
   223  	hasExact0 := false
   224  	hasInf := false
   225  	for _, num := range rawNums {
   226  		if num == 0 {
   227  			hasExact0 = true
   228  		}
   229  		if f, ok := num.(float64); ok && math.IsInf(f, 0) {
   230  			hasInf = true
   231  			break
   232  		}
   233  	}
   234  	if hasExact0 && !hasInf {
   235  		return 0
   236  	}
   237  
   238  	nums := vals.UnifyNums(rawNums, vals.BigInt)
   239  	switch nums := nums.(type) {
   240  	case []*big.Int:
   241  		acc := big.NewInt(1)
   242  		for _, num := range nums {
   243  			acc.Mul(acc, num)
   244  		}
   245  		return vals.NormalizeBigInt(acc)
   246  	case []*big.Rat:
   247  		acc := big.NewRat(1, 1)
   248  		for _, num := range nums {
   249  			acc.Mul(acc, num)
   250  		}
   251  		return vals.NormalizeBigRat(acc)
   252  	case []float64:
   253  		acc := float64(1)
   254  		for _, num := range nums {
   255  			acc *= num
   256  		}
   257  		return acc
   258  	default:
   259  		panic("unreachable")
   260  	}
   261  }
   262  
   263  func slash(fm *Frame, args ...vals.Num) error {
   264  	if len(args) == 0 {
   265  		fm.Deprecate("implicit cd is deprecated; use cd or location mode instead", fm.traceback.Head, 21)
   266  		// cd /
   267  		return fm.Evaler.Chdir("/")
   268  	}
   269  	// Division
   270  	result, err := div(args...)
   271  	if err != nil {
   272  		return err
   273  	}
   274  	return fm.ValueOutput().Put(vals.FromGo(result))
   275  }
   276  
   277  // ErrDivideByZero is thrown when attempting to divide by zero.
   278  var ErrDivideByZero = errs.BadValue{
   279  	What: "divisor", Valid: "number other than exact 0", Actual: "exact 0"}
   280  
   281  func div(rawNums ...vals.Num) (vals.Num, error) {
   282  	for _, num := range rawNums[1:] {
   283  		if num == 0 {
   284  			return nil, ErrDivideByZero
   285  		}
   286  	}
   287  	if rawNums[0] == 0 {
   288  		return 0, nil
   289  	}
   290  	nums := vals.UnifyNums(rawNums, vals.BigRat)
   291  	switch nums := nums.(type) {
   292  	case []*big.Rat:
   293  		acc := &big.Rat{}
   294  		acc.Set(nums[0])
   295  		if len(nums) == 1 {
   296  			acc.Inv(acc)
   297  			return acc, nil
   298  		}
   299  		for _, num := range nums[1:] {
   300  			acc.Quo(acc, num)
   301  		}
   302  		return acc, nil
   303  	case []float64:
   304  		acc := nums[0]
   305  		if len(nums) == 1 {
   306  			return 1 / acc, nil
   307  		}
   308  		for _, num := range nums[1:] {
   309  			acc /= num
   310  		}
   311  		return acc, nil
   312  	default:
   313  		panic("unreachable")
   314  	}
   315  }
   316  
   317  func rem(a, b vals.Num) (vals.Num, error) {
   318  	if err := checkExactIntArg(a); err != nil {
   319  		return 0, err
   320  	}
   321  	if err := checkExactIntArg(b); err != nil {
   322  		return 0, err
   323  	}
   324  	if b == 0 {
   325  		return 0, ErrDivideByZero
   326  	}
   327  	if a, ok := a.(int); ok {
   328  		if b, ok := b.(int); ok {
   329  			return a % b, nil
   330  		}
   331  	}
   332  	return new(big.Int).Rem(vals.PromoteToBigInt(a), vals.PromoteToBigInt(b)), nil
   333  }
   334  
   335  func checkExactIntArg(a vals.Num) error {
   336  	switch a.(type) {
   337  	case int, *big.Int:
   338  		return nil
   339  	default:
   340  		return errs.BadValue{What: "argument", Valid: "exact integer", Actual: vals.ReprPlain(a)}
   341  	}
   342  }
   343  
   344  func randFn() float64 { return withRand((*rand.Rand).Float64) }
   345  
   346  func randint(args ...vals.Num) (vals.Num, error) {
   347  	if len(args) == 0 || len(args) > 2 {
   348  		return -1, errs.ArityMismatch{What: "arguments",
   349  			ValidLow: 1, ValidHigh: 2, Actual: len(args)}
   350  	}
   351  	allInt := true
   352  	for _, arg := range args {
   353  		if err := checkExactIntArg(arg); err != nil {
   354  			return nil, err
   355  		}
   356  		if _, ok := arg.(*big.Int); ok {
   357  			allInt = false
   358  		}
   359  	}
   360  	if allInt {
   361  		var low, high int
   362  		if len(args) == 1 {
   363  			low, high = 0, args[0].(int)
   364  		} else {
   365  			low, high = args[0].(int), args[1].(int)
   366  		}
   367  		if high <= low {
   368  			return 0, errs.BadValue{What: "high value",
   369  				Valid: fmt.Sprint("larger than ", low), Actual: strconv.Itoa(high)}
   370  		}
   371  		x := withRand(func(r *rand.Rand) int { return r.Intn(high - low) })
   372  		return low + x, nil
   373  	}
   374  	var low, high *big.Int
   375  	if len(args) == 1 {
   376  		low, high = big.NewInt(0), args[0].(*big.Int)
   377  	} else {
   378  		low, high = args[0].(*big.Int), args[1].(*big.Int)
   379  	}
   380  	if high.Cmp(low) <= 0 {
   381  		return 0, errs.BadValue{What: "high value",
   382  			Valid: fmt.Sprint("larger than ", low), Actual: high.String()}
   383  	}
   384  	if low.Sign() == 0 {
   385  		return withRand(func(r *rand.Rand) *big.Int {
   386  			return new(big.Int).Rand(r, high)
   387  		}), nil
   388  	} else {
   389  		diff := new(big.Int).Sub(high, low)
   390  		z := withRand(func(r *rand.Rand) *big.Int {
   391  			return new(big.Int).Rand(r, diff)
   392  		})
   393  		z.Add(z, low)
   394  		return z, nil
   395  	}
   396  }
   397  
   398  func randseed(x int) {
   399  	withRandNullary(func(r *rand.Rand) { r.Seed(int64(x)) })
   400  }
   401  
   402  var (
   403  	randMutex    sync.Mutex
   404  	randInstance *rand.Rand
   405  )
   406  
   407  func withRand[T any](f func(*rand.Rand) T) T {
   408  	randMutex.Lock()
   409  	defer randMutex.Unlock()
   410  	if randInstance == nil {
   411  		randInstance = rand.New(rand.NewSource(time.Now().UnixNano()))
   412  	}
   413  	return f(randInstance)
   414  }
   415  
   416  func withRandNullary(f func(*rand.Rand)) {
   417  	withRand(func(r *rand.Rand) struct{} {
   418  		f(r)
   419  		return struct{}{}
   420  	})
   421  }
   422  
   423  type rangeOpts struct{ Step vals.Num }
   424  
   425  // TODO: The default value can only be used implicitly; passing "range
   426  // &step=nil" results in an error.
   427  func (o *rangeOpts) SetDefaultOptions() { o.Step = nil }
   428  
   429  func rangeFn(fm *Frame, opts rangeOpts, args ...vals.Num) error {
   430  	var rawNums []vals.Num
   431  	switch len(args) {
   432  	case 1:
   433  		rawNums = []vals.Num{0, args[0]}
   434  	case 2:
   435  		rawNums = []vals.Num{args[0], args[1]}
   436  	default:
   437  		return errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: 2, Actual: len(args)}
   438  	}
   439  	if opts.Step != nil {
   440  		rawNums = append(rawNums, opts.Step)
   441  	}
   442  	nums := vals.UnifyNums(rawNums, vals.Int)
   443  
   444  	out := fm.ValueOutput()
   445  
   446  	switch nums := nums.(type) {
   447  	case []int:
   448  		return rangeBuiltinNum(nums, out)
   449  	case []*big.Int:
   450  		return rangeBigNum(nums, out, bigIntDesc)
   451  	case []*big.Rat:
   452  		return rangeBigNum(nums, out, bigRatDesc)
   453  	case []float64:
   454  		return rangeBuiltinNum(nums, out)
   455  	default:
   456  		panic("unreachable")
   457  	}
   458  }
   459  
   460  type builtinNum interface{ int | float64 }
   461  
   462  func rangeBuiltinNum[T builtinNum](nums []T, out ValueOutput) error {
   463  	start, end := nums[0], nums[1]
   464  	var step T
   465  	if start <= end {
   466  		if len(nums) == 3 {
   467  			step = nums[2]
   468  			if step <= 0 {
   469  				return errs.BadValue{
   470  					What: "step", Valid: "positive", Actual: vals.ToString(step)}
   471  			}
   472  		} else {
   473  			step = 1
   474  		}
   475  		for cur := start; cur < end; cur += step {
   476  			err := out.Put(vals.FromGo(cur))
   477  			if err != nil {
   478  				return err
   479  			}
   480  			if cur+step <= cur {
   481  				break
   482  			}
   483  		}
   484  	} else {
   485  		if len(nums) == 3 {
   486  			step = nums[2]
   487  			if step >= 0 {
   488  				return errs.BadValue{
   489  					What: "step", Valid: "negative", Actual: vals.ToString(step)}
   490  			}
   491  		} else {
   492  			step = -1
   493  		}
   494  		for cur := start; cur > end; cur += step {
   495  			err := out.Put(vals.FromGo(cur))
   496  			if err != nil {
   497  				return err
   498  			}
   499  			if cur+step >= cur {
   500  				break
   501  			}
   502  		}
   503  	}
   504  	return nil
   505  }
   506  
   507  type bigNum[T any] interface {
   508  	Cmp(T) int
   509  	Sign() int
   510  	Add(T, T) T
   511  }
   512  
   513  type bigNumDesc[T any] struct {
   514  	one     T
   515  	negOne  T
   516  	newZero func() T
   517  }
   518  
   519  var bigIntDesc = bigNumDesc[*big.Int]{
   520  	one:     big.NewInt(1),
   521  	negOne:  big.NewInt(-1),
   522  	newZero: func() *big.Int { return &big.Int{} },
   523  }
   524  
   525  var bigRatDesc = bigNumDesc[*big.Rat]{
   526  	one:     big.NewRat(1, 1),
   527  	negOne:  big.NewRat(-1, 1),
   528  	newZero: func() *big.Rat { return &big.Rat{} },
   529  }
   530  
   531  func rangeBigNum[T bigNum[T]](nums []T, out ValueOutput, d bigNumDesc[T]) error {
   532  	start, end := nums[0], nums[1]
   533  	var step T
   534  	if start.Cmp(end) <= 0 {
   535  		if len(nums) == 3 {
   536  			step = nums[2]
   537  			if step.Sign() <= 0 {
   538  				return errs.BadValue{
   539  					What: "step", Valid: "positive", Actual: vals.ToString(step)}
   540  			}
   541  		} else {
   542  			step = d.one
   543  		}
   544  		var cur, next T
   545  		for cur = start; cur.Cmp(end) < 0; cur = next {
   546  			err := out.Put(vals.FromGo(cur))
   547  			if err != nil {
   548  				return err
   549  			}
   550  			next = d.newZero()
   551  			next.Add(cur, step)
   552  			cur = next
   553  		}
   554  	} else {
   555  		if len(nums) == 3 {
   556  			step = nums[2]
   557  			if step.Sign() >= 0 {
   558  				return errs.BadValue{
   559  					What: "step", Valid: "negative", Actual: vals.ToString(step)}
   560  			}
   561  		} else {
   562  			step = d.negOne
   563  		}
   564  		var cur, next T
   565  		for cur = start; cur.Cmp(end) > 0; cur = next {
   566  			err := out.Put(vals.FromGo(cur))
   567  			if err != nil {
   568  				return err
   569  			}
   570  			next = d.newZero()
   571  			next.Add(cur, step)
   572  			cur = next
   573  		}
   574  	}
   575  	return nil
   576  }