github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (about)

     1  //===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements mlir::applyPatternsGreedily.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Dialect/StandardOps/Ops.h"
    23  #include "mlir/IR/Builders.h"
    24  #include "mlir/IR/PatternMatch.h"
    25  #include "mlir/Transforms/FoldUtils.h"
    26  #include "llvm/ADT/DenseMap.h"
    27  #include "llvm/Support/CommandLine.h"
    28  #include "llvm/Support/Debug.h"
    29  #include "llvm/Support/raw_ostream.h"
    30  
    31  using namespace mlir;
    32  
    33  #define DEBUG_TYPE "pattern-matcher"
    34  
    35  static llvm::cl::opt<unsigned> maxPatternMatchIterations(
    36      "mlir-max-pattern-match-iterations",
    37      llvm::cl::desc("Max number of iterations scanning for pattern match"),
    38      llvm::cl::init(10));
    39  
    40  namespace {
    41  
    42  /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
    43  /// applies the locally optimal patterns in a roughly "bottom up" way.
    44  class GreedyPatternRewriteDriver : public PatternRewriter {
    45  public:
    46    explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
    47                                        const OwningRewritePatternList &patterns)
    48        : PatternRewriter(ctx), matcher(patterns) {
    49      worklist.reserve(64);
    50    }
    51  
    52    /// Perform the rewrites. Return true if the rewrite converges in
    53    /// `maxIterations`.
    54    bool simplify(Operation *op, int maxIterations);
    55  
    56    void addToWorklist(Operation *op) {
    57      // Check to see if the worklist already contains this op.
    58      if (worklistMap.count(op))
    59        return;
    60  
    61      worklistMap[op] = worklist.size();
    62      worklist.push_back(op);
    63    }
    64  
    65    Operation *popFromWorklist() {
    66      auto *op = worklist.back();
    67      worklist.pop_back();
    68  
    69      // This operation is no longer in the worklist, keep worklistMap up to date.
    70      if (op)
    71        worklistMap.erase(op);
    72      return op;
    73    }
    74  
    75    /// If the specified operation is in the worklist, remove it.  If not, this is
    76    /// a no-op.
    77    void removeFromWorklist(Operation *op) {
    78      auto it = worklistMap.find(op);
    79      if (it != worklistMap.end()) {
    80        assert(worklist[it->second] == op && "malformed worklist data structure");
    81        worklist[it->second] = nullptr;
    82      }
    83    }
    84  
    85    // These are hooks implemented for PatternRewriter.
    86  protected:
    87    // Implement the hook for creating operations, and make sure that newly
    88    // created ops are added to the worklist for processing.
    89    Operation *createOperation(const OperationState &state) override {
    90      auto *result = OpBuilder::createOperation(state);
    91      addToWorklist(result);
    92      return result;
    93    }
    94  
    95    // If an operation is about to be removed, make sure it is not in our
    96    // worklist anymore because we'd get dangling references to it.
    97    void notifyOperationRemoved(Operation *op) override {
    98      addToWorklist(op->getOperands());
    99      op->walk([this](Operation *operation) {
   100        removeFromWorklist(operation);
   101        folder.notifyRemoval(operation);
   102      });
   103    }
   104  
   105    // When the root of a pattern is about to be replaced, it can trigger
   106    // simplifications to its users - make sure to add them to the worklist
   107    // before the root is changed.
   108    void notifyRootReplaced(Operation *op) override {
   109      for (auto *result : op->getResults())
   110        for (auto *user : result->getUsers())
   111          addToWorklist(user);
   112    }
   113  
   114  private:
   115    // Look over the provided operands for any defining operations that should
   116    // be re-added to the worklist. This function should be called when an
   117    // operation is modified or removed, as it may trigger further
   118    // simplifications.
   119    template <typename Operands> void addToWorklist(Operands &&operands) {
   120      for (Value *operand : operands) {
   121        // If the use count of this operand is now < 2, we re-add the defining
   122        // operation to the worklist.
   123        // TODO(riverriddle) This is based on the fact that zero use operations
   124        // may be deleted, and that single use values often have more
   125        // canonicalization opportunities.
   126        if (!operand->use_empty() && !operand->hasOneUse())
   127          continue;
   128        if (auto *defInst = operand->getDefiningOp())
   129          addToWorklist(defInst);
   130      }
   131    }
   132  
   133    /// The low-level pattern matcher.
   134    RewritePatternMatcher matcher;
   135  
   136    /// The worklist for this transformation keeps track of the operations that
   137    /// need to be revisited, plus their index in the worklist.  This allows us to
   138    /// efficiently remove operations from the worklist when they are erased, even
   139    /// if they aren't the root of a pattern.
   140    std::vector<Operation *> worklist;
   141    DenseMap<Operation *, unsigned> worklistMap;
   142  
   143    /// Non-pattern based folder for operations.
   144    OperationFolder folder;
   145  };
   146  } // end anonymous namespace
   147  
   148  /// Perform the rewrites.
   149  bool GreedyPatternRewriteDriver::simplify(Operation *op, int maxIterations) {
   150    // Add the given operation to the worklist.
   151    auto collectOps = [this](Operation *op) { addToWorklist(op); };
   152  
   153    bool changed = false;
   154    int i = 0;
   155    do {
   156      // Add all nested operations to the worklist.
   157      for (auto &region : op->getRegions())
   158        region.walk(collectOps);
   159  
   160      // These are scratch vectors used in the folding loop below.
   161      SmallVector<Value *, 8> originalOperands, resultValues;
   162  
   163      changed = false;
   164      while (!worklist.empty()) {
   165        auto *op = popFromWorklist();
   166  
   167        // Nulls get added to the worklist when operations are removed, ignore
   168        // them.
   169        if (op == nullptr)
   170          continue;
   171  
   172        // If the operation has no side effects, and no users, then it is
   173        // trivially dead - remove it.
   174        if (op->hasNoSideEffect() && op->use_empty()) {
   175          // Be careful to update bookkeeping.
   176          notifyOperationRemoved(op);
   177          op->erase();
   178          continue;
   179        }
   180  
   181        // Collects all the operands and result uses of the given `op` into work
   182        // list. Also remove `op` and nested ops from worklist.
   183        originalOperands.assign(op->operand_begin(), op->operand_end());
   184        auto preReplaceAction = [&](Operation *op) {
   185          // Add the operands to the worklist for visitation.
   186          addToWorklist(originalOperands);
   187  
   188          // Add all the users of the result to the worklist so we make sure
   189          // to revisit them.
   190          for (auto *result : op->getResults())
   191            for (auto *operand : result->getUsers())
   192              addToWorklist(operand);
   193  
   194          notifyOperationRemoved(op);
   195        };
   196  
   197        // Try to fold this op.
   198        if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
   199          changed |= true;
   200          continue;
   201        }
   202  
   203        // Make sure that any new operations are inserted at this point.
   204        setInsertionPoint(op);
   205  
   206        // Try to match one of the patterns. The rewriter is automatically
   207        // notified of any necessary changes, so there is nothing else to do here.
   208        changed |= matcher.matchAndRewrite(op, *this);
   209      }
   210    } while (changed && ++i < maxIterations);
   211    // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   212    return !changed;
   213  }
   214  
   215  /// Rewrite the regions of the specified operation, which must be isolated from
   216  /// above, by repeatedly applying the highest benefit patterns in a greedy
   217  /// work-list driven manner. Return true if no more patterns can be matched in
   218  /// the result operation regions.
   219  /// Note: This does not apply patterns to the top-level operation itself.
   220  ///
   221  bool mlir::applyPatternsGreedily(Operation *op,
   222                                   const OwningRewritePatternList &patterns) {
   223    // The top-level operation must be known to be isolated from above to
   224    // prevent performing canonicalizations on operations defined at or above
   225    // the region containing 'op'.
   226    if (!op->isKnownIsolatedFromAbove())
   227      return false;
   228  
   229    GreedyPatternRewriteDriver driver(op->getContext(), patterns);
   230    bool converged = driver.simplify(op, maxPatternMatchIterations);
   231    LLVM_DEBUG(if (!converged) {
   232      llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
   233                   << maxPatternMatchIterations << " times";
   234    });
   235    return converged;
   236  }