src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/eval/builtin_fn_num.go (about) 1 package eval 2 3 import ( 4 "fmt" 5 "math" 6 "math/big" 7 "math/rand" 8 "strconv" 9 "sync" 10 "time" 11 12 "src.elv.sh/pkg/eval/errs" 13 "src.elv.sh/pkg/eval/vals" 14 ) 15 16 // Numerical operations. 17 18 func init() { 19 addBuiltinFns(map[string]any{ 20 // Constructor 21 "num": num, 22 "exact-num": exactNum, 23 "inexact-num": inexactNum, 24 25 // Comparison 26 "<": lt, 27 "<=": le, 28 "==": eqNum, 29 "!=": ne, 30 ">": gt, 31 ">=": ge, 32 33 // Arithmetic 34 "+": add, 35 "-": sub, 36 "*": mul, 37 // Also handles cd / 38 "/": slash, 39 "%": rem, 40 41 // Random 42 "rand": randFn, 43 "randint": randint, 44 "-randseed": randseed, 45 46 "range": rangeFn, 47 }) 48 49 } 50 51 func num(n vals.Num) vals.Num { 52 // Conversion is actually handled in vals/conversion.go. 53 return n 54 } 55 56 func exactNum(n vals.Num) (vals.Num, error) { 57 if f, ok := n.(float64); ok { 58 r := new(big.Rat).SetFloat64(f) 59 if r == nil { 60 return nil, errs.BadValue{What: "argument here", 61 Valid: "finite float", Actual: vals.ToString(f)} 62 } 63 return r, nil 64 } 65 return n, nil 66 } 67 68 func inexactNum(f float64) float64 { 69 return f 70 } 71 72 func lt(nums ...vals.Num) bool { 73 return chainCompareNums(nums, 74 func(a, b int) bool { return a < b }, 75 func(a, b *big.Int) bool { return a.Cmp(b) < 0 }, 76 func(a, b *big.Rat) bool { return a.Cmp(b) < 0 }, 77 func(a, b float64) bool { return a < b }) 78 79 } 80 81 func le(nums ...vals.Num) bool { 82 return chainCompareNums(nums, 83 func(a, b int) bool { return a <= b }, 84 func(a, b *big.Int) bool { return a.Cmp(b) <= 0 }, 85 func(a, b *big.Rat) bool { return a.Cmp(b) <= 0 }, 86 func(a, b float64) bool { return a <= b }) 87 } 88 89 func eqNum(nums ...vals.Num) bool { 90 return chainCompareNums(nums, 91 func(a, b int) bool { return a == b }, 92 func(a, b *big.Int) bool { return a.Cmp(b) == 0 }, 93 func(a, b *big.Rat) bool { return a.Cmp(b) == 0 }, 94 func(a, b float64) bool { return a == b }) 95 } 96 97 func ne(a, b vals.Num) bool { 98 return unifyNums2And(a, b, 99 func(a, b int) bool { return a != b }, 100 func(a, b *big.Int) bool { return a.Cmp(b) != 0 }, 101 func(a, b *big.Rat) bool { return a.Cmp(b) != 0 }, 102 func(a, b float64) bool { return a != b }) 103 } 104 105 func gt(nums ...vals.Num) bool { 106 return chainCompareNums(nums, 107 func(a, b int) bool { return a > b }, 108 func(a, b *big.Int) bool { return a.Cmp(b) > 0 }, 109 func(a, b *big.Rat) bool { return a.Cmp(b) > 0 }, 110 func(a, b float64) bool { return a > b }) 111 } 112 113 func ge(nums ...vals.Num) bool { 114 return chainCompareNums(nums, 115 func(a, b int) bool { return a >= b }, 116 func(a, b *big.Int) bool { return a.Cmp(b) >= 0 }, 117 func(a, b *big.Rat) bool { return a.Cmp(b) >= 0 }, 118 func(a, b float64) bool { return a >= b }) 119 } 120 121 func chainCompareNums(nums []vals.Num, 122 pInt func(a, b int) bool, pBigInt func(a, b *big.Int) bool, 123 pBigRat func(a, b *big.Rat) bool, pFloat64 func(a, b float64) bool) bool { 124 125 for i := 0; i < len(nums)-1; i++ { 126 r := unifyNums2And(nums[i], nums[i+1], pInt, pBigInt, pBigRat, pFloat64) 127 if !r { 128 return false 129 } 130 } 131 return true 132 } 133 134 func unifyNums2And[T any](a, b vals.Num, 135 fInt func(a, b int) T, fBigInt func(a, b *big.Int) T, 136 fBigRat func(a, b *big.Rat) T, fFloat64 func(a, b float64) T) T { 137 138 a, b = vals.UnifyNums2(a, b, 0) 139 switch a := a.(type) { 140 case int: 141 return fInt(a, b.(int)) 142 case *big.Int: 143 return fBigInt(a, b.(*big.Int)) 144 case *big.Rat: 145 return fBigRat(a, b.(*big.Rat)) 146 case float64: 147 return fFloat64(a, b.(float64)) 148 default: 149 panic("unreachable") 150 } 151 } 152 153 func add(rawNums ...vals.Num) vals.Num { 154 nums := vals.UnifyNums(rawNums, vals.BigInt) 155 switch nums := nums.(type) { 156 case []*big.Int: 157 acc := big.NewInt(0) 158 for _, num := range nums { 159 acc.Add(acc, num) 160 } 161 return vals.NormalizeBigInt(acc) 162 case []*big.Rat: 163 acc := big.NewRat(0, 1) 164 for _, num := range nums { 165 acc.Add(acc, num) 166 } 167 return vals.NormalizeBigRat(acc) 168 case []float64: 169 acc := float64(0) 170 for _, num := range nums { 171 acc += num 172 } 173 return acc 174 default: 175 panic("unreachable") 176 } 177 } 178 179 func sub(rawNums ...vals.Num) (vals.Num, error) { 180 if len(rawNums) == 0 { 181 return nil, errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: -1, Actual: 0} 182 } 183 184 nums := vals.UnifyNums(rawNums, vals.BigInt) 185 switch nums := nums.(type) { 186 case []*big.Int: 187 acc := &big.Int{} 188 if len(nums) == 1 { 189 acc.Neg(nums[0]) 190 return acc, nil 191 } 192 acc.Set(nums[0]) 193 for _, num := range nums[1:] { 194 acc.Sub(acc, num) 195 } 196 return acc, nil 197 case []*big.Rat: 198 acc := &big.Rat{} 199 if len(nums) == 1 { 200 acc.Neg(nums[0]) 201 return acc, nil 202 } 203 acc.Set(nums[0]) 204 for _, num := range nums[1:] { 205 acc.Sub(acc, num) 206 } 207 return acc, nil 208 case []float64: 209 if len(nums) == 1 { 210 return -nums[0], nil 211 } 212 acc := nums[0] 213 for _, num := range nums[1:] { 214 acc -= num 215 } 216 return acc, nil 217 default: 218 panic("unreachable") 219 } 220 } 221 222 func mul(rawNums ...vals.Num) vals.Num { 223 hasExact0 := false 224 hasInf := false 225 for _, num := range rawNums { 226 if num == 0 { 227 hasExact0 = true 228 } 229 if f, ok := num.(float64); ok && math.IsInf(f, 0) { 230 hasInf = true 231 break 232 } 233 } 234 if hasExact0 && !hasInf { 235 return 0 236 } 237 238 nums := vals.UnifyNums(rawNums, vals.BigInt) 239 switch nums := nums.(type) { 240 case []*big.Int: 241 acc := big.NewInt(1) 242 for _, num := range nums { 243 acc.Mul(acc, num) 244 } 245 return vals.NormalizeBigInt(acc) 246 case []*big.Rat: 247 acc := big.NewRat(1, 1) 248 for _, num := range nums { 249 acc.Mul(acc, num) 250 } 251 return vals.NormalizeBigRat(acc) 252 case []float64: 253 acc := float64(1) 254 for _, num := range nums { 255 acc *= num 256 } 257 return acc 258 default: 259 panic("unreachable") 260 } 261 } 262 263 func slash(fm *Frame, args ...vals.Num) error { 264 if len(args) == 0 { 265 fm.Deprecate("implicit cd is deprecated; use cd or location mode instead", fm.traceback.Head, 21) 266 // cd / 267 return fm.Evaler.Chdir("/") 268 } 269 // Division 270 result, err := div(args...) 271 if err != nil { 272 return err 273 } 274 return fm.ValueOutput().Put(vals.FromGo(result)) 275 } 276 277 // ErrDivideByZero is thrown when attempting to divide by zero. 278 var ErrDivideByZero = errs.BadValue{ 279 What: "divisor", Valid: "number other than exact 0", Actual: "exact 0"} 280 281 func div(rawNums ...vals.Num) (vals.Num, error) { 282 for _, num := range rawNums[1:] { 283 if num == 0 { 284 return nil, ErrDivideByZero 285 } 286 } 287 if rawNums[0] == 0 { 288 return 0, nil 289 } 290 nums := vals.UnifyNums(rawNums, vals.BigRat) 291 switch nums := nums.(type) { 292 case []*big.Rat: 293 acc := &big.Rat{} 294 acc.Set(nums[0]) 295 if len(nums) == 1 { 296 acc.Inv(acc) 297 return acc, nil 298 } 299 for _, num := range nums[1:] { 300 acc.Quo(acc, num) 301 } 302 return acc, nil 303 case []float64: 304 acc := nums[0] 305 if len(nums) == 1 { 306 return 1 / acc, nil 307 } 308 for _, num := range nums[1:] { 309 acc /= num 310 } 311 return acc, nil 312 default: 313 panic("unreachable") 314 } 315 } 316 317 func rem(a, b vals.Num) (vals.Num, error) { 318 if err := checkExactIntArg(a); err != nil { 319 return 0, err 320 } 321 if err := checkExactIntArg(b); err != nil { 322 return 0, err 323 } 324 if b == 0 { 325 return 0, ErrDivideByZero 326 } 327 if a, ok := a.(int); ok { 328 if b, ok := b.(int); ok { 329 return a % b, nil 330 } 331 } 332 return new(big.Int).Rem(vals.PromoteToBigInt(a), vals.PromoteToBigInt(b)), nil 333 } 334 335 func checkExactIntArg(a vals.Num) error { 336 switch a.(type) { 337 case int, *big.Int: 338 return nil 339 default: 340 return errs.BadValue{What: "argument", Valid: "exact integer", Actual: vals.ReprPlain(a)} 341 } 342 } 343 344 func randFn() float64 { return withRand((*rand.Rand).Float64) } 345 346 func randint(args ...vals.Num) (vals.Num, error) { 347 if len(args) == 0 || len(args) > 2 { 348 return -1, errs.ArityMismatch{What: "arguments", 349 ValidLow: 1, ValidHigh: 2, Actual: len(args)} 350 } 351 allInt := true 352 for _, arg := range args { 353 if err := checkExactIntArg(arg); err != nil { 354 return nil, err 355 } 356 if _, ok := arg.(*big.Int); ok { 357 allInt = false 358 } 359 } 360 if allInt { 361 var low, high int 362 if len(args) == 1 { 363 low, high = 0, args[0].(int) 364 } else { 365 low, high = args[0].(int), args[1].(int) 366 } 367 if high <= low { 368 return 0, errs.BadValue{What: "high value", 369 Valid: fmt.Sprint("larger than ", low), Actual: strconv.Itoa(high)} 370 } 371 x := withRand(func(r *rand.Rand) int { return r.Intn(high - low) }) 372 return low + x, nil 373 } 374 var low, high *big.Int 375 if len(args) == 1 { 376 low, high = big.NewInt(0), args[0].(*big.Int) 377 } else { 378 low, high = args[0].(*big.Int), args[1].(*big.Int) 379 } 380 if high.Cmp(low) <= 0 { 381 return 0, errs.BadValue{What: "high value", 382 Valid: fmt.Sprint("larger than ", low), Actual: high.String()} 383 } 384 if low.Sign() == 0 { 385 return withRand(func(r *rand.Rand) *big.Int { 386 return new(big.Int).Rand(r, high) 387 }), nil 388 } else { 389 diff := new(big.Int).Sub(high, low) 390 z := withRand(func(r *rand.Rand) *big.Int { 391 return new(big.Int).Rand(r, diff) 392 }) 393 z.Add(z, low) 394 return z, nil 395 } 396 } 397 398 func randseed(x int) { 399 withRandNullary(func(r *rand.Rand) { r.Seed(int64(x)) }) 400 } 401 402 var ( 403 randMutex sync.Mutex 404 randInstance *rand.Rand 405 ) 406 407 func withRand[T any](f func(*rand.Rand) T) T { 408 randMutex.Lock() 409 defer randMutex.Unlock() 410 if randInstance == nil { 411 randInstance = rand.New(rand.NewSource(time.Now().UnixNano())) 412 } 413 return f(randInstance) 414 } 415 416 func withRandNullary(f func(*rand.Rand)) { 417 withRand(func(r *rand.Rand) struct{} { 418 f(r) 419 return struct{}{} 420 }) 421 } 422 423 type rangeOpts struct{ Step vals.Num } 424 425 // TODO: The default value can only be used implicitly; passing "range 426 // &step=nil" results in an error. 427 func (o *rangeOpts) SetDefaultOptions() { o.Step = nil } 428 429 func rangeFn(fm *Frame, opts rangeOpts, args ...vals.Num) error { 430 var rawNums []vals.Num 431 switch len(args) { 432 case 1: 433 rawNums = []vals.Num{0, args[0]} 434 case 2: 435 rawNums = []vals.Num{args[0], args[1]} 436 default: 437 return errs.ArityMismatch{What: "arguments", ValidLow: 1, ValidHigh: 2, Actual: len(args)} 438 } 439 if opts.Step != nil { 440 rawNums = append(rawNums, opts.Step) 441 } 442 nums := vals.UnifyNums(rawNums, vals.Int) 443 444 out := fm.ValueOutput() 445 446 switch nums := nums.(type) { 447 case []int: 448 return rangeBuiltinNum(nums, out) 449 case []*big.Int: 450 return rangeBigNum(nums, out, bigIntDesc) 451 case []*big.Rat: 452 return rangeBigNum(nums, out, bigRatDesc) 453 case []float64: 454 return rangeBuiltinNum(nums, out) 455 default: 456 panic("unreachable") 457 } 458 } 459 460 type builtinNum interface{ int | float64 } 461 462 func rangeBuiltinNum[T builtinNum](nums []T, out ValueOutput) error { 463 start, end := nums[0], nums[1] 464 var step T 465 if start <= end { 466 if len(nums) == 3 { 467 step = nums[2] 468 if step <= 0 { 469 return errs.BadValue{ 470 What: "step", Valid: "positive", Actual: vals.ToString(step)} 471 } 472 } else { 473 step = 1 474 } 475 for cur := start; cur < end; cur += step { 476 err := out.Put(vals.FromGo(cur)) 477 if err != nil { 478 return err 479 } 480 if cur+step <= cur { 481 break 482 } 483 } 484 } else { 485 if len(nums) == 3 { 486 step = nums[2] 487 if step >= 0 { 488 return errs.BadValue{ 489 What: "step", Valid: "negative", Actual: vals.ToString(step)} 490 } 491 } else { 492 step = -1 493 } 494 for cur := start; cur > end; cur += step { 495 err := out.Put(vals.FromGo(cur)) 496 if err != nil { 497 return err 498 } 499 if cur+step >= cur { 500 break 501 } 502 } 503 } 504 return nil 505 } 506 507 type bigNum[T any] interface { 508 Cmp(T) int 509 Sign() int 510 Add(T, T) T 511 } 512 513 type bigNumDesc[T any] struct { 514 one T 515 negOne T 516 newZero func() T 517 } 518 519 var bigIntDesc = bigNumDesc[*big.Int]{ 520 one: big.NewInt(1), 521 negOne: big.NewInt(-1), 522 newZero: func() *big.Int { return &big.Int{} }, 523 } 524 525 var bigRatDesc = bigNumDesc[*big.Rat]{ 526 one: big.NewRat(1, 1), 527 negOne: big.NewRat(-1, 1), 528 newZero: func() *big.Rat { return &big.Rat{} }, 529 } 530 531 func rangeBigNum[T bigNum[T]](nums []T, out ValueOutput, d bigNumDesc[T]) error { 532 start, end := nums[0], nums[1] 533 var step T 534 if start.Cmp(end) <= 0 { 535 if len(nums) == 3 { 536 step = nums[2] 537 if step.Sign() <= 0 { 538 return errs.BadValue{ 539 What: "step", Valid: "positive", Actual: vals.ToString(step)} 540 } 541 } else { 542 step = d.one 543 } 544 var cur, next T 545 for cur = start; cur.Cmp(end) < 0; cur = next { 546 err := out.Put(vals.FromGo(cur)) 547 if err != nil { 548 return err 549 } 550 next = d.newZero() 551 next.Add(cur, step) 552 cur = next 553 } 554 } else { 555 if len(nums) == 3 { 556 step = nums[2] 557 if step.Sign() >= 0 { 558 return errs.BadValue{ 559 What: "step", Valid: "negative", Actual: vals.ToString(step)} 560 } 561 } else { 562 step = d.negOne 563 } 564 var cur, next T 565 for cur = start; cur.Cmp(end) > 0; cur = next { 566 err := out.Put(vals.FromGo(cur)) 567 if err != nil { 568 return err 569 } 570 next = d.newZero() 571 next.Add(cur, step) 572 cur = next 573 } 574 } 575 return nil 576 }