github.com/aclements/go-misc@v0.0.0-20240129233631-2f6ede80790c/stackmapcompress.py (about) 1 # -*- indent-tabs-mode: nil -*- 2 3 # Parse output of "go build -gcflags=all=-S -a cmd/go >& /tmp/go.s" and 4 # compress register liveness maps in various ways. 5 6 import re 7 import sys 8 import collections 9 10 if True: 11 # Register maps 12 FUNCDATA = "3" 13 PCDATA = "2" 14 else: 15 # Stack maps 16 FUNCDATA = "1" # Locals (not args) 17 PCDATA = "0" 18 19 class Stackmap: 20 def __init__(self, dec=None): 21 if dec is None: 22 self.n = self.nbit = 0 23 self.bitmaps = [] 24 else: 25 # Decode Go encoding of a runtime.stackmap. 26 n = dec.int32() 27 self.nbit = dec.int32() 28 self.bitmaps = [dec.bitmap(self.nbit) for i in range(n)] 29 30 def clone(self): 31 enc = Encoder() 32 self.encode(enc) 33 return Stackmap(Decoder(enc.b)) 34 35 def add(self, bitmap): 36 nbit, b2 = 0, bitmap 37 while b2 != 0: 38 nbit += 1 39 b2 >>= 1 40 self.nbit = max(nbit, self.nbit) 41 for i, b2 in enumerate(self.bitmaps): 42 if bitmap == b2: 43 return i 44 self.bitmaps.append(bitmap) 45 return len(self.bitmaps)-1 46 47 def sort(self): 48 s = sorted((b, i) for i, b in enumerate(self.bitmaps)) 49 self.bitmaps = [b for b, i in s] 50 return [i for b, i in s] 51 52 def encode(self, enc, compact=False): 53 enc.int32(len(self.bitmaps)) 54 if compact: 55 enc.uint8(self.nbit) 56 combined = 0 57 for i, b in enumerate(self.bitmaps): 58 combined |= b << (i * self.nbit) 59 enc.bitmap(combined, len(self.bitmaps) * self.nbit) 60 else: 61 enc.int32(self.nbit) 62 for b in self.bitmaps: 63 enc.bitmap(b, self.nbit) 64 65 class PCData: 66 def __init__(self): 67 self.pcdata = [] 68 69 def encode(self, enc): 70 last = (0, 0) 71 for e in self.pcdata: 72 enc.uvarint(e[0] - last[0]) 73 enc.svarint(e[1] - last[1]) 74 last = e 75 enc.uint8(0) 76 77 def huffSize(self, pcHuff, valHuff): 78 bits = 0 79 lastPC = 0 80 for pc, val in self.pcdata: 81 bits += pcHuff[pc - lastPC][1] + valHuff[val][1] 82 lastPC = pc 83 return (bits + 7) // 8 84 85 def grSize(self, pcHuff, n): 86 bits = 0 87 lastPC = 0 88 for pc, val in self.pcdata: 89 bits += pcHuff[pc - lastPC][1] 90 lastPC = pc 91 bits += grSize(val + 1, n) 92 return (bits + 7) // 8 93 94 def grSize(val, n): 95 """The number of bits in the Golomb-Rice coding of val in base 2^n.""" 96 return 1 + (val >> n) + n 97 98 class Decoder: 99 def __init__(self, b): 100 self.b = memoryview(b) 101 102 def int32(self): 103 b = self.b 104 self.b = b[4:] 105 return b[0] + (b[1] << 8) + (b[2] << 16) + (b[3] << 24) 106 107 def bitmap(self, nbits): 108 bitmap = 0 109 nbytes = (nbits + 7) // 8 110 for i in range(nbytes): 111 bitmap = bitmap | (self.b[i] << (i*8)) 112 self.b = self.b[nbytes:] 113 return bitmap 114 115 class Encoder: 116 def __init__(self): 117 self.b = bytearray() 118 119 def uint8(self, i): 120 self.b.append(i) 121 122 def int32(self, i): 123 self.b.extend([i&0xFF, (i>>8)&0xFF, (i>>16)&0xFF, (i>>24)&0xFF]) 124 125 def bitmap(self, bits, nbits): 126 for i in range((nbits + 7) // 8): 127 self.b.append((bits >> (i*8)) & 0xFF) 128 129 def uvarint(self, v): 130 if v < 0: 131 raise ValueError("negative unsigned varint", v) 132 while v > 0x7f: 133 self.b.append((v & 0x7f) | 0x80) 134 v >>= 7 135 self.b.append(v) 136 137 def svarint(self, v): 138 ux = v << 1 139 if v < 0: 140 ux = ~ux 141 self.uvarint(ux) 142 143 def parse(stream): 144 import parseasm 145 objs = parseasm.parse(stream) 146 fns = [] 147 for obj in objs.values(): 148 if not isinstance(obj, parseasm.Func): 149 continue 150 fns.append(obj) 151 obj.regMaps = [] # [(pc, register bitmap)] 152 regMap = None 153 for inst in obj.insts: 154 if inst.asm.startswith("FUNCDATA\t$"+FUNCDATA+", "): 155 regMapSym = inst.asm.split(" ")[1][:-4] 156 regMap = Stackmap(Decoder(objs[regMapSym].data)) 157 elif inst.asm.startswith("PCDATA\t$"+PCDATA+", "): 158 idx = int(inst.asm.split(" ")[1][1:]) 159 obj.regMaps.append((inst.pc, regMap.bitmaps[idx])) 160 return fns 161 162 def genStackMaps(fns, padToByte=True, dedup=True, sortBitmaps=False): 163 regMapSet = {} 164 165 for fn in fns: 166 # Create pcdata and register map for fn. 167 fn.pcdataRegs = PCData() 168 fn.funcdataRegMap = Stackmap() 169 for (pc, bitmap) in fn.regMaps: 170 fn.pcdataRegs.pcdata.append((pc, fn.funcdataRegMap.add(bitmap))) 171 172 if sortBitmaps: 173 remap = regMap.sort() 174 pcdata.pcdata = [(pc, remap[idx]) for pc, idx in pcdata.pcdata] 175 176 # Encode and dedup register maps. 177 if dedup: 178 e = Encoder() 179 fn.funcdataRegMap.encode(e, not padToByte) 180 regMap = bytes(e.b) 181 if regMap in regMapSet: 182 fn.funcdataRegMap = regMapSet[regMap] 183 else: 184 regMapSet[regMap] = fn.funcdataRegMap 185 else: 186 regMapSet[fn] = fn.funcdataRegMap 187 188 return regMapSet.values() 189 190 def likeStackMap(fns, padToByte=True, dedup=True, sortBitmaps=None, huffmanPcdata=False, grPcdata=False): 191 regMapSet = set() 192 regMaps = bytearray() 193 pcdatas = [] #Encoder() 194 extra = 0 195 for fn in fns: 196 # Create pcdata and register map for fn. 197 pcdata = PCData() 198 regMap = Stackmap() 199 if sortBitmaps == "freq": 200 # Pre-populate regMap in frequency order. 201 regMapFreq = collections.Counter() 202 for pc, bitmap in fn.regMaps: 203 regMapFreq[bitmap] += 1 204 for bitmap, freq in sorted(regMapFreq.items(), key=lambda item: item[1], reverse=True): 205 regMap.add(bitmap) 206 for pc, bitmap in fn.regMaps: 207 pcdata.pcdata.append((pc, regMap.add(bitmap))) 208 209 if sortBitmaps == "value": 210 remap = regMap.sort() 211 pcdata.pcdata = [(pc, remap[idx]) for pc, idx in pcdata.pcdata] 212 213 pcdatas.append(pcdata) 214 215 # Encode register map and dedup. 216 e = Encoder() 217 regMap.encode(e, not padToByte) 218 regMap = bytes(e.b) 219 if not dedup or regMap not in regMapSet: 220 regMapSet.add(regMap) 221 regMaps.extend(regMap) 222 223 extra += 8 + 4 # funcdata pointer, pcdata table offset 224 225 # Encode pcdata. 226 pcdataEnc = Encoder() 227 if huffmanPcdata or grPcdata: 228 pcDeltas, _ = countDeltas(fns) 229 pcdataHist = collections.Counter() 230 for pcdata in pcdatas: 231 for _, idx in pcdata.pcdata: 232 pcdataHist[idx] += 1 233 pcHuff = huffman(pcDeltas) 234 pcdataHuff = huffman(pcdataHist) 235 size = 0 236 for pcdata in pcdatas: 237 if huffmanPcdata: 238 size += pcdata.huffSize(pcHuff, pcdataHuff) 239 elif grPcdata: 240 size += pcdata.grSize(pcHuff, grPcdata) 241 pcdataEnc.b = "\0" * size # Whatever 242 else: 243 for pcdata in pcdatas: 244 pcdata.encode(pcdataEnc) 245 246 return {"gclocals": len(regMaps), "pcdata": len(pcdataEnc.b), "extra": extra} 247 248 def filterLiveToDead(fns): 249 # Only emit pcdata if something becomes newly-live (this is a 250 # lower bound on what the "don't care" optimization could 251 # achieve). 252 for fn in fns: 253 newRegMaps = [] 254 prevBitmap = 0 255 for (pc, bitmap) in fn.regMaps: 256 if bitmap is None: 257 newRegIdx.append((pc, None)) 258 prevBitmap = 0 259 continue 260 if bitmap & ~prevBitmap != 0: 261 # New bits set. 262 newRegMaps.append((pc, bitmap)) 263 prevBitmap = bitmap 264 fn.regMaps = newRegMaps 265 266 def total(dct): 267 dct["total"] = 0 268 dct["total"] = sum(dct.values()) 269 return dct 270 271 def iterDeltas(regMaps): 272 prevPC = prevBitmap = 0 273 for (pc, bitmap) in regMaps: 274 pcDelta = pc - prevPC 275 prevPC = pc 276 277 if bitmap is None: 278 bitmapDelta = None 279 prevBitmap = 0 280 else: 281 bitmapDelta = bitmap ^ prevBitmap 282 prevBitmap = bitmap 283 284 yield pcDelta, bitmapDelta 285 286 def countMaps(fns): 287 maps = collections.Counter() 288 for fn in fns: 289 for _, bitmap in fn.regMaps: 290 maps[bitmap] += 1 291 return maps 292 293 def countDeltas(fns): 294 pcDeltas, deltas = collections.Counter(), collections.Counter() 295 # This actually spreads out the head of the distribution quite a bit 296 # because things are more likely to die in clumps and at the same time 297 # as something else becomes live. 298 #filterLiveToDead(fns) 299 for fn in fns: 300 for pcDelta, bitmapDelta in iterDeltas(fn.regMaps): 301 pcDeltas[pcDelta] += 1 302 deltas[bitmapDelta] += 1 303 return pcDeltas, deltas 304 305 def huffman(counts, streamAlign=1): 306 code = [(count, val) for val, count in counts.items()] 307 radix = 2**streamAlign 308 while len(code) > 1: 309 code.sort(key=lambda x: x[0], reverse=True) 310 if len(code) < radix: 311 children, code = code, [] 312 else: 313 children, code = code[-radix:], code[:-radix] 314 code.append((sum(child[0] for child in children), 315 [child[1] for child in children])) 316 tree = {} 317 def mktree(node, codeword, bits): 318 if isinstance(node, list): 319 for i, child in enumerate(node): 320 mktree(child, (codeword << streamAlign) + i, bits + streamAlign) 321 else: 322 tree[node] = (codeword, bits) 323 mktree(code[0][1], 0, 0) 324 return tree 325 326 def huffmanCoded(fns, streamAlign=1): 327 pcDeltas, maskDeltas = countDeltas(fns) 328 hPCs = huffman(pcDeltas, streamAlign) 329 hBitmaps = huffman(maskDeltas, streamAlign) 330 331 pcdataBits = 0 332 extra = 0 333 for fn in fns: 334 for pcDelta, bitmapDelta in iterDeltas(fn.regMaps): 335 pcdataBits += hPCs[pcDelta][1] + hBitmaps[bitmapDelta][1] 336 pcdataBits = (pcdataBits + 7) &~ 7 # Byte align 337 extra += 4 # PCDATA 338 return {"pcdata": (pcdataBits + 7) // 8, "extra": extra} 339 fns = parse(sys.stdin) 340 341 if True: 342 print(total(likeStackMap(fns))) 343 # Linker dedup of gclocals reduces gclocals by >2X 344 #print(total(likeStackMap(fns, dedup=False))) 345 #print(total(likeStackMap(fns, sortBitmaps="value"))) 346 # 'total': 529225, 'pcdata': 292703, 'gclocals': 77558, 'extra': 158964 347 print(total(likeStackMap(fns, huffmanPcdata=True))) 348 print(total(likeStackMap(fns, huffmanPcdata=True, sortBitmaps="freq"))) 349 for n in range(0, 8): 350 print(n, total(likeStackMap(fns, grPcdata=n, sortBitmaps="freq"))) 351 #print(total(likeStackMap(fns, compactBitmap=True))) 352 # 'total': 407999, 'pcdata': 302023, 'extra': 105976 353 print(total(huffmanCoded(fns))) 354 print(total(huffmanCoded(fns, streamAlign=8))) 355 # Only emitting on newly live reduces pcdata by 42%, gclocals by 10% 356 filterLiveToDead(fns) 357 print(total(likeStackMap(fns))) 358 359 if False: 360 # What do the bitmaps look like? 361 counts = countMaps(fns) 362 for bitmap, count in counts.items(): 363 print(count, bin(bitmap)) 364 365 if False: 366 # What do the bitmap changes look like? 367 _, deltas = countDeltas(fns) 368 for delta, count in deltas.items(): 369 print(count, bin(delta)) 370 371 if False: 372 # PC delta histogram 373 pcDeltaHist, _ = countDeltas(fns) 374 for delta, count in pcDeltaHist.items(): 375 print(count, delta)