github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/DialectConversion.cpp (about)

     1  //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Transforms/DialectConversion.h"
    19  #include "mlir/IR/Block.h"
    20  #include "mlir/IR/BlockAndValueMapping.h"
    21  #include "mlir/IR/Builders.h"
    22  #include "mlir/IR/Function.h"
    23  #include "mlir/IR/Module.h"
    24  #include "mlir/Transforms/Utils.h"
    25  #include "llvm/ADT/SetVector.h"
    26  #include "llvm/ADT/SmallPtrSet.h"
    27  #include "llvm/Support/Debug.h"
    28  #include "llvm/Support/raw_ostream.h"
    29  
    30  using namespace mlir;
    31  using namespace mlir::detail;
    32  
    33  #define DEBUG_TYPE "dialect-conversion"
    34  
    35  //===----------------------------------------------------------------------===//
    36  // ArgConverter
    37  //===----------------------------------------------------------------------===//
    38  namespace {
    39  /// This class provides a simple interface for converting the types of block
    40  /// arguments. This is done by inserting fake cast operations that map from the
    41  /// illegal type to the original type to allow for undoing pending rewrites in
    42  /// the case of failure.
    43  struct ArgConverter {
    44    ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
    45        : castOpName(kCastName, rewriter.getContext()),
    46          loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
    47          rewriter(rewriter) {}
    48  
    49    /// Erase any rewrites registered for arguments to blocks within the given
    50    /// region. This function is called when the given region is to be destroyed.
    51    void cancelPendingRewrites(Block *block);
    52  
    53    /// Cleanup and undo any generated conversions for the arguments of block.
    54    /// This method differs from 'cancelPendingRewrites' in that it returns the
    55    /// block signature to its original state.
    56    void discardPendingRewrites(Block *block);
    57  
    58    /// Replace usages of the cast operations with the argument directly.
    59    void applyRewrites();
    60  
    61    /// Return if the signature of the given block has already been converted.
    62    bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
    63  
    64    /// Attempt to convert the signature of the given block.
    65    LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping);
    66  
    67    /// Apply the given signature conversion on the given block.
    68    void applySignatureConversion(
    69        Block *block, TypeConverter::SignatureConversion &signatureConversion,
    70        BlockAndValueMapping &mapping);
    71  
    72    /// Convert the given block argument given the provided set of new argument
    73    /// values that are to replace it. This function returns the operation used
    74    /// to perform the conversion.
    75    Operation *convertArgument(BlockArgument *origArg,
    76                               ArrayRef<Value *> newValues,
    77                               BlockAndValueMapping &mapping);
    78  
    79    /// A utility function used to create a conversion cast operation with the
    80    /// given input and result types.
    81    Operation *createCast(ArrayRef<Value *> inputs, Type outputType);
    82  
    83    /// This is an operation name for a fake operation that is inserted during the
    84    /// conversion process. Operations of this type are guaranteed to never escape
    85    /// the converter.
    86    static constexpr StringLiteral kCastName = "__mlir_conversion.cast";
    87    OperationName castOpName;
    88  
    89    /// This is a collection of cast operations that were generated during the
    90    /// conversion process when converting the types of block arguments.
    91    llvm::MapVector<Block *, SmallVector<Operation *, 4>> argMapping;
    92  
    93    /// An instance of the unknown location that is used when generating
    94    /// producers.
    95    Location loc;
    96  
    97    /// The type converter to use when changing types.
    98    TypeConverter *typeConverter;
    99  
   100    /// The pattern rewriter to use when materializing conversions.
   101    PatternRewriter &rewriter;
   102  };
   103  } // end anonymous namespace
   104  
   105  constexpr StringLiteral ArgConverter::kCastName;
   106  
   107  /// Erase any rewrites registered for arguments to the given block.
   108  void ArgConverter::cancelPendingRewrites(Block *block) {
   109    auto it = argMapping.find(block);
   110    if (it == argMapping.end())
   111      return;
   112    for (auto *op : it->second) {
   113      op->dropAllDefinedValueUses();
   114      op->erase();
   115    }
   116    argMapping.erase(it);
   117  }
   118  
   119  /// Cleanup and undo any generated conversions for the arguments of block.
   120  /// This method differs from 'cancelPendingRewrites' in that it returns the
   121  /// block signature to its original state.
   122  void ArgConverter::discardPendingRewrites(Block *block) {
   123    auto it = argMapping.find(block);
   124    if (it == argMapping.end())
   125      return;
   126  
   127    // Erase all of the new arguments.
   128    for (int i = block->getNumArguments() - 1; i >= 0; --i) {
   129      block->getArgument(i)->dropAllUses();
   130      block->eraseArgument(i, /*updatePredTerms=*/false);
   131    }
   132  
   133    // Re-instate the old arguments.
   134    auto &mapping = it->second;
   135    for (unsigned i = 0, e = mapping.size(); i != e; ++i) {
   136      auto *op = mapping[i];
   137      auto *arg = block->addArgument(op->getResult(0)->getType());
   138      op->getResult(0)->replaceAllUsesWith(arg);
   139  
   140      // If this operation is within a block, it will be cleaned up automatically.
   141      if (!op->getBlock())
   142        op->erase();
   143    }
   144    argMapping.erase(it);
   145  }
   146  
   147  /// Replace usages of the cast operations with the argument directly.
   148  void ArgConverter::applyRewrites() {
   149    Block *block;
   150    ArrayRef<Operation *> argOps;
   151    for (auto &mapping : argMapping) {
   152      std::tie(block, argOps) = mapping;
   153  
   154      // Process the remapping for each of the original arguments.
   155      for (unsigned i = 0, e = argOps.size(); i != e; ++i) {
   156        auto *op = argOps[i];
   157  
   158        // Handle the case of a 1->N value mapping.
   159        if (op->getNumOperands() > 1) {
   160          // If all of the uses were removed, we can drop this op. Otherwise,
   161          // keep the operation alive and let the user handle any remaining
   162          // usages.
   163          if (op->use_empty())
   164            op->erase();
   165          continue;
   166        }
   167  
   168        // If mapping is 1-1, replace the remaining uses and drop the cast
   169        // operation.
   170        // FIXME(riverriddle) This should check that the result type and operand
   171        // type are the same, otherwise it should force a conversion to be
   172        // materialized. This works around a current limitation with regards to
   173        // region entry argument type conversion.
   174        if (op->getNumOperands() == 1) {
   175          op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
   176          op->destroy();
   177          continue;
   178        }
   179  
   180        // Otherwise, if there are any dangling uses then replace the fake
   181        // conversion operation with one generated by the type converter. This
   182        // is necessary as the cast must persist in the IR after conversion.
   183        auto *opResult = op->getResult(0);
   184        if (!opResult->use_empty()) {
   185          rewriter.setInsertionPointToStart(block);
   186          SmallVector<Value *, 1> operands(op->getOperands());
   187          auto *newOp = typeConverter->materializeConversion(
   188              rewriter, opResult->getType(), operands, op->getLoc());
   189          opResult->replaceAllUsesWith(newOp->getResult(0));
   190        }
   191        op->destroy();
   192      }
   193    }
   194  }
   195  
   196  /// Converts the signature of the given entry block.
   197  LogicalResult ArgConverter::convertSignature(Block *block,
   198                                               BlockAndValueMapping &mapping) {
   199    if (auto conversion = typeConverter->convertBlockSignature(block))
   200      return applySignatureConversion(block, *conversion, mapping), success();
   201    return failure();
   202  }
   203  
   204  /// Apply the given signature conversion on the given block.
   205  void ArgConverter::applySignatureConversion(
   206      Block *block, TypeConverter::SignatureConversion &signatureConversion,
   207      BlockAndValueMapping &mapping) {
   208    unsigned origArgCount = block->getNumArguments();
   209    auto convertedTypes = signatureConversion.getConvertedTypes();
   210    if (origArgCount == 0 && convertedTypes.empty())
   211      return;
   212  
   213    SmallVector<Value *, 4> newArgRange(block->addArguments(convertedTypes));
   214    ArrayRef<Value *> newArgRef(newArgRange);
   215  
   216    // Remap each of the original arguments as determined by the signature
   217    // conversion.
   218    auto &newArgMapping = argMapping[block];
   219    rewriter.setInsertionPointToStart(block);
   220    for (unsigned i = 0; i != origArgCount; ++i) {
   221      ArrayRef<Value *> remappedValues;
   222      if (auto inputMap = signatureConversion.getInputMapping(i))
   223        remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
   224  
   225      BlockArgument *arg = block->getArgument(i);
   226      newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
   227    }
   228  
   229    // Erase all of the original arguments.
   230    for (unsigned i = 0; i != origArgCount; ++i)
   231      block->eraseArgument(0, /*updatePredTerms=*/false);
   232  }
   233  
   234  /// Convert the given block argument given the provided set of new argument
   235  /// values that are to replace it. This function returns the operation used
   236  /// to perform the conversion.
   237  Operation *ArgConverter::convertArgument(BlockArgument *origArg,
   238                                           ArrayRef<Value *> newValues,
   239                                           BlockAndValueMapping &mapping) {
   240    // Handle the cases of 1->0 or 1->1 mappings.
   241    if (newValues.size() < 2) {
   242      // Create a temporary producer for the argument during the conversion
   243      // process.
   244      auto *cast = createCast(newValues, origArg->getType());
   245      origArg->replaceAllUsesWith(cast->getResult(0));
   246  
   247      // Insert a mapping between this argument and the one that is replacing
   248      // it.
   249      if (!newValues.empty())
   250        mapping.map(cast->getResult(0), newValues[0]);
   251      return cast;
   252    }
   253  
   254    // Otherwise, this is a 1->N mapping. Call into the provided type converter
   255    // to pack the new values.
   256    auto *cast = typeConverter->materializeConversion(
   257        rewriter, origArg->getType(), newValues, loc);
   258    assert(cast->getNumResults() == 1 &&
   259           cast->getNumOperands() == newValues.size());
   260    origArg->replaceAllUsesWith(cast->getResult(0));
   261    return cast;
   262  }
   263  
   264  /// A utility function used to create a conversion cast operation with the
   265  /// given input and result types.
   266  Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) {
   267    return Operation::create(loc, castOpName, inputs, outputType, llvm::None,
   268                             llvm::None, 0, false);
   269  }
   270  
   271  //===----------------------------------------------------------------------===//
   272  // ConversionPatternRewriterImpl
   273  //===----------------------------------------------------------------------===//
   274  namespace {
   275  /// This class contains a snapshot of the current conversion rewriter state.
   276  /// This is useful when saving and undoing a set of rewrites.
   277  struct RewriterState {
   278    RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
   279                  unsigned numBlockActions)
   280        : numCreatedOperations(numCreatedOperations),
   281          numReplacements(numReplacements), numBlockActions(numBlockActions) {}
   282  
   283    /// The current number of created operations.
   284    unsigned numCreatedOperations;
   285  
   286    /// The current number of replacements queued.
   287    unsigned numReplacements;
   288  
   289    /// The current number of block actions performed.
   290    unsigned numBlockActions;
   291  };
   292  } // end anonymous namespace
   293  
   294  namespace mlir {
   295  namespace detail {
   296  struct ConversionPatternRewriterImpl {
   297    /// This class represents one requested operation replacement via 'replaceOp'.
   298    struct OpReplacement {
   299      OpReplacement() = default;
   300      OpReplacement(Operation *op, ArrayRef<Value *> newValues)
   301          : op(op), newValues(newValues.begin(), newValues.end()) {}
   302  
   303      Operation *op;
   304      SmallVector<Value *, 2> newValues;
   305    };
   306  
   307    /// The kind of the block action performed during the rewrite.  Actions can be
   308    /// undone if the conversion fails.
   309    enum class BlockActionKind { Split, Move, TypeConversion };
   310  
   311    /// Original position of the given block in its parent region.  We cannot use
   312    /// a region iterator because it could have been invalidated by other region
   313    /// operations since the position was stored.
   314    struct BlockPosition {
   315      Region *region;
   316      Region::iterator::difference_type position;
   317    };
   318  
   319    /// The storage class for an undoable block action (one of BlockActionKind),
   320    /// contains the information necessary to undo this action.
   321    struct BlockAction {
   322      static BlockAction getSplit(Block *block, Block *originalBlock) {
   323        BlockAction action{BlockActionKind::Split, block, {}};
   324        action.originalBlock = originalBlock;
   325        return action;
   326      }
   327      static BlockAction getMove(Block *block, BlockPosition originalPos) {
   328        return {BlockActionKind::Move, block, {originalPos}};
   329      }
   330      static BlockAction getTypeConversion(Block *block) {
   331        return BlockAction{BlockActionKind::TypeConversion, block, {}};
   332      }
   333  
   334      // The action kind.
   335      BlockActionKind kind;
   336  
   337      // A pointer to the block that was created by the action.
   338      Block *block;
   339  
   340      union {
   341        // In use if kind == BlockActionKind::Move and contains a pointer to the
   342        // region that originally contained the block as well as the position of
   343        // the block in that region.
   344        BlockPosition originalPosition;
   345        // In use if kind == BlockActionKind::Split and contains a pointer to the
   346        // block that was split into two parts.
   347        Block *originalBlock;
   348      };
   349    };
   350  
   351    ConversionPatternRewriterImpl(PatternRewriter &rewriter,
   352                                  TypeConverter *converter)
   353        : argConverter(converter, rewriter) {}
   354  
   355    /// Return the current state of the rewriter.
   356    RewriterState getCurrentState();
   357  
   358    /// Reset the state of the rewriter to a previously saved point.
   359    void resetState(RewriterState state);
   360  
   361    /// Undo the block actions (motions, splits) one by one in reverse order until
   362    /// "numActionsToKeep" actions remains.
   363    void undoBlockActions(unsigned numActionsToKeep = 0);
   364  
   365    /// Cleanup and destroy any generated rewrite operations. This method is
   366    /// invoked when the conversion process fails.
   367    void discardRewrites();
   368  
   369    /// Apply all requested operation rewrites. This method is invoked when the
   370    /// conversion process succeeds.
   371    void applyRewrites();
   372  
   373    /// Convert the signature of the given block.
   374    LogicalResult convertBlockSignature(Block *block);
   375  
   376    /// Apply a signature conversion on the given region.
   377    void applySignatureConversion(Region *region,
   378                                  TypeConverter::SignatureConversion &conversion);
   379  
   380    /// PatternRewriter hook for replacing the results of an operation.
   381    void replaceOp(Operation *op, ArrayRef<Value *> newValues,
   382                   ArrayRef<Value *> valuesToRemoveIfDead);
   383  
   384    /// Notifies that a block was split.
   385    void notifySplitBlock(Block *block, Block *continuation);
   386  
   387    /// Notifies that the blocks of a region are about to be moved.
   388    void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
   389                                          Region::iterator before);
   390  
   391    /// Remap the given operands to those with potentially different types.
   392    void remapValues(Operation::operand_range operands,
   393                     SmallVectorImpl<Value *> &remapped);
   394  
   395    // Mapping between replaced values that differ in type. This happens when
   396    // replacing a value with one of a different type.
   397    BlockAndValueMapping mapping;
   398  
   399    /// Utility used to convert block arguments.
   400    ArgConverter argConverter;
   401  
   402    /// Ordered vector of all of the newly created operations during conversion.
   403    SmallVector<Operation *, 4> createdOps;
   404  
   405    /// Ordered vector of any requested operation replacements.
   406    SmallVector<OpReplacement, 4> replacements;
   407  
   408    /// Ordered list of block operations (creations, splits, motions).
   409    SmallVector<BlockAction, 4> blockActions;
   410  };
   411  } // end namespace detail
   412  } // end namespace mlir
   413  
   414  RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   415    return RewriterState(createdOps.size(), replacements.size(),
   416                         blockActions.size());
   417  }
   418  
   419  void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   420    // Undo any block actions.
   421    undoBlockActions(state.numBlockActions);
   422  
   423    // Reset any replaced operations and undo any saved mappings.
   424    for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
   425      for (auto *result : repl.op->getResults())
   426        mapping.erase(result);
   427    replacements.resize(state.numReplacements);
   428  
   429    // Pop all of the newly created operations.
   430    while (createdOps.size() != state.numCreatedOperations)
   431      createdOps.pop_back_val()->erase();
   432  }
   433  
   434  void ConversionPatternRewriterImpl::undoBlockActions(
   435      unsigned numActionsToKeep) {
   436    for (auto &action :
   437         llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) {
   438      switch (action.kind) {
   439      // Merge back the block that was split out.
   440      case BlockActionKind::Split: {
   441        action.originalBlock->getOperations().splice(
   442            action.originalBlock->end(), action.block->getOperations());
   443        action.block->erase();
   444        break;
   445      }
   446      // Move the block back to its original position.
   447      case BlockActionKind::Move: {
   448        Region *originalRegion = action.originalPosition.region;
   449        originalRegion->getBlocks().splice(
   450            std::next(originalRegion->begin(), action.originalPosition.position),
   451            action.block->getParent()->getBlocks(), action.block);
   452        break;
   453      }
   454      // Undo the type conversion.
   455      case BlockActionKind::TypeConversion: {
   456        argConverter.discardPendingRewrites(action.block);
   457        break;
   458      }
   459      }
   460    }
   461    blockActions.resize(numActionsToKeep);
   462  }
   463  
   464  void ConversionPatternRewriterImpl::discardRewrites() {
   465    undoBlockActions();
   466  
   467    // Remove any newly created ops.
   468    for (auto *op : createdOps) {
   469      op->dropAllDefinedValueUses();
   470      op->erase();
   471    }
   472  }
   473  
   474  void ConversionPatternRewriterImpl::applyRewrites() {
   475    // Apply all of the rewrites replacements requested during conversion.
   476    for (auto &repl : replacements) {
   477      for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
   478        repl.op->getResult(i)->replaceAllUsesWith(
   479            mapping.lookupOrDefault(repl.newValues[i]));
   480  
   481      // If this operation defines any regions, drop any pending argument
   482      // rewrites.
   483      if (argConverter.typeConverter && repl.op->getNumRegions()) {
   484        for (auto &region : repl.op->getRegions())
   485          for (auto &block : region)
   486            argConverter.cancelPendingRewrites(&block);
   487      }
   488    }
   489  
   490    // In a second pass, erase all of the replaced operations in reverse. This
   491    // allows processing nested operations before their parent region is
   492    // destroyed.
   493    for (auto &repl : llvm::reverse(replacements))
   494      repl.op->erase();
   495  
   496    argConverter.applyRewrites();
   497  }
   498  
   499  LogicalResult
   500  ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
   501    // Check to see if this block should not be converted:
   502    // * There is no type converter.
   503    // * The block has already been converted.
   504    // * This is an entry block, these are converted explicitly via patterns.
   505    if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
   506        block->isEntryBlock())
   507      return success();
   508  
   509    // Otherwise, try to convert the block signature.
   510    if (failed(argConverter.convertSignature(block, mapping)))
   511      return failure();
   512    blockActions.push_back(BlockAction::getTypeConversion(block));
   513    return success();
   514  }
   515  
   516  void ConversionPatternRewriterImpl::applySignatureConversion(
   517      Region *region, TypeConverter::SignatureConversion &conversion) {
   518    if (!region->empty()) {
   519      argConverter.applySignatureConversion(&region->front(), conversion,
   520                                            mapping);
   521      blockActions.push_back(BlockAction::getTypeConversion(&region->front()));
   522    }
   523  }
   524  
   525  void ConversionPatternRewriterImpl::replaceOp(
   526      Operation *op, ArrayRef<Value *> newValues,
   527      ArrayRef<Value *> valuesToRemoveIfDead) {
   528    assert(newValues.size() == op->getNumResults());
   529  
   530    // Create mappings for each of the new result values.
   531    for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
   532      assert((newValues[i] || op->getResult(i)->use_empty()) &&
   533             "result value has remaining uses that must be replaced");
   534      if (newValues[i])
   535        mapping.map(op->getResult(i), newValues[i]);
   536    }
   537  
   538    // Record the requested operation replacement.
   539    replacements.emplace_back(op, newValues);
   540  }
   541  
   542  void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
   543                                                       Block *continuation) {
   544    blockActions.push_back(BlockAction::getSplit(continuation, block));
   545  }
   546  
   547  void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
   548      Region &region, Region &parent, Region::iterator before) {
   549    for (auto &pair : llvm::enumerate(region)) {
   550      Block &block = pair.value();
   551      unsigned position = pair.index();
   552      blockActions.push_back(BlockAction::getMove(&block, {&region, position}));
   553    }
   554  }
   555  
   556  void ConversionPatternRewriterImpl::remapValues(
   557      Operation::operand_range operands, SmallVectorImpl<Value *> &remapped) {
   558    remapped.reserve(llvm::size(operands));
   559    for (Value *operand : operands)
   560      remapped.push_back(mapping.lookupOrDefault(operand));
   561  }
   562  
   563  //===----------------------------------------------------------------------===//
   564  // ConversionPatternRewriter
   565  //===----------------------------------------------------------------------===//
   566  
   567  ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
   568                                                       TypeConverter *converter)
   569      : PatternRewriter(ctx),
   570        impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
   571  ConversionPatternRewriter::~ConversionPatternRewriter() {}
   572  
   573  /// PatternRewriter hook for replacing the results of an operation.
   574  void ConversionPatternRewriter::replaceOp(
   575      Operation *op, ArrayRef<Value *> newValues,
   576      ArrayRef<Value *> valuesToRemoveIfDead) {
   577    impl->replaceOp(op, newValues, valuesToRemoveIfDead);
   578  }
   579  
   580  /// Apply a signature conversion to the entry block of the given region.
   581  void ConversionPatternRewriter::applySignatureConversion(
   582      Region *region, TypeConverter::SignatureConversion &conversion) {
   583    impl->applySignatureConversion(region, conversion);
   584  }
   585  
   586  /// Clone the given operation without cloning its regions.
   587  Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) {
   588    Operation *newOp = OpBuilder::cloneWithoutRegions(*op);
   589    impl->createdOps.push_back(newOp);
   590    return newOp;
   591  }
   592  
   593  /// PatternRewriter hook for splitting a block into two parts.
   594  Block *ConversionPatternRewriter::splitBlock(Block *block,
   595                                               Block::iterator before) {
   596    auto *continuation = PatternRewriter::splitBlock(block, before);
   597    impl->notifySplitBlock(block, continuation);
   598    return continuation;
   599  }
   600  
   601  /// PatternRewriter hook for moving blocks out of a region.
   602  void ConversionPatternRewriter::inlineRegionBefore(Region &region,
   603                                                     Region &parent,
   604                                                     Region::iterator before) {
   605    impl->notifyRegionIsBeingInlinedBefore(region, parent, before);
   606    PatternRewriter::inlineRegionBefore(region, parent, before);
   607  }
   608  
   609  /// PatternRewriter hook for creating a new operation.
   610  Operation *
   611  ConversionPatternRewriter::createOperation(const OperationState &state) {
   612    auto *result = OpBuilder::createOperation(state);
   613    impl->createdOps.push_back(result);
   614    return result;
   615  }
   616  
   617  /// PatternRewriter hook for updating the root operation in-place.
   618  void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
   619    // The rewriter caches changes to the IR to allow for operating in-place and
   620    // backtracking. The rewriter is currently not capable of backtracking
   621    // in-place modifications.
   622    llvm_unreachable("in-place operation updates are not supported");
   623  }
   624  
   625  /// Return a reference to the internal implementation.
   626  detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   627    return *impl;
   628  }
   629  
   630  //===----------------------------------------------------------------------===//
   631  // Conversion Patterns
   632  //===----------------------------------------------------------------------===//
   633  
   634  /// Attempt to match and rewrite the IR root at the specified operation.
   635  PatternMatchResult
   636  ConversionPattern::matchAndRewrite(Operation *op,
   637                                     PatternRewriter &rewriter) const {
   638    SmallVector<Value *, 4> operands;
   639    auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
   640    dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
   641  
   642    // If this operation has no successors, invoke the rewrite directly.
   643    if (op->getNumSuccessors() == 0)
   644      return matchAndRewrite(op, operands, dialectRewriter);
   645  
   646    // Otherwise, we need to remap the successors.
   647    SmallVector<Block *, 2> destinations;
   648    destinations.reserve(op->getNumSuccessors());
   649  
   650    SmallVector<ArrayRef<Value *>, 2> operandsPerDestination;
   651    unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
   652    for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
   653      destinations.push_back(op->getSuccessor(i));
   654  
   655      // Lookup the successors operands.
   656      unsigned n = op->getNumSuccessorOperands(i);
   657      operandsPerDestination.push_back(
   658          llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
   659      seen += n;
   660    }
   661  
   662    // Rewrite the operation.
   663    return matchAndRewrite(
   664        op,
   665        llvm::makeArrayRef(operands.data(),
   666                           operands.data() + firstSuccessorOperand),
   667        destinations, operandsPerDestination, dialectRewriter);
   668  }
   669  
   670  //===----------------------------------------------------------------------===//
   671  // OperationLegalizer
   672  //===----------------------------------------------------------------------===//
   673  
   674  namespace {
   675  /// A set of rewrite patterns that can be used to legalize a given operation.
   676  using LegalizationPatterns = SmallVector<RewritePattern *, 1>;
   677  
   678  /// This class defines a recursive operation legalizer.
   679  class OperationLegalizer {
   680  public:
   681    using LegalizationAction = ConversionTarget::LegalizationAction;
   682  
   683    OperationLegalizer(ConversionTarget &targetInfo,
   684                       const OwningRewritePatternList &patterns)
   685        : target(targetInfo) {
   686      buildLegalizationGraph(patterns);
   687      computeLegalizationGraphBenefit();
   688    }
   689  
   690    /// Returns if the given operation is known to be illegal on the target.
   691    bool isIllegal(Operation *op) const;
   692  
   693    /// Attempt to legalize the given operation. Returns success if the operation
   694    /// was legalized, failure otherwise.
   695    LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
   696  
   697  private:
   698    /// Attempt to legalize the given operation by applying the provided pattern.
   699    /// Returns success if the operation was legalized, failure otherwise.
   700    LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
   701                                  ConversionPatternRewriter &rewriter);
   702  
   703    /// Build an optimistic legalization graph given the provided patterns. This
   704    /// function populates 'legalizerPatterns' with the operations that are not
   705    /// directly legal, but may be transitively legal for the current target given
   706    /// the provided patterns.
   707    void buildLegalizationGraph(const OwningRewritePatternList &patterns);
   708  
   709    /// Compute the benefit of each node within the computed legalization graph.
   710    /// This orders the patterns within 'legalizerPatterns' based upon two
   711    /// criteria:
   712    ///  1) Prefer patterns that have the lowest legalization depth, i.e.
   713    ///     represent the more direct mapping to the target.
   714    ///  2) When comparing patterns with the same legalization depth, prefer the
   715    ///     pattern with the highest PatternBenefit. This allows for users to
   716    ///     prefer specific legalizations over others.
   717    void computeLegalizationGraphBenefit();
   718  
   719    /// The current set of patterns that have been applied.
   720    llvm::SmallPtrSet<RewritePattern *, 8> appliedPatterns;
   721  
   722    /// The set of legality information for operations transitively supported by
   723    /// the target.
   724    DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
   725  
   726    /// The legalization information provided by the target.
   727    ConversionTarget &target;
   728  };
   729  } // namespace
   730  
   731  bool OperationLegalizer::isIllegal(Operation *op) const {
   732    // Check if the target explicitly marked this operation as illegal.
   733    if (auto action = target.getOpAction(op->getName()))
   734      return action == LegalizationAction::Illegal;
   735    return false;
   736  }
   737  
   738  LogicalResult
   739  OperationLegalizer::legalize(Operation *op,
   740                               ConversionPatternRewriter &rewriter) {
   741    LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName()
   742                            << "\n");
   743  
   744    // Check if this operation is legal on the target.
   745    if (target.isLegal(op)) {
   746      LLVM_DEBUG(llvm::dbgs()
   747                 << "-- Success : Operation marked legal by the target\n");
   748      return success();
   749    }
   750  
   751    // Otherwise, we need to apply a legalization pattern to this operation.
   752    auto it = legalizerPatterns.find(op->getName());
   753    if (it == legalizerPatterns.end()) {
   754      LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n");
   755      return failure();
   756    }
   757  
   758    // The patterns are sorted by expected benefit, so try to apply each in-order.
   759    for (auto *pattern : it->second)
   760      if (succeeded(legalizePattern(op, pattern, rewriter)))
   761        return success();
   762  
   763    LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n");
   764    return failure();
   765  }
   766  
   767  LogicalResult
   768  OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
   769                                      ConversionPatternRewriter &rewriter) {
   770    LLVM_DEBUG({
   771      llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> (";
   772      interleaveComma(pattern->getGeneratedOps(), llvm::dbgs());
   773      llvm::dbgs() << ")'.\n";
   774    });
   775  
   776    // Ensure that we don't cycle by not allowing the same pattern to be
   777    // applied twice in the same recursion stack.
   778    // TODO(riverriddle) We could eventually converge, but that requires more
   779    // complicated analysis.
   780    if (!appliedPatterns.insert(pattern).second) {
   781      LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n");
   782      return failure();
   783    }
   784  
   785    auto &rewriterImpl = rewriter.getImpl();
   786    RewriterState curState = rewriterImpl.getCurrentState();
   787    auto cleanupFailure = [&] {
   788      // Reset the rewriter state and pop this pattern.
   789      rewriterImpl.resetState(curState);
   790      appliedPatterns.erase(pattern);
   791      return failure();
   792    };
   793  
   794    // Try to rewrite with the given pattern.
   795    rewriter.setInsertionPoint(op);
   796    if (!pattern->matchAndRewrite(op, rewriter)) {
   797      LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
   798      return cleanupFailure();
   799    }
   800  
   801    // If the pattern moved any blocks, try to legalize their types. This ensures
   802    // that the types of the block arguments are legal for the region they were
   803    // moved into.
   804    for (unsigned i = curState.numBlockActions,
   805                  e = rewriterImpl.blockActions.size();
   806         i != e; ++i) {
   807      auto &action = rewriterImpl.blockActions[i];
   808      if (action.kind != ConversionPatternRewriterImpl::BlockActionKind::Move)
   809        continue;
   810  
   811      // Convert the block signature.
   812      if (failed(rewriterImpl.convertBlockSignature(action.block))) {
   813        LLVM_DEBUG(llvm::dbgs()
   814                   << "-- FAIL: failed to convert types of moved block.\n");
   815        return cleanupFailure();
   816      }
   817    }
   818  
   819    // Recursively legalize each of the new operations.
   820    for (unsigned i = curState.numCreatedOperations,
   821                  e = rewriterImpl.createdOps.size();
   822         i != e; ++i) {
   823      if (failed(legalize(rewriterImpl.createdOps[i], rewriter))) {
   824        LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
   825        return cleanupFailure();
   826      }
   827    }
   828  
   829    appliedPatterns.erase(pattern);
   830    return success();
   831  }
   832  
   833  void OperationLegalizer::buildLegalizationGraph(
   834      const OwningRewritePatternList &patterns) {
   835    // A mapping between an operation and a set of operations that can be used to
   836    // generate it.
   837    DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
   838    // A mapping between an operation and any currently invalid patterns it has.
   839    DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns;
   840    // A worklist of patterns to consider for legality.
   841    llvm::SetVector<RewritePattern *> patternWorklist;
   842  
   843    // Build the mapping from operations to the parent ops that may generate them.
   844    for (auto &pattern : patterns) {
   845      auto root = pattern->getRootKind();
   846  
   847      // Skip operations that are always known to be legal.
   848      if (target.getOpAction(root) == LegalizationAction::Legal)
   849        continue;
   850  
   851      // Add this pattern to the invalid set for the root op and record this root
   852      // as a parent for any generated operations.
   853      invalidPatterns[root].insert(pattern.get());
   854      for (auto op : pattern->getGeneratedOps())
   855        parentOps[op].insert(root);
   856  
   857      // Add this pattern to the worklist.
   858      patternWorklist.insert(pattern.get());
   859    }
   860  
   861    while (!patternWorklist.empty()) {
   862      auto *pattern = patternWorklist.pop_back_val();
   863  
   864      // Check to see if any of the generated operations are invalid.
   865      if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) {
   866            auto action = target.getOpAction(op);
   867            return !legalizerPatterns.count(op) &&
   868                   (!action || action == LegalizationAction::Illegal);
   869          }))
   870        continue;
   871  
   872      // Otherwise, if all of the generated operation are valid, this op is now
   873      // legal so add all of the child patterns to the worklist.
   874      legalizerPatterns[pattern->getRootKind()].push_back(pattern);
   875      invalidPatterns[pattern->getRootKind()].erase(pattern);
   876  
   877      // Add any invalid patterns of the parent operations to see if they have now
   878      // become legal.
   879      for (auto op : parentOps[pattern->getRootKind()])
   880        patternWorklist.set_union(invalidPatterns[op]);
   881    }
   882  }
   883  
   884  void OperationLegalizer::computeLegalizationGraphBenefit() {
   885    // The smallest pattern depth, when legalizing an operation.
   886    DenseMap<OperationName, unsigned> minPatternDepth;
   887  
   888    // Compute the minimum legalization depth for a given operation.
   889    std::function<unsigned(OperationName)> computeDepth = [&](OperationName op) {
   890      // Check for existing depth.
   891      auto depthIt = minPatternDepth.find(op);
   892      if (depthIt != minPatternDepth.end())
   893        return depthIt->second;
   894  
   895      // If a mapping for this operation does not exist, then this operation
   896      // is always legal. Return 0 as the depth for a directly legal operation.
   897      auto opPatternsIt = legalizerPatterns.find(op);
   898      if (opPatternsIt == legalizerPatterns.end())
   899        return 0u;
   900  
   901      auto &minDepth = minPatternDepth[op];
   902      if (opPatternsIt->second.empty())
   903        return minDepth;
   904  
   905      // Initialize the depth to the maximum value.
   906      minDepth = std::numeric_limits<unsigned>::max();
   907  
   908      // Compute the depth for each pattern used to legalize this operation.
   909      SmallVector<std::pair<RewritePattern *, unsigned>, 4> patternsByDepth;
   910      patternsByDepth.reserve(opPatternsIt->second.size());
   911      for (RewritePattern *pattern : opPatternsIt->second) {
   912        unsigned depth = 0;
   913        for (auto generatedOp : pattern->getGeneratedOps())
   914          depth = std::max(depth, computeDepth(generatedOp) + 1);
   915        patternsByDepth.emplace_back(pattern, depth);
   916  
   917        // Update the min depth for this operation.
   918        minDepth = std::min(minDepth, depth);
   919      }
   920  
   921      // If the operation only has one legalization pattern, there is no need to
   922      // sort them.
   923      if (patternsByDepth.size() == 1)
   924        return minDepth;
   925  
   926      // Sort the patterns by those likely to be the most beneficial.
   927      llvm::array_pod_sort(
   928          patternsByDepth.begin(), patternsByDepth.end(),
   929          [](const std::pair<RewritePattern *, unsigned> *lhs,
   930             const std::pair<RewritePattern *, unsigned> *rhs) {
   931            // First sort by the smaller pattern legalization depth.
   932            if (lhs->second != rhs->second)
   933              return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
   934                                                               &rhs->second);
   935  
   936            // Then sort by the larger pattern benefit.
   937            auto lhsBenefit = lhs->first->getBenefit();
   938            auto rhsBenefit = rhs->first->getBenefit();
   939            return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
   940                                                                   &lhsBenefit);
   941          });
   942  
   943      // Update the legalization pattern to use the new sorted list.
   944      opPatternsIt->second.clear();
   945      for (auto &patternIt : patternsByDepth)
   946        opPatternsIt->second.push_back(patternIt.first);
   947  
   948      return minDepth;
   949    };
   950  
   951    // For each operation that is transitively legal, compute a cost for it.
   952    for (auto &opIt : legalizerPatterns)
   953      if (!minPatternDepth.count(opIt.first))
   954        computeDepth(opIt.first);
   955  }
   956  
   957  //===----------------------------------------------------------------------===//
   958  // OperationConverter
   959  //===----------------------------------------------------------------------===//
   960  namespace {
   961  enum OpConversionMode {
   962    // In this mode, the conversion will ignore failed conversions to allow
   963    // illegal operations to co-exist in the IR.
   964    Partial,
   965  
   966    // In this mode, all operations must be legal for the given target for the
   967    // conversion to succeeed.
   968    Full,
   969  
   970    // In this mode, operations are analyzed for legality. No actual rewrites are
   971    // applied to the operations on success.
   972    Analysis,
   973  };
   974  
   975  // This class converts operations to a given conversion target via a set of
   976  // rewrite patterns. The conversion behaves differently depending on the
   977  // conversion mode.
   978  struct OperationConverter {
   979    explicit OperationConverter(ConversionTarget &target,
   980                                const OwningRewritePatternList &patterns,
   981                                OpConversionMode mode,
   982                                DenseSet<Operation *> *legalizableOps = nullptr)
   983        : opLegalizer(target, patterns), mode(mode),
   984          legalizableOps(legalizableOps) {}
   985  
   986    /// Converts the given operations to the conversion target.
   987    LogicalResult convertOperations(ArrayRef<Operation *> ops,
   988                                    TypeConverter *typeConverter);
   989  
   990  private:
   991    /// Converts an operation with the given rewriter.
   992    LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
   993  
   994    /// Recursively collect all of the operations to convert from within 'region'.
   995    LogicalResult computeConversionSet(Region &region,
   996                                       std::vector<Operation *> &toConvert);
   997  
   998    /// Converts the type signatures of the blocks nested within 'op'.
   999    LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
  1000                                         Operation *op);
  1001  
  1002    /// The legalizer to use when converting operations.
  1003    OperationLegalizer opLegalizer;
  1004  
  1005    /// The conversion mode to use when legalizing operations.
  1006    OpConversionMode mode;
  1007  
  1008    /// A set of pre-existing operations that were found to be legalizable to the
  1009    /// target. This field is only used when mode == OpConversionMode::Analysis.
  1010    DenseSet<Operation *> *legalizableOps;
  1011  };
  1012  } // end anonymous namespace
  1013  
  1014  LogicalResult
  1015  OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
  1016                                             Operation *op) {
  1017    // Check to see if type signatures need to be converted.
  1018    if (!rewriter.getImpl().argConverter.typeConverter)
  1019      return success();
  1020  
  1021    for (auto &region : op->getRegions()) {
  1022      for (auto &block : region)
  1023        if (failed(rewriter.getImpl().convertBlockSignature(&block)))
  1024          return failure();
  1025    }
  1026    return success();
  1027  }
  1028  
  1029  LogicalResult
  1030  OperationConverter::computeConversionSet(Region &region,
  1031                                           std::vector<Operation *> &toConvert) {
  1032    if (region.empty())
  1033      return success();
  1034  
  1035    // Traverse starting from the entry block.
  1036    SmallVector<Block *, 16> worklist(1, &region.front());
  1037    DenseSet<Block *> visitedBlocks;
  1038    visitedBlocks.insert(&region.front());
  1039    while (!worklist.empty()) {
  1040      auto *block = worklist.pop_back_val();
  1041  
  1042      // Compute the conversion set of each of the nested operations.
  1043      for (auto &op : *block) {
  1044        toConvert.emplace_back(&op);
  1045        for (auto &region : op.getRegions())
  1046          computeConversionSet(region, toConvert);
  1047      }
  1048  
  1049      // Recurse to children that haven't been visited.
  1050      for (Block *succ : block->getSuccessors())
  1051        if (visitedBlocks.insert(succ).second)
  1052          worklist.push_back(succ);
  1053    }
  1054  
  1055    // Check that all blocks in the region were visited.
  1056    if (llvm::any_of(llvm::drop_begin(region.getBlocks(), 1),
  1057                     [&](Block &block) { return !visitedBlocks.count(&block); }))
  1058      return emitError(region.getLoc(), "unreachable blocks were not converted");
  1059    return success();
  1060  }
  1061  
  1062  LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
  1063                                            Operation *op) {
  1064    // Legalize the given operation.
  1065    if (failed(opLegalizer.legalize(op, rewriter))) {
  1066      // Handle the case of a failed conversion for each of the different modes.
  1067      /// Full conversions expect all operations to be converted.
  1068      if (mode == OpConversionMode::Full)
  1069        return op->emitError()
  1070               << "failed to legalize operation '" << op->getName() << "'";
  1071      /// Partial conversions allow conversions to fail iff the operation was not
  1072      /// explicitly marked as illegal.
  1073      if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op))
  1074        return op->emitError()
  1075               << "failed to legalize operation '" << op->getName()
  1076               << "' that was explicitly marked illegal";
  1077    } else {
  1078      /// Analysis conversions don't fail if any operations fail to legalize,
  1079      /// they are only interested in the operations that were successfully
  1080      /// legalized.
  1081      if (mode == OpConversionMode::Analysis)
  1082        legalizableOps->insert(op);
  1083  
  1084      // If legalization succeeded, convert the types any of the blocks within
  1085      // this operation.
  1086      if (failed(convertBlockSignatures(rewriter, op)))
  1087        return failure();
  1088    }
  1089    return success();
  1090  }
  1091  
  1092  LogicalResult
  1093  OperationConverter::convertOperations(ArrayRef<Operation *> ops,
  1094                                        TypeConverter *typeConverter) {
  1095    if (ops.empty())
  1096      return success();
  1097  
  1098    /// Compute the set of operations and blocks to convert.
  1099    std::vector<Operation *> toConvert;
  1100    for (auto *op : ops) {
  1101      toConvert.emplace_back(op);
  1102      for (auto &region : op->getRegions())
  1103        if (failed(computeConversionSet(region, toConvert)))
  1104          return failure();
  1105    }
  1106  
  1107    // Convert each operation and discard rewrites on failure.
  1108    ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
  1109    for (auto *op : toConvert)
  1110      if (failed(convert(rewriter, op)))
  1111        return rewriter.getImpl().discardRewrites(), failure();
  1112  
  1113    // Otherwise, the body conversion succeeded. Apply rewrites if this is not an
  1114    // analysis conversion.
  1115    if (mode == OpConversionMode::Analysis)
  1116      rewriter.getImpl().discardRewrites();
  1117    else
  1118      rewriter.getImpl().applyRewrites();
  1119    return success();
  1120  }
  1121  
  1122  //===----------------------------------------------------------------------===//
  1123  // Type Conversion
  1124  //===----------------------------------------------------------------------===//
  1125  
  1126  /// Remap an input of the original signature with a new set of types. The
  1127  /// new types are appended to the new signature conversion.
  1128  void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
  1129                                                     ArrayRef<Type> types) {
  1130    assert(!types.empty() && "expected valid types");
  1131    remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size());
  1132    addInputs(types);
  1133  }
  1134  
  1135  /// Append new input types to the signature conversion, this should only be
  1136  /// used if the new types are not intended to remap an existing input.
  1137  void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
  1138    assert(!types.empty() &&
  1139           "1->0 type remappings don't need to be added explicitly");
  1140    argTypes.append(types.begin(), types.end());
  1141  }
  1142  
  1143  /// Remap an input of the original signature with a range of types in the
  1144  /// new signature.
  1145  void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
  1146                                                      unsigned newInputNo,
  1147                                                      unsigned newInputCount) {
  1148    assert(!remappedInputs[origInputNo] && "input has already been remapped");
  1149    assert(newInputCount != 0 && "expected valid input count");
  1150    remappedInputs[origInputNo] = InputMapping{newInputNo, newInputCount};
  1151  }
  1152  
  1153  /// This hooks allows for converting a type.
  1154  LogicalResult TypeConverter::convertType(Type t,
  1155                                           SmallVectorImpl<Type> &results) {
  1156    if (auto newT = convertType(t)) {
  1157      results.push_back(newT);
  1158      return success();
  1159    }
  1160    return failure();
  1161  }
  1162  
  1163  /// Convert the given set of types, filling 'results' as necessary. This
  1164  /// returns failure if the conversion of any of the types fails, success
  1165  /// otherwise.
  1166  LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
  1167                                            SmallVectorImpl<Type> &results) {
  1168    for (auto type : types)
  1169      if (failed(convertType(type, results)))
  1170        return failure();
  1171    return success();
  1172  }
  1173  
  1174  /// Return true if the given type is legal for this type converter, i.e. the
  1175  /// type converts to itself.
  1176  bool TypeConverter::isLegal(Type type) {
  1177    SmallVector<Type, 1> results;
  1178    return succeeded(convertType(type, results)) && results.size() == 1 &&
  1179           results.front() == type;
  1180  }
  1181  
  1182  /// Return true if the inputs and outputs of the given function type are
  1183  /// legal.
  1184  bool TypeConverter::isSignatureLegal(FunctionType funcType) {
  1185    return llvm::all_of(
  1186        llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
  1187        [this](Type type) { return isLegal(type); });
  1188  }
  1189  
  1190  /// This hook allows for converting a specific argument of a signature.
  1191  LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
  1192                                                   SignatureConversion &result) {
  1193    // Try to convert the given input type.
  1194    SmallVector<Type, 1> convertedTypes;
  1195    if (failed(convertType(type, convertedTypes)))
  1196      return failure();
  1197  
  1198    // If this argument is being dropped, there is nothing left to do.
  1199    if (convertedTypes.empty())
  1200      return success();
  1201  
  1202    // Otherwise, add the new inputs.
  1203    result.addInputs(inputNo, convertedTypes);
  1204    return success();
  1205  }
  1206  
  1207  /// Create a default conversion pattern that rewrites the type signature of a
  1208  /// FuncOp.
  1209  namespace {
  1210  struct FuncOpSignatureConversion : public ConversionPattern {
  1211    FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
  1212        : ConversionPattern(FuncOp::getOperationName(), 1, ctx),
  1213          converter(converter) {}
  1214  
  1215    /// Hook for derived classes to implement combined matching and rewriting.
  1216    PatternMatchResult
  1217    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
  1218                    ConversionPatternRewriter &rewriter) const override {
  1219      auto funcOp = cast<FuncOp>(op);
  1220      FunctionType type = funcOp.getType();
  1221  
  1222      // Convert the original function arguments.
  1223      TypeConverter::SignatureConversion result(type.getNumInputs());
  1224      for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
  1225        if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
  1226          return matchFailure();
  1227  
  1228      // Convert the original function results.
  1229      SmallVector<Type, 1> convertedResults;
  1230      if (failed(converter.convertTypes(type.getResults(), convertedResults)))
  1231        return matchFailure();
  1232  
  1233      // Create a new function with an updated signature.
  1234      auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
  1235      rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
  1236                                  newFuncOp.end());
  1237      newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
  1238                                          convertedResults, funcOp.getContext()));
  1239  
  1240      // Tell the rewriter to convert the region signature.
  1241      rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
  1242      rewriter.replaceOp(op, llvm::None);
  1243      return matchSuccess();
  1244    }
  1245  
  1246    /// The type converter to use when rewriting the signature.
  1247    TypeConverter &converter;
  1248  };
  1249  } // end anonymous namespace
  1250  
  1251  void mlir::populateFuncOpTypeConversionPattern(
  1252      OwningRewritePatternList &patterns, MLIRContext *ctx,
  1253      TypeConverter &converter) {
  1254    patterns.insert<FuncOpSignatureConversion>(ctx, converter);
  1255  }
  1256  
  1257  /// This function converts the type signature of the given block, by invoking
  1258  /// 'convertSignatureArg' for each argument. This function should return a valid
  1259  /// conversion for the signature on success, None otherwise.
  1260  auto TypeConverter::convertBlockSignature(Block *block)
  1261      -> llvm::Optional<SignatureConversion> {
  1262    SignatureConversion conversion(block->getNumArguments());
  1263    for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i)
  1264      if (failed(convertSignatureArg(i, block->getArgument(i)->getType(),
  1265                                     conversion)))
  1266        return llvm::None;
  1267    return conversion;
  1268  }
  1269  
  1270  //===----------------------------------------------------------------------===//
  1271  // ConversionTarget
  1272  //===----------------------------------------------------------------------===//
  1273  
  1274  /// Register a legality action for the given operation.
  1275  void ConversionTarget::setOpAction(OperationName op,
  1276                                     LegalizationAction action) {
  1277    legalOperations[op] = action;
  1278  }
  1279  
  1280  /// Register a legality action for the given dialects.
  1281  void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
  1282                                          LegalizationAction action) {
  1283    for (StringRef dialect : dialectNames)
  1284      legalDialects[dialect] = action;
  1285  }
  1286  
  1287  /// Get the legality action for the given operation.
  1288  auto ConversionTarget::getOpAction(OperationName op) const
  1289      -> llvm::Optional<LegalizationAction> {
  1290    // Check for an action for this specific operation.
  1291    auto it = legalOperations.find(op);
  1292    if (it != legalOperations.end())
  1293      return it->second;
  1294    // Otherwise, default to checking for an action on the parent dialect.
  1295    auto dialectIt = legalDialects.find(op.getDialect());
  1296    if (dialectIt != legalDialects.end())
  1297      return dialectIt->second;
  1298    return llvm::None;
  1299  }
  1300  
  1301  /// Return if the given operation instance is legal on this target.
  1302  bool ConversionTarget::isLegal(Operation *op) const {
  1303    auto action = getOpAction(op->getName());
  1304  
  1305    // Handle dynamic legality.
  1306    if (action == LegalizationAction::Dynamic) {
  1307      // Check for callbacks on the operation or dialect.
  1308      auto opFn = opLegalityFns.find(op->getName());
  1309      if (opFn != opLegalityFns.end())
  1310        return opFn->second(op);
  1311      auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
  1312      if (dialectFn != dialectLegalityFns.end())
  1313        return dialectFn->second(op);
  1314  
  1315      // Otherwise, invoke the hook on the derived instance.
  1316      return isDynamicallyLegal(op);
  1317    }
  1318  
  1319    // Otherwise, the operation is only legal if it was marked 'Legal'.
  1320    return action == LegalizationAction::Legal;
  1321  }
  1322  
  1323  /// Set the dynamic legality callback for the given operation.
  1324  void ConversionTarget::setLegalityCallback(
  1325      OperationName name, const DynamicLegalityCallbackFn &callback) {
  1326    assert(callback && "expected valid legality callback");
  1327    opLegalityFns[name] = callback;
  1328  }
  1329  
  1330  /// Set the dynamic legality callback for the given dialects.
  1331  void ConversionTarget::setLegalityCallback(
  1332      ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
  1333    assert(callback && "expected valid legality callback");
  1334    for (StringRef dialect : dialects)
  1335      dialectLegalityFns[dialect] = callback;
  1336  }
  1337  
  1338  //===----------------------------------------------------------------------===//
  1339  // Op Conversion Entry Points
  1340  //===----------------------------------------------------------------------===//
  1341  
  1342  /// Apply a partial conversion on the given operations, and all nested
  1343  /// operations. This method converts as many operations to the target as
  1344  /// possible, ignoring operations that failed to legalize.
  1345  LogicalResult mlir::applyPartialConversion(
  1346      ArrayRef<Operation *> ops, ConversionTarget &target,
  1347      const OwningRewritePatternList &patterns, TypeConverter *converter) {
  1348    OperationConverter opConverter(target, patterns, OpConversionMode::Partial);
  1349    return opConverter.convertOperations(ops, converter);
  1350  }
  1351  LogicalResult
  1352  mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
  1353                               const OwningRewritePatternList &patterns,
  1354                               TypeConverter *converter) {
  1355    return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
  1356                                  converter);
  1357  }
  1358  
  1359  /// Apply a complete conversion on the given operations, and all nested
  1360  /// operations. This method will return failure if the conversion of any
  1361  /// operation fails.
  1362  LogicalResult
  1363  mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
  1364                            const OwningRewritePatternList &patterns,
  1365                            TypeConverter *converter) {
  1366    OperationConverter opConverter(target, patterns, OpConversionMode::Full);
  1367    return opConverter.convertOperations(ops, converter);
  1368  }
  1369  LogicalResult
  1370  mlir::applyFullConversion(Operation *op, ConversionTarget &target,
  1371                            const OwningRewritePatternList &patterns,
  1372                            TypeConverter *converter) {
  1373    return applyFullConversion(llvm::makeArrayRef(op), target, patterns,
  1374                               converter);
  1375  }
  1376  
  1377  /// Apply an analysis conversion on the given operations, and all nested
  1378  /// operations. This method analyzes which operations would be successfully
  1379  /// converted to the target if a conversion was applied. All operations that
  1380  /// were found to be legalizable to the given 'target' are placed within the
  1381  /// provided 'convertedOps' set; note that no actual rewrites are applied to the
  1382  /// operations on success and only pre-existing operations are added to the set.
  1383  LogicalResult mlir::applyAnalysisConversion(
  1384      ArrayRef<Operation *> ops, ConversionTarget &target,
  1385      const OwningRewritePatternList &patterns,
  1386      DenseSet<Operation *> &convertedOps, TypeConverter *converter) {
  1387    OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
  1388                                   &convertedOps);
  1389    return opConverter.convertOperations(ops, converter);
  1390  }
  1391  LogicalResult
  1392  mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
  1393                                const OwningRewritePatternList &patterns,
  1394                                DenseSet<Operation *> &convertedOps,
  1395                                TypeConverter *converter) {
  1396    return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
  1397                                   convertedOps, converter);
  1398  }