github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/ssa/magic.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"math/big"
     9  	"math/bits"
    10  )
    11  
    12  // So you want to compute x / c for some constant c?
    13  // Machine division instructions are slow, so we try to
    14  // compute this division with a multiplication + a few
    15  // other cheap instructions instead.
    16  // (We assume here that c != 0, +/- 1, or +/- 2^i.  Those
    17  // cases are easy to handle in different ways).
    18  
    19  // Technique from https://gmplib.org/~tege/divcnst-pldi94.pdf
    20  
    21  // First consider unsigned division.
    22  // Our strategy is to precompute 1/c then do
    23  //   ⎣x / c⎦ = ⎣x * (1/c)⎦.
    24  // 1/c is less than 1, so we can't compute it directly in
    25  // integer arithmetic.  Let's instead compute 2^e/c
    26  // for a value of e TBD (^ = exponentiation).  Then
    27  //   ⎣x / c⎦ = ⎣x * (2^e/c) / 2^e⎦.
    28  // Dividing by 2^e is easy.  2^e/c isn't an integer, unfortunately.
    29  // So we must approximate it.  Let's call its approximation m.
    30  // We'll then compute
    31  //   ⎣x * m / 2^e⎦
    32  // Which we want to be equal to ⎣x / c⎦ for 0 <= x < 2^n-1
    33  // where n is the word size.
    34  // Setting x = c gives us c * m >= 2^e.
    35  // We'll chose m = ⎡2^e/c⎤ to satisfy that equation.
    36  // What remains is to choose e.
    37  // Let m = 2^e/c + delta, 0 <= delta < 1
    38  //   ⎣x * (2^e/c + delta) / 2^e⎦
    39  //   ⎣x / c + x * delta / 2^e⎦
    40  // We must have x * delta / 2^e < 1/c so that this
    41  // additional term never rounds differently than ⎣x / c⎦ does.
    42  // Rearranging,
    43  //   2^e > x * delta * c
    44  // x can be at most 2^n-1 and delta can be at most 1.
    45  // So it is sufficient to have 2^e >= 2^n*c.
    46  // So we'll choose e = n + s, with s = ⎡log2(c)⎤.
    47  //
    48  // An additional complication arises because m has n+1 bits in it.
    49  // Hardware restricts us to n bit by n bit multiplies.
    50  // We divide into 3 cases:
    51  //
    52  // Case 1: m is even.
    53  //   ⎣x / c⎦ = ⎣x * m / 2^(n+s)⎦
    54  //   ⎣x / c⎦ = ⎣x * (m/2) / 2^(n+s-1)⎦
    55  //   ⎣x / c⎦ = ⎣x * (m/2) / 2^n / 2^(s-1)⎦
    56  //   ⎣x / c⎦ = ⎣⎣x * (m/2) / 2^n⎦ / 2^(s-1)⎦
    57  //   multiply + shift
    58  //
    59  // Case 2: c is even.
    60  //   ⎣x / c⎦ = ⎣(x/2) / (c/2)⎦
    61  //   ⎣x / c⎦ = ⎣⎣x/2⎦ / (c/2)⎦
    62  //     This is just the original problem, with x' = ⎣x/2⎦, c' = c/2, n' = n-1.
    63  //       s' = s-1
    64  //       m' = ⎡2^(n'+s')/c'⎤
    65  //          = ⎡2^(n+s-1)/c⎤
    66  //          = ⎡m/2⎤
    67  //   ⎣x / c⎦ = ⎣x' * m' / 2^(n'+s')⎦
    68  //   ⎣x / c⎦ = ⎣⎣x/2⎦ * ⎡m/2⎤ / 2^(n+s-2)⎦
    69  //   ⎣x / c⎦ = ⎣⎣⎣x/2⎦ * ⎡m/2⎤ / 2^n⎦ / 2^(s-2)⎦
    70  //   shift + multiply + shift
    71  //
    72  // Case 3: everything else
    73  //   let k = m - 2^n. k fits in n bits.
    74  //   ⎣x / c⎦ = ⎣x * m / 2^(n+s)⎦
    75  //   ⎣x / c⎦ = ⎣x * (2^n + k) / 2^(n+s)⎦
    76  //   ⎣x / c⎦ = ⎣(x + x * k / 2^n) / 2^s⎦
    77  //   ⎣x / c⎦ = ⎣(x + ⎣x * k / 2^n⎦) / 2^s⎦
    78  //   ⎣x / c⎦ = ⎣(x + ⎣x * k / 2^n⎦) / 2^s⎦
    79  //   ⎣x / c⎦ = ⎣⎣(x + ⎣x * k / 2^n⎦) / 2⎦ / 2^(s-1)⎦
    80  //   multiply + avg + shift
    81  //
    82  // These can be implemented in hardware using:
    83  //  ⎣a * b / 2^n⎦ - aka high n bits of an n-bit by n-bit multiply.
    84  //  ⎣(a+b) / 2⎦   - aka "average" of two n-bit numbers.
    85  //                  (Not just a regular add & shift because the intermediate result
    86  //                   a+b has n+1 bits in it.  Nevertheless, can be done
    87  //                   in 2 instructions on x86.)
    88  
    89  // umagicOK reports whether we should strength reduce a n-bit divide by c.
    90  func umagicOK(n uint, c int64) bool {
    91  	// Convert from ConstX auxint values to the real uint64 constant they represent.
    92  	d := uint64(c) << (64 - n) >> (64 - n)
    93  
    94  	// Doesn't work for 0.
    95  	// Don't use for powers of 2.
    96  	return d&(d-1) != 0
    97  }
    98  
    99  // umagicOKn reports whether we should strength reduce an unsigned n-bit divide by c.
   100  // We can strength reduce when c != 0 and c is not a power of two.
   101  func umagicOK8(c int8) bool   { return c&(c-1) != 0 }
   102  func umagicOK16(c int16) bool { return c&(c-1) != 0 }
   103  func umagicOK32(c int32) bool { return c&(c-1) != 0 }
   104  func umagicOK64(c int64) bool { return c&(c-1) != 0 }
   105  
   106  type umagicData struct {
   107  	s int64  // ⎡log2(c)⎤
   108  	m uint64 // ⎡2^(n+s)/c⎤ - 2^n
   109  }
   110  
   111  // umagic computes the constants needed to strength reduce unsigned n-bit divides by the constant uint64(c).
   112  // The return values satisfy for all 0 <= x < 2^n
   113  //
   114  //	floor(x / uint64(c)) = x * (m + 2^n) >> (n+s)
   115  func umagic(n uint, c int64) umagicData {
   116  	// Convert from ConstX auxint values to the real uint64 constant they represent.
   117  	d := uint64(c) << (64 - n) >> (64 - n)
   118  
   119  	C := new(big.Int).SetUint64(d)
   120  	s := C.BitLen()
   121  	M := big.NewInt(1)
   122  	M.Lsh(M, n+uint(s))     // 2^(n+s)
   123  	M.Add(M, C)             // 2^(n+s)+c
   124  	M.Sub(M, big.NewInt(1)) // 2^(n+s)+c-1
   125  	M.Div(M, C)             // ⎡2^(n+s)/c⎤
   126  	if M.Bit(int(n)) != 1 {
   127  		panic("n+1st bit isn't set")
   128  	}
   129  	M.SetBit(M, int(n), 0)
   130  	m := M.Uint64()
   131  	return umagicData{s: int64(s), m: m}
   132  }
   133  
   134  func umagic8(c int8) umagicData   { return umagic(8, int64(c)) }
   135  func umagic16(c int16) umagicData { return umagic(16, int64(c)) }
   136  func umagic32(c int32) umagicData { return umagic(32, int64(c)) }
   137  func umagic64(c int64) umagicData { return umagic(64, c) }
   138  
   139  // For signed division, we use a similar strategy.
   140  // First, we enforce a positive c.
   141  //   x / c = -(x / (-c))
   142  // This will require an additional Neg op for c<0.
   143  //
   144  // If x is positive we're in a very similar state
   145  // to the unsigned case above.  We define:
   146  //   s = ⎡log2(c)⎤-1
   147  //   m = ⎡2^(n+s)/c⎤
   148  // Then
   149  //   ⎣x / c⎦ = ⎣x * m / 2^(n+s)⎦
   150  // If x is negative we have
   151  //   ⎡x / c⎤ = ⎣x * m / 2^(n+s)⎦ + 1
   152  // (TODO: derivation?)
   153  //
   154  // The multiply is a bit odd, as it is a signed n-bit value
   155  // times an unsigned n-bit value.  For n smaller than the
   156  // word size, we can extend x and m appropriately and use the
   157  // signed multiply instruction.  For n == word size,
   158  // we must use the signed multiply high and correct
   159  // the result by adding x*2^n.
   160  //
   161  // Adding 1 if x<0 is done by subtracting x>>(n-1).
   162  
   163  func smagicOK(n uint, c int64) bool {
   164  	if c < 0 {
   165  		// Doesn't work for negative c.
   166  		return false
   167  	}
   168  	// Doesn't work for 0.
   169  	// Don't use it for powers of 2.
   170  	return c&(c-1) != 0
   171  }
   172  
   173  // smagicOKn reports whether we should strength reduce a signed n-bit divide by c.
   174  func smagicOK8(c int8) bool   { return smagicOK(8, int64(c)) }
   175  func smagicOK16(c int16) bool { return smagicOK(16, int64(c)) }
   176  func smagicOK32(c int32) bool { return smagicOK(32, int64(c)) }
   177  func smagicOK64(c int64) bool { return smagicOK(64, c) }
   178  
   179  type smagicData struct {
   180  	s int64  // ⎡log2(c)⎤-1
   181  	m uint64 // ⎡2^(n+s)/c⎤
   182  }
   183  
   184  // smagic computes the constants needed to strength reduce signed n-bit divides by the constant c.
   185  // Must have c>0.
   186  // The return values satisfy for all -2^(n-1) <= x < 2^(n-1)
   187  //
   188  //	trunc(x / c) = x * m >> (n+s) + (x < 0 ? 1 : 0)
   189  func smagic(n uint, c int64) smagicData {
   190  	C := new(big.Int).SetInt64(c)
   191  	s := C.BitLen() - 1
   192  	M := big.NewInt(1)
   193  	M.Lsh(M, n+uint(s))     // 2^(n+s)
   194  	M.Add(M, C)             // 2^(n+s)+c
   195  	M.Sub(M, big.NewInt(1)) // 2^(n+s)+c-1
   196  	M.Div(M, C)             // ⎡2^(n+s)/c⎤
   197  	if M.Bit(int(n)) != 0 {
   198  		panic("n+1st bit is set")
   199  	}
   200  	if M.Bit(int(n-1)) == 0 {
   201  		panic("nth bit is not set")
   202  	}
   203  	m := M.Uint64()
   204  	return smagicData{s: int64(s), m: m}
   205  }
   206  
   207  func smagic8(c int8) smagicData   { return smagic(8, int64(c)) }
   208  func smagic16(c int16) smagicData { return smagic(16, int64(c)) }
   209  func smagic32(c int32) smagicData { return smagic(32, int64(c)) }
   210  func smagic64(c int64) smagicData { return smagic(64, c) }
   211  
   212  // Divisibility x%c == 0 can be checked more efficiently than directly computing
   213  // the modulus x%c and comparing against 0.
   214  //
   215  // The same "Division by invariant integers using multiplication" paper
   216  // by Granlund and Montgomery referenced above briefly mentions this method
   217  // and it is further elaborated in "Hacker's Delight" by Warren Section 10-17
   218  //
   219  // The first thing to note is that for odd integers, exact division can be computed
   220  // by using the modular inverse with respect to the word size 2^n.
   221  //
   222  // Given c, compute m such that (c * m) mod 2^n == 1
   223  // Then if c divides x (x%c ==0), the quotient is given by q = x/c == x*m mod 2^n
   224  //
   225  // x can range from 0, c, 2c, 3c, ... ⎣(2^n - 1)/c⎦ * c the maximum multiple
   226  // Thus, x*m mod 2^n is 0, 1, 2, 3, ... ⎣(2^n - 1)/c⎦
   227  // i.e. the quotient takes all values from zero up to max = ⎣(2^n - 1)/c⎦
   228  //
   229  // If x is not divisible by c, then x*m mod 2^n must take some larger value than max.
   230  //
   231  // This gives x*m mod 2^n <= ⎣(2^n - 1)/c⎦ as a test for divisibility
   232  // involving one multiplication and compare.
   233  //
   234  // To extend this to even integers, consider c = d0 * 2^k where d0 is odd.
   235  // We can test whether x is divisible by both d0 and 2^k.
   236  // For d0, the test is the same as above.  Let m be such that m*d0 mod 2^n == 1
   237  // Then x*m mod 2^n <= ⎣(2^n - 1)/d0⎦ is the first test.
   238  // The test for divisibility by 2^k is a check for k trailing zeroes.
   239  // Note that since d0 is odd, m is odd and thus x*m will have the same number of
   240  // trailing zeroes as x.  So the two tests are,
   241  //
   242  // x*m mod 2^n <= ⎣(2^n - 1)/d0⎦
   243  // and x*m ends in k zero bits
   244  //
   245  // These can be combined into a single comparison by the following
   246  // (theorem ZRU in Hacker's Delight) for unsigned integers.
   247  //
   248  // x <= a and x ends in k zero bits if and only if RotRight(x ,k) <= ⎣a/(2^k)⎦
   249  // Where RotRight(x ,k) is right rotation of x by k bits.
   250  //
   251  // To prove the first direction, x <= a -> ⎣x/(2^k)⎦ <= ⎣a/(2^k)⎦
   252  // But since x ends in k zeroes all the rotated bits would be zero too.
   253  // So RotRight(x, k) == ⎣x/(2^k)⎦ <= ⎣a/(2^k)⎦
   254  //
   255  // If x does not end in k zero bits, then RotRight(x, k)
   256  // has some non-zero bits in the k highest bits.
   257  // ⎣x/(2^k)⎦ has all zeroes in the k highest bits,
   258  // so RotRight(x, k) > ⎣x/(2^k)⎦
   259  //
   260  // Finally, if x > a and has k trailing zero bits, then RotRight(x, k) == ⎣x/(2^k)⎦
   261  // and ⎣x/(2^k)⎦ must be greater than ⎣a/(2^k)⎦, that is the top n-k bits of x must
   262  // be greater than the top n-k bits of a because the rest of x bits are zero.
   263  //
   264  // So the two conditions about can be replaced with the single test
   265  //
   266  // RotRight(x*m mod 2^n, k) <= ⎣(2^n - 1)/c⎦
   267  //
   268  // Where d0*2^k was replaced by c on the right hand side.
   269  
   270  // udivisibleOK reports whether we should strength reduce an unsigned n-bit divisibilty check by c.
   271  func udivisibleOK(n uint, c int64) bool {
   272  	// Convert from ConstX auxint values to the real uint64 constant they represent.
   273  	d := uint64(c) << (64 - n) >> (64 - n)
   274  
   275  	// Doesn't work for 0.
   276  	// Don't use for powers of 2.
   277  	return d&(d-1) != 0
   278  }
   279  
   280  func udivisibleOK8(c int8) bool   { return udivisibleOK(8, int64(c)) }
   281  func udivisibleOK16(c int16) bool { return udivisibleOK(16, int64(c)) }
   282  func udivisibleOK32(c int32) bool { return udivisibleOK(32, int64(c)) }
   283  func udivisibleOK64(c int64) bool { return udivisibleOK(64, c) }
   284  
   285  type udivisibleData struct {
   286  	k   int64  // trailingZeros(c)
   287  	m   uint64 // m * (c>>k) mod 2^n == 1 multiplicative inverse of odd portion modulo 2^n
   288  	max uint64 // ⎣(2^n - 1)/ c⎦ max value to for divisibility
   289  }
   290  
   291  func udivisible(n uint, c int64) udivisibleData {
   292  	// Convert from ConstX auxint values to the real uint64 constant they represent.
   293  	d := uint64(c) << (64 - n) >> (64 - n)
   294  
   295  	k := bits.TrailingZeros64(d)
   296  	d0 := d >> uint(k) // the odd portion of the divisor
   297  
   298  	mask := ^uint64(0) >> (64 - n)
   299  
   300  	// Calculate the multiplicative inverse via Newton's method.
   301  	// Quadratic convergence doubles the number of correct bits per iteration.
   302  	m := d0            // initial guess correct to 3-bits d0*d0 mod 8 == 1
   303  	m = m * (2 - m*d0) // 6-bits
   304  	m = m * (2 - m*d0) // 12-bits
   305  	m = m * (2 - m*d0) // 24-bits
   306  	m = m * (2 - m*d0) // 48-bits
   307  	m = m * (2 - m*d0) // 96-bits >= 64-bits
   308  	m = m & mask
   309  
   310  	max := mask / d
   311  
   312  	return udivisibleData{
   313  		k:   int64(k),
   314  		m:   m,
   315  		max: max,
   316  	}
   317  }
   318  
   319  func udivisible8(c int8) udivisibleData   { return udivisible(8, int64(c)) }
   320  func udivisible16(c int16) udivisibleData { return udivisible(16, int64(c)) }
   321  func udivisible32(c int32) udivisibleData { return udivisible(32, int64(c)) }
   322  func udivisible64(c int64) udivisibleData { return udivisible(64, c) }
   323  
   324  // For signed integers, a similar method follows.
   325  //
   326  // Given c > 1 and odd, compute m such that (c * m) mod 2^n == 1
   327  // Then if c divides x (x%c ==0), the quotient is given by q = x/c == x*m mod 2^n
   328  //
   329  // x can range from ⎡-2^(n-1)/c⎤ * c, ... -c, 0, c, ...  ⎣(2^(n-1) - 1)/c⎦ * c
   330  // Thus, x*m mod 2^n is ⎡-2^(n-1)/c⎤, ... -2, -1, 0, 1, 2, ... ⎣(2^(n-1) - 1)/c⎦
   331  //
   332  // So, x is a multiple of c if and only if:
   333  // ⎡-2^(n-1)/c⎤ <= x*m mod 2^n <= ⎣(2^(n-1) - 1)/c⎦
   334  //
   335  // Since c > 1 and odd, this can be simplified by
   336  // ⎡-2^(n-1)/c⎤ == ⎡(-2^(n-1) + 1)/c⎤ == -⎣(2^(n-1) - 1)/c⎦
   337  //
   338  // -⎣(2^(n-1) - 1)/c⎦ <= x*m mod 2^n <= ⎣(2^(n-1) - 1)/c⎦
   339  //
   340  // To extend this to even integers, consider c = d0 * 2^k where d0 is odd.
   341  // We can test whether x is divisible by both d0 and 2^k.
   342  //
   343  // Let m be such that (d0 * m) mod 2^n == 1.
   344  // Let q = x*m mod 2^n. Then c divides x if:
   345  //
   346  // -⎣(2^(n-1) - 1)/d0⎦ <= q <= ⎣(2^(n-1) - 1)/d0⎦ and q ends in at least k 0-bits
   347  //
   348  // To transform this to a single comparison, we use the following theorem (ZRS in Hacker's Delight).
   349  //
   350  // For a >= 0 the following conditions are equivalent:
   351  // 1) -a <= x <= a and x ends in at least k 0-bits
   352  // 2) RotRight(x+a', k) <= ⎣2a'/2^k⎦
   353  //
   354  // Where a' = a & -2^k (a with its right k bits set to zero)
   355  //
   356  // To see that 1 & 2 are equivalent, note that -a <= x <= a is equivalent to
   357  // -a' <= x <= a' if and only if x ends in at least k 0-bits.  Adding -a' to each side gives,
   358  // 0 <= x + a' <= 2a' and x + a' ends in at least k 0-bits if and only if x does since a' has
   359  // k 0-bits by definition.  We can use theorem ZRU above with x -> x + a' and a -> 2a' giving 1) == 2).
   360  //
   361  // Let m be such that (d0 * m) mod 2^n == 1.
   362  // Let q = x*m mod 2^n.
   363  // Let a' = ⎣(2^(n-1) - 1)/d0⎦ & -2^k
   364  //
   365  // Then the divisibility test is:
   366  //
   367  // RotRight(q+a', k) <= ⎣2a'/2^k⎦
   368  //
   369  // Note that the calculation is performed using unsigned integers.
   370  // Since a' can have n-1 bits, 2a' may have n bits and there is no risk of overflow.
   371  
   372  // sdivisibleOK reports whether we should strength reduce a signed n-bit divisibilty check by c.
   373  func sdivisibleOK(n uint, c int64) bool {
   374  	if c < 0 {
   375  		// Doesn't work for negative c.
   376  		return false
   377  	}
   378  	// Doesn't work for 0.
   379  	// Don't use it for powers of 2.
   380  	return c&(c-1) != 0
   381  }
   382  
   383  func sdivisibleOK8(c int8) bool   { return sdivisibleOK(8, int64(c)) }
   384  func sdivisibleOK16(c int16) bool { return sdivisibleOK(16, int64(c)) }
   385  func sdivisibleOK32(c int32) bool { return sdivisibleOK(32, int64(c)) }
   386  func sdivisibleOK64(c int64) bool { return sdivisibleOK(64, c) }
   387  
   388  type sdivisibleData struct {
   389  	k   int64  // trailingZeros(c)
   390  	m   uint64 // m * (c>>k) mod 2^n == 1 multiplicative inverse of odd portion modulo 2^n
   391  	a   uint64 // ⎣(2^(n-1) - 1)/ (c>>k)⎦ & -(1<<k) additive constant
   392  	max uint64 // ⎣(2 a) / (1<<k)⎦ max value to for divisibility
   393  }
   394  
   395  func sdivisible(n uint, c int64) sdivisibleData {
   396  	d := uint64(c)
   397  	k := bits.TrailingZeros64(d)
   398  	d0 := d >> uint(k) // the odd portion of the divisor
   399  
   400  	mask := ^uint64(0) >> (64 - n)
   401  
   402  	// Calculate the multiplicative inverse via Newton's method.
   403  	// Quadratic convergence doubles the number of correct bits per iteration.
   404  	m := d0            // initial guess correct to 3-bits d0*d0 mod 8 == 1
   405  	m = m * (2 - m*d0) // 6-bits
   406  	m = m * (2 - m*d0) // 12-bits
   407  	m = m * (2 - m*d0) // 24-bits
   408  	m = m * (2 - m*d0) // 48-bits
   409  	m = m * (2 - m*d0) // 96-bits >= 64-bits
   410  	m = m & mask
   411  
   412  	a := ((mask >> 1) / d0) & -(1 << uint(k))
   413  	max := (2 * a) >> uint(k)
   414  
   415  	return sdivisibleData{
   416  		k:   int64(k),
   417  		m:   m,
   418  		a:   a,
   419  		max: max,
   420  	}
   421  }
   422  
   423  func sdivisible8(c int8) sdivisibleData   { return sdivisible(8, int64(c)) }
   424  func sdivisible16(c int16) sdivisibleData { return sdivisible(16, int64(c)) }
   425  func sdivisible32(c int32) sdivisibleData { return sdivisible(32, int64(c)) }
   426  func sdivisible64(c int64) sdivisibleData { return sdivisible(64, c) }