github.com/cloudwego/frugal@v0.1.7/internal/atm/ssa/pass_branchelim.go (about) 1 /* 2 * Copyright 2022 ByteDance Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package ssa 18 19 import ( 20 `fmt` 21 `sort` 22 `strings` 23 24 `github.com/oleiade/lane` 25 ) 26 27 type _Term interface { 28 fmt.Stringer 29 term() 30 } 31 32 type ( 33 _TrRel uint8 34 _RegTerm Reg 35 _ValueTerm Int65 36 ) 37 38 func (_Stmt) term() {} 39 func (_RegTerm) term() {} 40 func (_ValueTerm) term() {} 41 42 func (self _RegTerm) String() string { return Reg(self).String() } 43 func (self _ValueTerm) String() string { return Int65(self).String() } 44 45 const ( 46 _R_eq _TrRel = iota 47 _R_ne 48 _R_lt 49 _R_ltu 50 _R_ge 51 _R_geu 52 ) 53 54 func (self _TrRel) String() string { 55 switch self { 56 case _R_eq : return "==" 57 case _R_ne : return "!=" 58 case _R_lt : return "<" 59 case _R_ltu : return "<#" 60 case _R_ge : return ">=" 61 case _R_geu : return ">=#" 62 default : panic("unreachable") 63 } 64 } 65 66 type _Edge struct { 67 bb *BasicBlock 68 to *BasicBlock 69 } 70 71 func (self _Edge) String() string { 72 return fmt.Sprintf("bb_%d => bb_%d", self.bb.Id, self.to.Id) 73 } 74 75 type _Stmt struct { 76 lhs Reg 77 rhs _Term 78 rel _TrRel 79 } 80 81 func (self _Stmt) String() string { 82 return fmt.Sprintf("%s %s %s", self.lhs, self.rel, self.rhs) 83 } 84 85 func (self _Stmt) negated() _Stmt { 86 switch self.rel { 87 case _R_eq : return _Stmt { self.lhs, self.rhs, _R_ne } 88 case _R_ne : return _Stmt { self.lhs, self.rhs, _R_eq } 89 case _R_lt : return _Stmt { self.lhs, self.rhs, _R_ge } 90 case _R_ltu : return _Stmt { self.lhs, self.rhs, _R_geu } 91 case _R_ge : return _Stmt { self.lhs, self.rhs, _R_lt } 92 case _R_geu : return _Stmt { self.lhs, self.rhs, _R_ltu } 93 default : panic("unreachable") 94 } 95 } 96 97 func (self _Stmt) condition(cond bool) _Stmt { 98 if cond { 99 return self 100 } else { 101 return self.negated() 102 } 103 } 104 105 type _Range struct { 106 rr []Int65 107 } 108 109 func newRange(lower Int65, upper Int65) *_Range { 110 return &_Range { 111 rr: []Int65 { lower, upper }, 112 } 113 } 114 115 func (self *_Range) lower() Int65 { 116 if len(self.rr) == 0 { 117 panic("empty range") 118 } else { 119 return self.rr[0] 120 } 121 } 122 123 func (self *_Range) upper() Int65 { 124 if n := len(self.rr); n == 0 { 125 panic("empty range") 126 } else { 127 return self.rr[n - 1] 128 } 129 } 130 131 func (self *_Range) truth() (bool, bool) { 132 var lower Int65 133 var upper Int65 134 135 /* empty range */ 136 if len(self.rr) == 0 { 137 return false, false 138 } 139 140 /* fast path: there is only one range */ 141 if len(self.rr) == 2 { 142 if self.rr[0].CompareZero() == 0 && self.rr[1].CompareZero() == 0 { 143 return false, true 144 } else if self.rr[0].CompareZero() > 0 || self.rr[1].CompareZero() < 0 { 145 return true, true 146 } else { 147 return false, false 148 } 149 } 150 151 /* check if any range contains the zero */ 152 for i := 0; i < len(self.rr); i += 2 { 153 lower = self.rr[i] 154 upper = self.rr[i + 1] 155 156 /* the range contains zero, the truth cannot be determained */ 157 if lower.CompareZero() <= 0 && upper.CompareZero() >= 0 { 158 return false, false 159 } 160 } 161 162 /* no, the range can be interpreted as true */ 163 return true, true 164 } 165 166 func (self *_Range) remove(lower Int65, upper Int65) { 167 for i := 0; i < len(self.rr); i += 2 { 168 l := self.rr[i] 169 u := self.rr[i + 1] 170 171 /* not intersecting */ 172 if lower.Compare(u) > 0 { break } 173 if upper.Compare(l) < 0 { continue } 174 175 /* splicing */ 176 if lower.Compare(l) > 0 && upper.Compare(u) < 0 { 177 next := []Int65 { l, lower.OneLess(), upper.OneMore(), u } 178 self.rr = append(self.rr[:i], append(next, self.rr[i + 2:]...)...) 179 i += 2 180 break 181 } 182 183 /* remove the upper half */ 184 if lower.Compare(l) > 0 { 185 self.rr[i + 1] = lower.OneLess() 186 continue 187 } 188 189 /* remove the lower half */ 190 if upper.Compare(u) < 0 { 191 self.rr[i] = upper.OneMore() 192 break 193 } 194 195 /* remove the entire range */ 196 copy(self.rr[i:], self.rr[i + 2:]) 197 self.rr = self.rr[:len(self.rr) - 2] 198 i -= 2 199 } 200 } 201 202 func (self *_Range) intersect(lower Int65, upper Int65) { 203 if lower != MinInt65 { self.remove(MinInt65, lower.OneLess()) } 204 if upper != MaxInt65 { self.remove(upper.OneMore(), MaxInt65) } 205 } 206 207 func (self *_Range) removeRange(r *_Range) { 208 for i := 0; i < len(r.rr); i += 2 { 209 self.remove(r.rr[i], r.rr[i + 1]) 210 } 211 } 212 213 func (self *_Range) intersectRange(r *_Range) { 214 for i := 0; i < len(r.rr); i += 2 { 215 self.intersect(r.rr[i], r.rr[i + 1]) 216 } 217 } 218 219 func (self *_Range) String() string { 220 nb := len(self.rr) 221 rb := make([]string, nb / 2) 222 223 /* empty ranges */ 224 if nb == 0 { 225 return "{ (empty) }" 226 } 227 228 /* dump every range */ 229 for i := 0; i < nb; i += 2 { 230 l := self.rr[i] 231 u := self.rr[i + 1] 232 s := new(strings.Builder) 233 234 /* lower bounds */ 235 if s.WriteRune('['); l == MinInt65 { 236 s.WriteString("-∞") 237 } else { 238 s.WriteString(fmt.Sprint(l)) 239 } 240 241 /* upper bounds */ 242 if s.WriteString(", "); u == MaxInt65 { 243 s.WriteString("+∞") 244 } else { 245 s.WriteString(fmt.Sprint(u)) 246 } 247 248 /* build the range */ 249 s.WriteRune(']') 250 rb[i / 2] = s.String() 251 } 252 253 /* join them together */ 254 return fmt.Sprintf( 255 "{ %s }", 256 strings.Join(rb, " ∪ "), 257 ) 258 } 259 260 type _Ranges struct { 261 rr map[Reg]*_Range 262 } 263 264 func newRanges(nb int) (r _Ranges) { 265 r.rr = make(map[Reg]*_Range, nb) 266 r.rr[Rz] = newRange(Int65{}, Int65{}) 267 return 268 } 269 270 func (self _Ranges) of(reg Reg) (r *_Range) { 271 var ok bool 272 var rr *_Range 273 274 /* check for existing range */ 275 if rr, ok = self.rr[reg]; ok { 276 return rr 277 } 278 279 /* create a new one if needed */ 280 rr = newRange(MinInt65, MaxInt65) 281 self.rr[reg] = rr 282 return rr 283 } 284 285 func (self _Ranges) at(reg Reg) (r *_Range, ok bool) { 286 r, ok = self.rr[reg] 287 return 288 } 289 290 type _Proof struct { 291 cp []int 292 st []_Stmt 293 } 294 295 func (self *_Proof) define(lhs Reg, rhs _Term, rel _TrRel) { 296 self.st = append(self.st, _Stmt { lhs, rhs, rel }) 297 } 298 299 func (self *_Proof) assume(ref Reg, lhs Reg, rhs _Term, rel _TrRel) { 300 self.st = append(self.st, _Stmt { 301 lhs: ref, 302 rel: _R_eq, 303 rhs: _Stmt { lhs, rhs, rel }, 304 }) 305 } 306 307 func (self *_Proof) restore() { 308 p := len(self.cp) - 1 309 self.st, self.cp = self.st[:self.cp[p]], self.cp[:p] 310 } 311 312 func (self *_Proof) checkpoint() { 313 self.cp = append(self.cp, len(self.st)) 314 } 315 316 func (self *_Proof) isContradiction(st _Stmt) (ret bool) { 317 self.checkpoint() 318 self.st = append(self.st, st) 319 ret = !self.verifyCorrectness() 320 self.restore() 321 return 322 } 323 324 func (self *_Proof) verifyCorrectness() bool { 325 rt := true 326 rr := newRanges(len(self.st)) 327 sp := make([]_Stmt, 0, len(self.st)) 328 st := append([]_Stmt(nil), self.st...) 329 330 /* calculate ranges for every variable */ 331 for rt { 332 rt = false 333 sp, st = st, sp[:0] 334 335 /* update all the ranges */ 336 for _, v := range sp { 337 var f bool 338 var p _ValueTerm 339 340 /* must be a value term */ 341 if p, f = v.rhs.(_ValueTerm); !f { 342 continue 343 } 344 345 /* evaluate the range */ 346 switch x := Int65(p); v.rel { 347 default: { 348 panic("unreachable") 349 } 350 351 /* simple ranges */ 352 case _R_ne: rr.of(v.lhs).remove(x, x) 353 case _R_eq: rr.of(v.lhs).intersect(x, x) 354 case _R_ge: rr.of(v.lhs).intersect(x, MaxInt65) 355 356 /* signed less-than */ 357 case _R_lt: { 358 if x == MinInt65 { 359 return false 360 } else { 361 rr.of(v.lhs).intersect(MinInt65, x.OneLess()) 362 } 363 } 364 365 /* unsigned greater-than-or-equal-to */ 366 case _R_geu: { 367 if x.CompareZero() < 0 { 368 panic(fmt.Sprintf("unsigned comparison to a negative value %s", x)) 369 } else { 370 rr.of(v.lhs).intersect(x, MaxInt65) 371 } 372 } 373 374 /* unsigned less-than */ 375 case _R_ltu: { 376 if x.CompareZero() <= 0 { 377 panic(fmt.Sprintf("unsigned comparison to a non-positive value %s", x)) 378 } else { 379 rr.of(v.lhs).intersect(Int65{}, x.OneLess()) 380 } 381 } 382 } 383 } 384 385 /* expand all the definations */ 386 for _, v := range sp { 387 if p, ok := v.rhs.(_Stmt); ok { 388 if r, rk := rr.at(v.lhs); rk { 389 if t, tk := r.truth(); tk { 390 rt = true 391 st = append(st, p.condition(t)) 392 } 393 } 394 } 395 } 396 397 /* evaluate all the registers */ 398 for _, v := range sp { 399 var f bool 400 var x Int65 401 var r *_Range 402 var t _RegTerm 403 404 /* must be a register term with a valid range */ 405 if t, f = v.rhs.(_RegTerm) ; !f { continue } 406 if r, f = rr.at(Reg(t)) ; !f { continue } 407 408 /* empty range, already found contradictions */ 409 if len(r.rr) == 0 { 410 return false 411 } 412 413 /* update the ranges */ 414 switch v.rel { 415 default: { 416 panic("unreachable") 417 } 418 419 /* equality and inequality */ 420 case _R_ne: rr.of(v.lhs).removeRange(r) 421 case _R_eq: rr.of(v.lhs).intersectRange(r) 422 423 /* signed less-than */ 424 case _R_lt: { 425 rt = true 426 st = append(st, _Stmt { v.lhs, _ValueTerm(r.upper()), _R_lt }) 427 } 428 429 /* signed greater-than */ 430 case _R_ge: { 431 rt = true 432 st = append(st, _Stmt { v.lhs, _ValueTerm(r.lower()), _R_ge }) 433 } 434 435 /* unsigned less-than */ 436 case _R_ltu: { 437 if x, rt = r.upper(), true; x.CompareZero() > 0 { 438 st = append(st, _Stmt { v.lhs, _ValueTerm(x), _R_ltu }) 439 } else { 440 return false 441 } 442 } 443 444 /* unsigned greater-than-or-equal-to */ 445 case _R_geu: { 446 if x, rt = r.lower(), true; x.CompareZero() >= 0 { 447 st = append(st, _Stmt { v.lhs, _ValueTerm(x), _R_geu }) 448 } else { 449 st = append(st, _Stmt { v.lhs, _ValueTerm(Int65{}), _R_geu }) 450 } 451 } 452 } 453 } 454 } 455 456 /* the statements are valid iff there are no empty ranges */ 457 for _, r := range rr.rr { 458 if len(r.rr) == 0 { 459 return false 460 } 461 } 462 463 /* all checked fine */ 464 return true 465 } 466 467 // BranchElim removes branches that can be proved unreachable. 468 type BranchElim struct{} 469 470 func (self BranchElim) dfs(cfg *CFG, bb *BasicBlock, ps *_Proof) { 471 var ok bool 472 var sw *IrSwitch 473 474 /* add facts for this basic block */ 475 for _, v := range bb.Ins { 476 switch p := v.(type) { 477 default: { 478 break 479 } 480 481 /* integer constant */ 482 case *IrConstInt: { 483 ps.define(p.R, _ValueTerm(Int65i(p.V)), _R_eq) 484 } 485 486 /* binary operators */ 487 case *IrBinaryExpr: { 488 switch p.Op { 489 case IrCmpEq : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_eq) 490 case IrCmpNe : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_ne) 491 case IrCmpLt : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_lt) 492 case IrCmpLtu : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_ltu) 493 case IrCmpGeu : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_geu) 494 } 495 } 496 } 497 } 498 499 /* only care about switches */ 500 if sw, ok = bb.Term.(*IrSwitch); !ok { 501 return 502 } 503 504 /* edges to be removed */ 505 rem := lane.NewQueue() 506 del := make(map[_Edge]bool) 507 val := make([]int32, 0, len(sw.Br)) 508 509 /* prove every branch */ 510 for v, p := range sw.Br { 511 if val = append(val, v); ps.isContradiction(_Stmt { sw.V, _ValueTerm(Int65i(int64(v))), _R_eq }) { 512 delete(sw.Br, v) 513 rem.Enqueue(_Edge { bb, p.To }) 514 } 515 } 516 517 /* create a save-point */ 518 ps.checkpoint() 519 sort.Slice(val, func(i int, j int) bool { return val[i] < val[j] }) 520 521 /* add all the negated conditions */ 522 for _, i := range val { 523 ps.define(sw.V, _ValueTerm(Int65i(int64(i))), _R_ne) 524 } 525 526 /* prove the default branch */ 527 reachable := ps.verifyCorrectness() 528 ps.restore() 529 530 /* check for reachability */ 531 if !reachable { 532 if rem.Enqueue(_Edge { bb, sw.Ln.To }); len(sw.Br) != 1 { 533 sw.Ln = IrUnlikely(cfg.CreateUnreachable(bb)) 534 } else { 535 sw.Ln, sw.Br = sw.Br[val[0]], make(map[int32]*IrBranch) 536 } 537 } 538 539 /* clear register reference if needed */ 540 if len(sw.Br) == 0 { 541 sw.V = Rz 542 } 543 544 /* adjust all the edges */ 545 for !rem.Empty() { 546 e := rem.Pop().(_Edge) 547 del[e] = true 548 549 /* adjust Phi nodes in the target block */ 550 for _, v := range e.to.Phi { 551 delete(v.V, e.bb) 552 } 553 554 /* remove predecessors from the target block */ 555 for i, p := range e.to.Pred { 556 if p == e.bb { 557 e.to.Pred = append(e.to.Pred[:i], e.to.Pred[i + 1:]...) 558 break 559 } 560 } 561 562 /* remove the entire block if no more entry edges left */ 563 if len(e.to.Pred) == 0 { 564 for it := e.to.Term.Successors(); it.Next(); { 565 rem.Enqueue(_Edge { 566 bb: e.to, 567 to: it.Block(), 568 }) 569 } 570 } 571 } 572 573 /* DFS the dominator tree */ 574 for _, p := range cfg.DominatorOf[bb.Id] { 575 var f bool 576 var v _ValueTerm 577 578 /* no need to recurse into unreachable branches */ 579 if del[_Edge { bb, p }] { 580 continue 581 } 582 583 /* find the branch value */ 584 for i, b := range sw.Br { 585 if b.To == p { 586 f, v = true, _ValueTerm(Int65i(int64(i))) 587 break 588 } 589 } 590 591 /* it is not a direct successor, just pass all the facts down */ 592 if !f { 593 self.dfs(cfg, p, ps) 594 continue 595 } 596 597 /* add the fact and recurse into the node */ 598 ps.checkpoint() 599 ps.define(sw.V, v, _R_eq) 600 self.dfs(cfg, p, ps) 601 ps.restore() 602 } 603 } 604 605 func (self BranchElim) Apply(cfg *CFG) { 606 self.dfs(cfg, cfg.Root, new(_Proof)) 607 cfg.Rebuild() 608 }