github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LoopFusion.cpp (about) 1 //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements loop fusion. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Analysis/AffineAnalysis.h" 23 #include "mlir/Analysis/AffineStructures.h" 24 #include "mlir/Analysis/LoopAnalysis.h" 25 #include "mlir/Analysis/Utils.h" 26 #include "mlir/Dialect/AffineOps/AffineOps.h" 27 #include "mlir/Dialect/StandardOps/Ops.h" 28 #include "mlir/IR/AffineExpr.h" 29 #include "mlir/IR/AffineMap.h" 30 #include "mlir/IR/Builders.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Transforms/LoopFusionUtils.h" 33 #include "mlir/Transforms/LoopUtils.h" 34 #include "mlir/Transforms/Passes.h" 35 #include "mlir/Transforms/Utils.h" 36 #include "llvm/ADT/DenseMap.h" 37 #include "llvm/ADT/DenseSet.h" 38 #include "llvm/ADT/SetVector.h" 39 #include "llvm/Support/CommandLine.h" 40 #include "llvm/Support/Debug.h" 41 #include "llvm/Support/raw_ostream.h" 42 #include <iomanip> 43 #include <sstream> 44 #define DEBUG_TYPE "affine-loop-fusion" 45 46 using llvm::SetVector; 47 48 using namespace mlir; 49 50 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 51 52 /// Disables fusion profitability check and fuses if valid. Ignore any 53 /// additional (redundant) computation tolerance threshold 54 /// that would have prevented fusion. 55 static llvm::cl::opt<bool> 56 clMaximalLoopFusion("fusion-maximal", 57 llvm::cl::desc("Enables maximal loop fusion"), 58 llvm::cl::cat(clOptionsCategory)); 59 60 /// A threshold in percent of additional computation allowed when fusing. 61 static llvm::cl::opt<double> clFusionAddlComputeTolerance( 62 "fusion-compute-tolerance", 63 llvm::cl::desc("Fractional increase in additional " 64 "computation tolerated while fusing"), 65 llvm::cl::cat(clOptionsCategory)); 66 67 static llvm::cl::opt<unsigned> clFusionFastMemorySpace( 68 "fusion-fast-mem-space", 69 llvm::cl::desc("Faster memory space number to promote fusion buffers to"), 70 llvm::cl::cat(clOptionsCategory)); 71 72 // A local buffer of size less than or equal to this size is automatically 73 // promoted to fast memory after producer-consumer fusion. 74 static llvm::cl::opt<unsigned long long> clFusionLocalBufThreshold( 75 "fusion-local-buf-threshold", 76 llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast " 77 "memory space"), 78 llvm::cl::cat(clOptionsCategory)); 79 80 namespace { 81 82 /// Loop fusion pass. This pass currently supports a greedy fusion policy, 83 /// which fuses loop nests with single-writer/single-reader memref dependences 84 /// with the goal of improving locality. 85 86 // TODO(andydavis) Support fusion of source loop nests which write to multiple 87 // memrefs, where each memref can have multiple users (if profitable). 88 // TODO(andydavis) Extend this pass to check for fusion preventing dependences, 89 // and add support for more general loop fusion algorithms. 90 91 struct LoopFusion : public FunctionPass<LoopFusion> { 92 LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, 93 bool maximalFusion = false) 94 : localBufSizeThreshold(localBufSizeThreshold), 95 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {} 96 97 void runOnFunction() override; 98 99 // Any local buffers smaller than this size (in bytes) will be created in 100 // `fastMemorySpace` if provided. 101 uint64_t localBufSizeThreshold; 102 Optional<unsigned> fastMemorySpace = None; 103 // If true, ignore any additional (redundant) computation tolerance threshold 104 // that would have prevented fusion. 105 bool maximalFusion; 106 107 // The amount of additional computation that is tolerated while fusing 108 // pair-wise as a fraction of the total computation. 109 constexpr static double kComputeToleranceThreshold = 0.30f; 110 }; 111 112 } // end anonymous namespace 113 114 std::unique_ptr<FunctionPassBase> 115 mlir::createLoopFusionPass(unsigned fastMemorySpace, 116 uint64_t localBufSizeThreshold, bool maximalFusion) { 117 return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold, 118 maximalFusion); 119 } 120 121 namespace { 122 123 // LoopNestStateCollector walks loop nests and collects load and store 124 // operations, and whether or not an IfInst was encountered in the loop nest. 125 struct LoopNestStateCollector { 126 SmallVector<AffineForOp, 4> forOps; 127 SmallVector<Operation *, 4> loadOpInsts; 128 SmallVector<Operation *, 4> storeOpInsts; 129 bool hasNonForRegion = false; 130 131 void collect(Operation *opToWalk) { 132 opToWalk->walk([&](Operation *op) { 133 if (isa<AffineForOp>(op)) 134 forOps.push_back(cast<AffineForOp>(op)); 135 else if (op->getNumRegions() != 0) 136 hasNonForRegion = true; 137 else if (isa<AffineLoadOp>(op)) 138 loadOpInsts.push_back(op); 139 else if (isa<AffineStoreOp>(op)) 140 storeOpInsts.push_back(op); 141 }); 142 } 143 }; 144 145 // TODO(b/117228571) Replace when this is modeled through side-effects/op traits 146 static bool isMemRefDereferencingOp(Operation &op) { 147 if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) || 148 isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) 149 return true; 150 return false; 151 } 152 153 // MemRefDependenceGraph is a graph data structure where graph nodes are 154 // top-level operations in a FuncOp which contain load/store ops, and edges 155 // are memref dependences between the nodes. 156 // TODO(andydavis) Add a more flexible dependece graph representation. 157 // TODO(andydavis) Add a depth parameter to dependence graph construction. 158 struct MemRefDependenceGraph { 159 public: 160 // Node represents a node in the graph. A Node is either an entire loop nest 161 // rooted at the top level which contains loads/stores, or a top level 162 // load/store. 163 struct Node { 164 // The unique identifier of this node in the graph. 165 unsigned id; 166 // The top-level statement which is (or contains) a load/store. 167 Operation *op; 168 // List of load operations. 169 SmallVector<Operation *, 4> loads; 170 // List of store op insts. 171 SmallVector<Operation *, 4> stores; 172 Node(unsigned id, Operation *op) : id(id), op(op) {} 173 174 // Returns the load op count for 'memref'. 175 unsigned getLoadOpCount(Value *memref) { 176 unsigned loadOpCount = 0; 177 for (auto *loadOpInst : loads) { 178 if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef()) 179 ++loadOpCount; 180 } 181 return loadOpCount; 182 } 183 184 // Returns the store op count for 'memref'. 185 unsigned getStoreOpCount(Value *memref) { 186 unsigned storeOpCount = 0; 187 for (auto *storeOpInst : stores) { 188 if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef()) 189 ++storeOpCount; 190 } 191 return storeOpCount; 192 } 193 194 // Returns all store ops in 'storeOps' which access 'memref'. 195 void getStoreOpsForMemref(Value *memref, 196 SmallVectorImpl<Operation *> *storeOps) { 197 for (auto *storeOpInst : stores) { 198 if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef()) 199 storeOps->push_back(storeOpInst); 200 } 201 } 202 203 // Returns all load ops in 'loadOps' which access 'memref'. 204 void getLoadOpsForMemref(Value *memref, 205 SmallVectorImpl<Operation *> *loadOps) { 206 for (auto *loadOpInst : loads) { 207 if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef()) 208 loadOps->push_back(loadOpInst); 209 } 210 } 211 212 // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node 213 // has at least one load and store operation. 214 void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) { 215 llvm::SmallDenseSet<Value *, 2> loadMemrefs; 216 for (auto *loadOpInst : loads) { 217 loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef()); 218 } 219 for (auto *storeOpInst : stores) { 220 auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); 221 if (loadMemrefs.count(memref) > 0) 222 loadAndStoreMemrefSet->insert(memref); 223 } 224 } 225 }; 226 227 // Edge represents a data dependece between nodes in the graph. 228 struct Edge { 229 // The id of the node at the other end of the edge. 230 // If this edge is stored in Edge = Node.inEdges[i], then 231 // 'Node.inEdges[i].id' is the identifier of the source node of the edge. 232 // If this edge is stored in Edge = Node.outEdges[i], then 233 // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. 234 unsigned id; 235 // The SSA value on which this edge represents a dependence. 236 // If the value is a memref, then the dependence is between graph nodes 237 // which contain accesses to the same memref 'value'. If the value is a 238 // non-memref value, then the dependence is between a graph node which 239 // defines an SSA value and another graph node which uses the SSA value 240 // (e.g. a constant operation defining a value which is used inside a loop 241 // nest). 242 Value *value; 243 }; 244 245 // Map from node id to Node. 246 DenseMap<unsigned, Node> nodes; 247 // Map from node id to list of input edges. 248 DenseMap<unsigned, SmallVector<Edge, 2>> inEdges; 249 // Map from node id to list of output edges. 250 DenseMap<unsigned, SmallVector<Edge, 2>> outEdges; 251 // Map from memref to a count on the dependence edges associated with that 252 // memref. 253 DenseMap<Value *, unsigned> memrefEdgeCount; 254 // The next unique identifier to use for newly created graph nodes. 255 unsigned nextNodeId = 0; 256 257 MemRefDependenceGraph() {} 258 259 // Initializes the dependence graph based on operations in 'f'. 260 // Returns true on success, false otherwise. 261 bool init(FuncOp f); 262 263 // Returns the graph node for 'id'. 264 Node *getNode(unsigned id) { 265 auto it = nodes.find(id); 266 assert(it != nodes.end()); 267 return &it->second; 268 } 269 270 // Returns the graph node for 'forOp'. 271 Node *getForOpNode(AffineForOp forOp) { 272 for (auto &idAndNode : nodes) 273 if (idAndNode.second.op == forOp.getOperation()) 274 return &idAndNode.second; 275 return nullptr; 276 } 277 278 // Adds a node with 'op' to the graph and returns its unique identifier. 279 unsigned addNode(Operation *op) { 280 Node node(nextNodeId++, op); 281 nodes.insert({node.id, node}); 282 return node.id; 283 } 284 285 // Remove node 'id' (and its associated edges) from graph. 286 void removeNode(unsigned id) { 287 // Remove each edge in 'inEdges[id]'. 288 if (inEdges.count(id) > 0) { 289 SmallVector<Edge, 2> oldInEdges = inEdges[id]; 290 for (auto &inEdge : oldInEdges) { 291 removeEdge(inEdge.id, id, inEdge.value); 292 } 293 } 294 // Remove each edge in 'outEdges[id]'. 295 if (outEdges.count(id) > 0) { 296 SmallVector<Edge, 2> oldOutEdges = outEdges[id]; 297 for (auto &outEdge : oldOutEdges) { 298 removeEdge(id, outEdge.id, outEdge.value); 299 } 300 } 301 // Erase remaining node state. 302 inEdges.erase(id); 303 outEdges.erase(id); 304 nodes.erase(id); 305 } 306 307 // Returns true if node 'id' writes to any memref which escapes (or is an 308 // argument to) the function/block. Returns false otherwise. 309 bool writesToLiveInOrEscapingMemrefs(unsigned id) { 310 Node *node = getNode(id); 311 for (auto *storeOpInst : node->stores) { 312 auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef(); 313 auto *op = memref->getDefiningOp(); 314 // Return true if 'memref' is a block argument. 315 if (!op) 316 return true; 317 // Return true if any use of 'memref' escapes the function. 318 for (auto *user : memref->getUsers()) 319 if (!isMemRefDereferencingOp(*user)) 320 return true; 321 } 322 return false; 323 } 324 325 // Returns true if node 'id' can be removed from the graph. Returns false 326 // otherwise. A node can be removed from the graph iff the following 327 // conditions are met: 328 // *) The node does not write to any memref which escapes (or is a 329 // function/block argument). 330 // *) The node has no successors in the dependence graph. 331 bool canRemoveNode(unsigned id) { 332 if (writesToLiveInOrEscapingMemrefs(id)) 333 return false; 334 Node *node = getNode(id); 335 for (auto *storeOpInst : node->stores) { 336 // Return false if there exist out edges from 'id' on 'memref'. 337 if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0) 338 return false; 339 } 340 return true; 341 } 342 343 // Returns true iff there is an edge from node 'srcId' to node 'dstId' which 344 // is for 'value' if non-null, or for any value otherwise. Returns false 345 // otherwise. 346 bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) { 347 if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { 348 return false; 349 } 350 bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { 351 return edge.id == dstId && (!value || edge.value == value); 352 }); 353 bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { 354 return edge.id == srcId && (!value || edge.value == value); 355 }); 356 return hasOutEdge && hasInEdge; 357 } 358 359 // Adds an edge from node 'srcId' to node 'dstId' for 'value'. 360 void addEdge(unsigned srcId, unsigned dstId, Value *value) { 361 if (!hasEdge(srcId, dstId, value)) { 362 outEdges[srcId].push_back({dstId, value}); 363 inEdges[dstId].push_back({srcId, value}); 364 if (value->getType().isa<MemRefType>()) 365 memrefEdgeCount[value]++; 366 } 367 } 368 369 // Removes an edge from node 'srcId' to node 'dstId' for 'value'. 370 void removeEdge(unsigned srcId, unsigned dstId, Value *value) { 371 assert(inEdges.count(dstId) > 0); 372 assert(outEdges.count(srcId) > 0); 373 if (value->getType().isa<MemRefType>()) { 374 assert(memrefEdgeCount.count(value) > 0); 375 memrefEdgeCount[value]--; 376 } 377 // Remove 'srcId' from 'inEdges[dstId]'. 378 for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { 379 if ((*it).id == srcId && (*it).value == value) { 380 inEdges[dstId].erase(it); 381 break; 382 } 383 } 384 // Remove 'dstId' from 'outEdges[srcId]'. 385 for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { 386 if ((*it).id == dstId && (*it).value == value) { 387 outEdges[srcId].erase(it); 388 break; 389 } 390 } 391 } 392 393 // Returns true if there is a path in the dependence graph from node 'srcId' 394 // to node 'dstId'. Returns false otherwise. 395 bool hasDependencePath(unsigned srcId, unsigned dstId) { 396 // Worklist state is: <node-id, next-output-edge-index-to-visit> 397 SmallVector<std::pair<unsigned, unsigned>, 4> worklist; 398 worklist.push_back({srcId, 0}); 399 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. 400 while (!worklist.empty()) { 401 auto &idAndIndex = worklist.back(); 402 // Return true if we have reached 'dstId'. 403 if (idAndIndex.first == dstId) 404 return true; 405 // Pop and continue if node has no out edges, or if all out edges have 406 // already been visited. 407 if (outEdges.count(idAndIndex.first) == 0 || 408 idAndIndex.second == outEdges[idAndIndex.first].size()) { 409 worklist.pop_back(); 410 continue; 411 } 412 // Get graph edge to traverse. 413 Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; 414 // Increment next output edge index for 'idAndIndex'. 415 ++idAndIndex.second; 416 // Add node at 'edge.id' to worklist. 417 worklist.push_back({edge.id, 0}); 418 } 419 return false; 420 } 421 422 // Returns the input edge count for node 'id' and 'memref' from src nodes 423 // which access 'memref' with a store operation. 424 unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) { 425 unsigned inEdgeCount = 0; 426 if (inEdges.count(id) > 0) 427 for (auto &inEdge : inEdges[id]) 428 if (inEdge.value == memref) { 429 Node *srcNode = getNode(inEdge.id); 430 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' 431 if (srcNode->getStoreOpCount(memref) > 0) 432 ++inEdgeCount; 433 } 434 return inEdgeCount; 435 } 436 437 // Returns the output edge count for node 'id' and 'memref' (if non-null), 438 // otherwise returns the total output edge count from node 'id'. 439 unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) { 440 unsigned outEdgeCount = 0; 441 if (outEdges.count(id) > 0) 442 for (auto &outEdge : outEdges[id]) 443 if (!memref || outEdge.value == memref) 444 ++outEdgeCount; 445 return outEdgeCount; 446 } 447 448 // Computes and returns an insertion point operation, before which the 449 // the fused <srcId, dstId> loop nest can be inserted while preserving 450 // dependences. Returns nullptr if no such insertion point is found. 451 Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { 452 if (outEdges.count(srcId) == 0) 453 return getNode(dstId)->op; 454 455 // Build set of insts in range (srcId, dstId) which depend on 'srcId'. 456 SmallPtrSet<Operation *, 2> srcDepInsts; 457 for (auto &outEdge : outEdges[srcId]) 458 if (outEdge.id != dstId) 459 srcDepInsts.insert(getNode(outEdge.id)->op); 460 461 // Build set of insts in range (srcId, dstId) on which 'dstId' depends. 462 SmallPtrSet<Operation *, 2> dstDepInsts; 463 for (auto &inEdge : inEdges[dstId]) 464 if (inEdge.id != srcId) 465 dstDepInsts.insert(getNode(inEdge.id)->op); 466 467 Operation *srcNodeInst = getNode(srcId)->op; 468 Operation *dstNodeInst = getNode(dstId)->op; 469 470 // Computing insertion point: 471 // *) Walk all operation positions in Block operation list in the 472 // range (src, dst). For each operation 'op' visited in this search: 473 // *) Store in 'firstSrcDepPos' the first position where 'op' has a 474 // dependence edge from 'srcNode'. 475 // *) Store in 'lastDstDepPost' the last position where 'op' has a 476 // dependence edge to 'dstNode'. 477 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the 478 // operation insertion point (or return null pointer if no such 479 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). 480 SmallVector<Operation *, 2> depInsts; 481 Optional<unsigned> firstSrcDepPos; 482 Optional<unsigned> lastDstDepPos; 483 unsigned pos = 0; 484 for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); 485 it != Block::iterator(dstNodeInst); ++it) { 486 Operation *op = &(*it); 487 if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None) 488 firstSrcDepPos = pos; 489 if (dstDepInsts.count(op) > 0) 490 lastDstDepPos = pos; 491 depInsts.push_back(op); 492 ++pos; 493 } 494 495 if (firstSrcDepPos.hasValue()) { 496 if (lastDstDepPos.hasValue()) { 497 if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) { 498 // No valid insertion point exists which preserves dependences. 499 return nullptr; 500 } 501 } 502 // Return the insertion point at 'firstSrcDepPos'. 503 return depInsts[firstSrcDepPos.getValue()]; 504 } 505 // No dependence targets in range (or only dst deps in range), return 506 // 'dstNodInst' insertion point. 507 return dstNodeInst; 508 } 509 510 // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef' 511 // has been replaced in node at 'dstId' by a private memref. 512 void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) { 513 // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'. 514 if (inEdges.count(srcId) > 0) { 515 SmallVector<Edge, 2> oldInEdges = inEdges[srcId]; 516 for (auto &inEdge : oldInEdges) { 517 // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'. 518 if (inEdge.value != oldMemRef) 519 addEdge(inEdge.id, dstId, inEdge.value); 520 } 521 } 522 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. 523 if (outEdges.count(srcId) > 0) { 524 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId]; 525 for (auto &outEdge : oldOutEdges) { 526 // Remove any out edges from 'srcId' to 'dstId' across memrefs. 527 if (outEdge.id == dstId) 528 removeEdge(srcId, outEdge.id, outEdge.value); 529 } 530 } 531 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being 532 // replaced by a private memref). These edges could come from nodes 533 // other than 'srcId' which were removed in the previous step. 534 if (inEdges.count(dstId) > 0) { 535 SmallVector<Edge, 2> oldInEdges = inEdges[dstId]; 536 for (auto &inEdge : oldInEdges) 537 if (inEdge.value == oldMemRef) 538 removeEdge(inEdge.id, dstId, inEdge.value); 539 } 540 } 541 542 // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion 543 // of sibling node 'sidId' into node 'dstId'. 544 void updateEdges(unsigned sibId, unsigned dstId) { 545 // For each edge in 'inEdges[sibId]': 546 // *) Add new edge from source node 'inEdge.id' to 'dstNode'. 547 // *) Remove edge from source node 'inEdge.id' to 'sibNode'. 548 if (inEdges.count(sibId) > 0) { 549 SmallVector<Edge, 2> oldInEdges = inEdges[sibId]; 550 for (auto &inEdge : oldInEdges) { 551 addEdge(inEdge.id, dstId, inEdge.value); 552 removeEdge(inEdge.id, sibId, inEdge.value); 553 } 554 } 555 556 // For each edge in 'outEdges[sibId]' to node 'id' 557 // *) Add new edge from 'dstId' to 'outEdge.id'. 558 // *) Remove edge from 'sibId' to 'outEdge.id'. 559 if (outEdges.count(sibId) > 0) { 560 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId]; 561 for (auto &outEdge : oldOutEdges) { 562 addEdge(dstId, outEdge.id, outEdge.value); 563 removeEdge(sibId, outEdge.id, outEdge.value); 564 } 565 } 566 } 567 568 // Adds ops in 'loads' and 'stores' to node at 'id'. 569 void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads, 570 const SmallVectorImpl<Operation *> &stores) { 571 Node *node = getNode(id); 572 for (auto *loadOpInst : loads) 573 node->loads.push_back(loadOpInst); 574 for (auto *storeOpInst : stores) 575 node->stores.push_back(storeOpInst); 576 } 577 578 void clearNodeLoadAndStores(unsigned id) { 579 Node *node = getNode(id); 580 node->loads.clear(); 581 node->stores.clear(); 582 } 583 584 // Calls 'callback' for each input edge incident to node 'id' which carries a 585 // memref dependence. 586 void forEachMemRefInputEdge(unsigned id, 587 const std::function<void(Edge)> &callback) { 588 if (inEdges.count(id) > 0) 589 forEachMemRefEdge(inEdges[id], callback); 590 } 591 592 // Calls 'callback' for each output edge from node 'id' which carries a 593 // memref dependence. 594 void forEachMemRefOutputEdge(unsigned id, 595 const std::function<void(Edge)> &callback) { 596 if (outEdges.count(id) > 0) 597 forEachMemRefEdge(outEdges[id], callback); 598 } 599 600 // Calls 'callback' for each edge in 'edges' which carries a memref 601 // dependence. 602 void forEachMemRefEdge(ArrayRef<Edge> edges, 603 const std::function<void(Edge)> &callback) { 604 for (auto &edge : edges) { 605 // Skip if 'edge' is not a memref dependence edge. 606 if (!edge.value->getType().isa<MemRefType>()) 607 continue; 608 assert(nodes.count(edge.id) > 0); 609 // Skip if 'edge.id' is not a loop nest. 610 if (!isa<AffineForOp>(getNode(edge.id)->op)) 611 continue; 612 // Visit current input edge 'edge'. 613 callback(edge); 614 } 615 } 616 617 void print(raw_ostream &os) const { 618 os << "\nMemRefDependenceGraph\n"; 619 os << "\nNodes:\n"; 620 for (auto &idAndNode : nodes) { 621 os << "Node: " << idAndNode.first << "\n"; 622 auto it = inEdges.find(idAndNode.first); 623 if (it != inEdges.end()) { 624 for (const auto &e : it->second) 625 os << " InEdge: " << e.id << " " << e.value << "\n"; 626 } 627 it = outEdges.find(idAndNode.first); 628 if (it != outEdges.end()) { 629 for (const auto &e : it->second) 630 os << " OutEdge: " << e.id << " " << e.value << "\n"; 631 } 632 } 633 } 634 void dump() const { print(llvm::errs()); } 635 }; 636 637 // Intializes the data dependence graph by walking operations in 'f'. 638 // Assigns each node in the graph a node id based on program order in 'f'. 639 // TODO(andydavis) Add support for taking a Block arg to construct the 640 // dependence graph at a different depth. 641 bool MemRefDependenceGraph::init(FuncOp f) { 642 DenseMap<Value *, SetVector<unsigned>> memrefAccesses; 643 644 // TODO: support multi-block functions. 645 if (f.getBlocks().size() != 1) 646 return false; 647 648 DenseMap<Operation *, unsigned> forToNodeMap; 649 for (auto &op : f.front()) { 650 if (auto forOp = dyn_cast<AffineForOp>(op)) { 651 // Create graph node 'id' to represent top-level 'forOp' and record 652 // all loads and store accesses it contains. 653 LoopNestStateCollector collector; 654 collector.collect(&op); 655 // Return false if a non 'affine.for' region was found (not currently 656 // supported). 657 if (collector.hasNonForRegion) 658 return false; 659 Node node(nextNodeId++, &op); 660 for (auto *opInst : collector.loadOpInsts) { 661 node.loads.push_back(opInst); 662 auto *memref = cast<AffineLoadOp>(opInst).getMemRef(); 663 memrefAccesses[memref].insert(node.id); 664 } 665 for (auto *opInst : collector.storeOpInsts) { 666 node.stores.push_back(opInst); 667 auto *memref = cast<AffineStoreOp>(opInst).getMemRef(); 668 memrefAccesses[memref].insert(node.id); 669 } 670 forToNodeMap[&op] = node.id; 671 nodes.insert({node.id, node}); 672 } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) { 673 // Create graph node for top-level load op. 674 Node node(nextNodeId++, &op); 675 node.loads.push_back(&op); 676 auto *memref = cast<AffineLoadOp>(op).getMemRef(); 677 memrefAccesses[memref].insert(node.id); 678 nodes.insert({node.id, node}); 679 } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) { 680 // Create graph node for top-level store op. 681 Node node(nextNodeId++, &op); 682 node.stores.push_back(&op); 683 auto *memref = cast<AffineStoreOp>(op).getMemRef(); 684 memrefAccesses[memref].insert(node.id); 685 nodes.insert({node.id, node}); 686 } else if (op.getNumRegions() != 0) { 687 // Return false if another region is found (not currently supported). 688 return false; 689 } else if (op.getNumResults() > 0 && !op.use_empty()) { 690 // Create graph node for top-level producer of SSA values, which 691 // could be used by loop nest nodes. 692 Node node(nextNodeId++, &op); 693 nodes.insert({node.id, node}); 694 } 695 } 696 697 // Add dependence edges between nodes which produce SSA values and their 698 // users. 699 for (auto &idAndNode : nodes) { 700 const Node &node = idAndNode.second; 701 if (!node.loads.empty() || !node.stores.empty()) 702 continue; 703 auto *opInst = node.op; 704 for (auto *value : opInst->getResults()) { 705 for (auto *user : value->getUsers()) { 706 SmallVector<AffineForOp, 4> loops; 707 getLoopIVs(*user, &loops); 708 if (loops.empty()) 709 continue; 710 assert(forToNodeMap.count(loops[0].getOperation()) > 0); 711 unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()]; 712 addEdge(node.id, userLoopNestId, value); 713 } 714 } 715 } 716 717 // Walk memref access lists and add graph edges between dependent nodes. 718 for (auto &memrefAndList : memrefAccesses) { 719 unsigned n = memrefAndList.second.size(); 720 for (unsigned i = 0; i < n; ++i) { 721 unsigned srcId = memrefAndList.second[i]; 722 bool srcHasStore = 723 getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; 724 for (unsigned j = i + 1; j < n; ++j) { 725 unsigned dstId = memrefAndList.second[j]; 726 bool dstHasStore = 727 getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; 728 if (srcHasStore || dstHasStore) 729 addEdge(srcId, dstId, memrefAndList.first); 730 } 731 } 732 } 733 return true; 734 } 735 736 // Removes load operations from 'srcLoads' which operate on 'memref', and 737 // adds them to 'dstLoads'. 738 static void moveLoadsAccessingMemrefTo(Value *memref, 739 SmallVectorImpl<Operation *> *srcLoads, 740 SmallVectorImpl<Operation *> *dstLoads) { 741 dstLoads->clear(); 742 SmallVector<Operation *, 4> srcLoadsToKeep; 743 for (auto *load : *srcLoads) { 744 if (cast<AffineLoadOp>(load).getMemRef() == memref) 745 dstLoads->push_back(load); 746 else 747 srcLoadsToKeep.push_back(load); 748 } 749 srcLoads->swap(srcLoadsToKeep); 750 } 751 752 // Returns the innermost common loop depth for the set of operations in 'ops'. 753 static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) { 754 unsigned numOps = ops.size(); 755 assert(numOps > 0); 756 757 std::vector<SmallVector<AffineForOp, 4>> loops(numOps); 758 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max(); 759 for (unsigned i = 0; i < numOps; ++i) { 760 getLoopIVs(*ops[i], &loops[i]); 761 loopDepthLimit = 762 std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size())); 763 } 764 765 unsigned loopDepth = 0; 766 for (unsigned d = 0; d < loopDepthLimit; ++d) { 767 unsigned i; 768 for (i = 1; i < numOps; ++i) { 769 if (loops[i - 1][d] != loops[i][d]) 770 break; 771 } 772 if (i != numOps) 773 break; 774 ++loopDepth; 775 } 776 return loopDepth; 777 } 778 779 // Returns the maximum loop depth at which no dependences between 'loadOpInsts' 780 // and 'storeOpInsts' are satisfied. 781 static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts, 782 ArrayRef<Operation *> storeOpInsts) { 783 // Merge loads and stores into the same array. 784 SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end()); 785 ops.append(storeOpInsts.begin(), storeOpInsts.end()); 786 787 // Compute the innermost common loop depth for loads and stores. 788 unsigned loopDepth = getInnermostCommonLoopDepth(ops); 789 790 // Return common loop depth for loads if there are no store ops. 791 if (storeOpInsts.empty()) 792 return loopDepth; 793 794 // Check dependences on all pairs of ops in 'ops' and store the minimum 795 // loop depth at which a dependence is satisfied. 796 for (unsigned i = 0, e = ops.size(); i < e; ++i) { 797 auto *srcOpInst = ops[i]; 798 MemRefAccess srcAccess(srcOpInst); 799 for (unsigned j = 0; j < e; ++j) { 800 auto *dstOpInst = ops[j]; 801 MemRefAccess dstAccess(dstOpInst); 802 803 unsigned numCommonLoops = 804 getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); 805 for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { 806 FlatAffineConstraints dependenceConstraints; 807 // TODO(andydavis) Cache dependence analysis results, check cache here. 808 DependenceResult result = checkMemrefAccessDependence( 809 srcAccess, dstAccess, d, &dependenceConstraints, 810 /*dependenceComponents=*/nullptr); 811 if (hasDependence(result)) { 812 // Store minimum loop depth and break because we want the min 'd' at 813 // which there is a dependence. 814 loopDepth = std::min(loopDepth, d - 1); 815 break; 816 } 817 } 818 } 819 } 820 return loopDepth; 821 } 822 823 // Sinks all sequential loops to the innermost levels (while preserving 824 // relative order among them) and moves all parallel loops to the 825 // outermost (while again preserving relative order among them). 826 // This can increase the loop depth at which we can fuse a slice, since we are 827 // pushing loop carried dependence to a greater depth in the loop nest. 828 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { 829 assert(isa<AffineForOp>(node->op)); 830 AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op)); 831 node->op = newRootForOp.getOperation(); 832 } 833 834 // TODO(mlir-team): improve/complete this when we have target data. 835 unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { 836 auto elementType = memRefType.getElementType(); 837 838 unsigned sizeInBits; 839 if (elementType.isIntOrFloat()) { 840 sizeInBits = elementType.getIntOrFloatBitWidth(); 841 } else { 842 auto vectorType = elementType.cast<VectorType>(); 843 sizeInBits = 844 vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); 845 } 846 return llvm::divideCeil(sizeInBits, 8); 847 } 848 849 // Creates and returns a private (single-user) memref for fused loop rooted 850 // at 'forOp', with (potentially reduced) memref size based on the 851 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. 852 // TODO(bondhugula): consider refactoring the common code from generateDma and 853 // this one. 854 static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, 855 unsigned dstLoopDepth, 856 Optional<unsigned> fastMemorySpace, 857 uint64_t localBufSizeThreshold) { 858 auto *forInst = forOp.getOperation(); 859 860 // Create builder to insert alloc op just before 'forOp'. 861 OpBuilder b(forInst); 862 // Builder to create constants at the top level. 863 OpBuilder top(forInst->getParentOfType<FuncOp>().getBody()); 864 // Create new memref type based on slice bounds. 865 auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef(); 866 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>(); 867 unsigned rank = oldMemRefType.getRank(); 868 869 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. 870 MemRefRegion region(srcStoreOpInst->getLoc()); 871 bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); 872 (void)validRegion; 873 assert(validRegion && "unexpected memref region failure"); 874 SmallVector<int64_t, 4> newShape; 875 std::vector<SmallVector<int64_t, 4>> lbs; 876 SmallVector<int64_t, 8> lbDivisors; 877 lbs.reserve(rank); 878 // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed 879 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 880 Optional<int64_t> numElements = 881 region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors); 882 assert(numElements.hasValue() && 883 "non-constant number of elts in local buffer"); 884 885 const FlatAffineConstraints *cst = region.getConstraints(); 886 // 'outerIVs' holds the values that this memory region is symbolic/paramteric 887 // on; this would correspond to loop IVs surrounding the level at which the 888 // slice is being materialized. 889 SmallVector<Value *, 8> outerIVs; 890 cst->getIdValues(rank, cst->getNumIds(), &outerIVs); 891 892 // Build 'rank' AffineExprs from MemRefRegion 'lbs' 893 SmallVector<AffineExpr, 4> offsets; 894 offsets.reserve(rank); 895 for (unsigned d = 0; d < rank; ++d) { 896 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size"); 897 898 AffineExpr offset = top.getAffineConstantExpr(0); 899 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) { 900 offset = offset + lbs[d][j] * top.getAffineDimExpr(j); 901 } 902 assert(lbDivisors[d] > 0); 903 offset = 904 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]); 905 offsets.push_back(offset); 906 } 907 908 // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed 909 // by 'srcStoreOpInst'. 910 uint64_t bufSize = 911 getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue(); 912 unsigned newMemSpace; 913 if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) { 914 newMemSpace = fastMemorySpace.getValue(); 915 } else { 916 newMemSpace = oldMemRefType.getMemorySpace(); 917 } 918 auto newMemRefType = top.getMemRefType( 919 newShape, oldMemRefType.getElementType(), {}, newMemSpace); 920 // Gather alloc operands for the dynamic dimensions of the memref. 921 SmallVector<Value *, 4> allocOperands; 922 unsigned dynamicDimCount = 0; 923 for (auto dimSize : oldMemRefType.getShape()) { 924 if (dimSize == -1) 925 allocOperands.push_back( 926 top.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++)); 927 } 928 929 // Create new private memref for fused loop 'forOp'. 930 // TODO(andydavis) Create/move alloc ops for private memrefs closer to their 931 // consumer loop nests to reduce their live range. Currently they are added 932 // at the beginning of the function, because loop nests can be reordered 933 // during the fusion pass. 934 Value *newMemRef = 935 top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands); 936 937 // Build an AffineMap to remap access functions based on lower bound offsets. 938 SmallVector<AffineExpr, 4> remapExprs; 939 remapExprs.reserve(rank); 940 unsigned zeroOffsetCount = 0; 941 for (unsigned i = 0; i < rank; i++) { 942 if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>()) 943 if (constExpr.getValue() == 0) 944 ++zeroOffsetCount; 945 auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i); 946 947 auto remapExpr = 948 simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0); 949 remapExprs.push_back(remapExpr); 950 } 951 auto indexRemap = zeroOffsetCount == rank 952 ? AffineMap() 953 : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs); 954 // Replace all users of 'oldMemRef' with 'newMemRef'. 955 LogicalResult res = 956 replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, 957 /*extraOperands=*/outerIVs, 958 /*domInstFilter=*/&*forOp.getBody()->begin()); 959 assert(succeeded(res) && 960 "replaceAllMemrefUsesWith should always succeed here"); 961 (void)res; 962 return newMemRef; 963 } 964 965 // Checks if node 'srcId' (which writes to a live out memref), can be safely 966 // fused into node 'dstId'. Returns true if the following conditions are met: 967 // *) 'srcNode' only writes to live out 'memref'. 968 // *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId'). 969 // *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's 970 // write region to 'memref'. 971 // TODO(andydavis) Generalize this to handle more live in/out cases. 972 static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, 973 Value *memref, 974 MemRefDependenceGraph *mdg) { 975 auto *srcNode = mdg->getNode(srcId); 976 auto *dstNode = mdg->getNode(dstId); 977 978 // Gather all memrefs from 'srcNode' store ops. 979 DenseSet<Value *> storeMemrefs; 980 for (auto *storeOpInst : srcNode->stores) { 981 storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef()); 982 } 983 // Return false if any of the following are true: 984 // *) 'srcNode' writes to a live in/out memref other than 'memref'. 985 // *) 'srcNode' has more than one output edge on 'memref'. 986 // Check that all stores are to the same memref. 987 if (storeMemrefs.size() != 1 || 988 mdg->getOutEdgeCount(srcNode->id, memref) != 1) 989 return false; 990 // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'. 991 auto *srcStoreOpInst = srcNode->stores.front(); 992 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 993 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 994 LLVM_DEBUG(llvm::dbgs() 995 << "Unable to compute MemRefRegion for source operation\n."); 996 return false; 997 } 998 SmallVector<int64_t, 4> srcShape; 999 // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'. 1000 // by 'srcStoreOpInst' at depth 'dstLoopDepth'. 1001 Optional<int64_t> srcNumElements = 1002 srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape); 1003 if (!srcNumElements.hasValue()) 1004 return false; 1005 1006 // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'. 1007 // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for 1008 // each store op in 'dstStoreOps'). 1009 SmallVector<Operation *, 2> dstStoreOps; 1010 dstNode->getStoreOpsForMemref(memref, &dstStoreOps); 1011 SmallVector<Operation *, 2> dstLoadOps; 1012 dstNode->getLoadOpsForMemref(memref, &dstLoadOps); 1013 1014 auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0]; 1015 MemRefRegion dstRegion(dstOpInst->getLoc()); 1016 if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) { 1017 LLVM_DEBUG(llvm::dbgs() 1018 << "Unable to compute MemRefRegion for dest operation\n."); 1019 return false; 1020 } 1021 SmallVector<int64_t, 4> dstShape; 1022 // Query 'dstRegion' for 'dstShape' and 'dstNumElements'. 1023 // by 'dstOpInst' at depth 'dstLoopDepth'. 1024 Optional<int64_t> dstNumElements = 1025 dstRegion.getConstantBoundingSizeAndShape(&dstShape); 1026 if (!dstNumElements.hasValue()) 1027 return false; 1028 1029 // Return false if write region is not a superset of 'srcNodes' write 1030 // region to 'memref'. 1031 // TODO(andydavis) Check the shape and lower bounds here too. 1032 if (srcNumElements != dstNumElements) 1033 return false; 1034 return true; 1035 } 1036 1037 // Checks the profitability of fusing a backwards slice of the loop nest 1038 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'. 1039 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on 1040 // the memref being produced and consumed, which is an input to the cost model. 1041 // For producer-constumer fusion, 'srcStoreOpInst' will be the same as 1042 // 'srcOpInst', as we are slicing w.r.t to that producer. 1043 // For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which 1044 // reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst' 1045 // will be the unique store op in the src node, which will be used to check 1046 // that the write region is the same after input-reuse fusion. 1047 // Returns true if it is profitable to fuse the candidate loop nests. Returns 1048 // false otherwise. `dstLoopDepth` is set to the most profitable depth at which 1049 // to materialize the source loop nest slice. 1050 // The profitability model executes the following steps: 1051 // *) Computes the backward computation slice at 'srcOpInst'. This 1052 // computation slice of the loop nest surrounding 'srcOpInst' is 1053 // represented by modified src loop bounds in 'sliceState', which are 1054 // functions of loop IVs in the loop nest surrounding 'srcOpInst'. 1055 // *) Computes the cost of unfused src/dst loop nests (currently the cost of a 1056 // loop nest is the total number of dynamic operation instances in the loop 1057 // nest). 1058 // *) Computes the cost of fusing a slice of the src loop nest into the dst 1059 // loop nest at various values of dst loop depth, attempting to fuse 1060 // the largest compution slice at the maximal dst loop depth (closest to the 1061 // load) to minimize reuse distance and potentially enable subsequent 1062 // load/store forwarding. 1063 // NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for 1064 // the same memref as is written by 'srcOpInst', then the union of slice 1065 // loop bounds is used to compute the slice and associated slice cost. 1066 // NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop 1067 // nest, at which the src computation slice is inserted/fused. 1068 // NOTE: We attempt to maximize the dst loop depth, but there are cases 1069 // where a particular setting for 'dstLoopNest' might fuse an unsliced 1070 // loop (within the src computation slice) at a depth which results in 1071 // execessive recomputation (see unit tests for examples). 1072 // *) Compares the total cost of the unfused loop nests to the min cost fused 1073 // loop nest computed in the previous step, and returns true if the latter 1074 // is lower. 1075 static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1076 ArrayRef<Operation *> dstLoadOpInsts, 1077 ArrayRef<Operation *> dstStoreOpInsts, 1078 ComputationSliceState *sliceState, 1079 unsigned *dstLoopDepth, bool maximalFusion) { 1080 LLVM_DEBUG({ 1081 llvm::dbgs() << "Checking whether fusion is profitable between:\n"; 1082 llvm::dbgs() << " " << *srcOpInst << " and \n"; 1083 for (auto dstOpInst : dstLoadOpInsts) { 1084 llvm::dbgs() << " " << *dstOpInst << "\n"; 1085 }; 1086 }); 1087 1088 // Compute cost of sliced and unsliced src loop nest. 1089 SmallVector<AffineForOp, 4> srcLoopIVs; 1090 getLoopIVs(*srcOpInst, &srcLoopIVs); 1091 unsigned numSrcLoopIVs = srcLoopIVs.size(); 1092 1093 // Walk src loop nest and collect stats. 1094 LoopNestStats srcLoopNestStats; 1095 if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats)) 1096 return false; 1097 1098 // Compute cost of dst loop nest. 1099 SmallVector<AffineForOp, 4> dstLoopIVs; 1100 getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); 1101 1102 LoopNestStats dstLoopNestStats; 1103 if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) 1104 return false; 1105 1106 // Compute the maximum loop depth at which we can can insert the src slice 1107 // and still satisfy dest loop nest dependences, for producer-consumer fusion. 1108 unsigned maxDstLoopDepth = 1109 (srcOpInst == srcStoreOpInst) 1110 ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts) 1111 : dstLoopIVs.size(); 1112 if (maxDstLoopDepth == 0) { 1113 LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n"); 1114 return false; 1115 } 1116 1117 // Search for min cost value for 'dstLoopDepth'. At each value of 1118 // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice 1119 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union 1120 // of these bounds). Next the union slice bounds are used to calculate 1121 // the cost of the slice and the cost of the slice inserted into the dst 1122 // loop nest at 'dstLoopDepth'. 1123 uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max(); 1124 double maxStorageReduction = 0.0; 1125 Optional<uint64_t> sliceMemEstimate = None; 1126 1127 SmallVector<ComputationSliceState, 4> sliceStates; 1128 sliceStates.resize(maxDstLoopDepth); 1129 // The best loop depth at which to materialize the slice. 1130 Optional<unsigned> bestDstLoopDepth = None; 1131 1132 // Compute op instance count for the src loop nest without iteration slicing. 1133 uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats); 1134 1135 // Compute src loop nest write region size. 1136 MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc()); 1137 if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) { 1138 LLVM_DEBUG(llvm::dbgs() 1139 << "Unable to compute MemRefRegion for source operation\n."); 1140 return false; 1141 } 1142 1143 Optional<int64_t> maybeSrcWriteRegionSizeBytes = 1144 srcWriteRegion.getRegionSize(); 1145 if (!maybeSrcWriteRegionSizeBytes.hasValue()) 1146 return false; 1147 int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue(); 1148 1149 // Compute op instance count for the src loop nest. 1150 uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats); 1151 1152 // Evaluate all depth choices for materializing the slice in the destination 1153 // loop nest. 1154 for (unsigned i = maxDstLoopDepth; i >= 1; --i) { 1155 // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. 1156 if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, 1157 /*loopDepth=*/i, 1158 /*numCommonLoops=*/0, 1159 /*isBackwardSlice=*/true, 1160 &sliceStates[i - 1]))) { 1161 LLVM_DEBUG(llvm::dbgs() 1162 << "computeSliceUnion failed for loopDepth: " << i << "\n"); 1163 continue; 1164 } 1165 1166 int64_t fusedLoopNestComputeCost; 1167 if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], 1168 dstLoopNestStats, &sliceStates[i - 1], 1169 &fusedLoopNestComputeCost)) { 1170 LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); 1171 continue; 1172 } 1173 1174 double additionalComputeFraction = 1175 fusedLoopNestComputeCost / 1176 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1177 1; 1178 1179 // Determine what the slice write MemRefRegion would be, if the src loop 1180 // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop 1181 // nest at loop depth 'i' 1182 MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); 1183 if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, 1184 &sliceStates[i - 1]))) { 1185 LLVM_DEBUG(llvm::dbgs() 1186 << "Failed to compute slice write region at loopDepth: " << i 1187 << "\n"); 1188 continue; 1189 } 1190 1191 Optional<int64_t> maybeSliceWriteRegionSizeBytes = 1192 sliceWriteRegion.getRegionSize(); 1193 if (!maybeSliceWriteRegionSizeBytes.hasValue() || 1194 maybeSliceWriteRegionSizeBytes.getValue() == 0) { 1195 LLVM_DEBUG(llvm::dbgs() 1196 << "Failed to get slice write region size at loopDepth: " << i 1197 << "\n"); 1198 continue; 1199 } 1200 int64_t sliceWriteRegionSizeBytes = 1201 maybeSliceWriteRegionSizeBytes.getValue(); 1202 1203 // If we are fusing for reuse, check that write regions remain the same. 1204 // TODO(andydavis) Write region check should check sizes and offsets in 1205 // each dimension, so that we are sure they are covering the same memref 1206 // region. Also, move this out to a isMemRefRegionSuperSet helper function. 1207 if (srcOpInst != srcStoreOpInst && 1208 sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes) 1209 continue; 1210 1211 double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) / 1212 static_cast<double>(sliceWriteRegionSizeBytes); 1213 1214 LLVM_DEBUG({ 1215 std::stringstream msg; 1216 msg << " evaluating fusion profitability at depth : " << i << "\n" 1217 << std::fixed << std::setprecision(2) 1218 << " additional compute fraction: " 1219 << 100.0 * additionalComputeFraction << "%\n" 1220 << " storage reduction factor: " << storageReduction << "x\n" 1221 << " fused nest cost: " << fusedLoopNestComputeCost << "\n" 1222 << " src write region size: " << srcWriteRegionSizeBytes << "\n" 1223 << " slice write region size: " << sliceWriteRegionSizeBytes 1224 << "\n"; 1225 llvm::dbgs() << msg.str(); 1226 }); 1227 1228 double computeToleranceThreshold = 1229 clFusionAddlComputeTolerance.getNumOccurrences() > 0 1230 ? clFusionAddlComputeTolerance 1231 : LoopFusion::kComputeToleranceThreshold; 1232 1233 // TODO(b/123247369): This is a placeholder cost model. 1234 // Among all choices that add an acceptable amount of redundant computation 1235 // (as per computeToleranceThreshold), we will simply pick the one that 1236 // reduces the intermediary size the most. 1237 if ((storageReduction > maxStorageReduction) && 1238 (maximalFusion || 1239 (additionalComputeFraction < computeToleranceThreshold))) { 1240 maxStorageReduction = storageReduction; 1241 bestDstLoopDepth = i; 1242 minFusedLoopNestComputeCost = fusedLoopNestComputeCost; 1243 sliceMemEstimate = sliceWriteRegionSizeBytes; 1244 } 1245 } 1246 1247 // A simple cost model: fuse if it reduces the memory footprint. If 1248 // -maximal-fusion is set, fuse nevertheless. 1249 1250 if (!maximalFusion && !bestDstLoopDepth.hasValue()) { 1251 LLVM_DEBUG( 1252 llvm::dbgs() 1253 << "All fusion choices involve more than the threshold amount of " 1254 "redundant computation; NOT fusing.\n"); 1255 return false; 1256 } 1257 1258 if (!bestDstLoopDepth.hasValue()) { 1259 LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); 1260 return false; 1261 } 1262 1263 // Set dstLoopDepth based on best values from search. 1264 *dstLoopDepth = bestDstLoopDepth.getValue(); 1265 1266 LLVM_DEBUG( 1267 llvm::dbgs() << " LoopFusion fusion stats:" 1268 << "\n best loop depth: " << bestDstLoopDepth 1269 << "\n src loop nest compute cost: " << srcLoopNestCost 1270 << "\n dst loop nest compute cost: " << dstLoopNestCost 1271 << "\n fused loop nest compute cost: " 1272 << minFusedLoopNestComputeCost << "\n"); 1273 1274 auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]); 1275 auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]); 1276 1277 Optional<double> storageReduction = None; 1278 1279 if (!maximalFusion) { 1280 if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { 1281 LLVM_DEBUG( 1282 llvm::dbgs() 1283 << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); 1284 return false; 1285 } 1286 1287 auto srcMemSizeVal = srcMemSize.getValue(); 1288 auto dstMemSizeVal = dstMemSize.getValue(); 1289 1290 assert(sliceMemEstimate.hasValue() && "expected value"); 1291 auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); 1292 1293 LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" 1294 << " dst mem: " << dstMemSizeVal << "\n" 1295 << " fused mem: " << fusedMem << "\n" 1296 << " slice mem: " << sliceMemEstimate << "\n"); 1297 1298 if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { 1299 LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); 1300 return false; 1301 } 1302 storageReduction = 1303 100.0 * 1304 (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal)); 1305 } 1306 1307 double additionalComputeFraction = 1308 100.0 * (minFusedLoopNestComputeCost / 1309 (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) - 1310 1); 1311 (void)additionalComputeFraction; 1312 LLVM_DEBUG({ 1313 std::stringstream msg; 1314 msg << " fusion is most profitable at depth " << *dstLoopDepth << " with " 1315 << std::setprecision(2) << additionalComputeFraction 1316 << "% redundant computation and a "; 1317 msg << (storageReduction.hasValue() 1318 ? std::to_string(storageReduction.getValue()) 1319 : "<unknown>"); 1320 msg << "% storage reduction.\n"; 1321 llvm::dbgs() << msg.str(); 1322 }); 1323 1324 // Update return parameter 'sliceState' with 'bestSliceState'. 1325 ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; 1326 sliceState->lbs = bestSliceState->lbs; 1327 sliceState->ubs = bestSliceState->ubs; 1328 sliceState->lbOperands = bestSliceState->lbOperands; 1329 sliceState->ubOperands = bestSliceState->ubOperands; 1330 1331 // Canonicalize slice bound affine maps. 1332 for (unsigned i = 0; i < numSrcLoopIVs; ++i) { 1333 if (sliceState->lbs[i] != AffineMap()) { 1334 canonicalizeMapAndOperands(&sliceState->lbs[i], 1335 &sliceState->lbOperands[i]); 1336 } 1337 if (sliceState->ubs[i] != AffineMap()) { 1338 canonicalizeMapAndOperands(&sliceState->ubs[i], 1339 &sliceState->ubOperands[i]); 1340 } 1341 } 1342 return true; 1343 } 1344 1345 // GreedyFusion greedily fuses loop nests which have a producer/consumer or 1346 // input-reuse relationship on a memref, with the goal of improving locality. 1347 // 1348 // The steps of the producer-consumer fusion algorithm are as follows: 1349 // 1350 // *) A worklist is initialized with node ids from the dependence graph. 1351 // *) For each node id in the worklist: 1352 // *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a 1353 // candidate destination AffineForOp into which fusion will be attempted. 1354 // *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'. 1355 // *) For each LoadOp in 'dstLoadOps' do: 1356 // *) Look up dependent loop nests which have a single store op to the same 1357 // memref. 1358 // *) Check if dependences would be violated by the fusion. 1359 // *) Get a computation slice of 'srcLoopNest', which adjusts its loop 1360 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1361 // *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest', 1362 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1363 // *) Add the newly fused load/store operations to the state, 1364 // and also add newly fused load ops to 'dstLoopOps' to be considered 1365 // as fusion dst load ops in another iteration. 1366 // *) Remove old src loop nest and its associated state. 1367 // 1368 // The steps of the input-reuse fusion algorithm are as follows: 1369 // 1370 // *) Initialize 'worklist' with node ids from the dependence graph. 1371 // *) For each 'dstNode' in the worklist: 1372 // *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which 1373 // loads from the same memref, but which has no dependence paths to/from. 1374 // *) Get a computation slice of 'sibLoopNest', which adjusts its loop 1375 // bounds to be functions of 'dstLoopNest' IVs and symbols. 1376 // *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest', 1377 // at a loop depth determined by the cost model in 'isFusionProfitable'. 1378 // This function also checks that the memref write region of 'sibLoopNest', 1379 // is preserved in the fused loop nest. 1380 // *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'. 1381 // 1382 // Given a graph where top-level operations are vertices in the set 'V' and 1383 // edges in the set 'E' are dependences between vertices, this algorithm 1384 // takes O(V) time for initialization, and has runtime O(V + E). 1385 // 1386 // This greedy algorithm is not 'maximal' due to the current restriction of 1387 // fusing along single producer consumer edges, but there is a TODO to fix this. 1388 // 1389 // TODO(andydavis) Experiment with other fusion policies. 1390 struct GreedyFusion { 1391 public: 1392 // The data dependence graph to traverse during fusion. 1393 MemRefDependenceGraph *mdg; 1394 // Worklist of graph nodes visited during the fusion pass. 1395 SmallVector<unsigned, 8> worklist; 1396 // Set of graph nodes which are present on the worklist. 1397 llvm::SmallDenseSet<unsigned, 16> worklistSet; 1398 // Parameter for local buffer size threshold. 1399 unsigned localBufSizeThreshold; 1400 // Parameter for fast memory space. 1401 Optional<unsigned> fastMemorySpace; 1402 // If true, ignore any additional (redundant) computation tolerance threshold 1403 // that would have prevented fusion. 1404 bool maximalFusion; 1405 1406 using Node = MemRefDependenceGraph::Node; 1407 1408 GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold, 1409 Optional<unsigned> fastMemorySpace, bool maximalFusion) 1410 : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold), 1411 fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {} 1412 1413 // Initializes 'worklist' with nodes from 'mdg' 1414 void init() { 1415 // TODO(andydavis) Add a priority queue for prioritizing nodes by different 1416 // metrics (e.g. arithmetic intensity/flops-to-bytes ratio). 1417 worklist.clear(); 1418 worklistSet.clear(); 1419 for (auto &idAndNode : mdg->nodes) { 1420 const Node &node = idAndNode.second; 1421 worklist.push_back(node.id); 1422 worklistSet.insert(node.id); 1423 } 1424 } 1425 1426 // Run the GreedyFusion pass. 1427 // *) First pass through the nodes fuses single-use producer nodes into their 1428 // unique consumer. 1429 // *) Second pass fuses sibling nodes which share no dependence edges. 1430 // *) Third pass fuses any remaining producer nodes into their users. 1431 void run() { 1432 // TODO(andydavis) Run this repeatedly until a fixed-point is reached. 1433 fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); 1434 fuseSiblingNodes(); 1435 fuseProducerConsumerNodes( 1436 /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max()); 1437 eraseUnusedMemRefAllocations(); 1438 } 1439 1440 void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { 1441 init(); 1442 while (!worklist.empty()) { 1443 unsigned dstId = worklist.back(); 1444 worklist.pop_back(); 1445 worklistSet.erase(dstId); 1446 1447 // Skip if this node was removed (fused into another node). 1448 if (mdg->nodes.count(dstId) == 0) 1449 continue; 1450 // Get 'dstNode' into which to attempt fusion. 1451 auto *dstNode = mdg->getNode(dstId); 1452 // Skip if 'dstNode' is not a loop nest. 1453 if (!isa<AffineForOp>(dstNode->op)) 1454 continue; 1455 // Sink sequential loops in 'dstNode' (and thus raise parallel loops) 1456 // while preserving relative order. This can increase the maximum loop 1457 // depth at which we can fuse a slice of a producer loop nest into a 1458 // consumer loop nest. 1459 sinkSequentialLoops(dstNode); 1460 1461 SmallVector<Operation *, 4> loads = dstNode->loads; 1462 SmallVector<Operation *, 4> dstLoadOpInsts; 1463 DenseSet<Value *> visitedMemrefs; 1464 while (!loads.empty()) { 1465 // Get memref of load on top of the stack. 1466 auto *memref = cast<AffineLoadOp>(loads.back()).getMemRef(); 1467 if (visitedMemrefs.count(memref) > 0) 1468 continue; 1469 visitedMemrefs.insert(memref); 1470 // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'. 1471 moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts); 1472 // Skip if no input edges along which to fuse. 1473 if (mdg->inEdges.count(dstId) == 0) 1474 continue; 1475 // Iterate through in-edges for 'dstId' and src node id for any 1476 // edges on 'memref'. 1477 SmallVector<unsigned, 2> srcNodeIds; 1478 for (auto &srcEdge : mdg->inEdges[dstId]) { 1479 // Skip 'srcEdge' if not for 'memref'. 1480 if (srcEdge.value != memref) 1481 continue; 1482 srcNodeIds.push_back(srcEdge.id); 1483 } 1484 for (unsigned srcId : srcNodeIds) { 1485 // Skip if this node was removed (fused into another node). 1486 if (mdg->nodes.count(srcId) == 0) 1487 continue; 1488 // Get 'srcNode' from which to attempt fusion into 'dstNode'. 1489 auto *srcNode = mdg->getNode(srcId); 1490 // Skip if 'srcNode' is not a loop nest. 1491 if (!isa<AffineForOp>(srcNode->op)) 1492 continue; 1493 // Skip if 'srcNode' has more than one store to any memref. 1494 // TODO(andydavis) Support fusing multi-output src loop nests. 1495 if (srcNode->stores.size() != 1) 1496 continue; 1497 1498 // Skip if 'srcNode' writes to any live in or escaping memrefs, 1499 // and cannot be fused. 1500 bool writesToLiveInOrOut = 1501 mdg->writesToLiveInOrEscapingMemrefs(srcNode->id); 1502 if (writesToLiveInOrOut && 1503 !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg)) 1504 continue; 1505 1506 // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'. 1507 if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount) 1508 continue; 1509 1510 // Compute an operation list insertion point for the fused loop 1511 // nest which preserves dependences. 1512 Operation *insertPointInst = 1513 mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); 1514 if (insertPointInst == nullptr) 1515 continue; 1516 1517 // Get unique 'srcNode' store op. 1518 auto *srcStoreOpInst = srcNode->stores.front(); 1519 // Gather 'dstNode' store ops to 'memref'. 1520 SmallVector<Operation *, 2> dstStoreOpInsts; 1521 for (auto *storeOpInst : dstNode->stores) 1522 if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref) 1523 dstStoreOpInsts.push_back(storeOpInst); 1524 1525 unsigned bestDstLoopDepth; 1526 mlir::ComputationSliceState sliceState; 1527 // Check if fusion would be profitable. 1528 if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst, 1529 dstLoadOpInsts, dstStoreOpInsts, &sliceState, 1530 &bestDstLoopDepth, maximalFusion)) 1531 continue; 1532 // TODO(andydavis) Remove the following test code when canFuseLoops 1533 // is fully functional. 1534 mlir::ComputationSliceState sliceUnion; 1535 if (!maximalFusion) { 1536 FusionResult result = mlir::canFuseLoops( 1537 cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op), 1538 bestDstLoopDepth, &sliceUnion); 1539 assert(result.value == FusionResult::Success); 1540 (void)result; 1541 } 1542 // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. 1543 auto sliceLoopNest = mlir::insertBackwardComputationSlice( 1544 srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); 1545 if (sliceLoopNest) { 1546 LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" 1547 << *sliceLoopNest.getOperation() << "\n"); 1548 // Move 'dstAffineForOp' before 'insertPointInst' if needed. 1549 auto dstAffineForOp = cast<AffineForOp>(dstNode->op); 1550 if (insertPointInst != dstAffineForOp.getOperation()) { 1551 dstAffineForOp.getOperation()->moveBefore(insertPointInst); 1552 } 1553 // Update edges between 'srcNode' and 'dstNode'. 1554 mdg->updateEdges(srcNode->id, dstNode->id, memref); 1555 1556 // Collect slice loop stats. 1557 LoopNestStateCollector sliceCollector; 1558 sliceCollector.collect(sliceLoopNest.getOperation()); 1559 // Promote single iteration slice loops to single IV value. 1560 for (auto forOp : sliceCollector.forOps) { 1561 promoteIfSingleIteration(forOp); 1562 } 1563 if (!writesToLiveInOrOut) { 1564 // Create private memref for 'memref' in 'dstAffineForOp'. 1565 SmallVector<Operation *, 4> storesForMemref; 1566 for (auto *storeOpInst : sliceCollector.storeOpInsts) { 1567 if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref) 1568 storesForMemref.push_back(storeOpInst); 1569 } 1570 assert(storesForMemref.size() == 1); 1571 auto *newMemRef = createPrivateMemRef( 1572 dstAffineForOp, storesForMemref[0], bestDstLoopDepth, 1573 fastMemorySpace, localBufSizeThreshold); 1574 visitedMemrefs.insert(newMemRef); 1575 // Create new node in dependence graph for 'newMemRef' alloc op. 1576 unsigned newMemRefNodeId = 1577 mdg->addNode(newMemRef->getDefiningOp()); 1578 // Add edge from 'newMemRef' node to dstNode. 1579 mdg->addEdge(newMemRefNodeId, dstId, newMemRef); 1580 } 1581 1582 // Collect dst loop stats after memref privatizaton transformation. 1583 LoopNestStateCollector dstLoopCollector; 1584 dstLoopCollector.collect(dstAffineForOp.getOperation()); 1585 1586 // Add new load ops to current Node load op list 'loads' to 1587 // continue fusing based on new operands. 1588 for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { 1589 auto *loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef(); 1590 if (visitedMemrefs.count(loadMemRef) == 0) 1591 loads.push_back(loadOpInst); 1592 } 1593 1594 // Clear and add back loads and stores. 1595 mdg->clearNodeLoadAndStores(dstNode->id); 1596 mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, 1597 dstLoopCollector.storeOpInsts); 1598 // Remove old src loop nest if it no longer has outgoing dependence 1599 // edges, and if it does not write to a memref which escapes the 1600 // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has 1601 // been fused into 'dstNode' and write region of 'dstNode' covers 1602 // the write region of 'srcNode', and 'srcNode' has no other users 1603 // so it is safe to remove. 1604 if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { 1605 mdg->removeNode(srcNode->id); 1606 srcNode->op->erase(); 1607 } else { 1608 // Add remaining users of 'oldMemRef' back on the worklist (if not 1609 // already there), as its replacement with a local/private memref 1610 // has reduced dependences on 'oldMemRef' which may have created 1611 // new fusion opportunities. 1612 if (mdg->outEdges.count(srcNode->id) > 0) { 1613 SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges = 1614 mdg->outEdges[srcNode->id]; 1615 for (auto &outEdge : oldOutEdges) { 1616 if (outEdge.value == memref && 1617 worklistSet.count(outEdge.id) == 0) { 1618 worklist.push_back(outEdge.id); 1619 worklistSet.insert(outEdge.id); 1620 } 1621 } 1622 } 1623 } 1624 } 1625 } 1626 } 1627 } 1628 } 1629 1630 // Visits each node in the graph, and for each node, attempts to fuse it with 1631 // its sibling nodes (nodes which share a parent, but no dependence edges). 1632 void fuseSiblingNodes() { 1633 init(); 1634 while (!worklist.empty()) { 1635 unsigned dstId = worklist.back(); 1636 worklist.pop_back(); 1637 worklistSet.erase(dstId); 1638 1639 // Skip if this node was removed (fused into another node). 1640 if (mdg->nodes.count(dstId) == 0) 1641 continue; 1642 // Get 'dstNode' into which to attempt fusion. 1643 auto *dstNode = mdg->getNode(dstId); 1644 // Skip if 'dstNode' is not a loop nest. 1645 if (!isa<AffineForOp>(dstNode->op)) 1646 continue; 1647 // Attempt to fuse 'dstNode' with its sibling nodes in the graph. 1648 fuseWithSiblingNodes(dstNode); 1649 } 1650 } 1651 1652 // Attempt to fuse 'dstNode' with sibling nodes in the graph. 1653 void fuseWithSiblingNodes(Node *dstNode) { 1654 DenseSet<unsigned> visitedSibNodeIds; 1655 std::pair<unsigned, Value *> idAndMemref; 1656 while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { 1657 unsigned sibId = idAndMemref.first; 1658 Value *memref = idAndMemref.second; 1659 // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other 1660 // stores to the same memref in 'sibNode' loop nest. 1661 auto *sibNode = mdg->getNode(sibId); 1662 // Compute an operation list insertion point for the fused loop 1663 // nest which preserves dependences. 1664 assert(sibNode->op->getBlock() == dstNode->op->getBlock()); 1665 Operation *insertPointInst = 1666 sibNode->op->isBeforeInBlock(dstNode->op) 1667 ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id) 1668 : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id); 1669 if (insertPointInst == nullptr) 1670 continue; 1671 1672 // Check if fusion would be profitable and at what depth. 1673 1674 // Get unique 'sibNode' load op to 'memref'. 1675 SmallVector<Operation *, 2> sibLoadOpInsts; 1676 sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts); 1677 // Currently findSiblingNodeToFuse searches for siblings with one load. 1678 assert(sibLoadOpInsts.size() == 1); 1679 Operation *sibLoadOpInst = sibLoadOpInsts[0]; 1680 assert(!sibNode->stores.empty()); 1681 // TODO(andydavis) Choose the store which postdominates all other stores. 1682 auto *sibStoreOpInst = sibNode->stores.back(); 1683 1684 // Gather 'dstNode' load ops to 'memref'. 1685 SmallVector<Operation *, 2> dstLoadOpInsts; 1686 dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); 1687 1688 // Gather 'dstNode' store ops to 'memref'. 1689 SmallVector<Operation *, 2> dstStoreOpInsts; 1690 dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts); 1691 1692 unsigned bestDstLoopDepth; 1693 mlir::ComputationSliceState sliceState; 1694 1695 // Check if fusion would be profitable. 1696 if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, 1697 dstStoreOpInsts, &sliceState, &bestDstLoopDepth, 1698 maximalFusion)) 1699 continue; 1700 1701 // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. 1702 auto sliceLoopNest = mlir::insertBackwardComputationSlice( 1703 sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); 1704 if (sliceLoopNest != nullptr) { 1705 auto dstForInst = cast<AffineForOp>(dstNode->op); 1706 // Update operation position of fused loop nest (if needed). 1707 if (insertPointInst != dstForInst.getOperation()) { 1708 dstForInst.getOperation()->moveBefore(insertPointInst); 1709 } 1710 // Update data dependence graph state post fusion. 1711 updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); 1712 } 1713 } 1714 } 1715 1716 // Searches function argument uses and the graph from 'dstNode' looking for a 1717 // fusion candidate sibling node which shares no dependences with 'dstNode' 1718 // but which loads from the same memref. Returns true and sets 1719 // 'idAndMemrefToFuse' on success. Returns false otherwise. 1720 bool findSiblingNodeToFuse(Node *dstNode, 1721 DenseSet<unsigned> *visitedSibNodeIds, 1722 std::pair<unsigned, Value *> *idAndMemrefToFuse) { 1723 // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse 1724 // on 'memref'. 1725 auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) { 1726 // Skip if 'outEdge' is not a read-after-write dependence. 1727 // TODO(andydavis) Remove restrict to single load op restriction. 1728 if (sibNode->getLoadOpCount(memref) != 1) 1729 return false; 1730 // Skip if there exists a path of dependent edges between 1731 // 'sibNode' and 'dstNode'. 1732 if (mdg->hasDependencePath(sibNode->id, dstNode->id) || 1733 mdg->hasDependencePath(dstNode->id, sibNode->id)) 1734 return false; 1735 // Skip sib node if it loads to (and stores from) the same memref on 1736 // which it also has an input dependence edge. 1737 DenseSet<Value *> loadAndStoreMemrefSet; 1738 sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); 1739 if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { 1740 return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; 1741 })) 1742 return false; 1743 1744 // Check that all stores are to the same memref. 1745 DenseSet<Value *> storeMemrefs; 1746 for (auto *storeOpInst : sibNode->stores) { 1747 storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef()); 1748 } 1749 if (storeMemrefs.size() != 1) 1750 return false; 1751 return true; 1752 }; 1753 1754 // Search for siblings which load the same memref function argument. 1755 auto fn = dstNode->op->getParentOfType<FuncOp>(); 1756 for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { 1757 for (auto *user : fn.getArgument(i)->getUsers()) { 1758 if (auto loadOp = dyn_cast<AffineLoadOp>(user)) { 1759 // Gather loops surrounding 'use'. 1760 SmallVector<AffineForOp, 4> loops; 1761 getLoopIVs(*user, &loops); 1762 // Skip 'use' if it is not within a loop nest. 1763 if (loops.empty()) 1764 continue; 1765 Node *sibNode = mdg->getForOpNode(loops[0]); 1766 assert(sibNode != nullptr); 1767 // Skip 'use' if it not a sibling to 'dstNode'. 1768 if (sibNode->id == dstNode->id) 1769 continue; 1770 // Skip 'use' if it has been visited. 1771 if (visitedSibNodeIds->count(sibNode->id) > 0) 1772 continue; 1773 // Skip 'use' if it does not load from the same memref as 'dstNode'. 1774 auto *memref = loadOp.getMemRef(); 1775 if (dstNode->getLoadOpCount(memref) == 0) 1776 continue; 1777 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1778 if (canFuseWithSibNode(sibNode, memref)) { 1779 visitedSibNodeIds->insert(sibNode->id); 1780 idAndMemrefToFuse->first = sibNode->id; 1781 idAndMemrefToFuse->second = memref; 1782 return true; 1783 } 1784 } 1785 } 1786 } 1787 1788 // Search for siblings by following edges through an intermediate src node. 1789 // Collect candidate 'dstNode' input edges in 'inEdges'. 1790 SmallVector<MemRefDependenceGraph::Edge, 2> inEdges; 1791 mdg->forEachMemRefInputEdge( 1792 dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) { 1793 // Add 'inEdge' if it is a read-after-write dependence. 1794 if (dstNode->getLoadOpCount(inEdge.value) > 0 && 1795 mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0) 1796 inEdges.push_back(inEdge); 1797 }); 1798 1799 // Search for sibling nodes to fuse by visiting output edges from each input 1800 // edge in 'inEdges'. 1801 for (auto &inEdge : inEdges) { 1802 // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'. 1803 SmallVector<MemRefDependenceGraph::Edge, 2> outEdges; 1804 mdg->forEachMemRefOutputEdge( 1805 inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) { 1806 unsigned sibNodeId = outEdge.id; 1807 if (visitedSibNodeIds->count(sibNodeId) > 0) 1808 return; 1809 // Skip output edge if not a sibling using the same memref. 1810 if (outEdge.id == dstNode->id || outEdge.value != inEdge.value) 1811 return; 1812 auto *sibNode = mdg->getNode(sibNodeId); 1813 if (!isa<AffineForOp>(sibNode->op)) 1814 return; 1815 // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. 1816 if (canFuseWithSibNode(sibNode, outEdge.value)) { 1817 // Add candidate 'outEdge' to sibling node. 1818 outEdges.push_back(outEdge); 1819 } 1820 }); 1821 1822 // Add first candidate if any were returned. 1823 if (!outEdges.empty()) { 1824 visitedSibNodeIds->insert(outEdges[0].id); 1825 idAndMemrefToFuse->first = outEdges[0].id; 1826 idAndMemrefToFuse->second = outEdges[0].value; 1827 return true; 1828 } 1829 } 1830 return false; 1831 } 1832 1833 void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode, 1834 Node *dstNode) { 1835 // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. 1836 mdg->updateEdges(sibNode->id, dstNode->id); 1837 1838 // Collect slice loop stats. 1839 LoopNestStateCollector sliceCollector; 1840 sliceCollector.collect(sliceLoopNest.getOperation()); 1841 // Promote single iteration slice loops to single IV value. 1842 for (auto forOp : sliceCollector.forOps) { 1843 promoteIfSingleIteration(forOp); 1844 } 1845 1846 // Collect dst loop stats after memref privatizaton transformation. 1847 auto dstForInst = cast<AffineForOp>(dstNode->op); 1848 LoopNestStateCollector dstLoopCollector; 1849 dstLoopCollector.collect(dstForInst.getOperation()); 1850 // Clear and add back loads and stores 1851 mdg->clearNodeLoadAndStores(dstNode->id); 1852 mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, 1853 dstLoopCollector.storeOpInsts); 1854 // Remove old sibling loop nest if it no longer has outgoing dependence 1855 // edges, and it does not write to a memref which escapes the 1856 // function. 1857 if (mdg->getOutEdgeCount(sibNode->id) == 0) { 1858 mdg->removeNode(sibNode->id); 1859 sibNode->op->erase(); 1860 } 1861 } 1862 1863 // Clean up any allocs with no users. 1864 void eraseUnusedMemRefAllocations() { 1865 for (auto &pair : mdg->memrefEdgeCount) { 1866 if (pair.second > 0) 1867 continue; 1868 auto *memref = pair.first; 1869 // Skip if there exist other uses (return operation or function calls). 1870 if (!memref->use_empty()) 1871 continue; 1872 // Use list expected to match the dep graph info. 1873 auto *op = memref->getDefiningOp(); 1874 if (isa_and_nonnull<AllocOp>(op)) 1875 op->erase(); 1876 } 1877 } 1878 }; 1879 1880 } // end anonymous namespace 1881 1882 void LoopFusion::runOnFunction() { 1883 // Override if a command line argument was provided. 1884 if (clFusionFastMemorySpace.getNumOccurrences() > 0) { 1885 fastMemorySpace = clFusionFastMemorySpace.getValue(); 1886 } 1887 1888 // Override if a command line argument was provided. 1889 if (clFusionLocalBufThreshold.getNumOccurrences() > 0) { 1890 localBufSizeThreshold = clFusionLocalBufThreshold * 1024; 1891 } 1892 1893 if (clMaximalLoopFusion.getNumOccurrences() > 0) 1894 maximalFusion = clMaximalLoopFusion; 1895 1896 MemRefDependenceGraph g; 1897 if (g.init(getFunction())) 1898 GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion) 1899 .run(); 1900 } 1901 1902 static PassRegistration<LoopFusion> pass("affine-loop-fusion", 1903 "Fuse loop nests");