github.com/ericlagergren/ctb@v0.0.0-20220810041818-96749d9c394d/lll/big.go (about)

     1  package lll
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  )
     7  
     8  // T is an integer.
     9  //
    10  // Unlike math/big, T is a value type.
    11  type T interface {
    12  	Sign() int
    13  	Cmp(T) int
    14  	CmpAbs(T) int
    15  	Add(T) T
    16  	Mul(T) T
    17  	Sub(T) T
    18  	Quo(T) T
    19  	String() string
    20  }
    21  
    22  func SetInt(z *big.Int, x T) {
    23  	switch x := x.(type) {
    24  	case *Int:
    25  		z.Set(&x.x)
    26  	case *Frac:
    27  		if !x.x.IsInt() {
    28  			SetInt(z, round(x))
    29  		} else {
    30  			z.Set(x.x.Num())
    31  		}
    32  	}
    33  }
    34  
    35  // Int is an integer.
    36  //
    37  // Int implements T.
    38  type Int struct {
    39  	x big.Int
    40  }
    41  
    42  var _ T = (*Int)(nil)
    43  
    44  // I64 creates an Int from x.
    45  func I64(x int64) T {
    46  	var z Int
    47  	z.x.SetInt64(x)
    48  	return &z
    49  }
    50  
    51  // I copies x into a new Int.
    52  func I(x *big.Int) T {
    53  	var z Int
    54  	z.x.Set(x)
    55  	return &z
    56  }
    57  
    58  func (x *Int) String() string {
    59  	return x.x.String()
    60  }
    61  
    62  func (x *Int) Sign() int {
    63  	return x.x.Sign()
    64  }
    65  
    66  func (x *Int) CmpAbs(y T) int {
    67  	switch {
    68  	case x.Sign() == 0 && y.Sign() == 0:
    69  		return 0
    70  	}
    71  	switch y := y.(type) {
    72  	case *Int:
    73  		return x.x.CmpAbs(&y.x)
    74  	case *Frac:
    75  		if y.x.IsInt() {
    76  			return x.x.CmpAbs(y.x.Num())
    77  		}
    78  		var tmp big.Rat
    79  		tmp.SetInt(&x.x)
    80  		// Set sign(x) = sign(y)
    81  		if y.Sign() < 0 {
    82  			tmp.Neg(&tmp)
    83  		} else {
    84  			tmp.Abs(&tmp)
    85  		}
    86  		return tmp.Cmp(&y.x)
    87  	default:
    88  		panic(fmt.Sprintf("unknown type: %T", y))
    89  	}
    90  }
    91  
    92  func (x *Int) Cmp(y T) int {
    93  	switch {
    94  	case x.Sign() < y.Sign():
    95  		return -1
    96  	case x.Sign() > y.Sign():
    97  		return +1
    98  	case x.Sign() == 0 && y.Sign() == 0:
    99  		return 0
   100  	}
   101  	switch y := y.(type) {
   102  	case *Int:
   103  		return x.x.Cmp(&y.x)
   104  	case *Frac:
   105  		var tmp big.Rat
   106  		tmp.SetInt(&x.x)
   107  		return tmp.Cmp(&y.x)
   108  	default:
   109  		panic(fmt.Sprintf("unknown type: %T", y))
   110  	}
   111  }
   112  
   113  func (x *Int) Add(y T) T {
   114  	switch y := y.(type) {
   115  	case *Int:
   116  		var z Int
   117  		z.x.Add(&x.x, &y.x)
   118  		return &z
   119  	case *Frac:
   120  		if y.x.IsInt() {
   121  			var z Int
   122  			z.x.Add(&x.x, y.x.Num())
   123  			return &z
   124  		}
   125  		var tmp big.Rat
   126  		tmp.SetInt(&x.x)
   127  		var z Frac
   128  		z.x.Add(&tmp, &y.x)
   129  		return &z
   130  	default:
   131  		panic(fmt.Sprintf("unknown type: %T", y))
   132  	}
   133  }
   134  
   135  func (x *Int) Mul(y T) T {
   136  	switch y := y.(type) {
   137  	case *Int:
   138  		var z Int
   139  		z.x.Mul(&x.x, &y.x)
   140  		return &z
   141  	case *Frac:
   142  		if y.x.IsInt() {
   143  			var z Int
   144  			z.x.Mul(&x.x, y.x.Num())
   145  			return &z
   146  		}
   147  		var z Frac
   148  		var tmp big.Rat
   149  		tmp.SetInt(&x.x)
   150  		z.x.Mul(&tmp, &y.x)
   151  		return &z
   152  	default:
   153  		panic(fmt.Sprintf("unknown type: %T", y))
   154  	}
   155  }
   156  
   157  func (x *Int) Sub(y T) T {
   158  	switch y := y.(type) {
   159  	case *Int:
   160  		var z Int
   161  		z.x.Sub(&x.x, &y.x)
   162  		return &z
   163  	case *Frac:
   164  		if y.x.IsInt() {
   165  			var z Int
   166  			z.x.Sub(&x.x, y.x.Num())
   167  			return &z
   168  		}
   169  		var tmp big.Rat
   170  		tmp.SetInt(&x.x)
   171  		var z Frac
   172  		z.x.Sub(&tmp, &y.x)
   173  		return &z
   174  	default:
   175  		panic(fmt.Sprintf("unknown type: %T", y))
   176  	}
   177  }
   178  
   179  func (x *Int) Quo(y T) T {
   180  	switch y := y.(type) {
   181  	case *Int:
   182  		var z Frac
   183  		z.x.SetFrac(&x.x, &y.x)
   184  		return &z
   185  	case *Frac:
   186  		var tmp big.Rat
   187  		tmp.SetInt(&x.x)
   188  		var z Frac
   189  		z.x.Quo(&tmp, &y.x)
   190  		return &z
   191  	default:
   192  		panic(fmt.Sprintf("unknown type: %T", y))
   193  	}
   194  }
   195  
   196  // Frac is a fraction (rational) number.
   197  //
   198  // Frac implements T.
   199  type Frac struct {
   200  	x big.Rat
   201  }
   202  
   203  var _ T = (*Frac)(nil)
   204  
   205  // F64 creates a Frac from a numerateor and denominator.
   206  func F64(n, d int64) T {
   207  	if d == 1 {
   208  		return I64(n)
   209  	}
   210  	var z Frac
   211  	z.x.SetFrac64(n, d)
   212  	return &z
   213  }
   214  
   215  // F copies the numerator and denominator into a Frac.
   216  func F(n, d *big.Int) T {
   217  	var z Frac
   218  	z.x.SetFrac(n, d)
   219  	return &z
   220  }
   221  
   222  func (x *Frac) Sign() int {
   223  	return x.x.Sign()
   224  }
   225  
   226  func (x *Frac) Cmp(y T) int {
   227  	switch {
   228  	case x.Sign() < y.Sign():
   229  		return -1
   230  	case x.Sign() > y.Sign():
   231  		return +1
   232  	case x.Sign() == 0 && y.Sign() == 0:
   233  		return 0
   234  	}
   235  	switch y := y.(type) {
   236  	case *Int:
   237  		var tmp big.Rat
   238  		tmp.SetInt(&y.x)
   239  		return x.x.Cmp(&tmp)
   240  	case *Frac:
   241  		return x.x.Cmp(&y.x)
   242  	default:
   243  		panic(fmt.Sprintf("unknown type: %T", y))
   244  	}
   245  }
   246  
   247  func (x *Frac) CmpAbs(y T) int {
   248  	switch {
   249  	case x.Sign() == 0 && y.Sign() == 0:
   250  		return 0
   251  	}
   252  	switch y := y.(type) {
   253  	case *Int:
   254  		if x.x.IsInt() {
   255  			return y.x.CmpAbs(x.x.Num())
   256  		}
   257  		r := +1
   258  		var tmp big.Rat
   259  		tmp.SetInt(&y.x)
   260  		// Set sign(y) = sign(x)
   261  		if x.Sign() < 0 {
   262  			r = -1
   263  			tmp.Neg(&tmp)
   264  		} else {
   265  			tmp.Abs(&tmp)
   266  		}
   267  		return x.x.Cmp(&tmp) * r
   268  	case *Frac:
   269  		if x.Sign() == y.Sign() {
   270  			return x.x.Cmp(&y.x)
   271  		}
   272  		r := +1
   273  		var tmp big.Rat
   274  		// Set sign(y) = sign(x)
   275  		if x.Sign() < 0 {
   276  			r = -1
   277  			tmp.Neg(&y.x)
   278  		} else {
   279  			tmp.Abs(&y.x)
   280  		}
   281  		return x.x.Cmp(&tmp) * r
   282  	default:
   283  		panic(fmt.Sprintf("unknown type: %T", y))
   284  	}
   285  }
   286  
   287  func (x *Frac) Add(y T) T {
   288  	switch y := y.(type) {
   289  	case *Frac:
   290  		var z Frac
   291  		z.x.Add(&x.x, &y.x)
   292  		return &z
   293  	case *Int:
   294  		var tmp big.Rat
   295  		tmp.SetInt(&y.x)
   296  		var z Frac
   297  		z.x.Add(&x.x, &tmp)
   298  		return &z
   299  	default:
   300  		panic(fmt.Sprintf("unknown type: %T", y))
   301  	}
   302  }
   303  
   304  func (x *Frac) Mul(y T) T {
   305  	switch y := y.(type) {
   306  	case *Frac:
   307  		var z Frac
   308  		z.x.Mul(&x.x, &y.x)
   309  		return &z
   310  	case *Int:
   311  		var tmp big.Rat
   312  		tmp.SetInt(&y.x)
   313  		var z Frac
   314  		z.x.Mul(&x.x, &tmp)
   315  		return &z
   316  	default:
   317  		panic(fmt.Sprintf("unknown type: %T", y))
   318  	}
   319  }
   320  
   321  func (x *Frac) Sub(y T) T {
   322  	switch y := y.(type) {
   323  	case *Frac:
   324  		var z Frac
   325  		z.x.Sub(&x.x, &y.x)
   326  		return &z
   327  	case *Int:
   328  		var tmp big.Rat
   329  		tmp.SetInt(&y.x)
   330  		var z Frac
   331  		z.x.Sub(&x.x, &tmp)
   332  		return &z
   333  	default:
   334  		panic(fmt.Sprintf("unknown type: %T", y))
   335  	}
   336  }
   337  
   338  func (x *Frac) Quo(y T) T {
   339  	switch y := y.(type) {
   340  	case *Frac:
   341  		var z Frac
   342  		z.x.Quo(&x.x, &y.x)
   343  		return &z
   344  	case *Int:
   345  		var tmp big.Rat
   346  		tmp.SetInt(&y.x)
   347  		var z Frac
   348  		z.x.Quo(&x.x, &tmp)
   349  		return &z
   350  	default:
   351  		panic(fmt.Sprintf("unknown type: %T", y))
   352  	}
   353  }
   354  
   355  func (f *Frac) String() string {
   356  	return f.x.String()
   357  }
   358  
   359  var bigOne = big.NewInt(1)
   360  
   361  func round(x T) T {
   362  	switch x := x.(type) {
   363  	case *Int:
   364  		return x
   365  	case *Frac:
   366  		if x.x.IsInt() {
   367  			return x
   368  		}
   369  
   370  		var z Int     // result
   371  		var r big.Int // scratch
   372  
   373  		n := x.x.Num()
   374  		d := x.x.Denom()
   375  
   376  		// Rats are always normalized, meaning the following
   377  		// holds:
   378  		//    if x.IsInt then n.Cmp(d) != 0
   379  		if n.CmpAbs(d) < 0 {
   380  			// Proper fraction.
   381  			if r.Add(n, n).CmpAbs(d) >= 0 {
   382  				z.x.Add(&z.x, bigOne)
   383  			}
   384  			// Round down to zero.
   385  			return &z
   386  		}
   387  
   388  		// Improper fraction.
   389  		z.x.QuoRem(n, d, &r)
   390  		// Is r >= 0.5? If so, round up away from zero.
   391  		if r.Add(&r, &r).CmpAbs(d) >= 0 {
   392  			if x.Sign() < 0 {
   393  				z.x.Sub(&z.x, bigOne)
   394  			} else {
   395  				z.x.Add(&z.x, bigOne)
   396  			}
   397  		}
   398  		return &z
   399  	default:
   400  		panic(fmt.Sprintf("unknown type: %T", x))
   401  	}
   402  }
   403  
   404  func sq(x T) T {
   405  	return x.Mul(x)
   406  }
   407  
   408  var (
   409  	one   = I64(1)
   410  	half  = F64(1, 2)
   411  	quart = F64(1, 4)
   412  )
   413  
   414  // Reduction computes the Lenstra–Lenstra–Lovász
   415  // lattice basis reduction algorithm.
   416  //
   417  // B is a lattice basis
   418  //    b0, b1, ... bn in Z^m
   419  // delta must be in (1/4, 1), typically 3/4.
   420  func Reduction(delta T, B [][]T) [][]T {
   421  	if delta.Cmp(quart) < 0 || delta.Cmp(one) >= 0 {
   422  		panic("delta out of range")
   423  	}
   424  	Bstar := gramSchmidt(nil, B)
   425  	mu := func(i, j int) T {
   426  		return projCoff(Bstar[j], B[i])
   427  	}
   428  	n := len(B)
   429  	k := 1
   430  	for k < n {
   431  		for j := k - 1; j >= 0; j-- {
   432  			mukj := mu(k, j)
   433  			if mukj.CmpAbs(half) > 0 {
   434  				bj := scale(nil, B[j], round(mukj))
   435  				B[k] = sub(B[k], B[k], bj)
   436  				Bstar = gramSchmidt(Bstar, B)
   437  			}
   438  		}
   439  		dmksq := delta.Sub(sq(mu(k, k-1)))
   440  		pbsk1 := sdot(Bstar[k-1])
   441  		if sdot(Bstar[k]).Cmp(dmksq.Mul(pbsk1)) >= 0 {
   442  			k++
   443  		} else {
   444  			B[k], B[k-1] = B[k-1], B[k]
   445  			Bstar = gramSchmidt(Bstar, B)
   446  			k--
   447  			if k < 1 {
   448  				k = 1
   449  			}
   450  		}
   451  	}
   452  	return B
   453  }
   454  
   455  func gramSchmidt(u, v [][]T) [][]T {
   456  	u = u[:0]
   457  	for _, vi := range v {
   458  		ui := vi
   459  		for _, uj := range u {
   460  			// ui -= uj*vi
   461  			uj = proj(nil, uj, vi)
   462  			ui = sub(nil, ui, uj)
   463  		}
   464  		if len(ui) > 0 {
   465  			u = append(u, ui)
   466  		}
   467  	}
   468  	return u
   469  }
   470  
   471  // scale is
   472  //    for i := range x {
   473  //        z[i] = x[i] * c
   474  //    }
   475  func scale(z, x []T, c T) []T {
   476  	z = zmake(z, len(x))
   477  	for i := range x {
   478  		z[i] = x[i].Mul(c)
   479  	}
   480  	return z
   481  }
   482  
   483  // mul is
   484  //    for i := range x {
   485  //        z[i] = x[i] * y[i]
   486  //    }
   487  func mul(z, x, y []T) []T {
   488  	z = zmake(z, len(x))
   489  	for i := range x {
   490  		z[i] = x[i].Mul(y[i])
   491  	}
   492  	return z
   493  }
   494  
   495  // sub is
   496  //    for i := range x {
   497  //        z[i] = x[i] - y[i]
   498  //    }
   499  func sub(z, x, y []T) []T {
   500  	z = zmake(z, len(x))
   501  	for i := range x {
   502  		z[i] = x[i].Sub(y[i])
   503  	}
   504  	return z
   505  }
   506  
   507  // proj is
   508  //    c := projCoff(x, y)
   509  //    scale(z, x, c)
   510  func proj(z, x, y []T) []T {
   511  	z = zmake(z, len(x))
   512  	return scale(z, x, projCoff(x, y))
   513  }
   514  
   515  // projCoff is
   516  //    dot(x, y) / sdot(x)
   517  func projCoff(x, y []T) T {
   518  	return dot(x, y).Quo(sdot(x))
   519  }
   520  
   521  // dot is
   522  //    for i := range x {
   523  //        sum += x[i] * y[i]
   524  //    }
   525  func dot(x, y []T) T {
   526  	sum := I64(0)
   527  	for i := range x {
   528  		sum = sum.Add(x[i].Mul(y[i]))
   529  	}
   530  	return sum
   531  }
   532  
   533  // sdot is
   534  //    dot(x, x)
   535  func sdot(x []T) T {
   536  	return dot(x, x)
   537  }
   538  
   539  func zmake(z []T, n int) []T {
   540  	if n <= cap(z) {
   541  		return z[:n]
   542  	}
   543  	return make([]T, n)
   544  }