github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/ssa/loopbce.go (about) 1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssa 6 7 import ( 8 "fmt" 9 10 "github.com/go-asm/go/cmd/compile/base" 11 "github.com/go-asm/go/cmd/compile/types" 12 ) 13 14 type indVarFlags uint8 15 16 const ( 17 indVarMinExc indVarFlags = 1 << iota // minimum value is exclusive (default: inclusive) 18 indVarMaxInc // maximum value is inclusive (default: exclusive) 19 indVarCountDown // if set the iteration starts at max and count towards min (default: min towards max) 20 ) 21 22 type indVar struct { 23 ind *Value // induction variable 24 nxt *Value // the incremented variable 25 min *Value // minimum value, inclusive/exclusive depends on flags 26 max *Value // maximum value, inclusive/exclusive depends on flags 27 entry *Block // entry block in the loop. 28 flags indVarFlags 29 // Invariant: for all blocks strictly dominated by entry: 30 // min <= ind < max [if flags == 0] 31 // min < ind < max [if flags == indVarMinExc] 32 // min <= ind <= max [if flags == indVarMaxInc] 33 // min < ind <= max [if flags == indVarMinExc|indVarMaxInc] 34 } 35 36 // parseIndVar checks whether the SSA value passed as argument is a valid induction 37 // variable, and, if so, extracts: 38 // - the minimum bound 39 // - the increment value 40 // - the "next" value (SSA value that is Phi'd into the induction variable every loop) 41 // 42 // Currently, we detect induction variables that match (Phi min nxt), 43 // with nxt being (Add inc ind). 44 // If it can't parse the induction variable correctly, it returns (nil, nil, nil). 45 func parseIndVar(ind *Value) (min, inc, nxt *Value) { 46 if ind.Op != OpPhi { 47 return 48 } 49 50 if n := ind.Args[0]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) { 51 min, nxt = ind.Args[1], n 52 } else if n := ind.Args[1]; (n.Op == OpAdd64 || n.Op == OpAdd32 || n.Op == OpAdd16 || n.Op == OpAdd8) && (n.Args[0] == ind || n.Args[1] == ind) { 53 min, nxt = ind.Args[0], n 54 } else { 55 // Not a recognized induction variable. 56 return 57 } 58 59 if nxt.Args[0] == ind { // nxt = ind + inc 60 inc = nxt.Args[1] 61 } else if nxt.Args[1] == ind { // nxt = inc + ind 62 inc = nxt.Args[0] 63 } else { 64 panic("unreachable") // one of the cases must be true from the above. 65 } 66 67 return 68 } 69 70 // findIndVar finds induction variables in a function. 71 // 72 // Look for variables and blocks that satisfy the following 73 // 74 // loop: 75 // ind = (Phi min nxt), 76 // if ind < max 77 // then goto enter_loop 78 // else goto exit_loop 79 // 80 // enter_loop: 81 // do something 82 // nxt = inc + ind 83 // goto loop 84 // 85 // exit_loop: 86 func findIndVar(f *Func) []indVar { 87 var iv []indVar 88 sdom := f.Sdom() 89 90 for _, b := range f.Blocks { 91 if b.Kind != BlockIf || len(b.Preds) != 2 { 92 continue 93 } 94 95 var ind *Value // induction variable 96 var init *Value // starting value 97 var limit *Value // ending value 98 99 // Check that the control if it either ind </<= limit or limit </<= ind. 100 // TODO: Handle unsigned comparisons? 101 c := b.Controls[0] 102 inclusive := false 103 switch c.Op { 104 case OpLeq64, OpLeq32, OpLeq16, OpLeq8: 105 inclusive = true 106 fallthrough 107 case OpLess64, OpLess32, OpLess16, OpLess8: 108 ind, limit = c.Args[0], c.Args[1] 109 default: 110 continue 111 } 112 113 // See if this is really an induction variable 114 less := true 115 init, inc, nxt := parseIndVar(ind) 116 if init == nil { 117 // We failed to parse the induction variable. Before punting, we want to check 118 // whether the control op was written with the induction variable on the RHS 119 // instead of the LHS. This happens for the downwards case, like: 120 // for i := len(n)-1; i >= 0; i-- 121 init, inc, nxt = parseIndVar(limit) 122 if init == nil { 123 // No recognized induction variable on either operand 124 continue 125 } 126 127 // Ok, the arguments were reversed. Swap them, and remember that we're 128 // looking at an ind >/>= loop (so the induction must be decrementing). 129 ind, limit = limit, ind 130 less = false 131 } 132 133 if ind.Block != b { 134 // TODO: Could be extended to include disjointed loop headers. 135 // I don't think this is causing missed optimizations in real world code often. 136 // See https://go.dev/issue/63955 137 continue 138 } 139 140 // Expect the increment to be a nonzero constant. 141 if !inc.isGenericIntConst() { 142 continue 143 } 144 step := inc.AuxInt 145 if step == 0 { 146 continue 147 } 148 149 // Increment sign must match comparison direction. 150 // When incrementing, the termination comparison must be ind </<= limit. 151 // When decrementing, the termination comparison must be ind >/>= limit. 152 // See issue 26116. 153 if step > 0 && !less { 154 continue 155 } 156 if step < 0 && less { 157 continue 158 } 159 160 // Up to now we extracted the induction variable (ind), 161 // the increment delta (inc), the temporary sum (nxt), 162 // the initial value (init) and the limiting value (limit). 163 // 164 // We also know that ind has the form (Phi init nxt) where 165 // nxt is (Add inc nxt) which means: 1) inc dominates nxt 166 // and 2) there is a loop starting at inc and containing nxt. 167 // 168 // We need to prove that the induction variable is incremented 169 // only when it's smaller than the limiting value. 170 // Two conditions must happen listed below to accept ind 171 // as an induction variable. 172 173 // First condition: loop entry has a single predecessor, which 174 // is the header block. This implies that b.Succs[0] is 175 // reached iff ind < limit. 176 if len(b.Succs[0].b.Preds) != 1 { 177 // b.Succs[1] must exit the loop. 178 continue 179 } 180 181 // Second condition: b.Succs[0] dominates nxt so that 182 // nxt is computed when inc < limit. 183 if !sdom.IsAncestorEq(b.Succs[0].b, nxt.Block) { 184 // inc+ind can only be reached through the branch that enters the loop. 185 continue 186 } 187 188 // Check for overflow/underflow. We need to make sure that inc never causes 189 // the induction variable to wrap around. 190 // We use a function wrapper here for easy return true / return false / keep going logic. 191 // This function returns true if the increment will never overflow/underflow. 192 ok := func() bool { 193 if step > 0 { 194 if limit.isGenericIntConst() { 195 // Figure out the actual largest value. 196 v := limit.AuxInt 197 if !inclusive { 198 if v == minSignedValue(limit.Type) { 199 return false // < minint is never satisfiable. 200 } 201 v-- 202 } 203 if init.isGenericIntConst() { 204 // Use stride to compute a better lower limit. 205 if init.AuxInt > v { 206 return false 207 } 208 v = addU(init.AuxInt, diff(v, init.AuxInt)/uint64(step)*uint64(step)) 209 } 210 if addWillOverflow(v, step) { 211 return false 212 } 213 if inclusive && v != limit.AuxInt || !inclusive && v+1 != limit.AuxInt { 214 // We know a better limit than the programmer did. Use our limit instead. 215 limit = f.constVal(limit.Op, limit.Type, v, true) 216 inclusive = true 217 } 218 return true 219 } 220 if step == 1 && !inclusive { 221 // Can't overflow because maxint is never a possible value. 222 return true 223 } 224 // If the limit is not a constant, check to see if it is a 225 // negative offset from a known non-negative value. 226 knn, k := findKNN(limit) 227 if knn == nil || k < 0 { 228 return false 229 } 230 // limit == (something nonnegative) - k. That subtraction can't underflow, so 231 // we can trust it. 232 if inclusive { 233 // ind <= knn - k cannot overflow if step is at most k 234 return step <= k 235 } 236 // ind < knn - k cannot overflow if step is at most k+1 237 return step <= k+1 && k != maxSignedValue(limit.Type) 238 } else { // step < 0 239 if limit.Op == OpConst64 { 240 // Figure out the actual smallest value. 241 v := limit.AuxInt 242 if !inclusive { 243 if v == maxSignedValue(limit.Type) { 244 return false // > maxint is never satisfiable. 245 } 246 v++ 247 } 248 if init.isGenericIntConst() { 249 // Use stride to compute a better lower limit. 250 if init.AuxInt < v { 251 return false 252 } 253 v = subU(init.AuxInt, diff(init.AuxInt, v)/uint64(-step)*uint64(-step)) 254 } 255 if subWillUnderflow(v, -step) { 256 return false 257 } 258 if inclusive && v != limit.AuxInt || !inclusive && v-1 != limit.AuxInt { 259 // We know a better limit than the programmer did. Use our limit instead. 260 limit = f.constVal(limit.Op, limit.Type, v, true) 261 inclusive = true 262 } 263 return true 264 } 265 if step == -1 && !inclusive { 266 // Can't underflow because minint is never a possible value. 267 return true 268 } 269 } 270 return false 271 272 } 273 274 if ok() { 275 flags := indVarFlags(0) 276 var min, max *Value 277 if step > 0 { 278 min = init 279 max = limit 280 if inclusive { 281 flags |= indVarMaxInc 282 } 283 } else { 284 min = limit 285 max = init 286 flags |= indVarMaxInc 287 if !inclusive { 288 flags |= indVarMinExc 289 } 290 flags |= indVarCountDown 291 step = -step 292 } 293 if f.pass.debug >= 1 { 294 printIndVar(b, ind, min, max, step, flags) 295 } 296 297 iv = append(iv, indVar{ 298 ind: ind, 299 nxt: nxt, 300 min: min, 301 max: max, 302 entry: b.Succs[0].b, 303 flags: flags, 304 }) 305 b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max) 306 } 307 308 // TODO: other unrolling idioms 309 // for i := 0; i < KNN - KNN % k ; i += k 310 // for i := 0; i < KNN&^(k-1) ; i += k // k a power of 2 311 // for i := 0; i < KNN&(-k) ; i += k // k a power of 2 312 } 313 314 return iv 315 } 316 317 // addWillOverflow reports whether x+y would result in a value more than maxint. 318 func addWillOverflow(x, y int64) bool { 319 return x+y < x 320 } 321 322 // subWillUnderflow reports whether x-y would result in a value less than minint. 323 func subWillUnderflow(x, y int64) bool { 324 return x-y > x 325 } 326 327 // diff returns x-y as a uint64. Requires x>=y. 328 func diff(x, y int64) uint64 { 329 if x < y { 330 base.Fatalf("diff %d - %d underflowed", x, y) 331 } 332 return uint64(x - y) 333 } 334 335 // addU returns x+y. Requires that x+y does not overflow an int64. 336 func addU(x int64, y uint64) int64 { 337 if y >= 1<<63 { 338 if x >= 0 { 339 base.Fatalf("addU overflowed %d + %d", x, y) 340 } 341 x += 1<<63 - 1 342 x += 1 343 y -= 1 << 63 344 } 345 if addWillOverflow(x, int64(y)) { 346 base.Fatalf("addU overflowed %d + %d", x, y) 347 } 348 return x + int64(y) 349 } 350 351 // subU returns x-y. Requires that x-y does not underflow an int64. 352 func subU(x int64, y uint64) int64 { 353 if y >= 1<<63 { 354 if x < 0 { 355 base.Fatalf("subU underflowed %d - %d", x, y) 356 } 357 x -= 1<<63 - 1 358 x -= 1 359 y -= 1 << 63 360 } 361 if subWillUnderflow(x, int64(y)) { 362 base.Fatalf("subU underflowed %d - %d", x, y) 363 } 364 return x - int64(y) 365 } 366 367 // if v is known to be x - c, where x is known to be nonnegative and c is a 368 // constant, return x, c. Otherwise return nil, 0. 369 func findKNN(v *Value) (*Value, int64) { 370 var x, y *Value 371 x = v 372 switch v.Op { 373 case OpSub64, OpSub32, OpSub16, OpSub8: 374 x = v.Args[0] 375 y = v.Args[1] 376 377 case OpAdd64, OpAdd32, OpAdd16, OpAdd8: 378 x = v.Args[0] 379 y = v.Args[1] 380 if x.isGenericIntConst() { 381 x, y = y, x 382 } 383 } 384 switch x.Op { 385 case OpSliceLen, OpStringLen, OpSliceCap: 386 default: 387 return nil, 0 388 } 389 if y == nil { 390 return x, 0 391 } 392 if !y.isGenericIntConst() { 393 return nil, 0 394 } 395 if v.Op == OpAdd64 || v.Op == OpAdd32 || v.Op == OpAdd16 || v.Op == OpAdd8 { 396 return x, -y.AuxInt 397 } 398 return x, y.AuxInt 399 } 400 401 func printIndVar(b *Block, i, min, max *Value, inc int64, flags indVarFlags) { 402 mb1, mb2 := "[", "]" 403 if flags&indVarMinExc != 0 { 404 mb1 = "(" 405 } 406 if flags&indVarMaxInc == 0 { 407 mb2 = ")" 408 } 409 410 mlim1, mlim2 := fmt.Sprint(min.AuxInt), fmt.Sprint(max.AuxInt) 411 if !min.isGenericIntConst() { 412 if b.Func.pass.debug >= 2 { 413 mlim1 = fmt.Sprint(min) 414 } else { 415 mlim1 = "?" 416 } 417 } 418 if !max.isGenericIntConst() { 419 if b.Func.pass.debug >= 2 { 420 mlim2 = fmt.Sprint(max) 421 } else { 422 mlim2 = "?" 423 } 424 } 425 extra := "" 426 if b.Func.pass.debug >= 2 { 427 extra = fmt.Sprintf(" (%s)", i) 428 } 429 b.Func.Warnl(b.Pos, "Induction variable: limits %v%v,%v%v, increment %d%s", mb1, mlim1, mlim2, mb2, inc, extra) 430 } 431 432 func minSignedValue(t *types.Type) int64 { 433 return -1 << (t.Size()*8 - 1) 434 } 435 436 func maxSignedValue(t *types.Type) int64 { 437 return 1<<((t.Size()*8)-1) - 1 438 }