github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/test/lib/TestDialect/TestPatterns.cpp (about)

     1  //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "TestDialect.h"
    19  #include "mlir/IR/PatternMatch.h"
    20  #include "mlir/Pass/Pass.h"
    21  #include "mlir/Transforms/DialectConversion.h"
    22  using namespace mlir;
    23  
    24  // Native function for testing NativeCodeCall
    25  static Value *chooseOperand(Value *input1, Value *input2, BoolAttr choice) {
    26    return choice.getValue() ? input1 : input2;
    27  }
    28  
    29  namespace {
    30  #include "TestPatterns.inc"
    31  } // end anonymous namespace
    32  
    33  //===----------------------------------------------------------------------===//
    34  // Canonicalizer Driver.
    35  //===----------------------------------------------------------------------===//
    36  
    37  namespace {
    38  struct TestPatternDriver : public FunctionPass<TestPatternDriver> {
    39    void runOnFunction() override {
    40      mlir::OwningRewritePatternList patterns;
    41      populateWithGenerated(&getContext(), &patterns);
    42  
    43      // Verify named pattern is generated with expected name.
    44      patterns.insert<TestNamedPatternRule>(&getContext());
    45  
    46      applyPatternsGreedily(getFunction(), patterns);
    47    }
    48  };
    49  } // end anonymous namespace
    50  
    51  static mlir::PassRegistration<TestPatternDriver>
    52      pass("test-patterns", "Run test dialect patterns");
    53  
    54  //===----------------------------------------------------------------------===//
    55  // Legalization Driver.
    56  //===----------------------------------------------------------------------===//
    57  
    58  namespace {
    59  /// This pattern is a simple pattern that inlines the first region of a given
    60  /// operation into the parent region.
    61  struct TestRegionRewriteBlockMovement : public ConversionPattern {
    62    TestRegionRewriteBlockMovement(MLIRContext *ctx)
    63        : ConversionPattern("test.region", 1, ctx) {}
    64  
    65    PatternMatchResult
    66    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
    67                    ConversionPatternRewriter &rewriter) const final {
    68      // Inline this region into the parent region.
    69      auto &parentRegion = *op->getParentRegion();
    70      rewriter.inlineRegionBefore(op->getRegion(0), parentRegion,
    71                                  parentRegion.end());
    72  
    73      // Drop this operation.
    74      rewriter.replaceOp(op, llvm::None);
    75      return matchSuccess();
    76    }
    77  };
    78  /// This pattern is a simple pattern that generates a region containing an
    79  /// illegal operation.
    80  struct TestRegionRewriteUndo : public RewritePattern {
    81    TestRegionRewriteUndo(MLIRContext *ctx)
    82        : RewritePattern("test.region_builder", 1, ctx) {}
    83  
    84    PatternMatchResult matchAndRewrite(Operation *op,
    85                                       PatternRewriter &rewriter) const final {
    86      // Create the region operation with an entry block containing arguments.
    87      OperationState newRegion(op->getLoc(), "test.region");
    88      newRegion.addRegion();
    89      auto *regionOp = rewriter.createOperation(newRegion);
    90      auto *entryBlock = rewriter.createBlock(&regionOp->getRegion(0));
    91      entryBlock->addArgument(rewriter.getIntegerType(64));
    92  
    93      // Add an explicitly illegal operation to ensure the conversion fails.
    94      rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32));
    95      rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value *>());
    96  
    97      // Drop this operation.
    98      rewriter.replaceOp(op, llvm::None);
    99      return matchSuccess();
   100    }
   101  };
   102  /// This pattern simply erases the given operation.
   103  struct TestDropOp : public ConversionPattern {
   104    TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {}
   105    PatternMatchResult
   106    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   107                    ConversionPatternRewriter &rewriter) const final {
   108      rewriter.replaceOp(op, llvm::None);
   109      return matchSuccess();
   110    }
   111  };
   112  /// This pattern simply updates the operands of the given operation.
   113  struct TestPassthroughInvalidOp : public ConversionPattern {
   114    TestPassthroughInvalidOp(MLIRContext *ctx)
   115        : ConversionPattern("test.invalid", 1, ctx) {}
   116    PatternMatchResult
   117    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   118                    ConversionPatternRewriter &rewriter) const final {
   119      rewriter.replaceOpWithNewOp<TestValidOp>(op, llvm::None, operands,
   120                                               llvm::None);
   121      return matchSuccess();
   122    }
   123  };
   124  /// This pattern handles the case of a split return value.
   125  struct TestSplitReturnType : public ConversionPattern {
   126    TestSplitReturnType(MLIRContext *ctx)
   127        : ConversionPattern("test.return", 1, ctx) {}
   128    PatternMatchResult
   129    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   130                    ConversionPatternRewriter &rewriter) const final {
   131      // Check for a return of F32.
   132      if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32())
   133        return matchFailure();
   134  
   135      // Check if the first operation is a cast operation, if it is we use the
   136      // results directly.
   137      auto *defOp = operands[0]->getDefiningOp();
   138      if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
   139        SmallVector<Value *, 2> returnOperands(packerOp.getOperands());
   140        rewriter.replaceOpWithNewOp<TestReturnOp>(op, returnOperands);
   141        return matchSuccess();
   142      }
   143  
   144      // Otherwise, fail to match.
   145      return matchFailure();
   146    }
   147  };
   148  } // namespace
   149  
   150  namespace {
   151  struct TestTypeConverter : public TypeConverter {
   152    using TypeConverter::TypeConverter;
   153  
   154    LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
   155      // Drop I16 types.
   156      if (t.isInteger(16))
   157        return success();
   158  
   159      // Convert I64 to F64.
   160      if (t.isInteger(64)) {
   161        results.push_back(FloatType::getF64(t.getContext()));
   162        return success();
   163      }
   164  
   165      // Split F32 into F16,F16.
   166      if (t.isF32()) {
   167        results.assign(2, FloatType::getF16(t.getContext()));
   168        return success();
   169      }
   170  
   171      // Otherwise, convert the type directly.
   172      results.push_back(t);
   173      return success();
   174    }
   175  
   176    /// Override the hook to materialize a conversion. This is necessary because
   177    /// we generate 1->N type mappings.
   178    Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
   179                                     ArrayRef<Value *> inputs,
   180                                     Location loc) override {
   181      return rewriter.create<TestCastOp>(loc, resultType, inputs);
   182    }
   183  };
   184  
   185  struct TestLegalizePatternDriver
   186      : public ModulePass<TestLegalizePatternDriver> {
   187    /// The mode of conversion to use with the driver.
   188    enum class ConversionMode { Analysis, Partial };
   189  
   190    TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
   191  
   192    void runOnModule() override {
   193      TestTypeConverter converter;
   194      mlir::OwningRewritePatternList patterns;
   195      populateWithGenerated(&getContext(), &patterns);
   196      patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
   197                      TestDropOp, TestPassthroughInvalidOp, TestSplitReturnType>(
   198          &getContext());
   199      mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
   200                                                converter);
   201  
   202      // Define the conversion target used for the test.
   203      ConversionTarget target(getContext());
   204      target.addLegalOp<LegalOpA, TestCastOp, TestValidOp>();
   205      target.addIllegalOp<ILLegalOpF, TestRegionBuilderOp>();
   206      target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {
   207        // Don't allow F32 operands.
   208        return llvm::none_of(op.getOperandTypes(),
   209                             [](Type type) { return type.isF32(); });
   210      });
   211      target.addDynamicallyLegalOp<FuncOp>(
   212          [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
   213  
   214      // Handle a partial conversion.
   215      if (mode == ConversionMode::Partial) {
   216        (void)applyPartialConversion(getModule(), target, patterns, &converter);
   217        return;
   218      }
   219  
   220      // Otherwise, handle an analysis conversion.
   221      assert(mode == ConversionMode::Analysis);
   222  
   223      // Analyze the convertible operations.
   224      DenseSet<Operation *> legalizedOps;
   225      if (failed(applyAnalysisConversion(getModule(), target, patterns,
   226                                         legalizedOps, &converter)))
   227        return signalPassFailure();
   228  
   229      // Emit remarks for each legalizable operation.
   230      for (auto *op : legalizedOps)
   231        op->emitRemark() << "op '" << op->getName() << "' is legalizable";
   232    }
   233  
   234    /// The mode of conversion to use.
   235    ConversionMode mode;
   236  };
   237  } // end anonymous namespace
   238  
   239  static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
   240      legalizerConversionMode(
   241          "test-legalize-mode",
   242          llvm::cl::desc("The legalization mode to use with the test driver"),
   243          llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
   244          llvm::cl::values(
   245              clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
   246                         "analysis", "Perform an analysis conversion"),
   247              clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
   248                         "partial", "Perform a partial conversion")));
   249  
   250  static mlir::PassRegistration<TestLegalizePatternDriver>
   251      legalizer_pass("test-legalize-patterns",
   252                     "Run test dialect legalization patterns", [] {
   253                       return std::make_unique<TestLegalizePatternDriver>(
   254                           legalizerConversionMode);
   255                     });