github.com/joomcode/cue@v0.4.4-0.20221111115225-539fe3512047/cue/literal/num.go (about)

     1  // Copyright 2020 CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package literal
    16  
    17  import (
    18  	"github.com/joomcode/cue/cue/errors"
    19  	"github.com/joomcode/cue/cue/token"
    20  	"github.com/cockroachdb/apd/v2"
    21  )
    22  
    23  var baseContext apd.Context
    24  
    25  func init() {
    26  	baseContext = apd.BaseContext
    27  	baseContext.Precision = 24
    28  }
    29  
    30  // NumInfo contains information about a parsed numbers.
    31  //
    32  // Reusing a NumInfo across parses may avoid memory allocations.
    33  type NumInfo struct {
    34  	pos token.Pos
    35  	src string
    36  	p   int
    37  	ch  byte
    38  	buf []byte
    39  
    40  	mul     Multiplier
    41  	base    byte
    42  	neg     bool
    43  	UseSep  bool
    44  	isFloat bool
    45  	err     error
    46  }
    47  
    48  // String returns a canonical string representation of the number so that
    49  // it can be parsed with math.Float.Parse.
    50  func (p *NumInfo) String() string {
    51  	if len(p.buf) > 0 && p.base == 10 && p.mul == 0 {
    52  		return string(p.buf)
    53  	}
    54  	var d apd.Decimal
    55  	_ = p.decimal(&d)
    56  	return d.String()
    57  }
    58  
    59  type decimal = apd.Decimal
    60  
    61  // Decimal is for internal use.
    62  func (p *NumInfo) Decimal(v *decimal) error {
    63  	return p.decimal(v)
    64  }
    65  
    66  func (p *NumInfo) decimal(v *apd.Decimal) error {
    67  	if p.base != 10 {
    68  		_, _, _ = v.SetString("0")
    69  		b := p.buf
    70  		if p.buf[0] == '-' {
    71  			v.Negative = p.neg
    72  			b = p.buf[1:]
    73  		}
    74  		v.Coeff.SetString(string(b), int(p.base))
    75  		return nil
    76  	}
    77  	_ = v.UnmarshalText(p.buf)
    78  	if p.mul != 0 {
    79  		_, _ = baseContext.Mul(v, v, mulToRat[p.mul])
    80  		cond, _ := baseContext.RoundToIntegralExact(v, v)
    81  		if cond.Inexact() {
    82  			return p.errorf("number cannot be represented as int")
    83  		}
    84  	}
    85  	return nil
    86  }
    87  
    88  // Multiplier reports which multiplier was used in an integral number.
    89  func (p *NumInfo) Multiplier() Multiplier {
    90  	return p.mul
    91  }
    92  
    93  // IsInt reports whether the number is an integral number.
    94  func (p *NumInfo) IsInt() bool {
    95  	return !p.isFloat
    96  }
    97  
    98  // ParseNum parses s and populates NumInfo with the result.
    99  func ParseNum(s string, n *NumInfo) error {
   100  	*n = NumInfo{pos: n.pos, src: s, buf: n.buf[:0]}
   101  	if !n.next() {
   102  		return n.errorf("invalid number %q", s)
   103  	}
   104  	if n.ch == '-' {
   105  		n.neg = true
   106  		n.buf = append(n.buf, '-')
   107  		n.next()
   108  	}
   109  	seenDecimalPoint := false
   110  	if n.ch == '.' {
   111  		n.next()
   112  		seenDecimalPoint = true
   113  	}
   114  	err := n.scanNumber(seenDecimalPoint)
   115  	if err != nil {
   116  		return err
   117  	}
   118  	if n.err != nil {
   119  		return n.err
   120  	}
   121  	if n.p < len(n.src) {
   122  		return n.errorf("invalid number %q", s)
   123  	}
   124  	if len(n.buf) == 0 {
   125  		n.buf = append(n.buf, '0')
   126  	}
   127  	return nil
   128  }
   129  
   130  func (p *NumInfo) errorf(format string, args ...interface{}) error {
   131  	return errors.Newf(p.pos, format, args...)
   132  }
   133  
   134  // A Multiplier indicates a multiplier indicator used in the literal.
   135  type Multiplier byte
   136  
   137  const (
   138  	mul1 Multiplier = 1 + iota
   139  	mul2
   140  	mul3
   141  	mul4
   142  	mul5
   143  	mul6
   144  	mul7
   145  	mul8
   146  
   147  	mulBin = 0x10
   148  	mulDec = 0x20
   149  
   150  	K = mulDec | mul1
   151  	M = mulDec | mul2
   152  	G = mulDec | mul3
   153  	T = mulDec | mul4
   154  	P = mulDec | mul5
   155  	E = mulDec | mul6
   156  	Z = mulDec | mul7
   157  	Y = mulDec | mul8
   158  
   159  	Ki = mulBin | mul1
   160  	Mi = mulBin | mul2
   161  	Gi = mulBin | mul3
   162  	Ti = mulBin | mul4
   163  	Pi = mulBin | mul5
   164  	Ei = mulBin | mul6
   165  	Zi = mulBin | mul7
   166  	Yi = mulBin | mul8
   167  )
   168  
   169  func (p *NumInfo) next() bool {
   170  	if p.p >= len(p.src) {
   171  		p.ch = 0
   172  		return false
   173  	}
   174  	p.ch = p.src[p.p]
   175  	p.p++
   176  	if p.ch == '.' {
   177  		if len(p.buf) == 0 {
   178  			p.buf = append(p.buf, '0')
   179  		}
   180  		p.buf = append(p.buf, '.')
   181  	}
   182  	return true
   183  }
   184  
   185  func (p *NumInfo) digitVal(ch byte) (d int) {
   186  	switch {
   187  	case '0' <= ch && ch <= '9':
   188  		d = int(ch - '0')
   189  	case ch == '_':
   190  		p.UseSep = true
   191  		return 0
   192  	case 'a' <= ch && ch <= 'f':
   193  		d = int(ch - 'a' + 10)
   194  	case 'A' <= ch && ch <= 'F':
   195  		d = int(ch - 'A' + 10)
   196  	default:
   197  		return 16 // larger than any legal digit val
   198  	}
   199  	return d
   200  }
   201  
   202  func (p *NumInfo) scanMantissa(base int) bool {
   203  	hasDigit := false
   204  	var last byte
   205  	for p.digitVal(p.ch) < base {
   206  		if p.ch != '_' {
   207  			p.buf = append(p.buf, p.ch)
   208  			hasDigit = true
   209  		}
   210  		last = p.ch
   211  		p.next()
   212  	}
   213  	if last == '_' {
   214  		p.err = p.errorf("illegal '_' in number")
   215  	}
   216  	return hasDigit
   217  }
   218  
   219  func (p *NumInfo) scanNumber(seenDecimalPoint bool) error {
   220  	p.base = 10
   221  
   222  	if seenDecimalPoint {
   223  		p.isFloat = true
   224  		if !p.scanMantissa(10) {
   225  			return p.errorf("illegal fraction %q", p.src)
   226  		}
   227  		goto exponent
   228  	}
   229  
   230  	if p.ch == '0' {
   231  		// int or float
   232  		p.next()
   233  		switch p.ch {
   234  		case 'x', 'X':
   235  			p.base = 16
   236  			// hexadecimal int
   237  			p.next()
   238  			if !p.scanMantissa(16) {
   239  				// only scanned "0x" or "0X"
   240  				return p.errorf("illegal hexadecimal number %q", p.src)
   241  			}
   242  		case 'b':
   243  			p.base = 2
   244  			// binary int
   245  			p.next()
   246  			if !p.scanMantissa(2) {
   247  				// only scanned "0b"
   248  				return p.errorf("illegal binary number %q", p.src)
   249  			}
   250  		case 'o':
   251  			p.base = 8
   252  			// octal int
   253  			p.next()
   254  			if !p.scanMantissa(8) {
   255  				// only scanned "0o"
   256  				return p.errorf("illegal octal number %q", p.src)
   257  			}
   258  		default:
   259  			// int (base 8 or 10) or float
   260  			p.scanMantissa(8)
   261  			if p.ch == '8' || p.ch == '9' {
   262  				p.scanMantissa(10)
   263  				if p.ch != '.' && p.ch != 'e' && p.ch != 'E' {
   264  					return p.errorf("illegal integer number %q", p.src)
   265  				}
   266  			}
   267  			switch p.ch {
   268  			case 'e', 'E':
   269  				if len(p.buf) == 0 {
   270  					p.buf = append(p.buf, '0')
   271  				}
   272  				fallthrough
   273  			case '.':
   274  				goto fraction
   275  			}
   276  			if len(p.buf) > 0 {
   277  				p.base = 8
   278  			}
   279  		}
   280  		goto exit
   281  	}
   282  
   283  	// decimal int or float
   284  	if !p.scanMantissa(10) {
   285  		return p.errorf("illegal number start %q", p.src)
   286  	}
   287  
   288  fraction:
   289  	if p.ch == '.' {
   290  		p.isFloat = true
   291  		p.next()
   292  		p.scanMantissa(10)
   293  	}
   294  
   295  exponent:
   296  	switch p.ch {
   297  	case 'K', 'M', 'G', 'T', 'P':
   298  		p.mul = charToMul[p.ch]
   299  		p.next()
   300  		if p.ch == 'i' {
   301  			p.mul |= mulBin
   302  			p.next()
   303  		} else {
   304  			p.mul |= mulDec
   305  		}
   306  		var v apd.Decimal
   307  		p.isFloat = false
   308  		return p.decimal(&v)
   309  
   310  	case 'e', 'E':
   311  		p.isFloat = true
   312  		p.next()
   313  		p.buf = append(p.buf, 'e')
   314  		if p.ch == '-' || p.ch == '+' {
   315  			p.buf = append(p.buf, p.ch)
   316  			p.next()
   317  		}
   318  		if !p.scanMantissa(10) {
   319  			return p.errorf("illegal exponent %q", p.src)
   320  		}
   321  	}
   322  
   323  exit:
   324  	return nil
   325  }
   326  
   327  var charToMul = map[byte]Multiplier{
   328  	'K': mul1,
   329  	'M': mul2,
   330  	'G': mul3,
   331  	'T': mul4,
   332  	'P': mul5,
   333  	'E': mul6,
   334  	'Z': mul7,
   335  	'Y': mul8,
   336  }
   337  
   338  var mulToRat = map[Multiplier]*apd.Decimal{}
   339  
   340  func init() {
   341  	d := apd.New(1, 0)
   342  	b := apd.New(1, 0)
   343  	dm := apd.New(1000, 0)
   344  	bm := apd.New(1024, 0)
   345  
   346  	c := apd.BaseContext
   347  	for i := Multiplier(1); int(i) < len(charToMul); i++ {
   348  		// TODO: may we write to one of the sources?
   349  		var bn, dn apd.Decimal
   350  		_, _ = c.Mul(&dn, d, dm)
   351  		d = &dn
   352  		_, _ = c.Mul(&bn, b, bm)
   353  		b = &bn
   354  		mulToRat[mulDec|i] = d
   355  		mulToRat[mulBin|i] = b
   356  	}
   357  }