github.com/cloudwego/frugal@v0.1.15/internal/atm/ssa/pass_return_spread.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ssa
    18  
    19  // ReturnSpread spreads the return block to all it's
    20  // successors, in order to shorten register live ranges.
    21  type ReturnSpread struct{}
    22  
    23  func (ReturnSpread) Apply(cfg *CFG) {
    24      more := true
    25      rets := make([]*BasicBlock, 0, 1)
    26  
    27      /* register replacer */
    28      replaceregs := func(rr map[Reg]Reg, ins IrNode) {
    29          var v Reg
    30          var ok bool
    31          var use IrUsages
    32          var def IrDefinitions
    33  
    34          /* replace register usages */
    35          if use, ok = ins.(IrUsages); ok {
    36              for _, r := range use.Usages() {
    37                  if v, ok = rr[*r]; ok {
    38                      *r = v
    39                  }
    40              }
    41          }
    42  
    43          /* replace register definitions */
    44          if def, ok = ins.(IrDefinitions); ok {
    45              for _, r := range def.Definitions() {
    46                  if v, ok = rr[*r]; ok {
    47                      *r = v
    48                  }
    49              }
    50          }
    51      }
    52  
    53      /* loop until no more modifications */
    54      for more {
    55          more = false
    56          rets = rets[:0]
    57  
    58          /* Phase 1: Find the return blocks that has more than one predecessors */
    59          for _, bb := range cfg.PostOrder().Reversed() {
    60              if _, ok := bb.Term.(*IrReturn); ok && len(bb.Pred) > 1 {
    61                  more = true
    62                  rets = append(rets, bb)
    63              }
    64          }
    65  
    66          /* Phase 2: Spread the blocks to it's predecessors */
    67          for _, bb := range rets {
    68              for _, pred := range bb.Pred {
    69                  var ok bool
    70                  var sw *IrSwitch
    71  
    72                  /* register mappings */
    73                  rr := make(map[Reg]Reg)
    74                  nb := len(bb.Phi) + len(bb.Ins)
    75  
    76                  /* allocate registers for Phi definitions */
    77                  for _, phi := range bb.Phi {
    78                      rr[phi.R] = cfg.CreateRegister(phi.R.Ptr())
    79                  }
    80  
    81                  /* allocate registers for instruction definitions */
    82                  for _, ins := range bb.Ins {
    83                      if def, ok := ins.(IrDefinitions); ok {
    84                          for _, r := range def.Definitions() {
    85                              rr[*r] = cfg.CreateRegister(r.Ptr())
    86                          }
    87                      }
    88                  }
    89  
    90                  /* create a new basic block */
    91                  ret := cfg.CreateBlock()
    92                  ret.Ins = make([]IrNode, 0, nb)
    93                  ret.Pred = []*BasicBlock { pred }
    94  
    95                  /* add copy instruction for Phi nodes */
    96                  for _, phi := range bb.Phi {
    97                      ret.Ins = append(ret.Ins, IrCopy(rr[phi.R], *phi.V[pred]))
    98                  }
    99  
   100                  /* copy all instructions */
   101                  for _, ins := range bb.Ins {
   102                      ins = ins.Clone()
   103                      ret.Ins = append(ret.Ins, ins)
   104                      replaceregs(rr, ins)
   105                  }
   106  
   107                  /* copy the terminator */
   108                  ret.Term = bb.Term.Clone().(IrTerminator)
   109                  replaceregs(rr, ret.Term)
   110  
   111                  /* link to the predecessor */
   112                  if sw, ok = pred.Term.(*IrSwitch); !ok {
   113                      panic("invalid block terminator: " + pred.Term.String())
   114                  }
   115  
   116                  /* check for default branch */
   117                  if sw.Ln.To == bb {
   118                      sw.Ln.To = ret
   119                      continue
   120                  }
   121  
   122                  /* replace the switch targets */
   123                  for v, b := range sw.Br {
   124                      if b.To == bb {
   125                          sw.Br[v] = &IrBranch {
   126                              To         : ret,
   127                              Likeliness : b.Likeliness,
   128                          }
   129                      }
   130                  }
   131              }
   132          }
   133  
   134          /* rebuild & cleanup the graph if needed */
   135          if more {
   136              cfg.Rebuild()
   137              new(BlockMerge).Apply(cfg)
   138          }
   139      }
   140  }