github.com/cloudwego/frugal@v0.1.7/internal/atm/ssa/pass_branchelim.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ssa
    18  
    19  import (
    20      `fmt`
    21      `sort`
    22      `strings`
    23  
    24      `github.com/oleiade/lane`
    25  )
    26  
    27  type _Term interface {
    28      fmt.Stringer
    29      term()
    30  }
    31  
    32  type (
    33      _TrRel     uint8
    34      _RegTerm   Reg
    35      _ValueTerm Int65
    36  )
    37  
    38  func (_Stmt)      term() {}
    39  func (_RegTerm)   term() {}
    40  func (_ValueTerm) term() {}
    41  
    42  func (self _RegTerm)   String() string { return Reg(self).String() }
    43  func (self _ValueTerm) String() string { return Int65(self).String() }
    44  
    45  const (
    46      _R_eq _TrRel = iota
    47      _R_ne
    48      _R_lt
    49      _R_ltu
    50      _R_ge
    51      _R_geu
    52  )
    53  
    54  func (self _TrRel) String() string {
    55      switch self {
    56          case _R_eq  : return "=="
    57          case _R_ne  : return "!="
    58          case _R_lt  : return "<"
    59          case _R_ltu : return "<#"
    60          case _R_ge  : return ">="
    61          case _R_geu : return ">=#"
    62          default     : panic("unreachable")
    63      }
    64  }
    65  
    66  type _Edge struct {
    67      bb *BasicBlock
    68      to *BasicBlock
    69  }
    70  
    71  func (self _Edge) String() string {
    72      return fmt.Sprintf("bb_%d => bb_%d", self.bb.Id, self.to.Id)
    73  }
    74  
    75  type _Stmt struct {
    76      lhs Reg
    77      rhs _Term
    78      rel _TrRel
    79  }
    80  
    81  func (self _Stmt) String() string {
    82      return fmt.Sprintf("%s %s %s", self.lhs, self.rel, self.rhs)
    83  }
    84  
    85  func (self _Stmt) negated() _Stmt {
    86      switch self.rel {
    87          case _R_eq  : return _Stmt { self.lhs, self.rhs, _R_ne }
    88          case _R_ne  : return _Stmt { self.lhs, self.rhs, _R_eq }
    89          case _R_lt  : return _Stmt { self.lhs, self.rhs, _R_ge }
    90          case _R_ltu : return _Stmt { self.lhs, self.rhs, _R_geu }
    91          case _R_ge  : return _Stmt { self.lhs, self.rhs, _R_lt }
    92          case _R_geu : return _Stmt { self.lhs, self.rhs, _R_ltu }
    93          default     : panic("unreachable")
    94      }
    95  }
    96  
    97  func (self _Stmt) condition(cond bool) _Stmt {
    98      if cond {
    99          return self
   100      } else {
   101          return self.negated()
   102      }
   103  }
   104  
   105  type _Range struct {
   106      rr []Int65
   107  }
   108  
   109  func newRange(lower Int65, upper Int65) *_Range {
   110      return &_Range {
   111          rr: []Int65 { lower, upper },
   112      }
   113  }
   114  
   115  func (self *_Range) lower() Int65 {
   116      if len(self.rr) == 0 {
   117          panic("empty range")
   118      } else {
   119          return self.rr[0]
   120      }
   121  }
   122  
   123  func (self *_Range) upper() Int65 {
   124      if n := len(self.rr); n == 0 {
   125          panic("empty range")
   126      } else {
   127          return self.rr[n - 1]
   128      }
   129  }
   130  
   131  func (self *_Range) truth() (bool, bool) {
   132      var lower Int65
   133      var upper Int65
   134  
   135      /* empty range */
   136      if len(self.rr) == 0 {
   137          return false, false
   138      }
   139  
   140      /* fast path: there is only one range */
   141      if len(self.rr) == 2 {
   142          if self.rr[0].CompareZero() == 0 && self.rr[1].CompareZero() == 0 {
   143              return false, true
   144          } else if self.rr[0].CompareZero() > 0 || self.rr[1].CompareZero() < 0 {
   145              return true, true
   146          } else {
   147              return false, false
   148          }
   149      }
   150  
   151      /* check if any range contains the zero */
   152      for i := 0; i < len(self.rr); i += 2 {
   153          lower = self.rr[i]
   154          upper = self.rr[i + 1]
   155  
   156          /* the range contains zero, the truth cannot be determained */
   157          if lower.CompareZero() <= 0 && upper.CompareZero() >= 0 {
   158              return false, false
   159          }
   160      }
   161  
   162      /* no, the range can be interpreted as true */
   163      return true, true
   164  }
   165  
   166  func (self *_Range) remove(lower Int65, upper Int65) {
   167      for i := 0; i < len(self.rr); i += 2 {
   168          l := self.rr[i]
   169          u := self.rr[i + 1]
   170  
   171          /* not intersecting */
   172          if lower.Compare(u) > 0 { break }
   173          if upper.Compare(l) < 0 { continue }
   174  
   175          /* splicing */
   176          if lower.Compare(l) > 0 && upper.Compare(u) < 0 {
   177              next := []Int65 { l, lower.OneLess(), upper.OneMore(), u }
   178              self.rr = append(self.rr[:i], append(next, self.rr[i + 2:]...)...)
   179              i += 2
   180              break
   181          }
   182  
   183          /* remove the upper half */
   184          if lower.Compare(l) > 0 {
   185              self.rr[i + 1] = lower.OneLess()
   186              continue
   187          }
   188  
   189          /* remove the lower half */
   190          if upper.Compare(u) < 0 {
   191              self.rr[i] = upper.OneMore()
   192              break
   193          }
   194  
   195          /* remove the entire range */
   196          copy(self.rr[i:], self.rr[i + 2:])
   197          self.rr = self.rr[:len(self.rr) - 2]
   198          i -= 2
   199      }
   200  }
   201  
   202  func (self *_Range) intersect(lower Int65, upper Int65) {
   203      if lower != MinInt65 { self.remove(MinInt65, lower.OneLess()) }
   204      if upper != MaxInt65 { self.remove(upper.OneMore(), MaxInt65) }
   205  }
   206  
   207  func (self *_Range) removeRange(r *_Range) {
   208      for i := 0; i < len(r.rr); i += 2 {
   209          self.remove(r.rr[i], r.rr[i + 1])
   210      }
   211  }
   212  
   213  func (self *_Range) intersectRange(r *_Range) {
   214      for i := 0; i < len(r.rr); i += 2 {
   215          self.intersect(r.rr[i], r.rr[i + 1])
   216      }
   217  }
   218  
   219  func (self *_Range) String() string {
   220      nb := len(self.rr)
   221      rb := make([]string, nb / 2)
   222  
   223      /* empty ranges */
   224      if nb == 0 {
   225          return "{ (empty) }"
   226      }
   227  
   228      /* dump every range */
   229      for i := 0; i < nb; i += 2 {
   230          l := self.rr[i]
   231          u := self.rr[i + 1]
   232          s := new(strings.Builder)
   233  
   234          /* lower bounds */
   235          if s.WriteRune('['); l == MinInt65 {
   236              s.WriteString("-∞")
   237          } else {
   238              s.WriteString(fmt.Sprint(l))
   239          }
   240  
   241          /* upper bounds */
   242          if s.WriteString(", "); u == MaxInt65 {
   243              s.WriteString("+∞")
   244          } else {
   245              s.WriteString(fmt.Sprint(u))
   246          }
   247  
   248          /* build the range */
   249          s.WriteRune(']')
   250          rb[i / 2] = s.String()
   251      }
   252  
   253      /* join them together */
   254      return fmt.Sprintf(
   255          "{ %s }",
   256          strings.Join(rb, " ∪ "),
   257      )
   258  }
   259  
   260  type _Ranges struct {
   261      rr map[Reg]*_Range
   262  }
   263  
   264  func newRanges(nb int) (r _Ranges) {
   265      r.rr = make(map[Reg]*_Range, nb)
   266      r.rr[Rz] = newRange(Int65{}, Int65{})
   267      return
   268  }
   269  
   270  func (self _Ranges) of(reg Reg) (r *_Range) {
   271      var ok bool
   272      var rr *_Range
   273  
   274      /* check for existing range */
   275      if rr, ok = self.rr[reg]; ok {
   276          return rr
   277      }
   278  
   279      /* create a new one if needed */
   280      rr = newRange(MinInt65, MaxInt65)
   281      self.rr[reg] = rr
   282      return rr
   283  }
   284  
   285  func (self _Ranges) at(reg Reg) (r *_Range, ok bool) {
   286      r, ok = self.rr[reg]
   287      return
   288  }
   289  
   290  type _Proof struct {
   291      cp []int
   292      st []_Stmt
   293  }
   294  
   295  func (self *_Proof) define(lhs Reg, rhs _Term, rel _TrRel) {
   296      self.st = append(self.st, _Stmt { lhs, rhs, rel })
   297  }
   298  
   299  func (self *_Proof) assume(ref Reg, lhs Reg, rhs _Term, rel _TrRel) {
   300      self.st = append(self.st, _Stmt {
   301          lhs: ref,
   302          rel: _R_eq,
   303          rhs: _Stmt { lhs, rhs, rel },
   304      })
   305  }
   306  
   307  func (self *_Proof) restore() {
   308      p := len(self.cp) - 1
   309      self.st, self.cp = self.st[:self.cp[p]], self.cp[:p]
   310  }
   311  
   312  func (self *_Proof) checkpoint() {
   313      self.cp = append(self.cp, len(self.st))
   314  }
   315  
   316  func (self *_Proof) isContradiction(st _Stmt) (ret bool) {
   317      self.checkpoint()
   318      self.st = append(self.st, st)
   319      ret = !self.verifyCorrectness()
   320      self.restore()
   321      return
   322  }
   323  
   324  func (self *_Proof) verifyCorrectness() bool {
   325      rt := true
   326      rr := newRanges(len(self.st))
   327      sp := make([]_Stmt, 0, len(self.st))
   328      st := append([]_Stmt(nil), self.st...)
   329  
   330      /* calculate ranges for every variable */
   331      for rt {
   332          rt = false
   333          sp, st = st, sp[:0]
   334  
   335          /* update all the ranges */
   336          for _, v := range sp {
   337              var f bool
   338              var p _ValueTerm
   339  
   340              /* must be a value term */
   341              if p, f = v.rhs.(_ValueTerm); !f {
   342                  continue
   343              }
   344  
   345              /* evaluate the range */
   346              switch x := Int65(p); v.rel {
   347                  default: {
   348                      panic("unreachable")
   349                  }
   350  
   351                  /* simple ranges */
   352                  case _R_ne: rr.of(v.lhs).remove(x, x)
   353                  case _R_eq: rr.of(v.lhs).intersect(x, x)
   354                  case _R_ge: rr.of(v.lhs).intersect(x, MaxInt65)
   355  
   356                  /* signed less-than */
   357                  case _R_lt: {
   358                      if x == MinInt65 {
   359                          return false
   360                      } else {
   361                          rr.of(v.lhs).intersect(MinInt65, x.OneLess())
   362                      }
   363                  }
   364  
   365                  /* unsigned greater-than-or-equal-to */
   366                  case _R_geu: {
   367                      if x.CompareZero() < 0 {
   368                          panic(fmt.Sprintf("unsigned comparison to a negative value %s", x))
   369                      } else {
   370                          rr.of(v.lhs).intersect(x, MaxInt65)
   371                      }
   372                  }
   373  
   374                  /* unsigned less-than */
   375                  case _R_ltu: {
   376                      if x.CompareZero() <= 0 {
   377                          panic(fmt.Sprintf("unsigned comparison to a non-positive value %s", x))
   378                      } else {
   379                          rr.of(v.lhs).intersect(Int65{}, x.OneLess())
   380                      }
   381                  }
   382              }
   383          }
   384  
   385          /* expand all the definations */
   386          for _, v := range sp {
   387              if p, ok := v.rhs.(_Stmt); ok {
   388                  if r, rk := rr.at(v.lhs); rk {
   389                      if t, tk := r.truth(); tk {
   390                          rt = true
   391                          st = append(st, p.condition(t))
   392                      }
   393                  }
   394              }
   395          }
   396  
   397          /* evaluate all the registers */
   398          for _, v := range sp {
   399              var f bool
   400              var x Int65
   401              var r *_Range
   402              var t _RegTerm
   403  
   404              /* must be a register term with a valid range */
   405              if t, f = v.rhs.(_RegTerm) ; !f { continue }
   406              if r, f = rr.at(Reg(t))    ; !f { continue }
   407  
   408              /* empty range, already found contradictions */
   409              if len(r.rr) == 0 {
   410                  return false
   411              }
   412  
   413              /* update the ranges */
   414              switch v.rel {
   415                  default: {
   416                      panic("unreachable")
   417                  }
   418  
   419                  /* equality and inequality */
   420                  case _R_ne: rr.of(v.lhs).removeRange(r)
   421                  case _R_eq: rr.of(v.lhs).intersectRange(r)
   422  
   423                  /* signed less-than */
   424                  case _R_lt: {
   425                      rt = true
   426                      st = append(st, _Stmt { v.lhs, _ValueTerm(r.upper()), _R_lt })
   427                  }
   428  
   429                  /* signed greater-than */
   430                  case _R_ge: {
   431                      rt = true
   432                      st = append(st, _Stmt { v.lhs, _ValueTerm(r.lower()), _R_ge })
   433                  }
   434  
   435                  /* unsigned less-than */
   436                  case _R_ltu: {
   437                      if x, rt = r.upper(), true; x.CompareZero() > 0 {
   438                          st = append(st, _Stmt { v.lhs, _ValueTerm(x), _R_ltu })
   439                      } else {
   440                          return false
   441                      }
   442                  }
   443  
   444                  /* unsigned greater-than-or-equal-to */
   445                  case _R_geu: {
   446                      if x, rt = r.lower(), true; x.CompareZero() >= 0 {
   447                          st = append(st, _Stmt { v.lhs, _ValueTerm(x), _R_geu })
   448                      } else {
   449                          st = append(st, _Stmt { v.lhs, _ValueTerm(Int65{}), _R_geu })
   450                      }
   451                  }
   452              }
   453          }
   454      }
   455  
   456      /* the statements are valid iff there are no empty ranges */
   457      for _, r := range rr.rr {
   458          if len(r.rr) == 0 {
   459              return false
   460          }
   461      }
   462  
   463      /* all checked fine */
   464      return true
   465  }
   466  
   467  // BranchElim removes branches that can be proved unreachable.
   468  type BranchElim struct{}
   469  
   470  func (self BranchElim) dfs(cfg *CFG, bb *BasicBlock, ps *_Proof) {
   471      var ok bool
   472      var sw *IrSwitch
   473  
   474      /* add facts for this basic block */
   475      for _, v := range bb.Ins {
   476          switch p := v.(type) {
   477              default: {
   478                  break
   479              }
   480  
   481              /* integer constant */
   482              case *IrConstInt: {
   483                  ps.define(p.R, _ValueTerm(Int65i(p.V)), _R_eq)
   484              }
   485  
   486              /* binary operators */
   487              case *IrBinaryExpr: {
   488                  switch p.Op {
   489                      case IrCmpEq  : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_eq)
   490                      case IrCmpNe  : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_ne)
   491                      case IrCmpLt  : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_lt)
   492                      case IrCmpLtu : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_ltu)
   493                      case IrCmpGeu : ps.assume(p.R, p.X, _RegTerm(p.Y), _R_geu)
   494                  }
   495              }
   496          }
   497      }
   498  
   499      /* only care about switches */
   500      if sw, ok = bb.Term.(*IrSwitch); !ok {
   501          return
   502      }
   503  
   504      /* edges to be removed */
   505      rem := lane.NewQueue()
   506      del := make(map[_Edge]bool)
   507      val := make([]int32, 0, len(sw.Br))
   508  
   509      /* prove every branch */
   510      for v, p := range sw.Br {
   511          if val = append(val, v); ps.isContradiction(_Stmt { sw.V, _ValueTerm(Int65i(int64(v))), _R_eq }) {
   512              delete(sw.Br, v)
   513              rem.Enqueue(_Edge { bb, p.To })
   514          }
   515      }
   516  
   517      /* create a save-point */
   518      ps.checkpoint()
   519      sort.Slice(val, func(i int, j int) bool { return val[i] < val[j] })
   520  
   521      /* add all the negated conditions */
   522      for _, i := range val {
   523          ps.define(sw.V, _ValueTerm(Int65i(int64(i))), _R_ne)
   524      }
   525  
   526      /* prove the default branch */
   527      reachable := ps.verifyCorrectness()
   528      ps.restore()
   529  
   530      /* check for reachability */
   531      if !reachable {
   532          if rem.Enqueue(_Edge { bb, sw.Ln.To }); len(sw.Br) != 1 {
   533              sw.Ln = IrUnlikely(cfg.CreateUnreachable(bb))
   534          } else {
   535              sw.Ln, sw.Br = sw.Br[val[0]], make(map[int32]*IrBranch)
   536          }
   537      }
   538  
   539      /* clear register reference if needed */
   540      if len(sw.Br) == 0 {
   541          sw.V = Rz
   542      }
   543  
   544      /* adjust all the edges */
   545      for !rem.Empty() {
   546          e := rem.Pop().(_Edge)
   547          del[e] = true
   548  
   549          /* adjust Phi nodes in the target block */
   550          for _, v := range e.to.Phi {
   551              delete(v.V, e.bb)
   552          }
   553  
   554          /* remove predecessors from the target block */
   555          for i, p := range e.to.Pred {
   556              if p == e.bb {
   557                  e.to.Pred = append(e.to.Pred[:i], e.to.Pred[i + 1:]...)
   558                  break
   559              }
   560          }
   561  
   562          /* remove the entire block if no more entry edges left */
   563          if len(e.to.Pred) == 0 {
   564              for it := e.to.Term.Successors(); it.Next(); {
   565                  rem.Enqueue(_Edge {
   566                      bb: e.to,
   567                      to: it.Block(),
   568                  })
   569              }
   570          }
   571      }
   572  
   573      /* DFS the dominator tree */
   574      for _, p := range cfg.DominatorOf[bb.Id] {
   575          var f bool
   576          var v _ValueTerm
   577  
   578          /* no need to recurse into unreachable branches */
   579          if del[_Edge { bb, p }] {
   580              continue
   581          }
   582  
   583          /* find the branch value */
   584          for i, b := range sw.Br {
   585              if b.To == p {
   586                  f, v = true, _ValueTerm(Int65i(int64(i)))
   587                  break
   588              }
   589          }
   590  
   591          /* it is not a direct successor, just pass all the facts down */
   592          if !f {
   593              self.dfs(cfg, p, ps)
   594              continue
   595          }
   596  
   597          /* add the fact and recurse into the node */
   598          ps.checkpoint()
   599          ps.define(sw.V, v, _R_eq)
   600          self.dfs(cfg, p, ps)
   601          ps.restore()
   602      }
   603  }
   604  
   605  func (self BranchElim) Apply(cfg *CFG) {
   606      self.dfs(cfg, cfg.Root, new(_Proof))
   607      cfg.Rebuild()
   608  }