github.com/cloudwego/frugal@v0.1.15/internal/atm/ssa/pass_phiprop.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ssa
    18  
    19  import (
    20      `fmt`
    21  
    22      `github.com/cloudwego/frugal/internal/rt`
    23      `gonum.org/v1/gonum/graph`
    24      `gonum.org/v1/gonum/graph/simple`
    25      `gonum.org/v1/gonum/graph/topo`
    26  )
    27  
    28  const (
    29      _W_likely   = 7.0 / 8.0         // must be exactly representable to avoid float precision loss
    30      _W_unlikely = 1.0 - _W_likely
    31  )
    32  
    33  var _WeightTab = [...]float64 {
    34      Likely   : _W_likely,
    35      Unlikely : _W_unlikely,
    36  }
    37  
    38  // PhiProp propagates Phi nodes into it's source blocks,
    39  // essentially get rid of them.
    40  // The CFG is no longer in SSA form after this pass.
    41  type PhiProp struct{}
    42  
    43  func (self PhiProp) dfs(dag *simple.DirectedGraph, bb *BasicBlock, vis map[int]*BasicBlock, path map[int]struct{}) {
    44      vis[bb.Id] = bb
    45      path[bb.Id] = struct{}{}
    46  
    47      /* traverse all the sucessors */
    48      for it := bb.Term.Successors(); it.Next(); {
    49          v := it.Block()
    50          s, d := bb.Id, v.Id
    51  
    52          /* back edge */
    53          if _, ok := path[d]; ok {
    54              continue
    55          }
    56  
    57          /* forward or cross edge */
    58          p, _ := dag.NodeWithID(int64(s))
    59          q, _ := dag.NodeWithID(int64(d))
    60          dag.SetEdge(dag.NewEdge(p, q))
    61  
    62          /* visit the successor if not already */
    63          if _, ok := vis[d]; !ok {
    64              self.dfs(dag, v, vis, path)
    65          }
    66      }
    67  
    68      /* remove the node from path */
    69      if _, ok := path[bb.Id]; !ok {
    70          panic("phiprop: corrupted DFS stack")
    71      } else {
    72          delete(path, bb.Id)
    73      }
    74  }
    75  
    76  func (self PhiProp) Apply(cfg *CFG) {
    77      var err error
    78      var ord []graph.Node
    79  
    80      /* convert to DAG by removing back edges (assuming they never takes) */
    81      // FIXME: this might cause inaccuracy, loops don't affect path probabilities
    82      //  if they are looked as a whole.
    83      dag := simple.NewDirectedGraph()
    84      bbs := make(map[int]*BasicBlock, cfg.MaxBlock())
    85      self.dfs(dag, cfg.Root, bbs, make(map[int]struct{}, cfg.MaxBlock()))
    86  
    87      /* topologically sort the DAG */
    88      if ord, err = topo.Sort(dag); err != nil {
    89          panic("phiprop: topology sort: " + err.Error())
    90      }
    91  
    92      /* weight from block to another block */
    93      subs := make(map[Reg]Reg)
    94      bias := make(map[int]float64)
    95      weight := make(map[int]map[int]float64, cfg.MaxBlock())
    96  
    97      /* calculate block weights in topological order */
    98      for _, p := range ord {
    99          var in float64
   100          var sum float64
   101  
   102          /* find the basic block */
   103          id := p.ID()
   104          bb := bbs[int(id)]
   105          tr := bb.Term.Successors()
   106  
   107          /* special case for root node */
   108          if len(bb.Pred) == 0 {
   109              in = 1.0
   110          }
   111  
   112          /* add all the incoming weights */
   113          for _, v := range bb.Pred {
   114              if dag.HasEdgeFromTo(int64(v.Id), id) {
   115                  in += weight[v.Id][bb.Id]
   116              }
   117          }
   118  
   119          /* allocate the output probability map */
   120          weight[bb.Id] = make(map[int]float64)
   121          rt.MapClear(bias)
   122  
   123          /* calculate the output bias factor */
   124          for tr.Next() {
   125              if vv := tr.Block(); dag.HasEdgeFromTo(id, int64(vv.Id)) {
   126                  w := _WeightTab[tr.Likeliness()]
   127                  sum += w
   128                  bias[vv.Id] = w
   129              }
   130          }
   131  
   132          /* bias the weight with the bias factor */
   133          for i, v := range bias {
   134              weight[bb.Id][i] = in * (v / sum)
   135          }
   136      }
   137  
   138      /* choose the register with highest probability as the "primary" register */
   139      cfg.PostOrder().ForEach(func(bb *BasicBlock) {
   140          for _, p := range bb.Phi {
   141              var rs Reg
   142              var ps float64
   143              var pp float64
   144  
   145              /* find the branch with the hightest probability */
   146              for b, r := range p.V {
   147                  if pp = weight[b.Id][bb.Id]; ps < pp {
   148                      rs = *r
   149                      ps = pp
   150                  }
   151              }
   152  
   153              /* mark the substitution */
   154              if _, ok := subs[p.R]; ok {
   155                  panic(fmt.Sprintf("phiprop: duplicated substitution: %s -> %s", p.R, rs))
   156              } else {
   157                  subs[p.R] = rs
   158              }
   159          }
   160      })
   161  
   162      /* register substitution routine */
   163      substitute := func(rr []*Reg) {
   164          for _, r := range rr {
   165              if d, ok := subs[*r]; ok {
   166                  for subs[d] != 0 { d = subs[d] }
   167                  *r = d
   168              }
   169          }
   170      }
   171  
   172      /* substitute every register */
   173      cfg.PostOrder().ForEach(func(bb *BasicBlock) {
   174          var ok  bool
   175          var use IrUsages
   176          var def IrDefinitions
   177  
   178          /* process Phi nodes */
   179          for _, v := range bb.Phi {
   180              substitute(v.Usages())
   181              substitute(v.Definitions())
   182          }
   183  
   184          /* process instructions */
   185          for _, v := range bb.Ins {
   186              if use, ok = v.(IrUsages)      ; ok { substitute(use.Usages()) }
   187              if def, ok = v.(IrDefinitions) ; ok { substitute(def.Definitions()) }
   188          }
   189  
   190          /* process the terminator */
   191          if use, ok = bb.Term.(IrUsages); ok {
   192              substitute(use.Usages())
   193          }
   194      })
   195  
   196      /* propagate Phi nodes upward */
   197      cfg.PostOrder().ForEach(func(bb *BasicBlock) {
   198          pp := bb.Phi
   199          bb.Phi = nil
   200  
   201          /* process every Phi node */
   202          for _, p := range pp {
   203              for b, r := range p.V {
   204                  if *r != p.R {
   205                      b.Ins = append(b.Ins, IrArchCopy(p.R, *r))
   206                  }
   207              }
   208          }
   209      })
   210  }