github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LoopFusion.cpp (about)

     1  //===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements loop fusion.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Analysis/AffineAnalysis.h"
    23  #include "mlir/Analysis/AffineStructures.h"
    24  #include "mlir/Analysis/LoopAnalysis.h"
    25  #include "mlir/Analysis/Utils.h"
    26  #include "mlir/Dialect/AffineOps/AffineOps.h"
    27  #include "mlir/Dialect/StandardOps/Ops.h"
    28  #include "mlir/IR/AffineExpr.h"
    29  #include "mlir/IR/AffineMap.h"
    30  #include "mlir/IR/Builders.h"
    31  #include "mlir/Pass/Pass.h"
    32  #include "mlir/Transforms/LoopFusionUtils.h"
    33  #include "mlir/Transforms/LoopUtils.h"
    34  #include "mlir/Transforms/Passes.h"
    35  #include "mlir/Transforms/Utils.h"
    36  #include "llvm/ADT/DenseMap.h"
    37  #include "llvm/ADT/DenseSet.h"
    38  #include "llvm/ADT/SetVector.h"
    39  #include "llvm/Support/CommandLine.h"
    40  #include "llvm/Support/Debug.h"
    41  #include "llvm/Support/raw_ostream.h"
    42  #include <iomanip>
    43  #include <sstream>
    44  #define DEBUG_TYPE "affine-loop-fusion"
    45  
    46  using llvm::SetVector;
    47  
    48  using namespace mlir;
    49  
    50  static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
    51  
    52  /// Disables fusion profitability check and fuses if valid. Ignore any
    53  /// additional (redundant) computation tolerance threshold
    54  /// that would have prevented fusion.
    55  static llvm::cl::opt<bool>
    56      clMaximalLoopFusion("fusion-maximal",
    57                          llvm::cl::desc("Enables maximal loop fusion"),
    58                          llvm::cl::cat(clOptionsCategory));
    59  
    60  /// A threshold in percent of additional computation allowed when fusing.
    61  static llvm::cl::opt<double> clFusionAddlComputeTolerance(
    62      "fusion-compute-tolerance",
    63      llvm::cl::desc("Fractional increase in additional "
    64                     "computation tolerated while fusing"),
    65      llvm::cl::cat(clOptionsCategory));
    66  
    67  static llvm::cl::opt<unsigned> clFusionFastMemorySpace(
    68      "fusion-fast-mem-space",
    69      llvm::cl::desc("Faster memory space number to promote fusion buffers to"),
    70      llvm::cl::cat(clOptionsCategory));
    71  
    72  // A local buffer of size less than or equal to this size is automatically
    73  // promoted to fast memory after producer-consumer fusion.
    74  static llvm::cl::opt<unsigned long long> clFusionLocalBufThreshold(
    75      "fusion-local-buf-threshold",
    76      llvm::cl::desc("Threshold size (KiB) for promoting local buffers to fast "
    77                     "memory space"),
    78      llvm::cl::cat(clOptionsCategory));
    79  
    80  namespace {
    81  
    82  /// Loop fusion pass. This pass currently supports a greedy fusion policy,
    83  /// which fuses loop nests with single-writer/single-reader memref dependences
    84  /// with the goal of improving locality.
    85  
    86  // TODO(andydavis) Support fusion of source loop nests which write to multiple
    87  // memrefs, where each memref can have multiple users (if profitable).
    88  // TODO(andydavis) Extend this pass to check for fusion preventing dependences,
    89  // and add support for more general loop fusion algorithms.
    90  
    91  struct LoopFusion : public FunctionPass<LoopFusion> {
    92    LoopFusion(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0,
    93               bool maximalFusion = false)
    94        : localBufSizeThreshold(localBufSizeThreshold),
    95          fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
    96  
    97    void runOnFunction() override;
    98  
    99    // Any local buffers smaller than this size (in bytes) will be created in
   100    // `fastMemorySpace` if provided.
   101    uint64_t localBufSizeThreshold;
   102    Optional<unsigned> fastMemorySpace = None;
   103    // If true, ignore any additional (redundant) computation tolerance threshold
   104    // that would have prevented fusion.
   105    bool maximalFusion;
   106  
   107    // The amount of additional computation that is tolerated while fusing
   108    // pair-wise as a fraction of the total computation.
   109    constexpr static double kComputeToleranceThreshold = 0.30f;
   110  };
   111  
   112  } // end anonymous namespace
   113  
   114  std::unique_ptr<FunctionPassBase>
   115  mlir::createLoopFusionPass(unsigned fastMemorySpace,
   116                             uint64_t localBufSizeThreshold, bool maximalFusion) {
   117    return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
   118                                        maximalFusion);
   119  }
   120  
   121  namespace {
   122  
   123  // LoopNestStateCollector walks loop nests and collects load and store
   124  // operations, and whether or not an IfInst was encountered in the loop nest.
   125  struct LoopNestStateCollector {
   126    SmallVector<AffineForOp, 4> forOps;
   127    SmallVector<Operation *, 4> loadOpInsts;
   128    SmallVector<Operation *, 4> storeOpInsts;
   129    bool hasNonForRegion = false;
   130  
   131    void collect(Operation *opToWalk) {
   132      opToWalk->walk([&](Operation *op) {
   133        if (isa<AffineForOp>(op))
   134          forOps.push_back(cast<AffineForOp>(op));
   135        else if (op->getNumRegions() != 0)
   136          hasNonForRegion = true;
   137        else if (isa<AffineLoadOp>(op))
   138          loadOpInsts.push_back(op);
   139        else if (isa<AffineStoreOp>(op))
   140          storeOpInsts.push_back(op);
   141      });
   142    }
   143  };
   144  
   145  // TODO(b/117228571) Replace when this is modeled through side-effects/op traits
   146  static bool isMemRefDereferencingOp(Operation &op) {
   147    if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
   148        isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
   149      return true;
   150    return false;
   151  }
   152  
   153  // MemRefDependenceGraph is a graph data structure where graph nodes are
   154  // top-level operations in a FuncOp which contain load/store ops, and edges
   155  // are memref dependences between the nodes.
   156  // TODO(andydavis) Add a more flexible dependece graph representation.
   157  // TODO(andydavis) Add a depth parameter to dependence graph construction.
   158  struct MemRefDependenceGraph {
   159  public:
   160    // Node represents a node in the graph. A Node is either an entire loop nest
   161    // rooted at the top level which contains loads/stores, or a top level
   162    // load/store.
   163    struct Node {
   164      // The unique identifier of this node in the graph.
   165      unsigned id;
   166      // The top-level statement which is (or contains) a load/store.
   167      Operation *op;
   168      // List of load operations.
   169      SmallVector<Operation *, 4> loads;
   170      // List of store op insts.
   171      SmallVector<Operation *, 4> stores;
   172      Node(unsigned id, Operation *op) : id(id), op(op) {}
   173  
   174      // Returns the load op count for 'memref'.
   175      unsigned getLoadOpCount(Value *memref) {
   176        unsigned loadOpCount = 0;
   177        for (auto *loadOpInst : loads) {
   178          if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
   179            ++loadOpCount;
   180        }
   181        return loadOpCount;
   182      }
   183  
   184      // Returns the store op count for 'memref'.
   185      unsigned getStoreOpCount(Value *memref) {
   186        unsigned storeOpCount = 0;
   187        for (auto *storeOpInst : stores) {
   188          if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
   189            ++storeOpCount;
   190        }
   191        return storeOpCount;
   192      }
   193  
   194      // Returns all store ops in 'storeOps' which access 'memref'.
   195      void getStoreOpsForMemref(Value *memref,
   196                                SmallVectorImpl<Operation *> *storeOps) {
   197        for (auto *storeOpInst : stores) {
   198          if (memref == cast<AffineStoreOp>(storeOpInst).getMemRef())
   199            storeOps->push_back(storeOpInst);
   200        }
   201      }
   202  
   203      // Returns all load ops in 'loadOps' which access 'memref'.
   204      void getLoadOpsForMemref(Value *memref,
   205                               SmallVectorImpl<Operation *> *loadOps) {
   206        for (auto *loadOpInst : loads) {
   207          if (memref == cast<AffineLoadOp>(loadOpInst).getMemRef())
   208            loadOps->push_back(loadOpInst);
   209        }
   210      }
   211  
   212      // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
   213      // has at least one load and store operation.
   214      void getLoadAndStoreMemrefSet(DenseSet<Value *> *loadAndStoreMemrefSet) {
   215        llvm::SmallDenseSet<Value *, 2> loadMemrefs;
   216        for (auto *loadOpInst : loads) {
   217          loadMemrefs.insert(cast<AffineLoadOp>(loadOpInst).getMemRef());
   218        }
   219        for (auto *storeOpInst : stores) {
   220          auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
   221          if (loadMemrefs.count(memref) > 0)
   222            loadAndStoreMemrefSet->insert(memref);
   223        }
   224      }
   225    };
   226  
   227    // Edge represents a data dependece between nodes in the graph.
   228    struct Edge {
   229      // The id of the node at the other end of the edge.
   230      // If this edge is stored in Edge = Node.inEdges[i], then
   231      // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
   232      // If this edge is stored in Edge = Node.outEdges[i], then
   233      // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
   234      unsigned id;
   235      // The SSA value on which this edge represents a dependence.
   236      // If the value is a memref, then the dependence is between graph nodes
   237      // which contain accesses to the same memref 'value'. If the value is a
   238      // non-memref value, then the dependence is between a graph node which
   239      // defines an SSA value and another graph node which uses the SSA value
   240      // (e.g. a constant operation defining a value which is used inside a loop
   241      // nest).
   242      Value *value;
   243    };
   244  
   245    // Map from node id to Node.
   246    DenseMap<unsigned, Node> nodes;
   247    // Map from node id to list of input edges.
   248    DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
   249    // Map from node id to list of output edges.
   250    DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
   251    // Map from memref to a count on the dependence edges associated with that
   252    // memref.
   253    DenseMap<Value *, unsigned> memrefEdgeCount;
   254    // The next unique identifier to use for newly created graph nodes.
   255    unsigned nextNodeId = 0;
   256  
   257    MemRefDependenceGraph() {}
   258  
   259    // Initializes the dependence graph based on operations in 'f'.
   260    // Returns true on success, false otherwise.
   261    bool init(FuncOp f);
   262  
   263    // Returns the graph node for 'id'.
   264    Node *getNode(unsigned id) {
   265      auto it = nodes.find(id);
   266      assert(it != nodes.end());
   267      return &it->second;
   268    }
   269  
   270    // Returns the graph node for 'forOp'.
   271    Node *getForOpNode(AffineForOp forOp) {
   272      for (auto &idAndNode : nodes)
   273        if (idAndNode.second.op == forOp.getOperation())
   274          return &idAndNode.second;
   275      return nullptr;
   276    }
   277  
   278    // Adds a node with 'op' to the graph and returns its unique identifier.
   279    unsigned addNode(Operation *op) {
   280      Node node(nextNodeId++, op);
   281      nodes.insert({node.id, node});
   282      return node.id;
   283    }
   284  
   285    // Remove node 'id' (and its associated edges) from graph.
   286    void removeNode(unsigned id) {
   287      // Remove each edge in 'inEdges[id]'.
   288      if (inEdges.count(id) > 0) {
   289        SmallVector<Edge, 2> oldInEdges = inEdges[id];
   290        for (auto &inEdge : oldInEdges) {
   291          removeEdge(inEdge.id, id, inEdge.value);
   292        }
   293      }
   294      // Remove each edge in 'outEdges[id]'.
   295      if (outEdges.count(id) > 0) {
   296        SmallVector<Edge, 2> oldOutEdges = outEdges[id];
   297        for (auto &outEdge : oldOutEdges) {
   298          removeEdge(id, outEdge.id, outEdge.value);
   299        }
   300      }
   301      // Erase remaining node state.
   302      inEdges.erase(id);
   303      outEdges.erase(id);
   304      nodes.erase(id);
   305    }
   306  
   307    // Returns true if node 'id' writes to any memref which escapes (or is an
   308    // argument to) the function/block. Returns false otherwise.
   309    bool writesToLiveInOrEscapingMemrefs(unsigned id) {
   310      Node *node = getNode(id);
   311      for (auto *storeOpInst : node->stores) {
   312        auto *memref = cast<AffineStoreOp>(storeOpInst).getMemRef();
   313        auto *op = memref->getDefiningOp();
   314        // Return true if 'memref' is a block argument.
   315        if (!op)
   316          return true;
   317        // Return true if any use of 'memref' escapes the function.
   318        for (auto *user : memref->getUsers())
   319          if (!isMemRefDereferencingOp(*user))
   320            return true;
   321      }
   322      return false;
   323    }
   324  
   325    // Returns true if node 'id' can be removed from the graph. Returns false
   326    // otherwise. A node can be removed from the graph iff the following
   327    // conditions are met:
   328    // *) The node does not write to any memref which escapes (or is a
   329    //    function/block argument).
   330    // *) The node has no successors in the dependence graph.
   331    bool canRemoveNode(unsigned id) {
   332      if (writesToLiveInOrEscapingMemrefs(id))
   333        return false;
   334      Node *node = getNode(id);
   335      for (auto *storeOpInst : node->stores) {
   336        // Return false if there exist out edges from 'id' on 'memref'.
   337        if (getOutEdgeCount(id, cast<AffineStoreOp>(storeOpInst).getMemRef()) > 0)
   338          return false;
   339      }
   340      return true;
   341    }
   342  
   343    // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
   344    // is for 'value' if non-null, or for any value otherwise. Returns false
   345    // otherwise.
   346    bool hasEdge(unsigned srcId, unsigned dstId, Value *value = nullptr) {
   347      if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
   348        return false;
   349      }
   350      bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
   351        return edge.id == dstId && (!value || edge.value == value);
   352      });
   353      bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
   354        return edge.id == srcId && (!value || edge.value == value);
   355      });
   356      return hasOutEdge && hasInEdge;
   357    }
   358  
   359    // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
   360    void addEdge(unsigned srcId, unsigned dstId, Value *value) {
   361      if (!hasEdge(srcId, dstId, value)) {
   362        outEdges[srcId].push_back({dstId, value});
   363        inEdges[dstId].push_back({srcId, value});
   364        if (value->getType().isa<MemRefType>())
   365          memrefEdgeCount[value]++;
   366      }
   367    }
   368  
   369    // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
   370    void removeEdge(unsigned srcId, unsigned dstId, Value *value) {
   371      assert(inEdges.count(dstId) > 0);
   372      assert(outEdges.count(srcId) > 0);
   373      if (value->getType().isa<MemRefType>()) {
   374        assert(memrefEdgeCount.count(value) > 0);
   375        memrefEdgeCount[value]--;
   376      }
   377      // Remove 'srcId' from 'inEdges[dstId]'.
   378      for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
   379        if ((*it).id == srcId && (*it).value == value) {
   380          inEdges[dstId].erase(it);
   381          break;
   382        }
   383      }
   384      // Remove 'dstId' from 'outEdges[srcId]'.
   385      for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
   386        if ((*it).id == dstId && (*it).value == value) {
   387          outEdges[srcId].erase(it);
   388          break;
   389        }
   390      }
   391    }
   392  
   393    // Returns true if there is a path in the dependence graph from node 'srcId'
   394    // to node 'dstId'. Returns false otherwise.
   395    bool hasDependencePath(unsigned srcId, unsigned dstId) {
   396      // Worklist state is: <node-id, next-output-edge-index-to-visit>
   397      SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
   398      worklist.push_back({srcId, 0});
   399      // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
   400      while (!worklist.empty()) {
   401        auto &idAndIndex = worklist.back();
   402        // Return true if we have reached 'dstId'.
   403        if (idAndIndex.first == dstId)
   404          return true;
   405        // Pop and continue if node has no out edges, or if all out edges have
   406        // already been visited.
   407        if (outEdges.count(idAndIndex.first) == 0 ||
   408            idAndIndex.second == outEdges[idAndIndex.first].size()) {
   409          worklist.pop_back();
   410          continue;
   411        }
   412        // Get graph edge to traverse.
   413        Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
   414        // Increment next output edge index for 'idAndIndex'.
   415        ++idAndIndex.second;
   416        // Add node at 'edge.id' to worklist.
   417        worklist.push_back({edge.id, 0});
   418      }
   419      return false;
   420    }
   421  
   422    // Returns the input edge count for node 'id' and 'memref' from src nodes
   423    // which access 'memref' with a store operation.
   424    unsigned getIncomingMemRefAccesses(unsigned id, Value *memref) {
   425      unsigned inEdgeCount = 0;
   426      if (inEdges.count(id) > 0)
   427        for (auto &inEdge : inEdges[id])
   428          if (inEdge.value == memref) {
   429            Node *srcNode = getNode(inEdge.id);
   430            // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
   431            if (srcNode->getStoreOpCount(memref) > 0)
   432              ++inEdgeCount;
   433          }
   434      return inEdgeCount;
   435    }
   436  
   437    // Returns the output edge count for node 'id' and 'memref' (if non-null),
   438    // otherwise returns the total output edge count from node 'id'.
   439    unsigned getOutEdgeCount(unsigned id, Value *memref = nullptr) {
   440      unsigned outEdgeCount = 0;
   441      if (outEdges.count(id) > 0)
   442        for (auto &outEdge : outEdges[id])
   443          if (!memref || outEdge.value == memref)
   444            ++outEdgeCount;
   445      return outEdgeCount;
   446    }
   447  
   448    // Computes and returns an insertion point operation, before which the
   449    // the fused <srcId, dstId> loop nest can be inserted while preserving
   450    // dependences. Returns nullptr if no such insertion point is found.
   451    Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
   452      if (outEdges.count(srcId) == 0)
   453        return getNode(dstId)->op;
   454  
   455      // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
   456      SmallPtrSet<Operation *, 2> srcDepInsts;
   457      for (auto &outEdge : outEdges[srcId])
   458        if (outEdge.id != dstId)
   459          srcDepInsts.insert(getNode(outEdge.id)->op);
   460  
   461      // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
   462      SmallPtrSet<Operation *, 2> dstDepInsts;
   463      for (auto &inEdge : inEdges[dstId])
   464        if (inEdge.id != srcId)
   465          dstDepInsts.insert(getNode(inEdge.id)->op);
   466  
   467      Operation *srcNodeInst = getNode(srcId)->op;
   468      Operation *dstNodeInst = getNode(dstId)->op;
   469  
   470      // Computing insertion point:
   471      // *) Walk all operation positions in Block operation list in the
   472      //    range (src, dst). For each operation 'op' visited in this search:
   473      //   *) Store in 'firstSrcDepPos' the first position where 'op' has a
   474      //      dependence edge from 'srcNode'.
   475      //   *) Store in 'lastDstDepPost' the last position where 'op' has a
   476      //      dependence edge to 'dstNode'.
   477      // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
   478      //    operation insertion point (or return null pointer if no such
   479      //    insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
   480      SmallVector<Operation *, 2> depInsts;
   481      Optional<unsigned> firstSrcDepPos;
   482      Optional<unsigned> lastDstDepPos;
   483      unsigned pos = 0;
   484      for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
   485           it != Block::iterator(dstNodeInst); ++it) {
   486        Operation *op = &(*it);
   487        if (srcDepInsts.count(op) > 0 && firstSrcDepPos == None)
   488          firstSrcDepPos = pos;
   489        if (dstDepInsts.count(op) > 0)
   490          lastDstDepPos = pos;
   491        depInsts.push_back(op);
   492        ++pos;
   493      }
   494  
   495      if (firstSrcDepPos.hasValue()) {
   496        if (lastDstDepPos.hasValue()) {
   497          if (firstSrcDepPos.getValue() <= lastDstDepPos.getValue()) {
   498            // No valid insertion point exists which preserves dependences.
   499            return nullptr;
   500          }
   501        }
   502        // Return the insertion point at 'firstSrcDepPos'.
   503        return depInsts[firstSrcDepPos.getValue()];
   504      }
   505      // No dependence targets in range (or only dst deps in range), return
   506      // 'dstNodInst' insertion point.
   507      return dstNodeInst;
   508    }
   509  
   510    // Updates edge mappings from node 'srcId' to node 'dstId' after 'oldMemRef'
   511    // has been replaced in node at 'dstId' by a private memref.
   512    void updateEdges(unsigned srcId, unsigned dstId, Value *oldMemRef) {
   513      // For each edge in 'inEdges[srcId]': add new edge remaping to 'dstId'.
   514      if (inEdges.count(srcId) > 0) {
   515        SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
   516        for (auto &inEdge : oldInEdges) {
   517          // Add edge from 'inEdge.id' to 'dstId' if not for 'oldMemRef'.
   518          if (inEdge.value != oldMemRef)
   519            addEdge(inEdge.id, dstId, inEdge.value);
   520        }
   521      }
   522      // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
   523      if (outEdges.count(srcId) > 0) {
   524        SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
   525        for (auto &outEdge : oldOutEdges) {
   526          // Remove any out edges from 'srcId' to 'dstId' across memrefs.
   527          if (outEdge.id == dstId)
   528            removeEdge(srcId, outEdge.id, outEdge.value);
   529        }
   530      }
   531      // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
   532      // replaced by a private memref). These edges could come from nodes
   533      // other than 'srcId' which were removed in the previous step.
   534      if (inEdges.count(dstId) > 0) {
   535        SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
   536        for (auto &inEdge : oldInEdges)
   537          if (inEdge.value == oldMemRef)
   538            removeEdge(inEdge.id, dstId, inEdge.value);
   539      }
   540    }
   541  
   542    // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
   543    // of sibling node 'sidId' into node 'dstId'.
   544    void updateEdges(unsigned sibId, unsigned dstId) {
   545      // For each edge in 'inEdges[sibId]':
   546      // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
   547      // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
   548      if (inEdges.count(sibId) > 0) {
   549        SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
   550        for (auto &inEdge : oldInEdges) {
   551          addEdge(inEdge.id, dstId, inEdge.value);
   552          removeEdge(inEdge.id, sibId, inEdge.value);
   553        }
   554      }
   555  
   556      // For each edge in 'outEdges[sibId]' to node 'id'
   557      // *) Add new edge from 'dstId' to 'outEdge.id'.
   558      // *) Remove edge from 'sibId' to 'outEdge.id'.
   559      if (outEdges.count(sibId) > 0) {
   560        SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
   561        for (auto &outEdge : oldOutEdges) {
   562          addEdge(dstId, outEdge.id, outEdge.value);
   563          removeEdge(sibId, outEdge.id, outEdge.value);
   564        }
   565      }
   566    }
   567  
   568    // Adds ops in 'loads' and 'stores' to node at 'id'.
   569    void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
   570                   const SmallVectorImpl<Operation *> &stores) {
   571      Node *node = getNode(id);
   572      for (auto *loadOpInst : loads)
   573        node->loads.push_back(loadOpInst);
   574      for (auto *storeOpInst : stores)
   575        node->stores.push_back(storeOpInst);
   576    }
   577  
   578    void clearNodeLoadAndStores(unsigned id) {
   579      Node *node = getNode(id);
   580      node->loads.clear();
   581      node->stores.clear();
   582    }
   583  
   584    // Calls 'callback' for each input edge incident to node 'id' which carries a
   585    // memref dependence.
   586    void forEachMemRefInputEdge(unsigned id,
   587                                const std::function<void(Edge)> &callback) {
   588      if (inEdges.count(id) > 0)
   589        forEachMemRefEdge(inEdges[id], callback);
   590    }
   591  
   592    // Calls 'callback' for each output edge from node 'id' which carries a
   593    // memref dependence.
   594    void forEachMemRefOutputEdge(unsigned id,
   595                                 const std::function<void(Edge)> &callback) {
   596      if (outEdges.count(id) > 0)
   597        forEachMemRefEdge(outEdges[id], callback);
   598    }
   599  
   600    // Calls 'callback' for each edge in 'edges' which carries a memref
   601    // dependence.
   602    void forEachMemRefEdge(ArrayRef<Edge> edges,
   603                           const std::function<void(Edge)> &callback) {
   604      for (auto &edge : edges) {
   605        // Skip if 'edge' is not a memref dependence edge.
   606        if (!edge.value->getType().isa<MemRefType>())
   607          continue;
   608        assert(nodes.count(edge.id) > 0);
   609        // Skip if 'edge.id' is not a loop nest.
   610        if (!isa<AffineForOp>(getNode(edge.id)->op))
   611          continue;
   612        // Visit current input edge 'edge'.
   613        callback(edge);
   614      }
   615    }
   616  
   617    void print(raw_ostream &os) const {
   618      os << "\nMemRefDependenceGraph\n";
   619      os << "\nNodes:\n";
   620      for (auto &idAndNode : nodes) {
   621        os << "Node: " << idAndNode.first << "\n";
   622        auto it = inEdges.find(idAndNode.first);
   623        if (it != inEdges.end()) {
   624          for (const auto &e : it->second)
   625            os << "  InEdge: " << e.id << " " << e.value << "\n";
   626        }
   627        it = outEdges.find(idAndNode.first);
   628        if (it != outEdges.end()) {
   629          for (const auto &e : it->second)
   630            os << "  OutEdge: " << e.id << " " << e.value << "\n";
   631        }
   632      }
   633    }
   634    void dump() const { print(llvm::errs()); }
   635  };
   636  
   637  // Intializes the data dependence graph by walking operations in 'f'.
   638  // Assigns each node in the graph a node id based on program order in 'f'.
   639  // TODO(andydavis) Add support for taking a Block arg to construct the
   640  // dependence graph at a different depth.
   641  bool MemRefDependenceGraph::init(FuncOp f) {
   642    DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
   643  
   644    // TODO: support multi-block functions.
   645    if (f.getBlocks().size() != 1)
   646      return false;
   647  
   648    DenseMap<Operation *, unsigned> forToNodeMap;
   649    for (auto &op : f.front()) {
   650      if (auto forOp = dyn_cast<AffineForOp>(op)) {
   651        // Create graph node 'id' to represent top-level 'forOp' and record
   652        // all loads and store accesses it contains.
   653        LoopNestStateCollector collector;
   654        collector.collect(&op);
   655        // Return false if a non 'affine.for' region was found (not currently
   656        // supported).
   657        if (collector.hasNonForRegion)
   658          return false;
   659        Node node(nextNodeId++, &op);
   660        for (auto *opInst : collector.loadOpInsts) {
   661          node.loads.push_back(opInst);
   662          auto *memref = cast<AffineLoadOp>(opInst).getMemRef();
   663          memrefAccesses[memref].insert(node.id);
   664        }
   665        for (auto *opInst : collector.storeOpInsts) {
   666          node.stores.push_back(opInst);
   667          auto *memref = cast<AffineStoreOp>(opInst).getMemRef();
   668          memrefAccesses[memref].insert(node.id);
   669        }
   670        forToNodeMap[&op] = node.id;
   671        nodes.insert({node.id, node});
   672      } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
   673        // Create graph node for top-level load op.
   674        Node node(nextNodeId++, &op);
   675        node.loads.push_back(&op);
   676        auto *memref = cast<AffineLoadOp>(op).getMemRef();
   677        memrefAccesses[memref].insert(node.id);
   678        nodes.insert({node.id, node});
   679      } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
   680        // Create graph node for top-level store op.
   681        Node node(nextNodeId++, &op);
   682        node.stores.push_back(&op);
   683        auto *memref = cast<AffineStoreOp>(op).getMemRef();
   684        memrefAccesses[memref].insert(node.id);
   685        nodes.insert({node.id, node});
   686      } else if (op.getNumRegions() != 0) {
   687        // Return false if another region is found (not currently supported).
   688        return false;
   689      } else if (op.getNumResults() > 0 && !op.use_empty()) {
   690        // Create graph node for top-level producer of SSA values, which
   691        // could be used by loop nest nodes.
   692        Node node(nextNodeId++, &op);
   693        nodes.insert({node.id, node});
   694      }
   695    }
   696  
   697    // Add dependence edges between nodes which produce SSA values and their
   698    // users.
   699    for (auto &idAndNode : nodes) {
   700      const Node &node = idAndNode.second;
   701      if (!node.loads.empty() || !node.stores.empty())
   702        continue;
   703      auto *opInst = node.op;
   704      for (auto *value : opInst->getResults()) {
   705        for (auto *user : value->getUsers()) {
   706          SmallVector<AffineForOp, 4> loops;
   707          getLoopIVs(*user, &loops);
   708          if (loops.empty())
   709            continue;
   710          assert(forToNodeMap.count(loops[0].getOperation()) > 0);
   711          unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
   712          addEdge(node.id, userLoopNestId, value);
   713        }
   714      }
   715    }
   716  
   717    // Walk memref access lists and add graph edges between dependent nodes.
   718    for (auto &memrefAndList : memrefAccesses) {
   719      unsigned n = memrefAndList.second.size();
   720      for (unsigned i = 0; i < n; ++i) {
   721        unsigned srcId = memrefAndList.second[i];
   722        bool srcHasStore =
   723            getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
   724        for (unsigned j = i + 1; j < n; ++j) {
   725          unsigned dstId = memrefAndList.second[j];
   726          bool dstHasStore =
   727              getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
   728          if (srcHasStore || dstHasStore)
   729            addEdge(srcId, dstId, memrefAndList.first);
   730        }
   731      }
   732    }
   733    return true;
   734  }
   735  
   736  // Removes load operations from 'srcLoads' which operate on 'memref', and
   737  // adds them to 'dstLoads'.
   738  static void moveLoadsAccessingMemrefTo(Value *memref,
   739                                         SmallVectorImpl<Operation *> *srcLoads,
   740                                         SmallVectorImpl<Operation *> *dstLoads) {
   741    dstLoads->clear();
   742    SmallVector<Operation *, 4> srcLoadsToKeep;
   743    for (auto *load : *srcLoads) {
   744      if (cast<AffineLoadOp>(load).getMemRef() == memref)
   745        dstLoads->push_back(load);
   746      else
   747        srcLoadsToKeep.push_back(load);
   748    }
   749    srcLoads->swap(srcLoadsToKeep);
   750  }
   751  
   752  // Returns the innermost common loop depth for the set of operations in 'ops'.
   753  static unsigned getInnermostCommonLoopDepth(ArrayRef<Operation *> ops) {
   754    unsigned numOps = ops.size();
   755    assert(numOps > 0);
   756  
   757    std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
   758    unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
   759    for (unsigned i = 0; i < numOps; ++i) {
   760      getLoopIVs(*ops[i], &loops[i]);
   761      loopDepthLimit =
   762          std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
   763    }
   764  
   765    unsigned loopDepth = 0;
   766    for (unsigned d = 0; d < loopDepthLimit; ++d) {
   767      unsigned i;
   768      for (i = 1; i < numOps; ++i) {
   769        if (loops[i - 1][d] != loops[i][d])
   770          break;
   771      }
   772      if (i != numOps)
   773        break;
   774      ++loopDepth;
   775    }
   776    return loopDepth;
   777  }
   778  
   779  // Returns the maximum loop depth at which no dependences between 'loadOpInsts'
   780  // and 'storeOpInsts' are satisfied.
   781  static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
   782                                  ArrayRef<Operation *> storeOpInsts) {
   783    // Merge loads and stores into the same array.
   784    SmallVector<Operation *, 2> ops(loadOpInsts.begin(), loadOpInsts.end());
   785    ops.append(storeOpInsts.begin(), storeOpInsts.end());
   786  
   787    // Compute the innermost common loop depth for loads and stores.
   788    unsigned loopDepth = getInnermostCommonLoopDepth(ops);
   789  
   790    // Return common loop depth for loads if there are no store ops.
   791    if (storeOpInsts.empty())
   792      return loopDepth;
   793  
   794    // Check dependences on all pairs of ops in 'ops' and store the minimum
   795    // loop depth at which a dependence is satisfied.
   796    for (unsigned i = 0, e = ops.size(); i < e; ++i) {
   797      auto *srcOpInst = ops[i];
   798      MemRefAccess srcAccess(srcOpInst);
   799      for (unsigned j = 0; j < e; ++j) {
   800        auto *dstOpInst = ops[j];
   801        MemRefAccess dstAccess(dstOpInst);
   802  
   803        unsigned numCommonLoops =
   804            getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
   805        for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
   806          FlatAffineConstraints dependenceConstraints;
   807          // TODO(andydavis) Cache dependence analysis results, check cache here.
   808          DependenceResult result = checkMemrefAccessDependence(
   809              srcAccess, dstAccess, d, &dependenceConstraints,
   810              /*dependenceComponents=*/nullptr);
   811          if (hasDependence(result)) {
   812            // Store minimum loop depth and break because we want the min 'd' at
   813            // which there is a dependence.
   814            loopDepth = std::min(loopDepth, d - 1);
   815            break;
   816          }
   817        }
   818      }
   819    }
   820    return loopDepth;
   821  }
   822  
   823  // Sinks all sequential loops to the innermost levels (while preserving
   824  // relative order among them) and moves all parallel loops to the
   825  // outermost (while again preserving relative order among them).
   826  // This can increase the loop depth at which we can fuse a slice, since we are
   827  // pushing loop carried dependence to a greater depth in the loop nest.
   828  static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   829    assert(isa<AffineForOp>(node->op));
   830    AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
   831    node->op = newRootForOp.getOperation();
   832  }
   833  
   834  //  TODO(mlir-team): improve/complete this when we have target data.
   835  unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
   836    auto elementType = memRefType.getElementType();
   837  
   838    unsigned sizeInBits;
   839    if (elementType.isIntOrFloat()) {
   840      sizeInBits = elementType.getIntOrFloatBitWidth();
   841    } else {
   842      auto vectorType = elementType.cast<VectorType>();
   843      sizeInBits =
   844          vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
   845    }
   846    return llvm::divideCeil(sizeInBits, 8);
   847  }
   848  
   849  // Creates and returns a private (single-user) memref for fused loop rooted
   850  // at 'forOp', with (potentially reduced) memref size based on the
   851  // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
   852  // TODO(bondhugula): consider refactoring the common code from generateDma and
   853  // this one.
   854  static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   855                                    unsigned dstLoopDepth,
   856                                    Optional<unsigned> fastMemorySpace,
   857                                    uint64_t localBufSizeThreshold) {
   858    auto *forInst = forOp.getOperation();
   859  
   860    // Create builder to insert alloc op just before 'forOp'.
   861    OpBuilder b(forInst);
   862    // Builder to create constants at the top level.
   863    OpBuilder top(forInst->getParentOfType<FuncOp>().getBody());
   864    // Create new memref type based on slice bounds.
   865    auto *oldMemRef = cast<AffineStoreOp>(srcStoreOpInst).getMemRef();
   866    auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
   867    unsigned rank = oldMemRefType.getRank();
   868  
   869    // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
   870    MemRefRegion region(srcStoreOpInst->getLoc());
   871    bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
   872    (void)validRegion;
   873    assert(validRegion && "unexpected memref region failure");
   874    SmallVector<int64_t, 4> newShape;
   875    std::vector<SmallVector<int64_t, 4>> lbs;
   876    SmallVector<int64_t, 8> lbDivisors;
   877    lbs.reserve(rank);
   878    // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
   879    // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
   880    Optional<int64_t> numElements =
   881        region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
   882    assert(numElements.hasValue() &&
   883           "non-constant number of elts in local buffer");
   884  
   885    const FlatAffineConstraints *cst = region.getConstraints();
   886    // 'outerIVs' holds the values that this memory region is symbolic/paramteric
   887    // on; this would correspond to loop IVs surrounding the level at which the
   888    // slice is being materialized.
   889    SmallVector<Value *, 8> outerIVs;
   890    cst->getIdValues(rank, cst->getNumIds(), &outerIVs);
   891  
   892    // Build 'rank' AffineExprs from MemRefRegion 'lbs'
   893    SmallVector<AffineExpr, 4> offsets;
   894    offsets.reserve(rank);
   895    for (unsigned d = 0; d < rank; ++d) {
   896      assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
   897  
   898      AffineExpr offset = top.getAffineConstantExpr(0);
   899      for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
   900        offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
   901      }
   902      assert(lbDivisors[d] > 0);
   903      offset =
   904          (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
   905      offsets.push_back(offset);
   906    }
   907  
   908    // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
   909    // by 'srcStoreOpInst'.
   910    uint64_t bufSize =
   911        getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue();
   912    unsigned newMemSpace;
   913    if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) {
   914      newMemSpace = fastMemorySpace.getValue();
   915    } else {
   916      newMemSpace = oldMemRefType.getMemorySpace();
   917    }
   918    auto newMemRefType = top.getMemRefType(
   919        newShape, oldMemRefType.getElementType(), {}, newMemSpace);
   920    // Gather alloc operands for the dynamic dimensions of the memref.
   921    SmallVector<Value *, 4> allocOperands;
   922    unsigned dynamicDimCount = 0;
   923    for (auto dimSize : oldMemRefType.getShape()) {
   924      if (dimSize == -1)
   925        allocOperands.push_back(
   926            top.create<DimOp>(forOp.getLoc(), oldMemRef, dynamicDimCount++));
   927    }
   928  
   929    // Create new private memref for fused loop 'forOp'.
   930    // TODO(andydavis) Create/move alloc ops for private memrefs closer to their
   931    // consumer loop nests to reduce their live range. Currently they are added
   932    // at the beginning of the function, because loop nests can be reordered
   933    // during the fusion pass.
   934    Value *newMemRef =
   935        top.create<AllocOp>(forOp.getLoc(), newMemRefType, allocOperands);
   936  
   937    // Build an AffineMap to remap access functions based on lower bound offsets.
   938    SmallVector<AffineExpr, 4> remapExprs;
   939    remapExprs.reserve(rank);
   940    unsigned zeroOffsetCount = 0;
   941    for (unsigned i = 0; i < rank; i++) {
   942      if (auto constExpr = offsets[i].dyn_cast<AffineConstantExpr>())
   943        if (constExpr.getValue() == 0)
   944          ++zeroOffsetCount;
   945      auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);
   946  
   947      auto remapExpr =
   948          simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
   949      remapExprs.push_back(remapExpr);
   950    }
   951    auto indexRemap = zeroOffsetCount == rank
   952                          ? AffineMap()
   953                          : b.getAffineMap(outerIVs.size() + rank, 0, remapExprs);
   954    // Replace all users of 'oldMemRef' with 'newMemRef'.
   955    LogicalResult res =
   956        replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
   957                                 /*extraOperands=*/outerIVs,
   958                                 /*domInstFilter=*/&*forOp.getBody()->begin());
   959    assert(succeeded(res) &&
   960           "replaceAllMemrefUsesWith should always succeed here");
   961    (void)res;
   962    return newMemRef;
   963  }
   964  
   965  // Checks if node 'srcId' (which writes to a live out memref), can be safely
   966  // fused into node 'dstId'. Returns true if the following conditions are met:
   967  // *) 'srcNode' only writes to live out 'memref'.
   968  // *) 'srcNode' has exactly one output edge on 'memref' (which is to 'dstId').
   969  // *) 'dstNode's read/write region to 'memref' is a super set of 'srcNode's
   970  //    write region to 'memref'.
   971  // TODO(andydavis) Generalize this to handle more live in/out cases.
   972  static bool canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
   973                                             Value *memref,
   974                                             MemRefDependenceGraph *mdg) {
   975    auto *srcNode = mdg->getNode(srcId);
   976    auto *dstNode = mdg->getNode(dstId);
   977  
   978    // Gather all memrefs from 'srcNode' store ops.
   979    DenseSet<Value *> storeMemrefs;
   980    for (auto *storeOpInst : srcNode->stores) {
   981      storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
   982    }
   983    // Return false if any of the following are true:
   984    // *) 'srcNode' writes to a live in/out memref other than 'memref'.
   985    // *) 'srcNode' has more than one output edge on 'memref'.
   986    // Check that all stores are to the same memref.
   987    if (storeMemrefs.size() != 1 ||
   988        mdg->getOutEdgeCount(srcNode->id, memref) != 1)
   989      return false;
   990    // Compute MemRefRegion 'srcWriteRegion' for 'srcStoreOpInst' on 'memref'.
   991    auto *srcStoreOpInst = srcNode->stores.front();
   992    MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
   993    if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
   994      LLVM_DEBUG(llvm::dbgs()
   995                 << "Unable to compute MemRefRegion for source operation\n.");
   996      return false;
   997    }
   998    SmallVector<int64_t, 4> srcShape;
   999    // Query 'srcWriteRegion' for 'srcShape' and 'srcNumElements'.
  1000    // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
  1001    Optional<int64_t> srcNumElements =
  1002        srcWriteRegion.getConstantBoundingSizeAndShape(&srcShape);
  1003    if (!srcNumElements.hasValue())
  1004      return false;
  1005  
  1006    // Compute MemRefRegion 'dstRegion' for 'dstStore/LoadOpInst' on 'memref'.
  1007    // TODO(andydavis) Compute 'unionboundingbox' of all write regions (one for
  1008    // each store op in 'dstStoreOps').
  1009    SmallVector<Operation *, 2> dstStoreOps;
  1010    dstNode->getStoreOpsForMemref(memref, &dstStoreOps);
  1011    SmallVector<Operation *, 2> dstLoadOps;
  1012    dstNode->getLoadOpsForMemref(memref, &dstLoadOps);
  1013  
  1014    auto *dstOpInst = dstStoreOps.empty() ? dstLoadOps[0] : dstStoreOps[0];
  1015    MemRefRegion dstRegion(dstOpInst->getLoc());
  1016    if (failed(dstRegion.compute(dstOpInst, /*loopDepth=*/0))) {
  1017      LLVM_DEBUG(llvm::dbgs()
  1018                 << "Unable to compute MemRefRegion for dest operation\n.");
  1019      return false;
  1020    }
  1021    SmallVector<int64_t, 4> dstShape;
  1022    // Query 'dstRegion' for 'dstShape' and 'dstNumElements'.
  1023    // by 'dstOpInst' at depth 'dstLoopDepth'.
  1024    Optional<int64_t> dstNumElements =
  1025        dstRegion.getConstantBoundingSizeAndShape(&dstShape);
  1026    if (!dstNumElements.hasValue())
  1027      return false;
  1028  
  1029    // Return false if write region is not a superset of 'srcNodes' write
  1030    // region to 'memref'.
  1031    // TODO(andydavis) Check the shape and lower bounds here too.
  1032    if (srcNumElements != dstNumElements)
  1033      return false;
  1034    return true;
  1035  }
  1036  
  1037  // Checks the profitability of fusing a backwards slice of the loop nest
  1038  // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
  1039  // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
  1040  // the memref being produced and consumed, which is an input to the cost model.
  1041  // For producer-constumer fusion, 'srcStoreOpInst' will be the same as
  1042  // 'srcOpInst', as we are slicing w.r.t to that producer.
  1043  // For input-reuse fusion, 'srcOpInst' will be the src loop nest LoadOp which
  1044  // reads from the same memref as dst loop nest load ops, and 'srcStoreOpInst'
  1045  // will be the unique store op in the src node, which will be used to check
  1046  // that the write region is the same after input-reuse fusion.
  1047  // Returns true if it is profitable to fuse the candidate loop nests. Returns
  1048  // false otherwise. `dstLoopDepth` is set to the most profitable depth at which
  1049  // to materialize the source loop nest slice.
  1050  // The profitability model executes the following steps:
  1051  // *) Computes the backward computation slice at 'srcOpInst'. This
  1052  //    computation slice of the loop nest surrounding 'srcOpInst' is
  1053  //    represented by modified src loop bounds in 'sliceState', which are
  1054  //    functions of loop IVs in the loop nest surrounding 'srcOpInst'.
  1055  // *) Computes the cost of unfused src/dst loop nests (currently the cost of a
  1056  //    loop nest is the total number of dynamic operation instances in the loop
  1057  //    nest).
  1058  // *) Computes the cost of fusing a slice of the src loop nest into the dst
  1059  //    loop nest at various values of dst loop depth, attempting to fuse
  1060  //    the largest compution slice at the maximal dst loop depth (closest to the
  1061  //    load) to minimize reuse distance and potentially enable subsequent
  1062  //    load/store forwarding.
  1063  //    NOTE: If the dst loop nest includes multiple loads in 'dstLoadOpInsts' for
  1064  //    the same memref as is written by 'srcOpInst', then the union of slice
  1065  //    loop bounds is used to compute the slice and associated slice cost.
  1066  //    NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
  1067  //    nest, at which the src computation slice is inserted/fused.
  1068  //    NOTE: We attempt to maximize the dst loop depth, but there are cases
  1069  //    where a particular setting for 'dstLoopNest' might fuse an unsliced
  1070  //    loop (within the src computation slice) at a depth which results in
  1071  //    execessive recomputation (see unit tests for examples).
  1072  // *) Compares the total cost of the unfused loop nests to the min cost fused
  1073  //    loop nest computed in the previous step, and returns true if the latter
  1074  //    is lower.
  1075  static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
  1076                                 ArrayRef<Operation *> dstLoadOpInsts,
  1077                                 ArrayRef<Operation *> dstStoreOpInsts,
  1078                                 ComputationSliceState *sliceState,
  1079                                 unsigned *dstLoopDepth, bool maximalFusion) {
  1080    LLVM_DEBUG({
  1081      llvm::dbgs() << "Checking whether fusion is profitable between:\n";
  1082      llvm::dbgs() << " " << *srcOpInst << " and \n";
  1083      for (auto dstOpInst : dstLoadOpInsts) {
  1084        llvm::dbgs() << " " << *dstOpInst << "\n";
  1085      };
  1086    });
  1087  
  1088    // Compute cost of sliced and unsliced src loop nest.
  1089    SmallVector<AffineForOp, 4> srcLoopIVs;
  1090    getLoopIVs(*srcOpInst, &srcLoopIVs);
  1091    unsigned numSrcLoopIVs = srcLoopIVs.size();
  1092  
  1093    // Walk src loop nest and collect stats.
  1094    LoopNestStats srcLoopNestStats;
  1095    if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
  1096      return false;
  1097  
  1098    // Compute cost of dst loop nest.
  1099    SmallVector<AffineForOp, 4> dstLoopIVs;
  1100    getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs);
  1101  
  1102    LoopNestStats dstLoopNestStats;
  1103    if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats))
  1104      return false;
  1105  
  1106    // Compute the maximum loop depth at which we can can insert the src slice
  1107    // and still satisfy dest loop nest dependences, for producer-consumer fusion.
  1108    unsigned maxDstLoopDepth =
  1109        (srcOpInst == srcStoreOpInst)
  1110            ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts)
  1111            : dstLoopIVs.size();
  1112    if (maxDstLoopDepth == 0) {
  1113      LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n");
  1114      return false;
  1115    }
  1116  
  1117    // Search for min cost value for 'dstLoopDepth'. At each value of
  1118    // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice
  1119    // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
  1120    // of these bounds). Next the union slice bounds are used to calculate
  1121    // the cost of the slice and the cost of the slice inserted into the dst
  1122    // loop nest at 'dstLoopDepth'.
  1123    uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
  1124    double maxStorageReduction = 0.0;
  1125    Optional<uint64_t> sliceMemEstimate = None;
  1126  
  1127    SmallVector<ComputationSliceState, 4> sliceStates;
  1128    sliceStates.resize(maxDstLoopDepth);
  1129    // The best loop depth at which to materialize the slice.
  1130    Optional<unsigned> bestDstLoopDepth = None;
  1131  
  1132    // Compute op instance count for the src loop nest without iteration slicing.
  1133    uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
  1134  
  1135    // Compute src loop nest write region size.
  1136    MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
  1137    if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
  1138      LLVM_DEBUG(llvm::dbgs()
  1139                 << "Unable to compute MemRefRegion for source operation\n.");
  1140      return false;
  1141    }
  1142  
  1143    Optional<int64_t> maybeSrcWriteRegionSizeBytes =
  1144        srcWriteRegion.getRegionSize();
  1145    if (!maybeSrcWriteRegionSizeBytes.hasValue())
  1146      return false;
  1147    int64_t srcWriteRegionSizeBytes = maybeSrcWriteRegionSizeBytes.getValue();
  1148  
  1149    // Compute op instance count for the src loop nest.
  1150    uint64_t dstLoopNestCost = getComputeCost(dstLoopIVs[0], dstLoopNestStats);
  1151  
  1152    // Evaluate all depth choices for materializing the slice in the destination
  1153    // loop nest.
  1154    for (unsigned i = maxDstLoopDepth; i >= 1; --i) {
  1155      // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'.
  1156      if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts,
  1157                                         /*loopDepth=*/i,
  1158                                         /*numCommonLoops=*/0,
  1159                                         /*isBackwardSlice=*/true,
  1160                                         &sliceStates[i - 1]))) {
  1161        LLVM_DEBUG(llvm::dbgs()
  1162                   << "computeSliceUnion failed for loopDepth: " << i << "\n");
  1163        continue;
  1164      }
  1165  
  1166      int64_t fusedLoopNestComputeCost;
  1167      if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0],
  1168                                dstLoopNestStats, &sliceStates[i - 1],
  1169                                &fusedLoopNestComputeCost)) {
  1170        LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
  1171        continue;
  1172      }
  1173  
  1174      double additionalComputeFraction =
  1175          fusedLoopNestComputeCost /
  1176              (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
  1177          1;
  1178  
  1179      // Determine what the slice write MemRefRegion would be, if the src loop
  1180      // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop
  1181      // nest at loop depth 'i'
  1182      MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
  1183      if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
  1184                                          &sliceStates[i - 1]))) {
  1185        LLVM_DEBUG(llvm::dbgs()
  1186                   << "Failed to compute slice write region at loopDepth: " << i
  1187                   << "\n");
  1188        continue;
  1189      }
  1190  
  1191      Optional<int64_t> maybeSliceWriteRegionSizeBytes =
  1192          sliceWriteRegion.getRegionSize();
  1193      if (!maybeSliceWriteRegionSizeBytes.hasValue() ||
  1194          maybeSliceWriteRegionSizeBytes.getValue() == 0) {
  1195        LLVM_DEBUG(llvm::dbgs()
  1196                   << "Failed to get slice write region size at loopDepth: " << i
  1197                   << "\n");
  1198        continue;
  1199      }
  1200      int64_t sliceWriteRegionSizeBytes =
  1201          maybeSliceWriteRegionSizeBytes.getValue();
  1202  
  1203      // If we are fusing for reuse, check that write regions remain the same.
  1204      // TODO(andydavis) Write region check should check sizes and offsets in
  1205      // each dimension, so that we are sure they are covering the same memref
  1206      // region. Also, move this out to a isMemRefRegionSuperSet helper function.
  1207      if (srcOpInst != srcStoreOpInst &&
  1208          sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
  1209        continue;
  1210  
  1211      double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
  1212                                static_cast<double>(sliceWriteRegionSizeBytes);
  1213  
  1214      LLVM_DEBUG({
  1215        std::stringstream msg;
  1216        msg << "  evaluating fusion profitability at depth : " << i << "\n"
  1217            << std::fixed << std::setprecision(2)
  1218            << "   additional compute fraction: "
  1219            << 100.0 * additionalComputeFraction << "%\n"
  1220            << "   storage reduction factor: " << storageReduction << "x\n"
  1221            << "   fused nest cost: " << fusedLoopNestComputeCost << "\n"
  1222            << "   src write region size: " << srcWriteRegionSizeBytes << "\n"
  1223            << "   slice write region size: " << sliceWriteRegionSizeBytes
  1224            << "\n";
  1225        llvm::dbgs() << msg.str();
  1226      });
  1227  
  1228      double computeToleranceThreshold =
  1229          clFusionAddlComputeTolerance.getNumOccurrences() > 0
  1230              ? clFusionAddlComputeTolerance
  1231              : LoopFusion::kComputeToleranceThreshold;
  1232  
  1233      // TODO(b/123247369): This is a placeholder cost model.
  1234      // Among all choices that add an acceptable amount of redundant computation
  1235      // (as per computeToleranceThreshold), we will simply pick the one that
  1236      // reduces the intermediary size the most.
  1237      if ((storageReduction > maxStorageReduction) &&
  1238          (maximalFusion ||
  1239           (additionalComputeFraction < computeToleranceThreshold))) {
  1240        maxStorageReduction = storageReduction;
  1241        bestDstLoopDepth = i;
  1242        minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
  1243        sliceMemEstimate = sliceWriteRegionSizeBytes;
  1244      }
  1245    }
  1246  
  1247    // A simple cost model: fuse if it reduces the memory footprint. If
  1248    // -maximal-fusion is set, fuse nevertheless.
  1249  
  1250    if (!maximalFusion && !bestDstLoopDepth.hasValue()) {
  1251      LLVM_DEBUG(
  1252          llvm::dbgs()
  1253          << "All fusion choices involve more than the threshold amount of "
  1254             "redundant computation; NOT fusing.\n");
  1255      return false;
  1256    }
  1257  
  1258    if (!bestDstLoopDepth.hasValue()) {
  1259      LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
  1260      return false;
  1261    }
  1262  
  1263    // Set dstLoopDepth based on best values from search.
  1264    *dstLoopDepth = bestDstLoopDepth.getValue();
  1265  
  1266    LLVM_DEBUG(
  1267        llvm::dbgs() << " LoopFusion fusion stats:"
  1268                     << "\n  best loop depth: " << bestDstLoopDepth
  1269                     << "\n  src loop nest compute cost: " << srcLoopNestCost
  1270                     << "\n  dst loop nest compute cost: " << dstLoopNestCost
  1271                     << "\n  fused loop nest compute cost: "
  1272                     << minFusedLoopNestComputeCost << "\n");
  1273  
  1274    auto dstMemSize = getMemoryFootprintBytes(dstLoopIVs[0]);
  1275    auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
  1276  
  1277    Optional<double> storageReduction = None;
  1278  
  1279    if (!maximalFusion) {
  1280      if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) {
  1281        LLVM_DEBUG(
  1282            llvm::dbgs()
  1283            << "  fusion memory benefit cannot be evaluated; NOT fusing.\n");
  1284        return false;
  1285      }
  1286  
  1287      auto srcMemSizeVal = srcMemSize.getValue();
  1288      auto dstMemSizeVal = dstMemSize.getValue();
  1289  
  1290      assert(sliceMemEstimate.hasValue() && "expected value");
  1291      auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue();
  1292  
  1293      LLVM_DEBUG(llvm::dbgs() << "   src mem: " << srcMemSizeVal << "\n"
  1294                              << "   dst mem: " << dstMemSizeVal << "\n"
  1295                              << "   fused mem: " << fusedMem << "\n"
  1296                              << "   slice mem: " << sliceMemEstimate << "\n");
  1297  
  1298      if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
  1299        LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
  1300        return false;
  1301      }
  1302      storageReduction =
  1303          100.0 *
  1304          (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));
  1305    }
  1306  
  1307    double additionalComputeFraction =
  1308        100.0 * (minFusedLoopNestComputeCost /
  1309                     (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
  1310                 1);
  1311    (void)additionalComputeFraction;
  1312    LLVM_DEBUG({
  1313      std::stringstream msg;
  1314      msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
  1315          << std::setprecision(2) << additionalComputeFraction
  1316          << "% redundant computation and a ";
  1317      msg << (storageReduction.hasValue()
  1318                  ? std::to_string(storageReduction.getValue())
  1319                  : "<unknown>");
  1320      msg << "% storage reduction.\n";
  1321      llvm::dbgs() << msg.str();
  1322    });
  1323  
  1324    // Update return parameter 'sliceState' with 'bestSliceState'.
  1325    ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1];
  1326    sliceState->lbs = bestSliceState->lbs;
  1327    sliceState->ubs = bestSliceState->ubs;
  1328    sliceState->lbOperands = bestSliceState->lbOperands;
  1329    sliceState->ubOperands = bestSliceState->ubOperands;
  1330  
  1331    // Canonicalize slice bound affine maps.
  1332    for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
  1333      if (sliceState->lbs[i] != AffineMap()) {
  1334        canonicalizeMapAndOperands(&sliceState->lbs[i],
  1335                                   &sliceState->lbOperands[i]);
  1336      }
  1337      if (sliceState->ubs[i] != AffineMap()) {
  1338        canonicalizeMapAndOperands(&sliceState->ubs[i],
  1339                                   &sliceState->ubOperands[i]);
  1340      }
  1341    }
  1342    return true;
  1343  }
  1344  
  1345  // GreedyFusion greedily fuses loop nests which have a producer/consumer or
  1346  // input-reuse relationship on a memref, with the goal of improving locality.
  1347  //
  1348  // The steps of the producer-consumer fusion algorithm are as follows:
  1349  //
  1350  // *) A worklist is initialized with node ids from the dependence graph.
  1351  // *) For each node id in the worklist:
  1352  //   *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
  1353  //      candidate destination AffineForOp into which fusion will be attempted.
  1354  //   *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
  1355  //   *) For each LoadOp in 'dstLoadOps' do:
  1356  //      *) Look up dependent loop nests which have a single store op to the same
  1357  //         memref.
  1358  //      *) Check if dependences would be violated by the fusion.
  1359  //      *) Get a computation slice of 'srcLoopNest', which adjusts its loop
  1360  //         bounds to be functions of 'dstLoopNest' IVs and symbols.
  1361  //      *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
  1362  //         at a loop depth determined by the cost model in 'isFusionProfitable'.
  1363  //      *) Add the newly fused load/store operations to the state,
  1364  //         and also add newly fused load ops to 'dstLoopOps' to be considered
  1365  //         as fusion dst load ops in another iteration.
  1366  //      *) Remove old src loop nest and its associated state.
  1367  //
  1368  // The steps of the input-reuse fusion algorithm are as follows:
  1369  //
  1370  // *) Initialize 'worklist' with node ids from the dependence graph.
  1371  // *) For each 'dstNode' in the worklist:
  1372  //   *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
  1373  //      loads from the same memref, but which has no dependence paths to/from.
  1374  //   *) Get a computation slice of 'sibLoopNest', which adjusts its loop
  1375  //      bounds to be functions of 'dstLoopNest' IVs and symbols.
  1376  //   *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
  1377  //      at a loop depth determined by the cost model in 'isFusionProfitable'.
  1378  //      This function also checks that the memref write region of 'sibLoopNest',
  1379  //      is preserved in the fused loop nest.
  1380  //   *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
  1381  //
  1382  // Given a graph where top-level operations are vertices in the set 'V' and
  1383  // edges in the set 'E' are dependences between vertices, this algorithm
  1384  // takes O(V) time for initialization, and has runtime O(V + E).
  1385  //
  1386  // This greedy algorithm is not 'maximal' due to the current restriction of
  1387  // fusing along single producer consumer edges, but there is a TODO to fix this.
  1388  //
  1389  // TODO(andydavis) Experiment with other fusion policies.
  1390  struct GreedyFusion {
  1391  public:
  1392    // The data dependence graph to traverse during fusion.
  1393    MemRefDependenceGraph *mdg;
  1394    // Worklist of graph nodes visited during the fusion pass.
  1395    SmallVector<unsigned, 8> worklist;
  1396    // Set of graph nodes which are present on the worklist.
  1397    llvm::SmallDenseSet<unsigned, 16> worklistSet;
  1398    // Parameter for local buffer size threshold.
  1399    unsigned localBufSizeThreshold;
  1400    // Parameter for fast memory space.
  1401    Optional<unsigned> fastMemorySpace;
  1402    // If true, ignore any additional (redundant) computation tolerance threshold
  1403    // that would have prevented fusion.
  1404    bool maximalFusion;
  1405  
  1406    using Node = MemRefDependenceGraph::Node;
  1407  
  1408    GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
  1409                 Optional<unsigned> fastMemorySpace, bool maximalFusion)
  1410        : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
  1411          fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion) {}
  1412  
  1413    // Initializes 'worklist' with nodes from 'mdg'
  1414    void init() {
  1415      // TODO(andydavis) Add a priority queue for prioritizing nodes by different
  1416      // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
  1417      worklist.clear();
  1418      worklistSet.clear();
  1419      for (auto &idAndNode : mdg->nodes) {
  1420        const Node &node = idAndNode.second;
  1421        worklist.push_back(node.id);
  1422        worklistSet.insert(node.id);
  1423      }
  1424    }
  1425  
  1426    // Run the GreedyFusion pass.
  1427    // *) First pass through the nodes fuses single-use producer nodes into their
  1428    //    unique consumer.
  1429    // *) Second pass fuses sibling nodes which share no dependence edges.
  1430    // *) Third pass fuses any remaining producer nodes into their users.
  1431    void run() {
  1432      // TODO(andydavis) Run this repeatedly until a fixed-point is reached.
  1433      fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
  1434      fuseSiblingNodes();
  1435      fuseProducerConsumerNodes(
  1436          /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
  1437      eraseUnusedMemRefAllocations();
  1438    }
  1439  
  1440    void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
  1441      init();
  1442      while (!worklist.empty()) {
  1443        unsigned dstId = worklist.back();
  1444        worklist.pop_back();
  1445        worklistSet.erase(dstId);
  1446  
  1447        // Skip if this node was removed (fused into another node).
  1448        if (mdg->nodes.count(dstId) == 0)
  1449          continue;
  1450        // Get 'dstNode' into which to attempt fusion.
  1451        auto *dstNode = mdg->getNode(dstId);
  1452        // Skip if 'dstNode' is not a loop nest.
  1453        if (!isa<AffineForOp>(dstNode->op))
  1454          continue;
  1455        // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
  1456        // while preserving relative order. This can increase the maximum loop
  1457        // depth at which we can fuse a slice of a producer loop nest into a
  1458        // consumer loop nest.
  1459        sinkSequentialLoops(dstNode);
  1460  
  1461        SmallVector<Operation *, 4> loads = dstNode->loads;
  1462        SmallVector<Operation *, 4> dstLoadOpInsts;
  1463        DenseSet<Value *> visitedMemrefs;
  1464        while (!loads.empty()) {
  1465          // Get memref of load on top of the stack.
  1466          auto *memref = cast<AffineLoadOp>(loads.back()).getMemRef();
  1467          if (visitedMemrefs.count(memref) > 0)
  1468            continue;
  1469          visitedMemrefs.insert(memref);
  1470          // Move all loads in 'loads' accessing 'memref' to 'dstLoadOpInsts'.
  1471          moveLoadsAccessingMemrefTo(memref, &loads, &dstLoadOpInsts);
  1472          // Skip if no input edges along which to fuse.
  1473          if (mdg->inEdges.count(dstId) == 0)
  1474            continue;
  1475          // Iterate through in-edges for 'dstId' and src node id for any
  1476          // edges on 'memref'.
  1477          SmallVector<unsigned, 2> srcNodeIds;
  1478          for (auto &srcEdge : mdg->inEdges[dstId]) {
  1479            // Skip 'srcEdge' if not for 'memref'.
  1480            if (srcEdge.value != memref)
  1481              continue;
  1482            srcNodeIds.push_back(srcEdge.id);
  1483          }
  1484          for (unsigned srcId : srcNodeIds) {
  1485            // Skip if this node was removed (fused into another node).
  1486            if (mdg->nodes.count(srcId) == 0)
  1487              continue;
  1488            // Get 'srcNode' from which to attempt fusion into 'dstNode'.
  1489            auto *srcNode = mdg->getNode(srcId);
  1490            // Skip if 'srcNode' is not a loop nest.
  1491            if (!isa<AffineForOp>(srcNode->op))
  1492              continue;
  1493            // Skip if 'srcNode' has more than one store to any memref.
  1494            // TODO(andydavis) Support fusing multi-output src loop nests.
  1495            if (srcNode->stores.size() != 1)
  1496              continue;
  1497  
  1498            // Skip if 'srcNode' writes to any live in or escaping memrefs,
  1499            // and cannot be fused.
  1500            bool writesToLiveInOrOut =
  1501                mdg->writesToLiveInOrEscapingMemrefs(srcNode->id);
  1502            if (writesToLiveInOrOut &&
  1503                !canFuseSrcWhichWritesToLiveOut(srcId, dstId, memref, mdg))
  1504              continue;
  1505  
  1506            // Skip if 'srcNode' out edge count on 'memref' > 'maxSrcUserCount'.
  1507            if (mdg->getOutEdgeCount(srcNode->id, memref) > maxSrcUserCount)
  1508              continue;
  1509  
  1510            // Compute an operation list insertion point for the fused loop
  1511            // nest which preserves dependences.
  1512            Operation *insertPointInst =
  1513                mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
  1514            if (insertPointInst == nullptr)
  1515              continue;
  1516  
  1517            // Get unique 'srcNode' store op.
  1518            auto *srcStoreOpInst = srcNode->stores.front();
  1519            // Gather 'dstNode' store ops to 'memref'.
  1520            SmallVector<Operation *, 2> dstStoreOpInsts;
  1521            for (auto *storeOpInst : dstNode->stores)
  1522              if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
  1523                dstStoreOpInsts.push_back(storeOpInst);
  1524  
  1525            unsigned bestDstLoopDepth;
  1526            mlir::ComputationSliceState sliceState;
  1527            // Check if fusion would be profitable.
  1528            if (!isFusionProfitable(srcStoreOpInst, srcStoreOpInst,
  1529                                    dstLoadOpInsts, dstStoreOpInsts, &sliceState,
  1530                                    &bestDstLoopDepth, maximalFusion))
  1531              continue;
  1532            // TODO(andydavis) Remove the following test code when canFuseLoops
  1533            // is fully functional.
  1534            mlir::ComputationSliceState sliceUnion;
  1535            if (!maximalFusion) {
  1536              FusionResult result = mlir::canFuseLoops(
  1537                  cast<AffineForOp>(srcNode->op), cast<AffineForOp>(dstNode->op),
  1538                  bestDstLoopDepth, &sliceUnion);
  1539              assert(result.value == FusionResult::Success);
  1540              (void)result;
  1541            }
  1542            // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
  1543            auto sliceLoopNest = mlir::insertBackwardComputationSlice(
  1544                srcStoreOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
  1545            if (sliceLoopNest) {
  1546              LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n"
  1547                                      << *sliceLoopNest.getOperation() << "\n");
  1548              // Move 'dstAffineForOp' before 'insertPointInst' if needed.
  1549              auto dstAffineForOp = cast<AffineForOp>(dstNode->op);
  1550              if (insertPointInst != dstAffineForOp.getOperation()) {
  1551                dstAffineForOp.getOperation()->moveBefore(insertPointInst);
  1552              }
  1553              // Update edges between 'srcNode' and 'dstNode'.
  1554              mdg->updateEdges(srcNode->id, dstNode->id, memref);
  1555  
  1556              // Collect slice loop stats.
  1557              LoopNestStateCollector sliceCollector;
  1558              sliceCollector.collect(sliceLoopNest.getOperation());
  1559              // Promote single iteration slice loops to single IV value.
  1560              for (auto forOp : sliceCollector.forOps) {
  1561                promoteIfSingleIteration(forOp);
  1562              }
  1563              if (!writesToLiveInOrOut) {
  1564                // Create private memref for 'memref' in 'dstAffineForOp'.
  1565                SmallVector<Operation *, 4> storesForMemref;
  1566                for (auto *storeOpInst : sliceCollector.storeOpInsts) {
  1567                  if (cast<AffineStoreOp>(storeOpInst).getMemRef() == memref)
  1568                    storesForMemref.push_back(storeOpInst);
  1569                }
  1570                assert(storesForMemref.size() == 1);
  1571                auto *newMemRef = createPrivateMemRef(
  1572                    dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
  1573                    fastMemorySpace, localBufSizeThreshold);
  1574                visitedMemrefs.insert(newMemRef);
  1575                // Create new node in dependence graph for 'newMemRef' alloc op.
  1576                unsigned newMemRefNodeId =
  1577                    mdg->addNode(newMemRef->getDefiningOp());
  1578                // Add edge from 'newMemRef' node to dstNode.
  1579                mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
  1580              }
  1581  
  1582              // Collect dst loop stats after memref privatizaton transformation.
  1583              LoopNestStateCollector dstLoopCollector;
  1584              dstLoopCollector.collect(dstAffineForOp.getOperation());
  1585  
  1586              // Add new load ops to current Node load op list 'loads' to
  1587              // continue fusing based on new operands.
  1588              for (auto *loadOpInst : dstLoopCollector.loadOpInsts) {
  1589                auto *loadMemRef = cast<AffineLoadOp>(loadOpInst).getMemRef();
  1590                if (visitedMemrefs.count(loadMemRef) == 0)
  1591                  loads.push_back(loadOpInst);
  1592              }
  1593  
  1594              // Clear and add back loads and stores.
  1595              mdg->clearNodeLoadAndStores(dstNode->id);
  1596              mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
  1597                             dstLoopCollector.storeOpInsts);
  1598              // Remove old src loop nest if it no longer has outgoing dependence
  1599              // edges, and if it does not write to a memref which escapes the
  1600              // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has
  1601              // been fused into 'dstNode' and write region of 'dstNode' covers
  1602              // the write region of 'srcNode', and 'srcNode' has no other users
  1603              // so it is safe to remove.
  1604              if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) {
  1605                mdg->removeNode(srcNode->id);
  1606                srcNode->op->erase();
  1607              } else {
  1608                // Add remaining users of 'oldMemRef' back on the worklist (if not
  1609                // already there), as its replacement with a local/private memref
  1610                // has reduced dependences on 'oldMemRef' which may have created
  1611                // new fusion opportunities.
  1612                if (mdg->outEdges.count(srcNode->id) > 0) {
  1613                  SmallVector<MemRefDependenceGraph::Edge, 2> oldOutEdges =
  1614                      mdg->outEdges[srcNode->id];
  1615                  for (auto &outEdge : oldOutEdges) {
  1616                    if (outEdge.value == memref &&
  1617                        worklistSet.count(outEdge.id) == 0) {
  1618                      worklist.push_back(outEdge.id);
  1619                      worklistSet.insert(outEdge.id);
  1620                    }
  1621                  }
  1622                }
  1623              }
  1624            }
  1625          }
  1626        }
  1627      }
  1628    }
  1629  
  1630    // Visits each node in the graph, and for each node, attempts to fuse it with
  1631    // its sibling nodes (nodes which share a parent, but no dependence edges).
  1632    void fuseSiblingNodes() {
  1633      init();
  1634      while (!worklist.empty()) {
  1635        unsigned dstId = worklist.back();
  1636        worklist.pop_back();
  1637        worklistSet.erase(dstId);
  1638  
  1639        // Skip if this node was removed (fused into another node).
  1640        if (mdg->nodes.count(dstId) == 0)
  1641          continue;
  1642        // Get 'dstNode' into which to attempt fusion.
  1643        auto *dstNode = mdg->getNode(dstId);
  1644        // Skip if 'dstNode' is not a loop nest.
  1645        if (!isa<AffineForOp>(dstNode->op))
  1646          continue;
  1647        // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
  1648        fuseWithSiblingNodes(dstNode);
  1649      }
  1650    }
  1651  
  1652    // Attempt to fuse 'dstNode' with sibling nodes in the graph.
  1653    void fuseWithSiblingNodes(Node *dstNode) {
  1654      DenseSet<unsigned> visitedSibNodeIds;
  1655      std::pair<unsigned, Value *> idAndMemref;
  1656      while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
  1657        unsigned sibId = idAndMemref.first;
  1658        Value *memref = idAndMemref.second;
  1659        // TODO(andydavis) Check that 'sibStoreOpInst' post-dominates all other
  1660        // stores to the same memref in 'sibNode' loop nest.
  1661        auto *sibNode = mdg->getNode(sibId);
  1662        // Compute an operation list insertion point for the fused loop
  1663        // nest which preserves dependences.
  1664        assert(sibNode->op->getBlock() == dstNode->op->getBlock());
  1665        Operation *insertPointInst =
  1666            sibNode->op->isBeforeInBlock(dstNode->op)
  1667                ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
  1668                : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
  1669        if (insertPointInst == nullptr)
  1670          continue;
  1671  
  1672        // Check if fusion would be profitable and at what depth.
  1673  
  1674        // Get unique 'sibNode' load op to 'memref'.
  1675        SmallVector<Operation *, 2> sibLoadOpInsts;
  1676        sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
  1677        // Currently findSiblingNodeToFuse searches for siblings with one load.
  1678        assert(sibLoadOpInsts.size() == 1);
  1679        Operation *sibLoadOpInst = sibLoadOpInsts[0];
  1680        assert(!sibNode->stores.empty());
  1681        // TODO(andydavis) Choose the store which postdominates all other stores.
  1682        auto *sibStoreOpInst = sibNode->stores.back();
  1683  
  1684        // Gather 'dstNode' load ops to 'memref'.
  1685        SmallVector<Operation *, 2> dstLoadOpInsts;
  1686        dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);
  1687  
  1688        // Gather 'dstNode' store ops to 'memref'.
  1689        SmallVector<Operation *, 2> dstStoreOpInsts;
  1690        dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts);
  1691  
  1692        unsigned bestDstLoopDepth;
  1693        mlir::ComputationSliceState sliceState;
  1694  
  1695        // Check if fusion would be profitable.
  1696        if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts,
  1697                                dstStoreOpInsts, &sliceState, &bestDstLoopDepth,
  1698                                maximalFusion))
  1699          continue;
  1700  
  1701        // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
  1702        auto sliceLoopNest = mlir::insertBackwardComputationSlice(
  1703            sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState);
  1704        if (sliceLoopNest != nullptr) {
  1705          auto dstForInst = cast<AffineForOp>(dstNode->op);
  1706          // Update operation position of fused loop nest (if needed).
  1707          if (insertPointInst != dstForInst.getOperation()) {
  1708            dstForInst.getOperation()->moveBefore(insertPointInst);
  1709          }
  1710          // Update data dependence graph state post fusion.
  1711          updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode);
  1712        }
  1713      }
  1714    }
  1715  
  1716    // Searches function argument uses and the graph from 'dstNode' looking for a
  1717    // fusion candidate sibling node which shares no dependences with 'dstNode'
  1718    // but which loads from the same memref. Returns true and sets
  1719    // 'idAndMemrefToFuse' on success. Returns false otherwise.
  1720    bool findSiblingNodeToFuse(Node *dstNode,
  1721                               DenseSet<unsigned> *visitedSibNodeIds,
  1722                               std::pair<unsigned, Value *> *idAndMemrefToFuse) {
  1723      // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
  1724      // on 'memref'.
  1725      auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) {
  1726        // Skip if 'outEdge' is not a read-after-write dependence.
  1727        // TODO(andydavis) Remove restrict to single load op restriction.
  1728        if (sibNode->getLoadOpCount(memref) != 1)
  1729          return false;
  1730        // Skip if there exists a path of dependent edges between
  1731        // 'sibNode' and 'dstNode'.
  1732        if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
  1733            mdg->hasDependencePath(dstNode->id, sibNode->id))
  1734          return false;
  1735        // Skip sib node if it loads to (and stores from) the same memref on
  1736        // which it also has an input dependence edge.
  1737        DenseSet<Value *> loadAndStoreMemrefSet;
  1738        sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
  1739        if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
  1740              return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
  1741            }))
  1742          return false;
  1743  
  1744        // Check that all stores are to the same memref.
  1745        DenseSet<Value *> storeMemrefs;
  1746        for (auto *storeOpInst : sibNode->stores) {
  1747          storeMemrefs.insert(cast<AffineStoreOp>(storeOpInst).getMemRef());
  1748        }
  1749        if (storeMemrefs.size() != 1)
  1750          return false;
  1751        return true;
  1752      };
  1753  
  1754      // Search for siblings which load the same memref function argument.
  1755      auto fn = dstNode->op->getParentOfType<FuncOp>();
  1756      for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
  1757        for (auto *user : fn.getArgument(i)->getUsers()) {
  1758          if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
  1759            // Gather loops surrounding 'use'.
  1760            SmallVector<AffineForOp, 4> loops;
  1761            getLoopIVs(*user, &loops);
  1762            // Skip 'use' if it is not within a loop nest.
  1763            if (loops.empty())
  1764              continue;
  1765            Node *sibNode = mdg->getForOpNode(loops[0]);
  1766            assert(sibNode != nullptr);
  1767            // Skip 'use' if it not a sibling to 'dstNode'.
  1768            if (sibNode->id == dstNode->id)
  1769              continue;
  1770            // Skip 'use' if it has been visited.
  1771            if (visitedSibNodeIds->count(sibNode->id) > 0)
  1772              continue;
  1773            // Skip 'use' if it does not load from the same memref as 'dstNode'.
  1774            auto *memref = loadOp.getMemRef();
  1775            if (dstNode->getLoadOpCount(memref) == 0)
  1776              continue;
  1777            // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
  1778            if (canFuseWithSibNode(sibNode, memref)) {
  1779              visitedSibNodeIds->insert(sibNode->id);
  1780              idAndMemrefToFuse->first = sibNode->id;
  1781              idAndMemrefToFuse->second = memref;
  1782              return true;
  1783            }
  1784          }
  1785        }
  1786      }
  1787  
  1788      // Search for siblings by following edges through an intermediate src node.
  1789      // Collect candidate 'dstNode' input edges in 'inEdges'.
  1790      SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
  1791      mdg->forEachMemRefInputEdge(
  1792          dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
  1793            // Add 'inEdge' if it is a read-after-write dependence.
  1794            if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
  1795                mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
  1796              inEdges.push_back(inEdge);
  1797          });
  1798  
  1799      // Search for sibling nodes to fuse by visiting output edges from each input
  1800      // edge in 'inEdges'.
  1801      for (auto &inEdge : inEdges) {
  1802        // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
  1803        SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
  1804        mdg->forEachMemRefOutputEdge(
  1805            inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
  1806              unsigned sibNodeId = outEdge.id;
  1807              if (visitedSibNodeIds->count(sibNodeId) > 0)
  1808                return;
  1809              // Skip output edge if not a sibling using the same memref.
  1810              if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
  1811                return;
  1812              auto *sibNode = mdg->getNode(sibNodeId);
  1813              if (!isa<AffineForOp>(sibNode->op))
  1814                return;
  1815              // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
  1816              if (canFuseWithSibNode(sibNode, outEdge.value)) {
  1817                // Add candidate 'outEdge' to sibling node.
  1818                outEdges.push_back(outEdge);
  1819              }
  1820            });
  1821  
  1822        // Add first candidate if any were returned.
  1823        if (!outEdges.empty()) {
  1824          visitedSibNodeIds->insert(outEdges[0].id);
  1825          idAndMemrefToFuse->first = outEdges[0].id;
  1826          idAndMemrefToFuse->second = outEdges[0].value;
  1827          return true;
  1828        }
  1829      }
  1830      return false;
  1831    }
  1832  
  1833    void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode,
  1834                                       Node *dstNode) {
  1835      // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
  1836      mdg->updateEdges(sibNode->id, dstNode->id);
  1837  
  1838      // Collect slice loop stats.
  1839      LoopNestStateCollector sliceCollector;
  1840      sliceCollector.collect(sliceLoopNest.getOperation());
  1841      // Promote single iteration slice loops to single IV value.
  1842      for (auto forOp : sliceCollector.forOps) {
  1843        promoteIfSingleIteration(forOp);
  1844      }
  1845  
  1846      // Collect dst loop stats after memref privatizaton transformation.
  1847      auto dstForInst = cast<AffineForOp>(dstNode->op);
  1848      LoopNestStateCollector dstLoopCollector;
  1849      dstLoopCollector.collect(dstForInst.getOperation());
  1850      // Clear and add back loads and stores
  1851      mdg->clearNodeLoadAndStores(dstNode->id);
  1852      mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
  1853                     dstLoopCollector.storeOpInsts);
  1854      // Remove old sibling loop nest if it no longer has outgoing dependence
  1855      // edges, and it does not write to a memref which escapes the
  1856      // function.
  1857      if (mdg->getOutEdgeCount(sibNode->id) == 0) {
  1858        mdg->removeNode(sibNode->id);
  1859        sibNode->op->erase();
  1860      }
  1861    }
  1862  
  1863    // Clean up any allocs with no users.
  1864    void eraseUnusedMemRefAllocations() {
  1865      for (auto &pair : mdg->memrefEdgeCount) {
  1866        if (pair.second > 0)
  1867          continue;
  1868        auto *memref = pair.first;
  1869        // Skip if there exist other uses (return operation or function calls).
  1870        if (!memref->use_empty())
  1871          continue;
  1872        // Use list expected to match the dep graph info.
  1873        auto *op = memref->getDefiningOp();
  1874        if (isa_and_nonnull<AllocOp>(op))
  1875          op->erase();
  1876      }
  1877    }
  1878  };
  1879  
  1880  } // end anonymous namespace
  1881  
  1882  void LoopFusion::runOnFunction() {
  1883    // Override if a command line argument was provided.
  1884    if (clFusionFastMemorySpace.getNumOccurrences() > 0) {
  1885      fastMemorySpace = clFusionFastMemorySpace.getValue();
  1886    }
  1887  
  1888    // Override if a command line argument was provided.
  1889    if (clFusionLocalBufThreshold.getNumOccurrences() > 0) {
  1890      localBufSizeThreshold = clFusionLocalBufThreshold * 1024;
  1891    }
  1892  
  1893    if (clMaximalLoopFusion.getNumOccurrences() > 0)
  1894      maximalFusion = clMaximalLoopFusion;
  1895  
  1896    MemRefDependenceGraph g;
  1897    if (g.init(getFunction()))
  1898      GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion)
  1899          .run();
  1900  }
  1901  
  1902  static PassRegistration<LoopFusion> pass("affine-loop-fusion",
  1903                                           "Fuse loop nests");