github.com/bir3/gocompiler@v0.3.205/src/cmd/compile/internal/ssa/loopbce.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"github.com/bir3/gocompiler/src/cmd/compile/internal/base"
     9  	"fmt"
    10  	"math"
    11  )
    12  
    13  type indVarFlags uint8
    14  
    15  const (
    16  	indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive)
    17  	indVarMaxInc                         // maximum value is inclusive (default: exclusive)
    18  )
    19  
    20  type indVar struct {
    21  	ind   *Value // induction variable
    22  	min   *Value // minimum value, inclusive/exclusive depends on flags
    23  	max   *Value // maximum value, inclusive/exclusive depends on flags
    24  	entry *Block // entry block in the loop.
    25  	flags indVarFlags
    26  	// Invariant: for all blocks strictly dominated by entry:
    27  	//	min <= ind <  max    [if flags == 0]
    28  	//	min <  ind <  max    [if flags == indVarMinExc]
    29  	//	min <= ind <= max    [if flags == indVarMaxInc]
    30  	//	min <  ind <= max    [if flags == indVarMinExc|indVarMaxInc]
    31  }
    32  
    33  // parseIndVar checks whether the SSA value passed as argument is a valid induction
    34  // variable, and, if so, extracts:
    35  //   - the minimum bound
    36  //   - the increment value
    37  //   - the "next" value (SSA value that is Phi'd into the induction variable every loop)
    38  //
    39  // Currently, we detect induction variables that match (Phi min nxt),
    40  // with nxt being (Add inc ind).
    41  // If it can't parse the induction variable correctly, it returns (nil, nil, nil).
    42  func parseIndVar(ind *Value) (min, inc, nxt *Value) {
    43  	if ind.Op != OpPhi {
    44  		return
    45  	}
    46  
    47  	if n := ind.Args[0]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
    48  		min, nxt = ind.Args[1], n
    49  	} else if n := ind.Args[1]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
    50  		min, nxt = ind.Args[0], n
    51  	} else {
    52  		// Not a recognized induction variable.
    53  		return
    54  	}
    55  
    56  	if nxt.Args[0] == ind { // nxt = ind + inc
    57  		inc = nxt.Args[1]
    58  	} else if nxt.Args[1] == ind { // nxt = inc + ind
    59  		inc = nxt.Args[0]
    60  	} else {
    61  		panic("unreachable") // one of the cases must be true from the above.
    62  	}
    63  
    64  	return
    65  }
    66  
    67  // findIndVar finds induction variables in a function.
    68  //
    69  // Look for variables and blocks that satisfy the following
    70  //
    71  //	 loop:
    72  //	   ind = (Phi min nxt),
    73  //	   if ind < max
    74  //	     then goto enter_loop
    75  //	     else goto exit_loop
    76  //
    77  //	   enter_loop:
    78  //		do something
    79  //	      nxt = inc + ind
    80  //		goto loop
    81  //
    82  //	 exit_loop:
    83  //
    84  // TODO: handle 32 bit operations
    85  func findIndVar(f *Func) []indVar {
    86  	var iv []indVar
    87  	sdom := f.Sdom()
    88  
    89  	for _, b := range f.Blocks {
    90  		if b.Kind != BlockIf || len(b.Preds) != 2 {
    91  			continue
    92  		}
    93  
    94  		var ind *Value   // induction variable
    95  		var init *Value  // starting value
    96  		var limit *Value // ending value
    97  
    98  		// Check thet the control if it either ind </<= limit or limit </<= ind.
    99  		// TODO: Handle 32-bit comparisons.
   100  		// TODO: Handle unsigned comparisons?
   101  		c := b.Controls[0]
   102  		inclusive := false
   103  		switch c.Op {
   104  		case OpLeq64:
   105  			inclusive = true
   106  			fallthrough
   107  		case OpLess64:
   108  			ind, limit = c.Args[0], c.Args[1]
   109  		default:
   110  			continue
   111  		}
   112  
   113  		// See if this is really an induction variable
   114  		less := true
   115  		init, inc, nxt := parseIndVar(ind)
   116  		if init == nil {
   117  			// We failed to parse the induction variable. Before punting, we want to check
   118  			// whether the control op was written with the induction variable on the RHS
   119  			// instead of the LHS. This happens for the downwards case, like:
   120  			//     for i := len(n)-1; i >= 0; i--
   121  			init, inc, nxt = parseIndVar(limit)
   122  			if init == nil {
   123  				// No recognied induction variable on either operand
   124  				continue
   125  			}
   126  
   127  			// Ok, the arguments were reversed. Swap them, and remember that we're
   128  			// looking at a ind >/>= loop (so the induction must be decrementing).
   129  			ind, limit = limit, ind
   130  			less = false
   131  		}
   132  
   133  		// Expect the increment to be a nonzero constant.
   134  		if inc.Op != OpConst64 {
   135  			continue
   136  		}
   137  		step := inc.AuxInt
   138  		if step == 0 {
   139  			continue
   140  		}
   141  
   142  		// Increment sign must match comparison direction.
   143  		// When incrementing, the termination comparison must be ind </<= limit.
   144  		// When decrementing, the termination comparison must be ind >/>= limit.
   145  		// See issue 26116.
   146  		if step > 0 && !less {
   147  			continue
   148  		}
   149  		if step < 0 && less {
   150  			continue
   151  		}
   152  
   153  		// Up to now we extracted the induction variable (ind),
   154  		// the increment delta (inc), the temporary sum (nxt),
   155  		// the initial value (init) and the limiting value (limit).
   156  		//
   157  		// We also know that ind has the form (Phi init nxt) where
   158  		// nxt is (Add inc nxt) which means: 1) inc dominates nxt
   159  		// and 2) there is a loop starting at inc and containing nxt.
   160  		//
   161  		// We need to prove that the induction variable is incremented
   162  		// only when it's smaller than the limiting value.
   163  		// Two conditions must happen listed below to accept ind
   164  		// as an induction variable.
   165  
   166  		// First condition: loop entry has a single predecessor, which
   167  		// is the header block.  This implies that b.Succs[0] is
   168  		// reached iff ind < limit.
   169  		if len(b.Succs[0].b.Preds) != 1 {
   170  			// b.Succs[1] must exit the loop.
   171  			continue
   172  		}
   173  
   174  		// Second condition: b.Succs[0] dominates nxt so that
   175  		// nxt is computed when inc < limit.
   176  		if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) {
   177  			// inc+ind can only be reached through the branch that enters the loop.
   178  			continue
   179  		}
   180  
   181  		// Check for overflow/underflow. We need to make sure that inc never causes
   182  		// the induction variable to wrap around.
   183  		// We use a function wrapper here for easy return true / return false / keep going logic.
   184  		// This function returns true if the increment will never overflow/underflow.
   185  		ok := func() bool {
   186  			if step > 0 {
   187  				if limit.Op == OpConst64 {
   188  					// Figure out the actual largest value.
   189  					v := limit.AuxInt
   190  					if !inclusive {
   191  						if v == math.MinInt64 {
   192  							return false // < minint is never satisfiable.
   193  						}
   194  						v--
   195  					}
   196  					if init.Op == OpConst64 {
   197  						// Use stride to compute a better lower limit.
   198  						if init.AuxInt > v {
   199  							return false
   200  						}
   201  						v = addU(init.AuxInt, diff(v, init.AuxInt)/uint64(step)*uint64(step))
   202  					}
   203  					if addWillOverflow(v, step) {
   204  						return false
   205  					}
   206  					if inclusive && v != limit.AuxInt || !inclusive && v+1 != limit.AuxInt {
   207  						// We know a better limit than the programmer did. Use our limit instead.
   208  						limit = f.ConstInt64(f.Config.Types.Int64, v)
   209  						inclusive = true
   210  					}
   211  					return true
   212  				}
   213  				if step == 1 && !inclusive {
   214  					// Can't overflow because maxint is never a possible value.
   215  					return true
   216  				}
   217  				// If the limit is not a constant, check to see if it is a
   218  				// negative offset from a known non-negative value.
   219  				knn, k := findKNN(limit)
   220  				if knn == nil || k < 0 {
   221  					return false
   222  				}
   223  				// limit == (something nonnegative) - k. That subtraction can't underflow, so
   224  				// we can trust it.
   225  				if inclusive {
   226  					// ind <= knn - k cannot overflow if step is at most k
   227  					return step <= k
   228  				}
   229  				// ind < knn - k cannot overflow if step is at most k+1
   230  				return step <= k+1 && k != math.MaxInt64
   231  			} else { // step < 0
   232  				if limit.Op == OpConst64 {
   233  					// Figure out the actual smallest value.
   234  					v := limit.AuxInt
   235  					if !inclusive {
   236  						if v == math.MaxInt64 {
   237  							return false // > maxint is never satisfiable.
   238  						}
   239  						v++
   240  					}
   241  					if init.Op == OpConst64 {
   242  						// Use stride to compute a better lower limit.
   243  						if init.AuxInt < v {
   244  							return false
   245  						}
   246  						v = subU(init.AuxInt, diff(init.AuxInt, v)/uint64(-step)*uint64(-step))
   247  					}
   248  					if subWillUnderflow(v, -step) {
   249  						return false
   250  					}
   251  					if inclusive && v != limit.AuxInt || !inclusive && v-1 != limit.AuxInt {
   252  						// We know a better limit than the programmer did. Use our limit instead.
   253  						limit = f.ConstInt64(f.Config.Types.Int64, v)
   254  						inclusive = true
   255  					}
   256  					return true
   257  				}
   258  				if step == -1 && !inclusive {
   259  					// Can't underflow because minint is never a possible value.
   260  					return true
   261  				}
   262  			}
   263  			return false
   264  
   265  		}
   266  
   267  		if ok() {
   268  			flags := indVarFlags(0)
   269  			var min, max *Value
   270  			if step > 0 {
   271  				min = init
   272  				max = limit
   273  				if inclusive {
   274  					flags |= indVarMaxInc
   275  				}
   276  			} else {
   277  				min = limit
   278  				max = init
   279  				flags |= indVarMaxInc
   280  				if !inclusive {
   281  					flags |= indVarMinExc
   282  				}
   283  				step = -step
   284  			}
   285  			if f.pass.debug >= 1 {
   286  				printIndVar(b, ind, min, max, step, flags)
   287  			}
   288  
   289  			iv = append(iv, indVar{
   290  				ind:   ind,
   291  				min:   min,
   292  				max:   max,
   293  				entry: b.Succs[0].b,
   294  				flags: flags,
   295  			})
   296  			b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
   297  		}
   298  
   299  		// TODO: other unrolling idioms
   300  		// for i := 0; i < KNN - KNN % k ; i += k
   301  		// for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2
   302  		// for i := 0; i < KNN&(-k) ; i += k // k a power of 2
   303  	}
   304  
   305  	return iv
   306  }
   307  
   308  // addWillOverflow reports whether x+y would result in a value more than maxint.
   309  func addWillOverflow(x, y int64) bool {
   310  	return x+y < x
   311  }
   312  
   313  // subWillUnderflow reports whether x-y would result in a value less than minint.
   314  func subWillUnderflow(x, y int64) bool {
   315  	return x-y > x
   316  }
   317  
   318  // diff returns x-y as a uint64. Requires x>=y.
   319  func diff(x, y int64) uint64 {
   320  	if x < y {
   321  		base.Fatalf("diff %d - %d underflowed", x, y)
   322  	}
   323  	return uint64(x - y)
   324  }
   325  
   326  // addU returns x+y. Requires that x+y does not overflow an int64.
   327  func addU(x int64, y uint64) int64 {
   328  	if y >= 1<<63 {
   329  		if x >= 0 {
   330  			base.Fatalf("addU overflowed %d + %d", x, y)
   331  		}
   332  		x += 1<<63 - 1
   333  		x += 1
   334  		y -= 1 << 63
   335  	}
   336  	if addWillOverflow(x, int64(y)) {
   337  		base.Fatalf("addU overflowed %d + %d", x, y)
   338  	}
   339  	return x + int64(y)
   340  }
   341  
   342  // subU returns x-y. Requires that x-y does not underflow an int64.
   343  func subU(x int64, y uint64) int64 {
   344  	if y >= 1<<63 {
   345  		if x < 0 {
   346  			base.Fatalf("subU underflowed %d - %d", x, y)
   347  		}
   348  		x -= 1<<63 - 1
   349  		x -= 1
   350  		y -= 1 << 63
   351  	}
   352  	if subWillUnderflow(x, int64(y)) {
   353  		base.Fatalf("subU underflowed %d - %d", x, y)
   354  	}
   355  	return x - int64(y)
   356  }
   357  
   358  // if v is known to be x - c, where x is known to be nonnegative and c is a
   359  // constant, return x, c. Otherwise return nil, 0.
   360  func findKNN(v *Value) (*Value, int64) {
   361  	var x, y *Value
   362  	x = v
   363  	switch v.Op {
   364  	case OpSub64:
   365  		x = v.Args[0]
   366  		y = v.Args[1]
   367  
   368  	case OpAdd64:
   369  		x = v.Args[0]
   370  		y = v.Args[1]
   371  		if x.Op == OpConst64 {
   372  			x, y = y, x
   373  		}
   374  	}
   375  	switch x.Op {
   376  	case OpSliceLen, OpStringLen, OpSliceCap:
   377  	default:
   378  		return nil, 0
   379  	}
   380  	if y == nil {
   381  		return x, 0
   382  	}
   383  	if y.Op != OpConst64 {
   384  		return nil, 0
   385  	}
   386  	if v.Op == OpAdd64 {
   387  		return x, -y.AuxInt
   388  	}
   389  	return x, y.AuxInt
   390  }
   391  
   392  func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) {
   393  	mb1, mb2 := "[", "]"
   394  	if flags&indVarMinExc != 0 {
   395  		mb1 = "("
   396  	}
   397  	if flags&indVarMaxInc == 0 {
   398  		mb2 = ")"
   399  	}
   400  
   401  	mlim1, mlim2 := fmt.Sprint(min.AuxInt), fmt.Sprint(max.AuxInt)
   402  	if !min.isGenericIntConst() {
   403  		if b.Func.pass.debug >= 2 {
   404  			mlim1 = fmt.Sprint(min)
   405  		} else {
   406  			mlim1 = "?"
   407  		}
   408  	}
   409  	if !max.isGenericIntConst() {
   410  		if b.Func.pass.debug >= 2 {
   411  			mlim2 = fmt.Sprint(max)
   412  		} else {
   413  			mlim2 = "?"
   414  		}
   415  	}
   416  	extra := ""
   417  	if b.Func.pass.debug >= 2 {
   418  		extra = fmt.Sprintf(" (%s)", i)
   419  	}
   420  	b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra)
   421  }