github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp (about) 1 //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "TestDialect.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 using namespace mlir; 23 24 // Native function for testing NativeCodeCall 25 static Value *chooseOperand(Value *input1, Value *input2, BoolAttr choice) { 26 return choice.getValue() ? input1 : input2; 27 } 28 29 namespace { 30 #include "TestPatterns.inc" 31 } // end anonymous namespace 32 33 //===----------------------------------------------------------------------===// 34 // Canonicalizer Driver. 35 //===----------------------------------------------------------------------===// 36 37 namespace { 38 struct TestPatternDriver : public FunctionPass<TestPatternDriver> { 39 void runOnFunction() override { 40 mlir::OwningRewritePatternList patterns; 41 populateWithGenerated(&getContext(), &patterns); 42 43 // Verify named pattern is generated with expected name. 44 patterns.insert<TestNamedPatternRule>(&getContext()); 45 46 applyPatternsGreedily(getFunction(), patterns); 47 } 48 }; 49 } // end anonymous namespace 50 51 static mlir::PassRegistration<TestPatternDriver> 52 pass("test-patterns", "Run test dialect patterns"); 53 54 //===----------------------------------------------------------------------===// 55 // Legalization Driver. 56 //===----------------------------------------------------------------------===// 57 58 namespace { 59 /// This pattern is a simple pattern that inlines the first region of a given 60 /// operation into the parent region. 61 struct TestRegionRewriteBlockMovement : public ConversionPattern { 62 TestRegionRewriteBlockMovement(MLIRContext *ctx) 63 : ConversionPattern("test.region", 1, ctx) {} 64 65 PatternMatchResult 66 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 67 ConversionPatternRewriter &rewriter) const final { 68 // Inline this region into the parent region. 69 auto &parentRegion = *op->getParentRegion(); 70 rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, 71 parentRegion.end()); 72 73 // Drop this operation. 74 rewriter.replaceOp(op, llvm::None); 75 return matchSuccess(); 76 } 77 }; 78 /// This pattern is a simple pattern that generates a region containing an 79 /// illegal operation. 80 struct TestRegionRewriteUndo : public RewritePattern { 81 TestRegionRewriteUndo(MLIRContext *ctx) 82 : RewritePattern("test.region_builder", 1, ctx) {} 83 84 PatternMatchResult matchAndRewrite(Operation *op, 85 PatternRewriter &rewriter) const final { 86 // Create the region operation with an entry block containing arguments. 87 OperationState newRegion(op->getLoc(), "test.region"); 88 newRegion.addRegion(); 89 auto *regionOp = rewriter.createOperation(newRegion); 90 auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 91 entryBlock->addArgument(rewriter.getIntegerType(64)); 92 93 // Add an explicitly illegal operation to ensure the conversion fails. 94 rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 95 rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value *>()); 96 97 // Drop this operation. 98 rewriter.replaceOp(op, llvm::None); 99 return matchSuccess(); 100 } 101 }; 102 /// This pattern simply erases the given operation. 103 struct TestDropOp : public ConversionPattern { 104 TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {} 105 PatternMatchResult 106 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 107 ConversionPatternRewriter &rewriter) const final { 108 rewriter.replaceOp(op, llvm::None); 109 return matchSuccess(); 110 } 111 }; 112 /// This pattern simply updates the operands of the given operation. 113 struct TestPassthroughInvalidOp : public ConversionPattern { 114 TestPassthroughInvalidOp(MLIRContext *ctx) 115 : ConversionPattern("test.invalid", 1, ctx) {} 116 PatternMatchResult 117 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 118 ConversionPatternRewriter &rewriter) const final { 119 rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands, 120 llvm::None); 121 return matchSuccess(); 122 } 123 }; 124 /// This pattern handles the case of a split return value. 125 struct TestSplitReturnType : public ConversionPattern { 126 TestSplitReturnType(MLIRContext *ctx) 127 : ConversionPattern("test.return", 1, ctx) {} 128 PatternMatchResult 129 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 130 ConversionPatternRewriter &rewriter) const final { 131 // Check for a return of F32. 132 if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32()) 133 return matchFailure(); 134 135 // Check if the first operation is a cast operation, if it is we use the 136 // results directly. 137 auto *defOp = operands[0]->getDefiningOp(); 138 if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) { 139 SmallVector<Value *, 2> returnOperands(packerOp.getOperands()); 140 rewriter.replaceOpWithNewOp<TestReturnOp>(op, returnOperands); 141 return matchSuccess(); 142 } 143 144 // Otherwise, fail to match. 145 return matchFailure(); 146 } 147 }; 148 } // namespace 149 150 namespace { 151 struct TestTypeConverter : public TypeConverter { 152 using TypeConverter::TypeConverter; 153 154 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override { 155 // Drop I16 types. 156 if (t.isInteger(16)) 157 return success(); 158 159 // Convert I64 to F64. 160 if (t.isInteger(64)) { 161 results.push_back(FloatType::getF64(t.getContext())); 162 return success(); 163 } 164 165 // Split F32 into F16,F16. 166 if (t.isF32()) { 167 results.assign(2, FloatType::getF16(t.getContext())); 168 return success(); 169 } 170 171 // Otherwise, convert the type directly. 172 results.push_back(t); 173 return success(); 174 } 175 176 /// Override the hook to materialize a conversion. This is necessary because 177 /// we generate 1->N type mappings. 178 Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, 179 ArrayRef<Value *> inputs, 180 Location loc) override { 181 return rewriter.create<TestCastOp>(loc, resultType, inputs); 182 } 183 }; 184 185 struct TestLegalizePatternDriver 186 : public ModulePass<TestLegalizePatternDriver> { 187 /// The mode of conversion to use with the driver. 188 enum class ConversionMode { Analysis, Partial }; 189 190 TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 191 192 void runOnModule() override { 193 TestTypeConverter converter; 194 mlir::OwningRewritePatternList patterns; 195 populateWithGenerated(&getContext(), &patterns); 196 patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo, 197 TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>( 198 &getContext()); 199 mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), 200 converter); 201 202 // Define the conversion target used for the test. 203 ConversionTarget target(getContext()); 204 target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>(); 205 target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>(); 206 target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 207 // Don't allow F32 operands. 208 return llvm::none_of(op.getOperandTypes(), 209 [](Type type) { return type.isF32(); }); 210 }); 211 target.addDynamicallyLegalOp<FuncOp>( 212 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 213 214 // Handle a partial conversion. 215 if (mode == ConversionMode::Partial) { 216 (void)applyPartialConversion(getModule(), target, patterns, &converter); 217 return; 218 } 219 220 // Otherwise, handle an analysis conversion. 221 assert(mode == ConversionMode::Analysis); 222 223 // Analyze the convertible operations. 224 DenseSet<Operation *> legalizedOps; 225 if (failed(applyAnalysisConversion(getModule(), target, patterns, 226 legalizedOps, &converter))) 227 return signalPassFailure(); 228 229 // Emit remarks for each legalizable operation. 230 for (auto *op : legalizedOps) 231 op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 232 } 233 234 /// The mode of conversion to use. 235 ConversionMode mode; 236 }; 237 } // end anonymous namespace 238 239 static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 240 legalizerConversionMode( 241 "test-legalize-mode", 242 llvm::cl::desc("The legalization mode to use with the test driver"), 243 llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 244 llvm::cl::values( 245 clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 246 "analysis", "Perform an analysis conversion"), 247 clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 248 "partial", "Perform a partial conversion"))); 249 250 static mlir::PassRegistration<TestLegalizePatternDriver> 251 legalizer_pass("test-legalize-patterns", 252 "Run test dialect legalization patterns", [] { 253 return std::make_unique<TestLegalizePatternDriver>( 254 legalizerConversionMode); 255 });