github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/PatternMatch.cpp (about)

     1  //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/IR/PatternMatch.h"
    19  #include "mlir/IR/Operation.h"
    20  #include "mlir/IR/Value.h"
    21  using namespace mlir;
    22  
    23  PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
    24    assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
    25           "This pattern match benefit is too large to represent");
    26  }
    27  
    28  unsigned short PatternBenefit::getBenefit() const {
    29    assert(representation != ImpossibleToMatchSentinel &&
    30           "Pattern doesn't match");
    31    return representation;
    32  }
    33  
    34  //===----------------------------------------------------------------------===//
    35  // Pattern implementation
    36  //===----------------------------------------------------------------------===//
    37  
    38  Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
    39                   MLIRContext *context)
    40      : rootKind(OperationName(rootName, context)), benefit(benefit) {}
    41  
    42  // Out-of-line vtable anchor.
    43  void Pattern::anchor() {}
    44  
    45  //===----------------------------------------------------------------------===//
    46  // RewritePattern and PatternRewriter implementation
    47  //===----------------------------------------------------------------------===//
    48  
    49  void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
    50                               PatternRewriter &rewriter) const {
    51    rewrite(op, rewriter);
    52  }
    53  
    54  void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
    55    llvm_unreachable("need to implement either matchAndRewrite or one of the "
    56                     "rewrite functions!");
    57  }
    58  
    59  PatternMatchResult RewritePattern::match(Operation *op) const {
    60    llvm_unreachable("need to implement either match or matchAndRewrite!");
    61  }
    62  
    63  /// Patterns must specify the root operation name they match against, and can
    64  /// also specify the benefit of the pattern matching. They can also specify the
    65  /// names of operations that may be generated during a successful rewrite.
    66  RewritePattern::RewritePattern(StringRef rootName,
    67                                 ArrayRef<StringRef> generatedNames,
    68                                 PatternBenefit benefit, MLIRContext *context)
    69      : Pattern(rootName, benefit, context) {
    70    generatedOps.reserve(generatedNames.size());
    71    std::transform(generatedNames.begin(), generatedNames.end(),
    72                   std::back_inserter(generatedOps), [context](StringRef name) {
    73                     return OperationName(name, context);
    74                   });
    75  }
    76  
    77  PatternRewriter::~PatternRewriter() {
    78    // Out of line to provide a vtable anchor for the class.
    79  }
    80  
    81  /// This method performs the final replacement for a pattern, where the
    82  /// results of the operation are updated to use the specified list of SSA
    83  /// values.  In addition to replacing and removing the specified operation,
    84  /// clients can specify a list of other nodes that this replacement may make
    85  /// (perhaps transitively) dead.  If any of those ops are dead, this will
    86  /// remove them as well.
    87  void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
    88                                  ArrayRef<Value *> valuesToRemoveIfDead) {
    89    // Notify the rewriter subclass that we're about to replace this root.
    90    notifyRootReplaced(op);
    91  
    92    assert(op->getNumResults() == newValues.size() &&
    93           "incorrect # of replacement values");
    94    op->replaceAllUsesWith(newValues);
    95  
    96    notifyOperationRemoved(op);
    97    op->erase();
    98  
    99    // TODO: Process the valuesToRemoveIfDead list, removing things and calling
   100    // the notifyOperationRemoved hook in the process.
   101  }
   102  
   103  /// op and newOp are known to have the same number of results, replace the
   104  /// uses of op with uses of newOp
   105  void PatternRewriter::replaceOpWithResultsOfAnotherOp(
   106      Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
   107    assert(op->getNumResults() == newOp->getNumResults() &&
   108           "replacement op doesn't match results of original op");
   109    if (op->getNumResults() == 1)
   110      return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
   111  
   112    SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
   113                                       newOp->getResults().end());
   114    return replaceOp(op, newResults, valuesToRemoveIfDead);
   115  }
   116  
   117  /// Move the blocks that belong to "region" before the given position in
   118  /// another region.  The two regions must be different.  The caller is in
   119  /// charge to update create the operation transferring the control flow to the
   120  /// region and pass it the correct block arguments.
   121  void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
   122                                           Region::iterator before) {
   123    parent.getBlocks().splice(before, region.getBlocks());
   124  }
   125  void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
   126    inlineRegionBefore(region, *before->getParent(), before->getIterator());
   127  }
   128  
   129  /// This method is used as the final notification hook for patterns that end
   130  /// up modifying the pattern root in place, by changing its operands.  This is
   131  /// a minor efficiency win (it avoids creating a new operation and removing
   132  /// the old one) but also often allows simpler code in the client.
   133  ///
   134  /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
   135  /// should remove if they are dead at this point.
   136  ///
   137  void PatternRewriter::updatedRootInPlace(
   138      Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
   139    // Notify the rewriter subclass that we're about to replace this root.
   140    notifyRootUpdated(op);
   141  
   142    // TODO: Process the valuesToRemoveIfDead list, removing things and calling
   143    // the notifyOperationRemoved hook in the process.
   144  }
   145  
   146  //===----------------------------------------------------------------------===//
   147  // PatternMatcher implementation
   148  //===----------------------------------------------------------------------===//
   149  
   150  RewritePatternMatcher::RewritePatternMatcher(
   151      const OwningRewritePatternList &patterns) {
   152    for (auto &pattern : patterns)
   153      this->patterns.push_back(pattern.get());
   154  
   155    // Sort the patterns by benefit to simplify the matching logic.
   156    std::stable_sort(this->patterns.begin(), this->patterns.end(),
   157                     [](RewritePattern *l, RewritePattern *r) {
   158                       return r->getBenefit() < l->getBenefit();
   159                     });
   160  }
   161  
   162  /// Try to match the given operation to a pattern and rewrite it.
   163  bool RewritePatternMatcher::matchAndRewrite(Operation *op,
   164                                              PatternRewriter &rewriter) {
   165    for (auto *pattern : patterns) {
   166      // Ignore patterns that are for the wrong root or are impossible to match.
   167      if (pattern->getRootKind() != op->getName() ||
   168          pattern->getBenefit().isImpossibleToMatch())
   169        continue;
   170  
   171      // Try to match and rewrite this pattern. The patterns are sorted by
   172      // benefit, so if we match we can immediately rewrite and return.
   173      if (pattern->matchAndRewrite(op, rewriter))
   174        return true;
   175    }
   176    return false;
   177  }