github.com/bir3/gocompiler@v0.9.2202/src/cmd/compile/internal/ssa/cse.go (about) 1 // Copyright 2015 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssa 6 7 import ( 8 "github.com/bir3/gocompiler/src/cmd/compile/internal/types" 9 "github.com/bir3/gocompiler/src/cmd/internal/src" 10 "fmt" 11 "sort" 12 ) 13 14 // cse does common-subexpression elimination on the Function. 15 // Values are just relinked, nothing is deleted. A subsequent deadcode 16 // pass is required to actually remove duplicate expressions. 17 func cse(f *Func) { 18 // Two values are equivalent if they satisfy the following definition: 19 // equivalent(v, w): 20 // v.op == w.op 21 // v.type == w.type 22 // v.aux == w.aux 23 // v.auxint == w.auxint 24 // len(v.args) == len(w.args) 25 // v.block == w.block if v.op == OpPhi 26 // equivalent(v.args[i], w.args[i]) for i in 0..len(v.args)-1 27 28 // The algorithm searches for a partition of f's values into 29 // equivalence classes using the above definition. 30 // It starts with a coarse partition and iteratively refines it 31 // until it reaches a fixed point. 32 33 // Make initial coarse partitions by using a subset of the conditions above. 34 a := f.Cache.allocValueSlice(f.NumValues()) 35 defer func() { f.Cache.freeValueSlice(a) }() // inside closure to use final value of a 36 a = a[:0] 37 if f.auxmap == nil { 38 f.auxmap = auxmap{} 39 } 40 for _, b := range f.Blocks { 41 for _, v := range b.Values { 42 if v.Type.IsMemory() { 43 continue // memory values can never cse 44 } 45 if f.auxmap[v.Aux] == 0 { 46 f.auxmap[v.Aux] = int32(len(f.auxmap)) + 1 47 } 48 a = append(a, v) 49 } 50 } 51 partition := partitionValues(a, f.auxmap) 52 53 // map from value id back to eqclass id 54 valueEqClass := f.Cache.allocIDSlice(f.NumValues()) 55 defer f.Cache.freeIDSlice(valueEqClass) 56 for _, b := range f.Blocks { 57 for _, v := range b.Values { 58 // Use negative equivalence class #s for unique values. 59 valueEqClass[v.ID] = -v.ID 60 } 61 } 62 var pNum ID = 1 63 for _, e := range partition { 64 if f.pass.debug > 1 && len(e) > 500 { 65 fmt.Printf("CSE.large partition (%d): ", len(e)) 66 for j := 0; j < 3; j++ { 67 fmt.Printf("%s ", e[j].LongString()) 68 } 69 fmt.Println() 70 } 71 72 for _, v := range e { 73 valueEqClass[v.ID] = pNum 74 } 75 if f.pass.debug > 2 && len(e) > 1 { 76 fmt.Printf("CSE.partition #%d:", pNum) 77 for _, v := range e { 78 fmt.Printf(" %s", v.String()) 79 } 80 fmt.Printf("\n") 81 } 82 pNum++ 83 } 84 85 // Split equivalence classes at points where they have 86 // non-equivalent arguments. Repeat until we can't find any 87 // more splits. 88 var splitPoints []int 89 byArgClass := new(partitionByArgClass) // reusable partitionByArgClass to reduce allocations 90 for { 91 changed := false 92 93 // partition can grow in the loop. By not using a range loop here, 94 // we process new additions as they arrive, avoiding O(n^2) behavior. 95 for i := 0; i < len(partition); i++ { 96 e := partition[i] 97 98 if opcodeTable[e[0].Op].commutative { 99 // Order the first two args before comparison. 100 for _, v := range e { 101 if valueEqClass[v.Args[0].ID] > valueEqClass[v.Args[1].ID] { 102 v.Args[0], v.Args[1] = v.Args[1], v.Args[0] 103 } 104 } 105 } 106 107 // Sort by eq class of arguments. 108 byArgClass.a = e 109 byArgClass.eqClass = valueEqClass 110 sort.Sort(byArgClass) 111 112 // Find split points. 113 splitPoints = append(splitPoints[:0], 0) 114 for j := 1; j < len(e); j++ { 115 v, w := e[j-1], e[j] 116 // Note: commutative args already correctly ordered by byArgClass. 117 eqArgs := true 118 for k, a := range v.Args { 119 b := w.Args[k] 120 if valueEqClass[a.ID] != valueEqClass[b.ID] { 121 eqArgs = false 122 break 123 } 124 } 125 if !eqArgs { 126 splitPoints = append(splitPoints, j) 127 } 128 } 129 if len(splitPoints) == 1 { 130 continue // no splits, leave equivalence class alone. 131 } 132 133 // Move another equivalence class down in place of e. 134 partition[i] = partition[len(partition)-1] 135 partition = partition[:len(partition)-1] 136 i-- 137 138 // Add new equivalence classes for the parts of e we found. 139 splitPoints = append(splitPoints, len(e)) 140 for j := 0; j < len(splitPoints)-1; j++ { 141 f := e[splitPoints[j]:splitPoints[j+1]] 142 if len(f) == 1 { 143 // Don't add singletons. 144 valueEqClass[f[0].ID] = -f[0].ID 145 continue 146 } 147 for _, v := range f { 148 valueEqClass[v.ID] = pNum 149 } 150 pNum++ 151 partition = append(partition, f) 152 } 153 changed = true 154 } 155 156 if !changed { 157 break 158 } 159 } 160 161 sdom := f.Sdom() 162 163 // Compute substitutions we would like to do. We substitute v for w 164 // if v and w are in the same equivalence class and v dominates w. 165 rewrite := f.Cache.allocValueSlice(f.NumValues()) 166 defer f.Cache.freeValueSlice(rewrite) 167 byDom := new(partitionByDom) // reusable partitionByDom to reduce allocs 168 for _, e := range partition { 169 byDom.a = e 170 byDom.sdom = sdom 171 sort.Sort(byDom) 172 for i := 0; i < len(e)-1; i++ { 173 // e is sorted by domorder, so a maximal dominant element is first in the slice 174 v := e[i] 175 if v == nil { 176 continue 177 } 178 179 e[i] = nil 180 // Replace all elements of e which v dominates 181 for j := i + 1; j < len(e); j++ { 182 w := e[j] 183 if w == nil { 184 continue 185 } 186 if sdom.IsAncestorEq(v.Block, w.Block) { 187 rewrite[w.ID] = v 188 e[j] = nil 189 } else { 190 // e is sorted by domorder, so v.Block doesn't dominate any subsequent blocks in e 191 break 192 } 193 } 194 } 195 } 196 197 rewrites := int64(0) 198 199 // Apply substitutions 200 for _, b := range f.Blocks { 201 for _, v := range b.Values { 202 for i, w := range v.Args { 203 if x := rewrite[w.ID]; x != nil { 204 if w.Pos.IsStmt() == src.PosIsStmt { 205 // about to lose a statement marker, w 206 // w is an input to v; if they're in the same block 207 // and the same line, v is a good-enough new statement boundary. 208 if w.Block == v.Block && w.Pos.Line() == v.Pos.Line() { 209 v.Pos = v.Pos.WithIsStmt() 210 w.Pos = w.Pos.WithNotStmt() 211 } // TODO and if this fails? 212 } 213 v.SetArg(i, x) 214 rewrites++ 215 } 216 } 217 } 218 for i, v := range b.ControlValues() { 219 if x := rewrite[v.ID]; x != nil { 220 if v.Op == OpNilCheck { 221 // nilcheck pass will remove the nil checks and log 222 // them appropriately, so don't mess with them here. 223 continue 224 } 225 b.ReplaceControl(i, x) 226 } 227 } 228 } 229 230 if f.pass.stats > 0 { 231 f.LogStat("CSE REWRITES", rewrites) 232 } 233 } 234 235 // An eqclass approximates an equivalence class. During the 236 // algorithm it may represent the union of several of the 237 // final equivalence classes. 238 type eqclass []*Value 239 240 // partitionValues partitions the values into equivalence classes 241 // based on having all the following features match: 242 // - opcode 243 // - type 244 // - auxint 245 // - aux 246 // - nargs 247 // - block # if a phi op 248 // - first two arg's opcodes and auxint 249 // - NOT first two arg's aux; that can break CSE. 250 // 251 // partitionValues returns a list of equivalence classes, each 252 // being a sorted by ID list of *Values. The eqclass slices are 253 // backed by the same storage as the input slice. 254 // Equivalence classes of size 1 are ignored. 255 func partitionValues(a []*Value, auxIDs auxmap) []eqclass { 256 sort.Sort(sortvalues{a, auxIDs}) 257 258 var partition []eqclass 259 for len(a) > 0 { 260 v := a[0] 261 j := 1 262 for ; j < len(a); j++ { 263 w := a[j] 264 if cmpVal(v, w, auxIDs) != types.CMPeq { 265 break 266 } 267 } 268 if j > 1 { 269 partition = append(partition, a[:j]) 270 } 271 a = a[j:] 272 } 273 274 return partition 275 } 276 func lt2Cmp(isLt bool) types.Cmp { 277 if isLt { 278 return types.CMPlt 279 } 280 return types.CMPgt 281 } 282 283 type auxmap map[Aux]int32 284 285 func cmpVal(v, w *Value, auxIDs auxmap) types.Cmp { 286 // Try to order these comparison by cost (cheaper first) 287 if v.Op != w.Op { 288 return lt2Cmp(v.Op < w.Op) 289 } 290 if v.AuxInt != w.AuxInt { 291 return lt2Cmp(v.AuxInt < w.AuxInt) 292 } 293 if len(v.Args) != len(w.Args) { 294 return lt2Cmp(len(v.Args) < len(w.Args)) 295 } 296 if v.Op == OpPhi && v.Block != w.Block { 297 return lt2Cmp(v.Block.ID < w.Block.ID) 298 } 299 if v.Type.IsMemory() { 300 // We will never be able to CSE two values 301 // that generate memory. 302 return lt2Cmp(v.ID < w.ID) 303 } 304 // OpSelect is a pseudo-op. We need to be more aggressive 305 // regarding CSE to keep multiple OpSelect's of the same 306 // argument from existing. 307 if v.Op != OpSelect0 && v.Op != OpSelect1 && v.Op != OpSelectN { 308 if tc := v.Type.Compare(w.Type); tc != types.CMPeq { 309 return tc 310 } 311 } 312 313 if v.Aux != w.Aux { 314 if v.Aux == nil { 315 return types.CMPlt 316 } 317 if w.Aux == nil { 318 return types.CMPgt 319 } 320 return lt2Cmp(auxIDs[v.Aux] < auxIDs[w.Aux]) 321 } 322 323 return types.CMPeq 324 } 325 326 // Sort values to make the initial partition. 327 type sortvalues struct { 328 a []*Value // array of values 329 auxIDs auxmap // aux -> aux ID map 330 } 331 332 func (sv sortvalues) Len() int { return len(sv.a) } 333 func (sv sortvalues) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] } 334 func (sv sortvalues) Less(i, j int) bool { 335 v := sv.a[i] 336 w := sv.a[j] 337 if cmp := cmpVal(v, w, sv.auxIDs); cmp != types.CMPeq { 338 return cmp == types.CMPlt 339 } 340 341 // Sort by value ID last to keep the sort result deterministic. 342 return v.ID < w.ID 343 } 344 345 type partitionByDom struct { 346 a []*Value // array of values 347 sdom SparseTree 348 } 349 350 func (sv partitionByDom) Len() int { return len(sv.a) } 351 func (sv partitionByDom) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] } 352 func (sv partitionByDom) Less(i, j int) bool { 353 v := sv.a[i] 354 w := sv.a[j] 355 return sv.sdom.domorder(v.Block) < sv.sdom.domorder(w.Block) 356 } 357 358 type partitionByArgClass struct { 359 a []*Value // array of values 360 eqClass []ID // equivalence class IDs of values 361 } 362 363 func (sv partitionByArgClass) Len() int { return len(sv.a) } 364 func (sv partitionByArgClass) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] } 365 func (sv partitionByArgClass) Less(i, j int) bool { 366 v := sv.a[i] 367 w := sv.a[j] 368 for i, a := range v.Args { 369 b := w.Args[i] 370 if sv.eqClass[a.ID] < sv.eqClass[b.ID] { 371 return true 372 } 373 if sv.eqClass[a.ID] > sv.eqClass[b.ID] { 374 return false 375 } 376 } 377 return false 378 }