src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/mods/math/math.go (about)

     1  // Package math exposes functionality from Go's math package as an elvish
     2  // module.
     3  package math
     4  
     5  import (
     6  	"math"
     7  	"math/big"
     8  
     9  	"src.elv.sh/pkg/eval"
    10  	"src.elv.sh/pkg/eval/errs"
    11  	"src.elv.sh/pkg/eval/vals"
    12  	"src.elv.sh/pkg/eval/vars"
    13  )
    14  
    15  // Ns is the namespace for the math: module.
    16  var Ns = eval.BuildNsNamed("math").
    17  	AddVars(map[string]vars.Var{
    18  		"e":  vars.NewReadOnly(math.E),
    19  		"pi": vars.NewReadOnly(math.Pi),
    20  	}).
    21  	AddGoFns(map[string]any{
    22  		"abs":           abs,
    23  		"acos":          math.Acos,
    24  		"acosh":         math.Acosh,
    25  		"asin":          math.Asin,
    26  		"asinh":         math.Asinh,
    27  		"atan":          math.Atan,
    28  		"atan2":         math.Atan2,
    29  		"atanh":         math.Atanh,
    30  		"ceil":          ceil,
    31  		"cos":           math.Cos,
    32  		"cosh":          math.Cosh,
    33  		"floor":         floor,
    34  		"is-inf":        isInf,
    35  		"is-nan":        isNaN,
    36  		"log":           math.Log,
    37  		"log10":         math.Log10,
    38  		"log2":          math.Log2,
    39  		"max":           max,
    40  		"min":           min,
    41  		"pow":           pow,
    42  		"round":         round,
    43  		"round-to-even": roundToEven,
    44  		"sin":           math.Sin,
    45  		"sinh":          math.Sinh,
    46  		"sqrt":          math.Sqrt,
    47  		"tan":           math.Tan,
    48  		"tanh":          math.Tanh,
    49  		"trunc":         trunc,
    50  	}).Ns()
    51  
    52  const (
    53  	maxInt = int(^uint(0) >> 1)
    54  	minInt = -maxInt - 1
    55  )
    56  
    57  var absMinInt = new(big.Int).Abs(big.NewInt(int64(minInt)))
    58  
    59  func abs(n vals.Num) vals.Num {
    60  	switch n := n.(type) {
    61  	case int:
    62  		if n < 0 {
    63  			if n == minInt {
    64  				return absMinInt
    65  			}
    66  			return -n
    67  		}
    68  		return n
    69  	case *big.Int:
    70  		if n.Sign() < 0 {
    71  			return new(big.Int).Abs(n)
    72  		}
    73  		return n
    74  	case *big.Rat:
    75  		if n.Sign() < 0 {
    76  			return new(big.Rat).Abs(n)
    77  		}
    78  		return n
    79  	case float64:
    80  		return math.Abs(n)
    81  	default:
    82  		panic("unreachable")
    83  	}
    84  }
    85  
    86  var (
    87  	big1 = big.NewInt(1)
    88  	big2 = big.NewInt(2)
    89  )
    90  
    91  func ceil(n vals.Num) vals.Num {
    92  	return integerize(n,
    93  		math.Ceil,
    94  		func(n *big.Rat) *big.Int {
    95  			q := new(big.Int).Div(n.Num(), n.Denom())
    96  			return q.Add(q, big1)
    97  		})
    98  }
    99  
   100  func floor(n vals.Num) vals.Num {
   101  	return integerize(n,
   102  		math.Floor,
   103  		func(n *big.Rat) *big.Int {
   104  			return new(big.Int).Div(n.Num(), n.Denom())
   105  		})
   106  }
   107  
   108  type isInfOpts struct{ Sign int }
   109  
   110  func (opts *isInfOpts) SetDefaultOptions() { opts.Sign = 0 }
   111  
   112  func isInf(opts isInfOpts, n vals.Num) bool {
   113  	if f, ok := n.(float64); ok {
   114  		return math.IsInf(f, opts.Sign)
   115  	}
   116  	return false
   117  }
   118  
   119  func isNaN(n vals.Num) bool {
   120  	if f, ok := n.(float64); ok {
   121  		return math.IsNaN(f)
   122  	}
   123  	return false
   124  }
   125  
   126  func max(rawNums ...vals.Num) (vals.Num, error) {
   127  	if len(rawNums) == 0 {
   128  		return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
   129  	}
   130  	nums := vals.UnifyNums(rawNums, 0)
   131  	switch nums := nums.(type) {
   132  	case []int:
   133  		n := nums[0]
   134  		for i := 1; i < len(nums); i++ {
   135  			if n < nums[i] {
   136  				n = nums[i]
   137  			}
   138  		}
   139  		return n, nil
   140  	case []*big.Int:
   141  		n := nums[0]
   142  		for i := 1; i < len(nums); i++ {
   143  			if n.Cmp(nums[i]) < 0 {
   144  				n = nums[i]
   145  			}
   146  		}
   147  		return n, nil
   148  	case []*big.Rat:
   149  		n := nums[0]
   150  		for i := 1; i < len(nums); i++ {
   151  			if n.Cmp(nums[i]) < 0 {
   152  				n = nums[i]
   153  			}
   154  		}
   155  		return n, nil
   156  	case []float64:
   157  		n := nums[0]
   158  		for i := 1; i < len(nums); i++ {
   159  			n = math.Max(n, nums[i])
   160  		}
   161  		return n, nil
   162  	default:
   163  		panic("unreachable")
   164  	}
   165  }
   166  
   167  func min(rawNums ...vals.Num) (vals.Num, error) {
   168  	if len(rawNums) == 0 {
   169  		return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0}
   170  	}
   171  	nums := vals.UnifyNums(rawNums, 0)
   172  	switch nums := nums.(type) {
   173  	case []int:
   174  		n := nums[0]
   175  		for i := 1; i < len(nums); i++ {
   176  			if n > nums[i] {
   177  				n = nums[i]
   178  			}
   179  		}
   180  		return n, nil
   181  	case []*big.Int:
   182  		n := nums[0]
   183  		for i := 1; i < len(nums); i++ {
   184  			if n.Cmp(nums[i]) > 0 {
   185  				n = nums[i]
   186  			}
   187  		}
   188  		return n, nil
   189  	case []*big.Rat:
   190  		n := nums[0]
   191  		for i := 1; i < len(nums); i++ {
   192  			if n.Cmp(nums[i]) > 0 {
   193  				n = nums[i]
   194  			}
   195  		}
   196  		return n, nil
   197  	case []float64:
   198  		n := nums[0]
   199  		for i := 1; i < len(nums); i++ {
   200  			n = math.Min(n, nums[i])
   201  		}
   202  		return n, nil
   203  	default:
   204  		panic("unreachable")
   205  	}
   206  }
   207  
   208  func pow(base, exp vals.Num) vals.Num {
   209  	if isExact(base) && isExactInt(exp) {
   210  		// Produce exact result
   211  		switch exp {
   212  		case 0:
   213  			return 1
   214  		case 1:
   215  			return base
   216  		case -1:
   217  			return new(big.Rat).Inv(vals.PromoteToBigRat(base))
   218  		}
   219  		exp := vals.PromoteToBigInt(exp)
   220  		if isExactInt(base) && exp.Sign() > 0 {
   221  			base := vals.PromoteToBigInt(base)
   222  			return new(big.Int).Exp(base, exp, nil)
   223  		}
   224  		base := vals.PromoteToBigRat(base)
   225  		if exp.Sign() < 0 {
   226  			base = new(big.Rat).Inv(base)
   227  			exp = new(big.Int).Neg(exp)
   228  		}
   229  		return new(big.Rat).SetFrac(
   230  			new(big.Int).Exp(base.Num(), exp, nil),
   231  			new(big.Int).Exp(base.Denom(), exp, nil))
   232  	}
   233  
   234  	// Produce inexact result
   235  	basef := vals.ConvertToFloat64(base)
   236  	expf := vals.ConvertToFloat64(exp)
   237  	return math.Pow(basef, expf)
   238  }
   239  
   240  func isExact(n vals.Num) bool {
   241  	switch n.(type) {
   242  	case int, *big.Int, *big.Rat:
   243  		return true
   244  	default:
   245  		return false
   246  	}
   247  }
   248  
   249  func isExactInt(n vals.Num) bool {
   250  	switch n.(type) {
   251  	case int, *big.Int:
   252  		return true
   253  	default:
   254  		return false
   255  	}
   256  }
   257  
   258  func round(n vals.Num) vals.Num {
   259  	return integerize(n,
   260  		math.Round,
   261  		func(n *big.Rat) *big.Int {
   262  			q, m := new(big.Int).QuoRem(n.Num(), n.Denom(), new(big.Int))
   263  			m = m.Mul(m, big2)
   264  			if m.CmpAbs(n.Denom()) < 0 {
   265  				return q
   266  			}
   267  			if n.Sign() < 0 {
   268  				return q.Sub(q, big1)
   269  			}
   270  			return q.Add(q, big1)
   271  		})
   272  }
   273  
   274  func roundToEven(n vals.Num) vals.Num {
   275  	return integerize(n,
   276  		math.RoundToEven,
   277  		func(n *big.Rat) *big.Int {
   278  			q, m := new(big.Int).QuoRem(n.Num(), n.Denom(), new(big.Int))
   279  			m = m.Mul(m, big2)
   280  			if diff := m.CmpAbs(n.Denom()); diff < 0 || diff == 0 && q.Bit(0) == 0 {
   281  				return q
   282  			}
   283  			if n.Sign() < 0 {
   284  				return q.Sub(q, big1)
   285  			}
   286  			return q.Add(q, big1)
   287  		})
   288  }
   289  
   290  func trunc(n vals.Num) vals.Num {
   291  	return integerize(n,
   292  		math.Trunc,
   293  		func(n *big.Rat) *big.Int {
   294  			return new(big.Int).Quo(n.Num(), n.Denom())
   295  		})
   296  }
   297  
   298  func integerize(n vals.Num, fnFloat func(float64) float64, fnRat func(*big.Rat) *big.Int) vals.Num {
   299  	switch n := n.(type) {
   300  	case int:
   301  		return n
   302  	case *big.Int:
   303  		return n
   304  	case *big.Rat:
   305  		if n.Denom().IsInt64() && n.Denom().Int64() == 1 {
   306  			// Elvish always normalizes *big.Rat with a denominator of 1 to
   307  			// *big.Int, but we still try to be defensive here.
   308  			return n.Num()
   309  		}
   310  		return fnRat(n)
   311  	case float64:
   312  		return fnFloat(n)
   313  	default:
   314  		panic("unreachable")
   315  	}
   316  }