github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp (about) 1 //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements loop fusion transformation utility functions. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Transforms/LoopFusionUtils.h" 23 24 #include "mlir/Analysis/AffineAnalysis.h" 25 #include "mlir/Analysis/AffineStructures.h" 26 #include "mlir/Analysis/LoopAnalysis.h" 27 #include "mlir/Analysis/Utils.h" 28 #include "mlir/Dialect/AffineOps/AffineOps.h" 29 #include "mlir/Dialect/StandardOps/Ops.h" 30 #include "mlir/IR/AffineExpr.h" 31 #include "mlir/IR/AffineMap.h" 32 #include "mlir/IR/BlockAndValueMapping.h" 33 #include "mlir/IR/Builders.h" 34 #include "mlir/IR/Function.h" 35 #include "mlir/IR/Operation.h" 36 #include "llvm/ADT/DenseMap.h" 37 #include "llvm/ADT/SmallVector.h" 38 #include "llvm/Support/Debug.h" 39 #include "llvm/Support/raw_ostream.h" 40 41 #define DEBUG_TYPE "loop-fusion-utils" 42 43 using namespace mlir; 44 45 // Gathers all load and store memref accesses in 'opA' into 'values', where 46 // 'values[memref] == true' for each store operation. 47 static void getLoadAndStoreMemRefAccesses(Operation *opA, 48 DenseMap<Value *, bool> &values) { 49 opA->walk([&](Operation *op) { 50 if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { 51 if (values.count(loadOp.getMemRef()) == 0) 52 values[loadOp.getMemRef()] = false; 53 } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { 54 values[storeOp.getMemRef()] = true; 55 } 56 }); 57 } 58 59 // Returns true if 'op' is a load or store operation which access an memref 60 // accessed 'values' and at least one of the access is a store operation. 61 // Returns false otherwise. 62 static bool isDependentLoadOrStoreOp(Operation *op, 63 DenseMap<Value *, bool> &values) { 64 if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { 65 return values.count(loadOp.getMemRef()) > 0 && 66 values[loadOp.getMemRef()] == true; 67 } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { 68 return values.count(storeOp.getMemRef()) > 0; 69 } 70 return false; 71 } 72 73 // Returns the first operation in range ('opA', 'opB') which has a data 74 // dependence on 'opA'. Returns 'nullptr' of no dependence exists. 75 static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) { 76 // Record memref values from all loads/store in loop nest rooted at 'opA'. 77 // Map from memref value to bool which is true if store, false otherwise. 78 DenseMap<Value *, bool> values; 79 getLoadAndStoreMemRefAccesses(opA, values); 80 81 // For each 'opX' in block in range ('opA', 'opB'), check if there is a data 82 // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref 83 // and at least one of the accesses is a store). 84 Operation *firstDepOp = nullptr; 85 for (Block::iterator it = std::next(Block::iterator(opA)); 86 it != Block::iterator(opB); ++it) { 87 Operation *opX = &(*it); 88 opX->walk([&](Operation *op) { 89 if (!firstDepOp && isDependentLoadOrStoreOp(op, values)) 90 firstDepOp = opX; 91 }); 92 if (firstDepOp) 93 break; 94 } 95 return firstDepOp; 96 } 97 98 // Returns the last operation 'opX' in range ('opA', 'opB'), for which there 99 // exists a data dependence from 'opX' to 'opB'. 100 // Returns 'nullptr' of no dependence exists. 101 static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) { 102 // Record memref values from all loads/store in loop nest rooted at 'opB'. 103 // Map from memref value to bool which is true if store, false otherwise. 104 DenseMap<Value *, bool> values; 105 getLoadAndStoreMemRefAccesses(opB, values); 106 107 // For each 'opX' in block in range ('opA', 'opB') in reverse order, 108 // check if there is a data dependence from 'opX' to 'opB': 109 // *) 'opX' and 'opB' access the same memref and at least one of the accesses 110 // is a store. 111 // *) 'opX' produces an SSA Value which is used by 'opB'. 112 Operation *lastDepOp = nullptr; 113 for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB)); 114 it != Block::reverse_iterator(opA); ++it) { 115 Operation *opX = &(*it); 116 opX->walk([&](Operation *op) { 117 if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) { 118 if (isDependentLoadOrStoreOp(op, values)) { 119 lastDepOp = opX; 120 return WalkResult::interrupt(); 121 } 122 return WalkResult::advance(); 123 } 124 for (auto *value : op->getResults()) { 125 for (auto *user : value->getUsers()) { 126 SmallVector<AffineForOp, 4> loops; 127 // Check if any loop in loop nest surrounding 'user' is 'opB'. 128 getLoopIVs(*user, &loops); 129 if (llvm::is_contained(loops, cast<AffineForOp>(opB))) { 130 lastDepOp = opX; 131 return WalkResult::interrupt(); 132 } 133 } 134 } 135 return WalkResult::advance(); 136 }); 137 if (lastDepOp) 138 break; 139 } 140 return lastDepOp; 141 } 142 143 // Computes and returns an insertion point operation, before which the 144 // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving 145 // dependences. Returns nullptr if no such insertion point is found. 146 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp, 147 AffineForOp dstForOp) { 148 bool isSrcForOpBeforeDstForOp = 149 srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); 150 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 151 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 152 153 auto *firstDepOpA = 154 getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); 155 auto *lastDepOpB = 156 getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation()); 157 // Block: 158 // ... 159 // |-- opA 160 // | ... 161 // | lastDepOpB --| 162 // | ... | 163 // |-> firstDepOpA | 164 // ... | 165 // opB <--------- 166 // 167 // Valid insertion point range: (lastDepOpB, firstDepOpA) 168 // 169 if (firstDepOpA != nullptr) { 170 if (lastDepOpB != nullptr) { 171 if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB) 172 // No valid insertion point exists which preserves dependences. 173 return nullptr; 174 } 175 // Return insertion point in valid range closest to 'opB'. 176 // TODO(andydavis) Consider other insertion points in valid range. 177 return firstDepOpA; 178 } 179 // No dependences from 'opA' to operation in range ('opA', 'opB'), return 180 // 'opB' insertion point. 181 return forOpB.getOperation(); 182 } 183 184 // Gathers all load and store ops in loop nest rooted at 'forOp' into 185 // 'loadAndStoreOps'. 186 static bool 187 gatherLoadsAndStores(AffineForOp forOp, 188 SmallVectorImpl<Operation *> &loadAndStoreOps) { 189 bool hasIfOp = false; 190 forOp.getOperation()->walk([&](Operation *op) { 191 if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) 192 loadAndStoreOps.push_back(op); 193 else if (isa<AffineIfOp>(op)) 194 hasIfOp = true; 195 }); 196 return !hasIfOp; 197 } 198 199 // TODO(andydavis) Prevent fusion of loop nests with side-effecting operations. 200 FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, 201 unsigned dstLoopDepth, 202 ComputationSliceState *srcSlice) { 203 // Return 'failure' if 'dstLoopDepth == 0'. 204 if (dstLoopDepth == 0) { 205 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n."); 206 return FusionResult::FailPrecondition; 207 } 208 // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. 209 auto *block = srcForOp.getOperation()->getBlock(); 210 if (block != dstForOp.getOperation()->getBlock()) { 211 LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n."); 212 return FusionResult::FailPrecondition; 213 } 214 215 // Return 'failure' if no valid insertion point for fused loop nest in 'block' 216 // exists which would preserve dependences. 217 if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { 218 LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n."); 219 return FusionResult::FailBlockDependence; 220 } 221 222 // Check if 'srcForOp' precedeces 'dstForOp' in 'block'. 223 bool isSrcForOpBeforeDstForOp = 224 srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation()); 225 // 'forOpA' executes before 'forOpB' in 'block'. 226 auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp; 227 auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp; 228 229 // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. 230 SmallVector<Operation *, 4> opsA; 231 if (!gatherLoadsAndStores(forOpA, opsA)) { 232 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); 233 return FusionResult::FailPrecondition; 234 } 235 236 // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. 237 SmallVector<Operation *, 4> opsB; 238 if (!gatherLoadsAndStores(forOpB, opsB)) { 239 LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); 240 return FusionResult::FailPrecondition; 241 } 242 243 // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. 244 unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( 245 *srcForOp.getOperation(), *dstForOp.getOperation()); 246 247 // Compute union of computation slices computed between all pairs of ops 248 // from 'forOpA' and 'forOpB'. 249 if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops, 250 isSrcForOpBeforeDstForOp, srcSlice))) { 251 LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); 252 return FusionResult::FailPrecondition; 253 } 254 255 return FusionResult::Success; 256 } 257 258 /// Collect loop nest statistics (eg. loop trip count and operation count) 259 /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success, 260 /// returns false otherwise. 261 bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) { 262 auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { 263 auto *childForOp = forOp.getOperation(); 264 auto *parentForOp = forOp.getOperation()->getParentOp(); 265 if (!llvm::isa<FuncOp>(parentForOp)) { 266 if (!isa<AffineForOp>(parentForOp)) { 267 LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp"); 268 return WalkResult::interrupt(); 269 } 270 // Add mapping to 'forOp' from its parent AffineForOp. 271 stats->loopMap[parentForOp].push_back(forOp); 272 } 273 274 // Record the number of op operations in the body of 'forOp'. 275 unsigned count = 0; 276 stats->opCountMap[childForOp] = 0; 277 for (auto &op : *forOp.getBody()) { 278 if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op)) 279 ++count; 280 } 281 stats->opCountMap[childForOp] = count; 282 283 // Record trip count for 'forOp'. Set flag if trip count is not 284 // constant. 285 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 286 if (!maybeConstTripCount.hasValue()) { 287 // Currently only constant trip count loop nests are supported. 288 LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported"); 289 return WalkResult::interrupt(); 290 } 291 292 stats->tripCountMap[childForOp] = maybeConstTripCount.getValue(); 293 return WalkResult::advance(); 294 }); 295 return !walkResult.wasInterrupted(); 296 } 297 298 // Computes the total cost of the loop nest rooted at 'forOp'. 299 // Currently, the total cost is computed by counting the total operation 300 // instance count (i.e. total number of operations in the loop bodyloop 301 // operation count * loop trip count) for the entire loop nest. 302 // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops 303 // specified in the map when computing the total op instance count. 304 // NOTEs: 1) This is used to compute the cost of computation slices, which are 305 // sliced along the iteration dimension, and thus reduce the trip count. 306 // If 'computeCostMap' is non-null, the total op count for forOps specified 307 // in the map is increased (not overridden) by adding the op count from the 308 // map to the existing op count for the for loop. This is done before 309 // multiplying by the loop's trip count, and is used to model the cost of 310 // inserting a sliced loop nest of known cost into the loop's body. 311 // 2) This is also used to compute the cost of fusing a slice of some loop nest 312 // within another loop. 313 static int64_t getComputeCostHelper( 314 Operation *forOp, LoopNestStats &stats, 315 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap, 316 DenseMap<Operation *, int64_t> *computeCostMap) { 317 // 'opCount' is the total number operations in one iteration of 'forOp' body, 318 // minus terminator op which is a no-op. 319 int64_t opCount = stats.opCountMap[forOp] - 1; 320 if (stats.loopMap.count(forOp) > 0) { 321 for (auto childForOp : stats.loopMap[forOp]) { 322 opCount += getComputeCostHelper(childForOp.getOperation(), stats, 323 tripCountOverrideMap, computeCostMap); 324 } 325 } 326 // Add in additional op instances from slice (if specified in map). 327 if (computeCostMap != nullptr) { 328 auto it = computeCostMap->find(forOp); 329 if (it != computeCostMap->end()) { 330 opCount += it->second; 331 } 332 } 333 // Override trip count (if specified in map). 334 int64_t tripCount = stats.tripCountMap[forOp]; 335 if (tripCountOverrideMap != nullptr) { 336 auto it = tripCountOverrideMap->find(forOp); 337 if (it != tripCountOverrideMap->end()) { 338 tripCount = it->second; 339 } 340 } 341 // Returns the total number of dynamic instances of operations in loop body. 342 return tripCount * opCount; 343 } 344 345 // TODO(andydavis,b/126426796): extend this to handle multiple result maps. 346 static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) { 347 assert(lbMap.getNumResults() == 1 && "expected single result bound map"); 348 assert(ubMap.getNumResults() == 1 && "expected single result bound map"); 349 assert(lbMap.getNumDims() == ubMap.getNumDims()); 350 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); 351 AffineExpr lbExpr(lbMap.getResult(0)); 352 AffineExpr ubExpr(ubMap.getResult(0)); 353 auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), 354 lbMap.getNumSymbols()); 355 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>(); 356 if (!cExpr) 357 return None; 358 return cExpr.getValue(); 359 } 360 361 // Return the number of iterations in the given slice. 362 static uint64_t getSliceIterationCount( 363 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) { 364 uint64_t iterCount = 1; 365 for (const auto &count : sliceTripCountMap) { 366 iterCount *= count.second; 367 } 368 return iterCount; 369 } 370 371 // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop 372 // nest surrounding represented by slice loop bounds in 'slice'. 373 // Returns true on success, false otherwise (if a non-constant trip count 374 // was encountered). 375 // TODO(andydavis) Make this work with non-unit step loops. 376 static bool buildSliceTripCountMap( 377 ComputationSliceState *slice, 378 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) { 379 unsigned numSrcLoopIVs = slice->ivs.size(); 380 // Populate map from AffineForOp -> trip count 381 for (unsigned i = 0; i < numSrcLoopIVs; ++i) { 382 AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]); 383 auto *op = forOp.getOperation(); 384 AffineMap lbMap = slice->lbs[i]; 385 AffineMap ubMap = slice->ubs[i]; 386 if (lbMap == AffineMap() || ubMap == AffineMap()) { 387 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. 388 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { 389 (*tripCountMap)[op] = 390 forOp.getConstantUpperBound() - forOp.getConstantLowerBound(); 391 continue; 392 } 393 Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); 394 if (maybeConstTripCount.hasValue()) { 395 (*tripCountMap)[op] = maybeConstTripCount.getValue(); 396 continue; 397 } 398 return false; 399 } 400 Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap); 401 // Slice bounds are created with a constant ub - lb difference. 402 if (!tripCount.hasValue()) 403 return false; 404 (*tripCountMap)[op] = tripCount.getValue(); 405 } 406 return true; 407 } 408 409 /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'. 410 /// Currently, the total cost is computed by counting the total operation 411 /// instance count (i.e. total number of operations in the loop body * loop 412 /// trip count) for the entire loop nest. 413 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { 414 return getComputeCostHelper(forOp.getOperation(), stats, 415 /*tripCountOverrideMap=*/nullptr, 416 /*computeCostMap=*/nullptr); 417 } 418 419 /// Computes and returns in 'computeCost', the total compute cost of fusing the 420 /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently, 421 /// the total cost is computed by counting the total operation instance count 422 /// (i.e. total number of operations in the loop body * loop trip count) for 423 /// the entire loop nest. 424 bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, 425 AffineForOp dstForOp, LoopNestStats &dstStats, 426 ComputationSliceState *slice, 427 int64_t *computeCost) { 428 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap; 429 DenseMap<Operation *, int64_t> computeCostMap; 430 431 // Build trip count map for computation slice. 432 if (!buildSliceTripCountMap(slice, &sliceTripCountMap)) 433 return false; 434 // Checks whether a store to load forwarding will happen. 435 int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); 436 assert(sliceIterationCount > 0); 437 bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); 438 auto *insertPointParent = slice->insertPoint->getParentOp(); 439 440 // The store and loads to this memref will disappear. 441 // TODO(andydavis) Add load coalescing to memref data flow opt pass. 442 if (storeLoadFwdGuaranteed) { 443 // Subtract from operation count the loads/store we expect load/store 444 // forwarding to remove. 445 unsigned storeCount = 0; 446 llvm::SmallDenseSet<Value *, 4> storeMemrefs; 447 srcForOp.getOperation()->walk([&](Operation *op) { 448 if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { 449 storeMemrefs.insert(storeOp.getMemRef()); 450 ++storeCount; 451 } 452 }); 453 // Subtract out any store ops in single-iteration src slice loop nest. 454 if (storeCount > 0) 455 computeCostMap[insertPointParent] = -storeCount; 456 // Subtract out any load users of 'storeMemrefs' nested below 457 // 'insertPointParent'. 458 for (auto *value : storeMemrefs) { 459 for (auto *user : value->getUsers()) { 460 if (auto loadOp = dyn_cast<AffineLoadOp>(user)) { 461 SmallVector<AffineForOp, 4> loops; 462 // Check if any loop in loop nest surrounding 'user' is 463 // 'insertPointParent'. 464 getLoopIVs(*user, &loops); 465 if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) { 466 if (auto forOp = 467 dyn_cast_or_null<AffineForOp>(user->getParentOp())) { 468 if (computeCostMap.count(forOp) == 0) 469 computeCostMap[forOp] = 0; 470 computeCostMap[forOp] -= 1; 471 } 472 } 473 } 474 } 475 } 476 } 477 478 // Compute op instance count for the src loop nest with iteration slicing. 479 int64_t sliceComputeCost = getComputeCostHelper( 480 srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap); 481 482 // Compute cost of fusion for this depth. 483 computeCostMap[insertPointParent] = sliceComputeCost; 484 485 *computeCost = 486 getComputeCostHelper(dstForOp.getOperation(), dstStats, 487 /*tripCountOverrideMap=*/nullptr, &computeCostMap); 488 return true; 489 }