github.com/aclements/go-misc@v0.0.0-20240129233631-2f6ede80790c/stackmapcompress.py (about)

     1  # -*- indent-tabs-mode: nil -*-
     2  
     3  # Parse output of "go build -gcflags=all=-S -a cmd/go >& /tmp/go.s" and
     4  # compress register liveness maps in various ways.
     5  
     6  import re
     7  import sys
     8  import collections
     9  
    10  if True:
    11      # Register maps
    12      FUNCDATA = "3"
    13      PCDATA = "2"
    14  else:
    15      # Stack maps
    16      FUNCDATA = "1" # Locals (not args)
    17      PCDATA = "0"
    18  
    19  class Stackmap:
    20      def __init__(self, dec=None):
    21          if dec is None:
    22              self.n = self.nbit = 0
    23              self.bitmaps = []
    24          else:
    25              # Decode Go encoding of a runtime.stackmap.
    26              n = dec.int32()
    27              self.nbit = dec.int32()
    28              self.bitmaps = [dec.bitmap(self.nbit) for i in range(n)]
    29  
    30      def clone(self):
    31          enc = Encoder()
    32          self.encode(enc)
    33          return Stackmap(Decoder(enc.b))
    34  
    35      def add(self, bitmap):
    36          nbit, b2 = 0, bitmap
    37          while b2 != 0:
    38              nbit += 1
    39              b2 >>= 1
    40          self.nbit = max(nbit, self.nbit)
    41          for i, b2 in enumerate(self.bitmaps):
    42              if bitmap == b2:
    43                  return i
    44          self.bitmaps.append(bitmap)
    45          return len(self.bitmaps)-1
    46  
    47      def sort(self):
    48          s = sorted((b, i) for i, b in enumerate(self.bitmaps))
    49          self.bitmaps = [b for b, i in s]
    50          return [i for b, i in s]
    51  
    52      def encode(self, enc, compact=False):
    53          enc.int32(len(self.bitmaps))
    54          if compact:
    55              enc.uint8(self.nbit)
    56              combined = 0
    57              for i, b in enumerate(self.bitmaps):
    58                  combined |= b << (i * self.nbit)
    59              enc.bitmap(combined, len(self.bitmaps) * self.nbit)
    60          else:
    61              enc.int32(self.nbit)
    62              for b in self.bitmaps:
    63                  enc.bitmap(b, self.nbit)
    64  
    65  class PCData:
    66      def __init__(self):
    67          self.pcdata = []
    68  
    69      def encode(self, enc):
    70          last = (0, 0)
    71          for e in self.pcdata:
    72              enc.uvarint(e[0] - last[0])
    73              enc.svarint(e[1] - last[1])
    74              last = e
    75          enc.uint8(0)
    76  
    77      def huffSize(self, pcHuff, valHuff):
    78          bits = 0
    79          lastPC = 0
    80          for pc, val in self.pcdata:
    81              bits += pcHuff[pc - lastPC][1] + valHuff[val][1]
    82              lastPC = pc
    83          return (bits + 7) // 8
    84  
    85      def grSize(self, pcHuff, n):
    86          bits = 0
    87          lastPC = 0
    88          for pc, val in self.pcdata:
    89              bits += pcHuff[pc - lastPC][1]
    90              lastPC = pc
    91              bits += grSize(val + 1, n)
    92          return (bits + 7) // 8
    93  
    94  def grSize(val, n):
    95      """The number of bits in the Golomb-Rice coding of val in base 2^n."""
    96      return 1 + (val >> n) + n
    97  
    98  class Decoder:
    99      def __init__(self, b):
   100          self.b = memoryview(b)
   101  
   102      def int32(self):
   103          b = self.b
   104          self.b = b[4:]
   105          return b[0] + (b[1] << 8) + (b[2] << 16) + (b[3] << 24)
   106  
   107      def bitmap(self, nbits):
   108          bitmap = 0
   109          nbytes = (nbits + 7) // 8
   110          for i in range(nbytes):
   111              bitmap = bitmap | (self.b[i] << (i*8))
   112          self.b = self.b[nbytes:]
   113          return bitmap
   114  
   115  class Encoder:
   116      def __init__(self):
   117          self.b = bytearray()
   118  
   119      def uint8(self, i):
   120          self.b.append(i)
   121  
   122      def int32(self, i):
   123          self.b.extend([i&0xFF, (i>>8)&0xFF, (i>>16)&0xFF, (i>>24)&0xFF])
   124  
   125      def bitmap(self, bits, nbits):
   126          for i in range((nbits + 7) // 8):
   127              self.b.append((bits >> (i*8)) & 0xFF)
   128  
   129      def uvarint(self, v):
   130          if v < 0:
   131              raise ValueError("negative unsigned varint", v)
   132          while v > 0x7f:
   133              self.b.append((v & 0x7f) | 0x80)
   134              v >>= 7
   135          self.b.append(v)
   136  
   137      def svarint(self, v):
   138          ux = v << 1
   139          if v < 0:
   140              ux = ~ux
   141          self.uvarint(ux)
   142  
   143  def parse(stream):
   144      import parseasm
   145      objs = parseasm.parse(stream)
   146      fns = []
   147      for obj in objs.values():
   148          if not isinstance(obj, parseasm.Func):
   149              continue
   150          fns.append(obj)
   151          obj.regMaps = []        # [(pc, register bitmap)]
   152          regMap = None
   153          for inst in obj.insts:
   154              if inst.asm.startswith("FUNCDATA\t$"+FUNCDATA+", "):
   155                  regMapSym = inst.asm.split(" ")[1][:-4]
   156                  regMap = Stackmap(Decoder(objs[regMapSym].data))
   157              elif inst.asm.startswith("PCDATA\t$"+PCDATA+", "):
   158                  idx = int(inst.asm.split(" ")[1][1:])
   159                  obj.regMaps.append((inst.pc, regMap.bitmaps[idx]))
   160      return fns
   161  
   162  def genStackMaps(fns, padToByte=True, dedup=True, sortBitmaps=False):
   163      regMapSet = {}
   164  
   165      for fn in fns:
   166          # Create pcdata and register map for fn.
   167          fn.pcdataRegs = PCData()
   168          fn.funcdataRegMap = Stackmap()
   169          for (pc, bitmap) in fn.regMaps:
   170              fn.pcdataRegs.pcdata.append((pc, fn.funcdataRegMap.add(bitmap)))
   171  
   172          if sortBitmaps:
   173              remap = regMap.sort()
   174              pcdata.pcdata = [(pc, remap[idx]) for pc, idx in pcdata.pcdata]
   175  
   176          # Encode and dedup register maps.
   177          if dedup:
   178              e = Encoder()
   179              fn.funcdataRegMap.encode(e, not padToByte)
   180              regMap = bytes(e.b)
   181              if regMap in regMapSet:
   182                  fn.funcdataRegMap = regMapSet[regMap]
   183              else:
   184                  regMapSet[regMap] = fn.funcdataRegMap
   185          else:
   186              regMapSet[fn] = fn.funcdataRegMap
   187  
   188      return regMapSet.values()
   189  
   190  def likeStackMap(fns, padToByte=True, dedup=True, sortBitmaps=None, huffmanPcdata=False, grPcdata=False):
   191      regMapSet = set()
   192      regMaps = bytearray()
   193      pcdatas = [] #Encoder()
   194      extra = 0
   195      for fn in fns:
   196          # Create pcdata and register map for fn.
   197          pcdata = PCData()
   198          regMap = Stackmap()
   199          if sortBitmaps == "freq":
   200              # Pre-populate regMap in frequency order.
   201              regMapFreq = collections.Counter()
   202              for pc, bitmap in fn.regMaps:
   203                  regMapFreq[bitmap] += 1
   204              for bitmap, freq in sorted(regMapFreq.items(), key=lambda item: item[1], reverse=True):
   205                  regMap.add(bitmap)
   206          for pc, bitmap in fn.regMaps:
   207              pcdata.pcdata.append((pc, regMap.add(bitmap)))
   208  
   209          if sortBitmaps == "value":
   210              remap = regMap.sort()
   211              pcdata.pcdata = [(pc, remap[idx]) for pc, idx in pcdata.pcdata]
   212  
   213          pcdatas.append(pcdata)
   214  
   215          # Encode register map and dedup.
   216          e = Encoder()
   217          regMap.encode(e, not padToByte)
   218          regMap = bytes(e.b)
   219          if not dedup or regMap not in regMapSet:
   220              regMapSet.add(regMap)
   221              regMaps.extend(regMap)
   222  
   223          extra += 8 + 4 # funcdata pointer, pcdata table offset
   224  
   225      # Encode pcdata.
   226      pcdataEnc = Encoder()
   227      if huffmanPcdata or grPcdata:
   228          pcDeltas, _ = countDeltas(fns)
   229          pcdataHist = collections.Counter()
   230          for pcdata in pcdatas:
   231              for _, idx in pcdata.pcdata:
   232                  pcdataHist[idx] += 1
   233          pcHuff = huffman(pcDeltas)
   234          pcdataHuff = huffman(pcdataHist)
   235          size = 0
   236          for pcdata in pcdatas:
   237              if huffmanPcdata:
   238                  size += pcdata.huffSize(pcHuff, pcdataHuff)
   239              elif grPcdata:
   240                  size += pcdata.grSize(pcHuff, grPcdata)
   241          pcdataEnc.b = "\0" * size # Whatever
   242      else:
   243          for pcdata in pcdatas:
   244              pcdata.encode(pcdataEnc)
   245  
   246      return {"gclocals": len(regMaps), "pcdata": len(pcdataEnc.b), "extra": extra}
   247  
   248  def filterLiveToDead(fns):
   249      # Only emit pcdata if something becomes newly-live (this is a
   250      # lower bound on what the "don't care" optimization could
   251      # achieve).
   252      for fn in fns:
   253          newRegMaps = []
   254          prevBitmap = 0
   255          for (pc, bitmap) in fn.regMaps:
   256              if bitmap is None:
   257                  newRegIdx.append((pc, None))
   258                  prevBitmap = 0
   259                  continue
   260              if bitmap & ~prevBitmap != 0:
   261                  # New bits set.
   262                  newRegMaps.append((pc, bitmap))
   263              prevBitmap = bitmap
   264          fn.regMaps = newRegMaps
   265  
   266  def total(dct):
   267      dct["total"] = 0
   268      dct["total"] = sum(dct.values())
   269      return dct
   270  
   271  def iterDeltas(regMaps):
   272      prevPC = prevBitmap = 0
   273      for (pc, bitmap) in regMaps:
   274          pcDelta = pc - prevPC
   275          prevPC = pc
   276  
   277          if bitmap is None:
   278              bitmapDelta = None
   279              prevBitmap = 0
   280          else:
   281              bitmapDelta = bitmap ^ prevBitmap
   282              prevBitmap = bitmap
   283  
   284          yield pcDelta, bitmapDelta
   285  
   286  def countMaps(fns):
   287      maps = collections.Counter()
   288      for fn in fns:
   289          for _, bitmap in fn.regMaps:
   290              maps[bitmap] += 1
   291      return maps
   292  
   293  def countDeltas(fns):
   294      pcDeltas, deltas = collections.Counter(), collections.Counter()
   295      # This actually spreads out the head of the distribution quite a bit
   296      # because things are more likely to die in clumps and at the same time
   297      # as something else becomes live.
   298      #filterLiveToDead(fns)
   299      for fn in fns:
   300          for pcDelta, bitmapDelta in iterDeltas(fn.regMaps):
   301              pcDeltas[pcDelta] += 1
   302              deltas[bitmapDelta] += 1
   303      return pcDeltas, deltas
   304  
   305  def huffman(counts, streamAlign=1):
   306      code = [(count, val) for val, count in counts.items()]
   307      radix = 2**streamAlign
   308      while len(code) > 1:
   309          code.sort(key=lambda x: x[0], reverse=True)
   310          if len(code) < radix:
   311              children, code = code, []
   312          else:
   313              children, code = code[-radix:], code[:-radix]
   314          code.append((sum(child[0] for child in children),
   315                       [child[1] for child in children]))
   316      tree = {}
   317      def mktree(node, codeword, bits):
   318          if isinstance(node, list):
   319              for i, child in enumerate(node):
   320                  mktree(child, (codeword << streamAlign) + i, bits + streamAlign)
   321          else:
   322              tree[node] = (codeword, bits)
   323      mktree(code[0][1], 0, 0)
   324      return tree
   325  
   326  def huffmanCoded(fns, streamAlign=1):
   327      pcDeltas, maskDeltas = countDeltas(fns)
   328      hPCs = huffman(pcDeltas, streamAlign)
   329      hBitmaps = huffman(maskDeltas, streamAlign)
   330  
   331      pcdataBits = 0
   332      extra = 0
   333      for fn in fns:
   334          for pcDelta, bitmapDelta in iterDeltas(fn.regMaps):
   335              pcdataBits += hPCs[pcDelta][1] + hBitmaps[bitmapDelta][1]
   336          pcdataBits = (pcdataBits + 7) &~ 7 # Byte align
   337          extra += 4                         # PCDATA
   338      return {"pcdata": (pcdataBits + 7) // 8, "extra": extra}
   339  fns = parse(sys.stdin)
   340  
   341  if True:
   342      print(total(likeStackMap(fns)))
   343      # Linker dedup of gclocals reduces gclocals by >2X
   344      #print(total(likeStackMap(fns, dedup=False)))
   345      #print(total(likeStackMap(fns, sortBitmaps="value")))
   346      # 'total': 529225, 'pcdata': 292703, 'gclocals': 77558, 'extra': 158964
   347      print(total(likeStackMap(fns, huffmanPcdata=True)))
   348      print(total(likeStackMap(fns, huffmanPcdata=True, sortBitmaps="freq")))
   349      for n in range(0, 8):
   350          print(n, total(likeStackMap(fns, grPcdata=n, sortBitmaps="freq")))
   351      #print(total(likeStackMap(fns, compactBitmap=True)))
   352      # 'total': 407999, 'pcdata': 302023, 'extra': 105976
   353      print(total(huffmanCoded(fns)))
   354      print(total(huffmanCoded(fns, streamAlign=8)))
   355      # Only emitting on newly live reduces pcdata by 42%, gclocals by 10%
   356      filterLiveToDead(fns)
   357      print(total(likeStackMap(fns)))
   358  
   359  if False:
   360      # What do the bitmaps look like?
   361      counts = countMaps(fns)
   362      for bitmap, count in counts.items():
   363          print(count, bin(bitmap))
   364  
   365  if False:
   366      # What do the bitmap changes look like?
   367      _, deltas = countDeltas(fns)
   368      for delta, count in deltas.items():
   369          print(count, bin(delta))
   370  
   371  if False:
   372      # PC delta histogram
   373      pcDeltaHist, _ = countDeltas(fns)
   374      for delta, count in pcDeltaHist.items():
   375          print(count, delta)