github.com/primecitizens/pcz/std@v0.2.1/math/fma.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package math
     6  
     7  import (
     8  	"github.com/primecitizens/pcz/std/core/bits"
     9  )
    10  
    11  func zero(x uint64) uint64 {
    12  	if x == 0 {
    13  		return 1
    14  	}
    15  	return 0
    16  	// branchless:
    17  	// return ((x>>1 | x&1) - 1) >> 63
    18  }
    19  
    20  func nonzero(x uint64) uint64 {
    21  	if x != 0 {
    22  		return 1
    23  	}
    24  	return 0
    25  	// branchless:
    26  	// return 1 - ((x>>1|x&1)-1)>>63
    27  }
    28  
    29  func shl(u1, u2 uint64, n uint) (r1, r2 uint64) {
    30  	r1 = u1<<n | u2>>(64-n) | u2<<(n-64)
    31  	r2 = u2 << n
    32  	return
    33  }
    34  
    35  func shr(u1, u2 uint64, n uint) (r1, r2 uint64) {
    36  	r2 = u2>>n | u1<<(64-n) | u1>>(n-64)
    37  	r1 = u1 >> n
    38  	return
    39  }
    40  
    41  // shrcompress compresses the bottom n+1 bits of the two-word
    42  // value into a single bit. the result is equal to the value
    43  // shifted to the right by n, except the result's 0th bit is
    44  // set to the bitwise OR of the bottom n+1 bits.
    45  func shrcompress(u1, u2 uint64, n uint) (r1, r2 uint64) {
    46  	// TODO: Performance here is really sensitive to the
    47  	// order/placement of these branches. n == 0 is common
    48  	// enough to be in the fast path. Perhaps more measurement
    49  	// needs to be done to find the optimal order/placement?
    50  	switch {
    51  	case n == 0:
    52  		return u1, u2
    53  	case n == 64:
    54  		return 0, u1 | nonzero(u2)
    55  	case n >= 128:
    56  		return 0, nonzero(u1 | u2)
    57  	case n < 64:
    58  		r1, r2 = shr(u1, u2, n)
    59  		r2 |= nonzero(u2 & (1<<n - 1))
    60  	case n < 128:
    61  		r1, r2 = shr(u1, u2, n)
    62  		r2 |= nonzero(u1&(1<<(n-64)-1) | u2)
    63  	}
    64  	return
    65  }
    66  
    67  func lz(u1, u2 uint64) (l int32) {
    68  	l = int32(bits.LeadingZeros64(u1))
    69  	if l == 64 {
    70  		l += int32(bits.LeadingZeros64(u2))
    71  	}
    72  	return l
    73  }
    74  
    75  // split splits b into sign, biased exponent, and mantissa.
    76  // It adds the implicit 1 bit to the mantissa for normal values,
    77  // and normalizes subnormal values.
    78  func split(b uint64) (sign uint32, exp int32, mantissa uint64) {
    79  	sign = uint32(b >> 63)
    80  	exp = int32(b>>52) & mask
    81  	mantissa = b & fracMask
    82  
    83  	if exp == 0 {
    84  		// Normalize value if subnormal.
    85  		shift := uint(bits.LeadingZeros64(mantissa) - 11)
    86  		mantissa <<= shift
    87  		exp = 1 - int32(shift)
    88  	} else {
    89  		// Add implicit 1 bit
    90  		mantissa |= 1 << 52
    91  	}
    92  	return
    93  }
    94  
    95  // FMA returns x * y + z, computed with only one rounding.
    96  // (That is, FMA returns the fused multiply-add of x, y, and z.)
    97  func FMA(x, y, z float64) float64 {
    98  	bx, by, bz := Float64bits(x), Float64bits(y), Float64bits(z)
    99  
   100  	// Inf or NaN or zero involved. At most one rounding will occur.
   101  	if x == 0.0 || y == 0.0 || z == 0.0 || bx&uvinf == uvinf || by&uvinf == uvinf {
   102  		return x*y + z
   103  	}
   104  	// Handle non-finite z separately. Evaluating x*y+z where
   105  	// x and y are finite, but z is infinite, should always result in z.
   106  	if bz&uvinf == uvinf {
   107  		return z
   108  	}
   109  
   110  	// Inputs are (sub)normal.
   111  	// Split x, y, z into sign, exponent, mantissa.
   112  	xs, xe, xm := split(bx)
   113  	ys, ye, ym := split(by)
   114  	zs, ze, zm := split(bz)
   115  
   116  	// Compute product p = x*y as sign, exponent, two-word mantissa.
   117  	// Start with exponent. "is normal" bit isn't subtracted yet.
   118  	pe := xe + ye - bias + 1
   119  
   120  	// pm1:pm2 is the double-word mantissa for the product p.
   121  	// Shift left to leave top bit in product. Effectively
   122  	// shifts the 106-bit product to the left by 21.
   123  	pm1, pm2 := bits.Mul64(xm<<10, ym<<11)
   124  	zm1, zm2 := zm<<10, uint64(0)
   125  	ps := xs ^ ys // product sign
   126  
   127  	// normalize to 62nd bit
   128  	is62zero := uint((^pm1 >> 62) & 1)
   129  	pm1, pm2 = shl(pm1, pm2, is62zero)
   130  	pe -= int32(is62zero)
   131  
   132  	// Swap addition operands so |p| >= |z|
   133  	if pe < ze || pe == ze && pm1 < zm1 {
   134  		ps, pe, pm1, pm2, zs, ze, zm1, zm2 = zs, ze, zm1, zm2, ps, pe, pm1, pm2
   135  	}
   136  
   137  	// Align significands
   138  	zm1, zm2 = shrcompress(zm1, zm2, uint(pe-ze))
   139  
   140  	// Compute resulting significands, normalizing if necessary.
   141  	var m, c uint64
   142  	if ps == zs {
   143  		// Adding (pm1:pm2) + (zm1:zm2)
   144  		pm2, c = bits.Add64(pm2, zm2, 0)
   145  		pm1, _ = bits.Add64(pm1, zm1, c)
   146  		pe -= int32(^pm1 >> 63)
   147  		pm1, m = shrcompress(pm1, pm2, uint(64+pm1>>63))
   148  	} else {
   149  		// Subtracting (pm1:pm2) - (zm1:zm2)
   150  		// TODO: should we special-case cancellation?
   151  		pm2, c = bits.Sub64(pm2, zm2, 0)
   152  		pm1, _ = bits.Sub64(pm1, zm1, c)
   153  		nz := lz(pm1, pm2)
   154  		pe -= nz
   155  		m, pm2 = shl(pm1, pm2, uint(nz-1))
   156  		m |= nonzero(pm2)
   157  	}
   158  
   159  	// Round and break ties to even
   160  	if pe > 1022+bias || pe == 1022+bias && (m+1<<9)>>63 == 1 {
   161  		// rounded value overflows exponent range
   162  		return Float64frombits(uint64(ps)<<63 | uvinf)
   163  	}
   164  	if pe < 0 {
   165  		n := uint(-pe)
   166  		m = m>>n | nonzero(m&(1<<n-1))
   167  		pe = 0
   168  	}
   169  	m = ((m + 1<<9) >> 10) & ^zero((m&(1<<10-1))^1<<9)
   170  	pe &= -int32(nonzero(m))
   171  	return Float64frombits(uint64(ps)<<63 + uint64(pe)<<52 + m)
   172  }