github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/PatternMatch.cpp (about) 1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/IR/Value.h" 21 using namespace mlir; 22 23 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { 24 assert(representation == benefit && benefit != ImpossibleToMatchSentinel && 25 "This pattern match benefit is too large to represent"); 26 } 27 28 unsigned short PatternBenefit::getBenefit() const { 29 assert(representation != ImpossibleToMatchSentinel && 30 "Pattern doesn't match"); 31 return representation; 32 } 33 34 //===----------------------------------------------------------------------===// 35 // Pattern implementation 36 //===----------------------------------------------------------------------===// 37 38 Pattern::Pattern(StringRef rootName, PatternBenefit benefit, 39 MLIRContext *context) 40 : rootKind(OperationName(rootName, context)), benefit(benefit) {} 41 42 // Out-of-line vtable anchor. 43 void Pattern::anchor() {} 44 45 //===----------------------------------------------------------------------===// 46 // RewritePattern and PatternRewriter implementation 47 //===----------------------------------------------------------------------===// 48 49 void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state, 50 PatternRewriter &rewriter) const { 51 rewrite(op, rewriter); 52 } 53 54 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { 55 llvm_unreachable("need to implement either matchAndRewrite or one of the " 56 "rewrite functions!"); 57 } 58 59 PatternMatchResult RewritePattern::match(Operation *op) const { 60 llvm_unreachable("need to implement either match or matchAndRewrite!"); 61 } 62 63 /// Patterns must specify the root operation name they match against, and can 64 /// also specify the benefit of the pattern matching. They can also specify the 65 /// names of operations that may be generated during a successful rewrite. 66 RewritePattern::RewritePattern(StringRef rootName, 67 ArrayRef<StringRef> generatedNames, 68 PatternBenefit benefit, MLIRContext *context) 69 : Pattern(rootName, benefit, context) { 70 generatedOps.reserve(generatedNames.size()); 71 std::transform(generatedNames.begin(), generatedNames.end(), 72 std::back_inserter(generatedOps), [context](StringRef name) { 73 return OperationName(name, context); 74 }); 75 } 76 77 PatternRewriter::~PatternRewriter() { 78 // Out of line to provide a vtable anchor for the class. 79 } 80 81 /// This method performs the final replacement for a pattern, where the 82 /// results of the operation are updated to use the specified list of SSA 83 /// values. In addition to replacing and removing the specified operation, 84 /// clients can specify a list of other nodes that this replacement may make 85 /// (perhaps transitively) dead. If any of those ops are dead, this will 86 /// remove them as well. 87 void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues, 88 ArrayRef<Value *> valuesToRemoveIfDead) { 89 // Notify the rewriter subclass that we're about to replace this root. 90 notifyRootReplaced(op); 91 92 assert(op->getNumResults() == newValues.size() && 93 "incorrect # of replacement values"); 94 op->replaceAllUsesWith(newValues); 95 96 notifyOperationRemoved(op); 97 op->erase(); 98 99 // TODO: Process the valuesToRemoveIfDead list, removing things and calling 100 // the notifyOperationRemoved hook in the process. 101 } 102 103 /// op and newOp are known to have the same number of results, replace the 104 /// uses of op with uses of newOp 105 void PatternRewriter::replaceOpWithResultsOfAnotherOp( 106 Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) { 107 assert(op->getNumResults() == newOp->getNumResults() && 108 "replacement op doesn't match results of original op"); 109 if (op->getNumResults() == 1) 110 return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead); 111 112 SmallVector<Value *, 8> newResults(newOp->getResults().begin(), 113 newOp->getResults().end()); 114 return replaceOp(op, newResults, valuesToRemoveIfDead); 115 } 116 117 /// Move the blocks that belong to "region" before the given position in 118 /// another region. The two regions must be different. The caller is in 119 /// charge to update create the operation transferring the control flow to the 120 /// region and pass it the correct block arguments. 121 void PatternRewriter::inlineRegionBefore(Region ®ion, Region &parent, 122 Region::iterator before) { 123 parent.getBlocks().splice(before, region.getBlocks()); 124 } 125 void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) { 126 inlineRegionBefore(region, *before->getParent(), before->getIterator()); 127 } 128 129 /// This method is used as the final notification hook for patterns that end 130 /// up modifying the pattern root in place, by changing its operands. This is 131 /// a minor efficiency win (it avoids creating a new operation and removing 132 /// the old one) but also often allows simpler code in the client. 133 /// 134 /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter 135 /// should remove if they are dead at this point. 136 /// 137 void PatternRewriter::updatedRootInPlace( 138 Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) { 139 // Notify the rewriter subclass that we're about to replace this root. 140 notifyRootUpdated(op); 141 142 // TODO: Process the valuesToRemoveIfDead list, removing things and calling 143 // the notifyOperationRemoved hook in the process. 144 } 145 146 //===----------------------------------------------------------------------===// 147 // PatternMatcher implementation 148 //===----------------------------------------------------------------------===// 149 150 RewritePatternMatcher::RewritePatternMatcher( 151 const OwningRewritePatternList &patterns) { 152 for (auto &pattern : patterns) 153 this->patterns.push_back(pattern.get()); 154 155 // Sort the patterns by benefit to simplify the matching logic. 156 std::stable_sort(this->patterns.begin(), this->patterns.end(), 157 [](RewritePattern *l, RewritePattern *r) { 158 return r->getBenefit() < l->getBenefit(); 159 }); 160 } 161 162 /// Try to match the given operation to a pattern and rewrite it. 163 bool RewritePatternMatcher::matchAndRewrite(Operation *op, 164 PatternRewriter &rewriter) { 165 for (auto *pattern : patterns) { 166 // Ignore patterns that are for the wrong root or are impossible to match. 167 if (pattern->getRootKind() != op->getName() || 168 pattern->getBenefit().isImpossibleToMatch()) 169 continue; 170 171 // Try to match and rewrite this pattern. The patterns are sorted by 172 // benefit, so if we match we can immediately rewrite and return. 173 if (pattern->matchAndRewrite(op, rewriter)) 174 return true; 175 } 176 return false; 177 }