github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/DialectConversion.cpp (about) 1 //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "mlir/IR/Block.h" 20 #include "mlir/IR/BlockAndValueMapping.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/Transforms/Utils.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/ADT/SmallPtrSet.h" 27 #include "llvm/Support/Debug.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 using namespace mlir; 31 using namespace mlir::detail; 32 33 #define DEBUG_TYPE "dialect-conversion" 34 35 //===----------------------------------------------------------------------===// 36 // ArgConverter 37 //===----------------------------------------------------------------------===// 38 namespace { 39 /// This class provides a simple interface for converting the types of block 40 /// arguments. This is done by inserting fake cast operations that map from the 41 /// illegal type to the original type to allow for undoing pending rewrites in 42 /// the case of failure. 43 struct ArgConverter { 44 ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter) 45 : castOpName(kCastName, rewriter.getContext()), 46 loc(rewriter.getUnknownLoc()), typeConverter(typeConverter), 47 rewriter(rewriter) {} 48 49 /// Erase any rewrites registered for arguments to blocks within the given 50 /// region. This function is called when the given region is to be destroyed. 51 void cancelPendingRewrites(Block *block); 52 53 /// Cleanup and undo any generated conversions for the arguments of block. 54 /// This method differs from 'cancelPendingRewrites' in that it returns the 55 /// block signature to its original state. 56 void discardPendingRewrites(Block *block); 57 58 /// Replace usages of the cast operations with the argument directly. 59 void applyRewrites(); 60 61 /// Return if the signature of the given block has already been converted. 62 bool hasBeenConverted(Block *block) const { return argMapping.count(block); } 63 64 /// Attempt to convert the signature of the given block. 65 LogicalResult convertSignature(Block *block, BlockAndValueMapping &mapping); 66 67 /// Apply the given signature conversion on the given block. 68 void applySignatureConversion( 69 Block *block, TypeConverter::SignatureConversion &signatureConversion, 70 BlockAndValueMapping &mapping); 71 72 /// Convert the given block argument given the provided set of new argument 73 /// values that are to replace it. This function returns the operation used 74 /// to perform the conversion. 75 Operation *convertArgument(BlockArgument *origArg, 76 ArrayRef<Value *> newValues, 77 BlockAndValueMapping &mapping); 78 79 /// A utility function used to create a conversion cast operation with the 80 /// given input and result types. 81 Operation *createCast(ArrayRef<Value *> inputs, Type outputType); 82 83 /// This is an operation name for a fake operation that is inserted during the 84 /// conversion process. Operations of this type are guaranteed to never escape 85 /// the converter. 86 static constexpr StringLiteral kCastName = "__mlir_conversion.cast"; 87 OperationName castOpName; 88 89 /// This is a collection of cast operations that were generated during the 90 /// conversion process when converting the types of block arguments. 91 llvm::MapVector<Block *, SmallVector<Operation *, 4>> argMapping; 92 93 /// An instance of the unknown location that is used when generating 94 /// producers. 95 Location loc; 96 97 /// The type converter to use when changing types. 98 TypeConverter *typeConverter; 99 100 /// The pattern rewriter to use when materializing conversions. 101 PatternRewriter &rewriter; 102 }; 103 } // end anonymous namespace 104 105 constexpr StringLiteral ArgConverter::kCastName; 106 107 /// Erase any rewrites registered for arguments to the given block. 108 void ArgConverter::cancelPendingRewrites(Block *block) { 109 auto it = argMapping.find(block); 110 if (it == argMapping.end()) 111 return; 112 for (auto *op : it->second) { 113 op->dropAllDefinedValueUses(); 114 op->erase(); 115 } 116 argMapping.erase(it); 117 } 118 119 /// Cleanup and undo any generated conversions for the arguments of block. 120 /// This method differs from 'cancelPendingRewrites' in that it returns the 121 /// block signature to its original state. 122 void ArgConverter::discardPendingRewrites(Block *block) { 123 auto it = argMapping.find(block); 124 if (it == argMapping.end()) 125 return; 126 127 // Erase all of the new arguments. 128 for (int i = block->getNumArguments() - 1; i >= 0; --i) { 129 block->getArgument(i)->dropAllUses(); 130 block->eraseArgument(i, /*updatePredTerms=*/false); 131 } 132 133 // Re-instate the old arguments. 134 auto &mapping = it->second; 135 for (unsigned i = 0, e = mapping.size(); i != e; ++i) { 136 auto *op = mapping[i]; 137 auto *arg = block->addArgument(op->getResult(0)->getType()); 138 op->getResult(0)->replaceAllUsesWith(arg); 139 140 // If this operation is within a block, it will be cleaned up automatically. 141 if (!op->getBlock()) 142 op->erase(); 143 } 144 argMapping.erase(it); 145 } 146 147 /// Replace usages of the cast operations with the argument directly. 148 void ArgConverter::applyRewrites() { 149 Block *block; 150 ArrayRef<Operation *> argOps; 151 for (auto &mapping : argMapping) { 152 std::tie(block, argOps) = mapping; 153 154 // Process the remapping for each of the original arguments. 155 for (unsigned i = 0, e = argOps.size(); i != e; ++i) { 156 auto *op = argOps[i]; 157 158 // Handle the case of a 1->N value mapping. 159 if (op->getNumOperands() > 1) { 160 // If all of the uses were removed, we can drop this op. Otherwise, 161 // keep the operation alive and let the user handle any remaining 162 // usages. 163 if (op->use_empty()) 164 op->erase(); 165 continue; 166 } 167 168 // If mapping is 1-1, replace the remaining uses and drop the cast 169 // operation. 170 // FIXME(riverriddle) This should check that the result type and operand 171 // type are the same, otherwise it should force a conversion to be 172 // materialized. This works around a current limitation with regards to 173 // region entry argument type conversion. 174 if (op->getNumOperands() == 1) { 175 op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); 176 op->destroy(); 177 continue; 178 } 179 180 // Otherwise, if there are any dangling uses then replace the fake 181 // conversion operation with one generated by the type converter. This 182 // is necessary as the cast must persist in the IR after conversion. 183 auto *opResult = op->getResult(0); 184 if (!opResult->use_empty()) { 185 rewriter.setInsertionPointToStart(block); 186 SmallVector<Value *, 1> operands(op->getOperands()); 187 auto *newOp = typeConverter->materializeConversion( 188 rewriter, opResult->getType(), operands, op->getLoc()); 189 opResult->replaceAllUsesWith(newOp->getResult(0)); 190 } 191 op->destroy(); 192 } 193 } 194 } 195 196 /// Converts the signature of the given entry block. 197 LogicalResult ArgConverter::convertSignature(Block *block, 198 BlockAndValueMapping &mapping) { 199 if (auto conversion = typeConverter->convertBlockSignature(block)) 200 return applySignatureConversion(block, *conversion, mapping), success(); 201 return failure(); 202 } 203 204 /// Apply the given signature conversion on the given block. 205 void ArgConverter::applySignatureConversion( 206 Block *block, TypeConverter::SignatureConversion &signatureConversion, 207 BlockAndValueMapping &mapping) { 208 unsigned origArgCount = block->getNumArguments(); 209 auto convertedTypes = signatureConversion.getConvertedTypes(); 210 if (origArgCount == 0 && convertedTypes.empty()) 211 return; 212 213 SmallVector<Value *, 4> newArgRange(block->addArguments(convertedTypes)); 214 ArrayRef<Value *> newArgRef(newArgRange); 215 216 // Remap each of the original arguments as determined by the signature 217 // conversion. 218 auto &newArgMapping = argMapping[block]; 219 rewriter.setInsertionPointToStart(block); 220 for (unsigned i = 0; i != origArgCount; ++i) { 221 ArrayRef<Value *> remappedValues; 222 if (auto inputMap = signatureConversion.getInputMapping(i)) 223 remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size); 224 225 BlockArgument *arg = block->getArgument(i); 226 newArgMapping.push_back(convertArgument(arg, remappedValues, mapping)); 227 } 228 229 // Erase all of the original arguments. 230 for (unsigned i = 0; i != origArgCount; ++i) 231 block->eraseArgument(0, /*updatePredTerms=*/false); 232 } 233 234 /// Convert the given block argument given the provided set of new argument 235 /// values that are to replace it. This function returns the operation used 236 /// to perform the conversion. 237 Operation *ArgConverter::convertArgument(BlockArgument *origArg, 238 ArrayRef<Value *> newValues, 239 BlockAndValueMapping &mapping) { 240 // Handle the cases of 1->0 or 1->1 mappings. 241 if (newValues.size() < 2) { 242 // Create a temporary producer for the argument during the conversion 243 // process. 244 auto *cast = createCast(newValues, origArg->getType()); 245 origArg->replaceAllUsesWith(cast->getResult(0)); 246 247 // Insert a mapping between this argument and the one that is replacing 248 // it. 249 if (!newValues.empty()) 250 mapping.map(cast->getResult(0), newValues[0]); 251 return cast; 252 } 253 254 // Otherwise, this is a 1->N mapping. Call into the provided type converter 255 // to pack the new values. 256 auto *cast = typeConverter->materializeConversion( 257 rewriter, origArg->getType(), newValues, loc); 258 assert(cast->getNumResults() == 1 && 259 cast->getNumOperands() == newValues.size()); 260 origArg->replaceAllUsesWith(cast->getResult(0)); 261 return cast; 262 } 263 264 /// A utility function used to create a conversion cast operation with the 265 /// given input and result types. 266 Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) { 267 return Operation::create(loc, castOpName, inputs, outputType, llvm::None, 268 llvm::None, 0, false); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // ConversionPatternRewriterImpl 273 //===----------------------------------------------------------------------===// 274 namespace { 275 /// This class contains a snapshot of the current conversion rewriter state. 276 /// This is useful when saving and undoing a set of rewrites. 277 struct RewriterState { 278 RewriterState(unsigned numCreatedOperations, unsigned numReplacements, 279 unsigned numBlockActions) 280 : numCreatedOperations(numCreatedOperations), 281 numReplacements(numReplacements), numBlockActions(numBlockActions) {} 282 283 /// The current number of created operations. 284 unsigned numCreatedOperations; 285 286 /// The current number of replacements queued. 287 unsigned numReplacements; 288 289 /// The current number of block actions performed. 290 unsigned numBlockActions; 291 }; 292 } // end anonymous namespace 293 294 namespace mlir { 295 namespace detail { 296 struct ConversionPatternRewriterImpl { 297 /// This class represents one requested operation replacement via 'replaceOp'. 298 struct OpReplacement { 299 OpReplacement() = default; 300 OpReplacement(Operation *op, ArrayRef<Value *> newValues) 301 : op(op), newValues(newValues.begin(), newValues.end()) {} 302 303 Operation *op; 304 SmallVector<Value *, 2> newValues; 305 }; 306 307 /// The kind of the block action performed during the rewrite. Actions can be 308 /// undone if the conversion fails. 309 enum class BlockActionKind { Split, Move, TypeConversion }; 310 311 /// Original position of the given block in its parent region. We cannot use 312 /// a region iterator because it could have been invalidated by other region 313 /// operations since the position was stored. 314 struct BlockPosition { 315 Region *region; 316 Region::iterator::difference_type position; 317 }; 318 319 /// The storage class for an undoable block action (one of BlockActionKind), 320 /// contains the information necessary to undo this action. 321 struct BlockAction { 322 static BlockAction getSplit(Block *block, Block *originalBlock) { 323 BlockAction action{BlockActionKind::Split, block, {}}; 324 action.originalBlock = originalBlock; 325 return action; 326 } 327 static BlockAction getMove(Block *block, BlockPosition originalPos) { 328 return {BlockActionKind::Move, block, {originalPos}}; 329 } 330 static BlockAction getTypeConversion(Block *block) { 331 return BlockAction{BlockActionKind::TypeConversion, block, {}}; 332 } 333 334 // The action kind. 335 BlockActionKind kind; 336 337 // A pointer to the block that was created by the action. 338 Block *block; 339 340 union { 341 // In use if kind == BlockActionKind::Move and contains a pointer to the 342 // region that originally contained the block as well as the position of 343 // the block in that region. 344 BlockPosition originalPosition; 345 // In use if kind == BlockActionKind::Split and contains a pointer to the 346 // block that was split into two parts. 347 Block *originalBlock; 348 }; 349 }; 350 351 ConversionPatternRewriterImpl(PatternRewriter &rewriter, 352 TypeConverter *converter) 353 : argConverter(converter, rewriter) {} 354 355 /// Return the current state of the rewriter. 356 RewriterState getCurrentState(); 357 358 /// Reset the state of the rewriter to a previously saved point. 359 void resetState(RewriterState state); 360 361 /// Undo the block actions (motions, splits) one by one in reverse order until 362 /// "numActionsToKeep" actions remains. 363 void undoBlockActions(unsigned numActionsToKeep = 0); 364 365 /// Cleanup and destroy any generated rewrite operations. This method is 366 /// invoked when the conversion process fails. 367 void discardRewrites(); 368 369 /// Apply all requested operation rewrites. This method is invoked when the 370 /// conversion process succeeds. 371 void applyRewrites(); 372 373 /// Convert the signature of the given block. 374 LogicalResult convertBlockSignature(Block *block); 375 376 /// Apply a signature conversion on the given region. 377 void applySignatureConversion(Region *region, 378 TypeConverter::SignatureConversion &conversion); 379 380 /// PatternRewriter hook for replacing the results of an operation. 381 void replaceOp(Operation *op, ArrayRef<Value *> newValues, 382 ArrayRef<Value *> valuesToRemoveIfDead); 383 384 /// Notifies that a block was split. 385 void notifySplitBlock(Block *block, Block *continuation); 386 387 /// Notifies that the blocks of a region are about to be moved. 388 void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, 389 Region::iterator before); 390 391 /// Remap the given operands to those with potentially different types. 392 void remapValues(Operation::operand_range operands, 393 SmallVectorImpl<Value *> &remapped); 394 395 // Mapping between replaced values that differ in type. This happens when 396 // replacing a value with one of a different type. 397 BlockAndValueMapping mapping; 398 399 /// Utility used to convert block arguments. 400 ArgConverter argConverter; 401 402 /// Ordered vector of all of the newly created operations during conversion. 403 SmallVector<Operation *, 4> createdOps; 404 405 /// Ordered vector of any requested operation replacements. 406 SmallVector<OpReplacement, 4> replacements; 407 408 /// Ordered list of block operations (creations, splits, motions). 409 SmallVector<BlockAction, 4> blockActions; 410 }; 411 } // end namespace detail 412 } // end namespace mlir 413 414 RewriterState ConversionPatternRewriterImpl::getCurrentState() { 415 return RewriterState(createdOps.size(), replacements.size(), 416 blockActions.size()); 417 } 418 419 void ConversionPatternRewriterImpl::resetState(RewriterState state) { 420 // Undo any block actions. 421 undoBlockActions(state.numBlockActions); 422 423 // Reset any replaced operations and undo any saved mappings. 424 for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) 425 for (auto *result : repl.op->getResults()) 426 mapping.erase(result); 427 replacements.resize(state.numReplacements); 428 429 // Pop all of the newly created operations. 430 while (createdOps.size() != state.numCreatedOperations) 431 createdOps.pop_back_val()->erase(); 432 } 433 434 void ConversionPatternRewriterImpl::undoBlockActions( 435 unsigned numActionsToKeep) { 436 for (auto &action : 437 llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { 438 switch (action.kind) { 439 // Merge back the block that was split out. 440 case BlockActionKind::Split: { 441 action.originalBlock->getOperations().splice( 442 action.originalBlock->end(), action.block->getOperations()); 443 action.block->erase(); 444 break; 445 } 446 // Move the block back to its original position. 447 case BlockActionKind::Move: { 448 Region *originalRegion = action.originalPosition.region; 449 originalRegion->getBlocks().splice( 450 std::next(originalRegion->begin(), action.originalPosition.position), 451 action.block->getParent()->getBlocks(), action.block); 452 break; 453 } 454 // Undo the type conversion. 455 case BlockActionKind::TypeConversion: { 456 argConverter.discardPendingRewrites(action.block); 457 break; 458 } 459 } 460 } 461 blockActions.resize(numActionsToKeep); 462 } 463 464 void ConversionPatternRewriterImpl::discardRewrites() { 465 undoBlockActions(); 466 467 // Remove any newly created ops. 468 for (auto *op : createdOps) { 469 op->dropAllDefinedValueUses(); 470 op->erase(); 471 } 472 } 473 474 void ConversionPatternRewriterImpl::applyRewrites() { 475 // Apply all of the rewrites replacements requested during conversion. 476 for (auto &repl : replacements) { 477 for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) 478 repl.op->getResult(i)->replaceAllUsesWith( 479 mapping.lookupOrDefault(repl.newValues[i])); 480 481 // If this operation defines any regions, drop any pending argument 482 // rewrites. 483 if (argConverter.typeConverter && repl.op->getNumRegions()) { 484 for (auto ®ion : repl.op->getRegions()) 485 for (auto &block : region) 486 argConverter.cancelPendingRewrites(&block); 487 } 488 } 489 490 // In a second pass, erase all of the replaced operations in reverse. This 491 // allows processing nested operations before their parent region is 492 // destroyed. 493 for (auto &repl : llvm::reverse(replacements)) 494 repl.op->erase(); 495 496 argConverter.applyRewrites(); 497 } 498 499 LogicalResult 500 ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { 501 // Check to see if this block should not be converted: 502 // * There is no type converter. 503 // * The block has already been converted. 504 // * This is an entry block, these are converted explicitly via patterns. 505 if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) || 506 block->isEntryBlock()) 507 return success(); 508 509 // Otherwise, try to convert the block signature. 510 if (failed(argConverter.convertSignature(block, mapping))) 511 return failure(); 512 blockActions.push_back(BlockAction::getTypeConversion(block)); 513 return success(); 514 } 515 516 void ConversionPatternRewriterImpl::applySignatureConversion( 517 Region *region, TypeConverter::SignatureConversion &conversion) { 518 if (!region->empty()) { 519 argConverter.applySignatureConversion(®ion->front(), conversion, 520 mapping); 521 blockActions.push_back(BlockAction::getTypeConversion(®ion->front())); 522 } 523 } 524 525 void ConversionPatternRewriterImpl::replaceOp( 526 Operation *op, ArrayRef<Value *> newValues, 527 ArrayRef<Value *> valuesToRemoveIfDead) { 528 assert(newValues.size() == op->getNumResults()); 529 530 // Create mappings for each of the new result values. 531 for (unsigned i = 0, e = newValues.size(); i < e; ++i) { 532 assert((newValues[i] || op->getResult(i)->use_empty()) && 533 "result value has remaining uses that must be replaced"); 534 if (newValues[i]) 535 mapping.map(op->getResult(i), newValues[i]); 536 } 537 538 // Record the requested operation replacement. 539 replacements.emplace_back(op, newValues); 540 } 541 542 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, 543 Block *continuation) { 544 blockActions.push_back(BlockAction::getSplit(continuation, block)); 545 } 546 547 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( 548 Region ®ion, Region &parent, Region::iterator before) { 549 for (auto &pair : llvm::enumerate(region)) { 550 Block &block = pair.value(); 551 unsigned position = pair.index(); 552 blockActions.push_back(BlockAction::getMove(&block, {®ion, position})); 553 } 554 } 555 556 void ConversionPatternRewriterImpl::remapValues( 557 Operation::operand_range operands, SmallVectorImpl<Value *> &remapped) { 558 remapped.reserve(llvm::size(operands)); 559 for (Value *operand : operands) 560 remapped.push_back(mapping.lookupOrDefault(operand)); 561 } 562 563 //===----------------------------------------------------------------------===// 564 // ConversionPatternRewriter 565 //===----------------------------------------------------------------------===// 566 567 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx, 568 TypeConverter *converter) 569 : PatternRewriter(ctx), 570 impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {} 571 ConversionPatternRewriter::~ConversionPatternRewriter() {} 572 573 /// PatternRewriter hook for replacing the results of an operation. 574 void ConversionPatternRewriter::replaceOp( 575 Operation *op, ArrayRef<Value *> newValues, 576 ArrayRef<Value *> valuesToRemoveIfDead) { 577 impl->replaceOp(op, newValues, valuesToRemoveIfDead); 578 } 579 580 /// Apply a signature conversion to the entry block of the given region. 581 void ConversionPatternRewriter::applySignatureConversion( 582 Region *region, TypeConverter::SignatureConversion &conversion) { 583 impl->applySignatureConversion(region, conversion); 584 } 585 586 /// Clone the given operation without cloning its regions. 587 Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) { 588 Operation *newOp = OpBuilder::cloneWithoutRegions(*op); 589 impl->createdOps.push_back(newOp); 590 return newOp; 591 } 592 593 /// PatternRewriter hook for splitting a block into two parts. 594 Block *ConversionPatternRewriter::splitBlock(Block *block, 595 Block::iterator before) { 596 auto *continuation = PatternRewriter::splitBlock(block, before); 597 impl->notifySplitBlock(block, continuation); 598 return continuation; 599 } 600 601 /// PatternRewriter hook for moving blocks out of a region. 602 void ConversionPatternRewriter::inlineRegionBefore(Region ®ion, 603 Region &parent, 604 Region::iterator before) { 605 impl->notifyRegionIsBeingInlinedBefore(region, parent, before); 606 PatternRewriter::inlineRegionBefore(region, parent, before); 607 } 608 609 /// PatternRewriter hook for creating a new operation. 610 Operation * 611 ConversionPatternRewriter::createOperation(const OperationState &state) { 612 auto *result = OpBuilder::createOperation(state); 613 impl->createdOps.push_back(result); 614 return result; 615 } 616 617 /// PatternRewriter hook for updating the root operation in-place. 618 void ConversionPatternRewriter::notifyRootUpdated(Operation *op) { 619 // The rewriter caches changes to the IR to allow for operating in-place and 620 // backtracking. The rewriter is currently not capable of backtracking 621 // in-place modifications. 622 llvm_unreachable("in-place operation updates are not supported"); 623 } 624 625 /// Return a reference to the internal implementation. 626 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { 627 return *impl; 628 } 629 630 //===----------------------------------------------------------------------===// 631 // Conversion Patterns 632 //===----------------------------------------------------------------------===// 633 634 /// Attempt to match and rewrite the IR root at the specified operation. 635 PatternMatchResult 636 ConversionPattern::matchAndRewrite(Operation *op, 637 PatternRewriter &rewriter) const { 638 SmallVector<Value *, 4> operands; 639 auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); 640 dialectRewriter.getImpl().remapValues(op->getOperands(), operands); 641 642 // If this operation has no successors, invoke the rewrite directly. 643 if (op->getNumSuccessors() == 0) 644 return matchAndRewrite(op, operands, dialectRewriter); 645 646 // Otherwise, we need to remap the successors. 647 SmallVector<Block *, 2> destinations; 648 destinations.reserve(op->getNumSuccessors()); 649 650 SmallVector<ArrayRef<Value *>, 2> operandsPerDestination; 651 unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0); 652 for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) { 653 destinations.push_back(op->getSuccessor(i)); 654 655 // Lookup the successors operands. 656 unsigned n = op->getNumSuccessorOperands(i); 657 operandsPerDestination.push_back( 658 llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n)); 659 seen += n; 660 } 661 662 // Rewrite the operation. 663 return matchAndRewrite( 664 op, 665 llvm::makeArrayRef(operands.data(), 666 operands.data() + firstSuccessorOperand), 667 destinations, operandsPerDestination, dialectRewriter); 668 } 669 670 //===----------------------------------------------------------------------===// 671 // OperationLegalizer 672 //===----------------------------------------------------------------------===// 673 674 namespace { 675 /// A set of rewrite patterns that can be used to legalize a given operation. 676 using LegalizationPatterns = SmallVector<RewritePattern *, 1>; 677 678 /// This class defines a recursive operation legalizer. 679 class OperationLegalizer { 680 public: 681 using LegalizationAction = ConversionTarget::LegalizationAction; 682 683 OperationLegalizer(ConversionTarget &targetInfo, 684 const OwningRewritePatternList &patterns) 685 : target(targetInfo) { 686 buildLegalizationGraph(patterns); 687 computeLegalizationGraphBenefit(); 688 } 689 690 /// Returns if the given operation is known to be illegal on the target. 691 bool isIllegal(Operation *op) const; 692 693 /// Attempt to legalize the given operation. Returns success if the operation 694 /// was legalized, failure otherwise. 695 LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); 696 697 private: 698 /// Attempt to legalize the given operation by applying the provided pattern. 699 /// Returns success if the operation was legalized, failure otherwise. 700 LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, 701 ConversionPatternRewriter &rewriter); 702 703 /// Build an optimistic legalization graph given the provided patterns. This 704 /// function populates 'legalizerPatterns' with the operations that are not 705 /// directly legal, but may be transitively legal for the current target given 706 /// the provided patterns. 707 void buildLegalizationGraph(const OwningRewritePatternList &patterns); 708 709 /// Compute the benefit of each node within the computed legalization graph. 710 /// This orders the patterns within 'legalizerPatterns' based upon two 711 /// criteria: 712 /// 1) Prefer patterns that have the lowest legalization depth, i.e. 713 /// represent the more direct mapping to the target. 714 /// 2) When comparing patterns with the same legalization depth, prefer the 715 /// pattern with the highest PatternBenefit. This allows for users to 716 /// prefer specific legalizations over others. 717 void computeLegalizationGraphBenefit(); 718 719 /// The current set of patterns that have been applied. 720 llvm::SmallPtrSet<RewritePattern *, 8> appliedPatterns; 721 722 /// The set of legality information for operations transitively supported by 723 /// the target. 724 DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; 725 726 /// The legalization information provided by the target. 727 ConversionTarget ⌖ 728 }; 729 } // namespace 730 731 bool OperationLegalizer::isIllegal(Operation *op) const { 732 // Check if the target explicitly marked this operation as illegal. 733 if (auto action = target.getOpAction(op->getName())) 734 return action == LegalizationAction::Illegal; 735 return false; 736 } 737 738 LogicalResult 739 OperationLegalizer::legalize(Operation *op, 740 ConversionPatternRewriter &rewriter) { 741 LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() 742 << "\n"); 743 744 // Check if this operation is legal on the target. 745 if (target.isLegal(op)) { 746 LLVM_DEBUG(llvm::dbgs() 747 << "-- Success : Operation marked legal by the target\n"); 748 return success(); 749 } 750 751 // Otherwise, we need to apply a legalization pattern to this operation. 752 auto it = legalizerPatterns.find(op->getName()); 753 if (it == legalizerPatterns.end()) { 754 LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n"); 755 return failure(); 756 } 757 758 // The patterns are sorted by expected benefit, so try to apply each in-order. 759 for (auto *pattern : it->second) 760 if (succeeded(legalizePattern(op, pattern, rewriter))) 761 return success(); 762 763 LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n"); 764 return failure(); 765 } 766 767 LogicalResult 768 OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, 769 ConversionPatternRewriter &rewriter) { 770 LLVM_DEBUG({ 771 llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> ("; 772 interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); 773 llvm::dbgs() << ")'.\n"; 774 }); 775 776 // Ensure that we don't cycle by not allowing the same pattern to be 777 // applied twice in the same recursion stack. 778 // TODO(riverriddle) We could eventually converge, but that requires more 779 // complicated analysis. 780 if (!appliedPatterns.insert(pattern).second) { 781 LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n"); 782 return failure(); 783 } 784 785 auto &rewriterImpl = rewriter.getImpl(); 786 RewriterState curState = rewriterImpl.getCurrentState(); 787 auto cleanupFailure = [&] { 788 // Reset the rewriter state and pop this pattern. 789 rewriterImpl.resetState(curState); 790 appliedPatterns.erase(pattern); 791 return failure(); 792 }; 793 794 // Try to rewrite with the given pattern. 795 rewriter.setInsertionPoint(op); 796 if (!pattern->matchAndRewrite(op, rewriter)) { 797 LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n"); 798 return cleanupFailure(); 799 } 800 801 // If the pattern moved any blocks, try to legalize their types. This ensures 802 // that the types of the block arguments are legal for the region they were 803 // moved into. 804 for (unsigned i = curState.numBlockActions, 805 e = rewriterImpl.blockActions.size(); 806 i != e; ++i) { 807 auto &action = rewriterImpl.blockActions[i]; 808 if (action.kind != ConversionPatternRewriterImpl::BlockActionKind::Move) 809 continue; 810 811 // Convert the block signature. 812 if (failed(rewriterImpl.convertBlockSignature(action.block))) { 813 LLVM_DEBUG(llvm::dbgs() 814 << "-- FAIL: failed to convert types of moved block.\n"); 815 return cleanupFailure(); 816 } 817 } 818 819 // Recursively legalize each of the new operations. 820 for (unsigned i = curState.numCreatedOperations, 821 e = rewriterImpl.createdOps.size(); 822 i != e; ++i) { 823 if (failed(legalize(rewriterImpl.createdOps[i], rewriter))) { 824 LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n"); 825 return cleanupFailure(); 826 } 827 } 828 829 appliedPatterns.erase(pattern); 830 return success(); 831 } 832 833 void OperationLegalizer::buildLegalizationGraph( 834 const OwningRewritePatternList &patterns) { 835 // A mapping between an operation and a set of operations that can be used to 836 // generate it. 837 DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; 838 // A mapping between an operation and any currently invalid patterns it has. 839 DenseMap<OperationName, SmallPtrSet<RewritePattern *, 2>> invalidPatterns; 840 // A worklist of patterns to consider for legality. 841 llvm::SetVector<RewritePattern *> patternWorklist; 842 843 // Build the mapping from operations to the parent ops that may generate them. 844 for (auto &pattern : patterns) { 845 auto root = pattern->getRootKind(); 846 847 // Skip operations that are always known to be legal. 848 if (target.getOpAction(root) == LegalizationAction::Legal) 849 continue; 850 851 // Add this pattern to the invalid set for the root op and record this root 852 // as a parent for any generated operations. 853 invalidPatterns[root].insert(pattern.get()); 854 for (auto op : pattern->getGeneratedOps()) 855 parentOps[op].insert(root); 856 857 // Add this pattern to the worklist. 858 patternWorklist.insert(pattern.get()); 859 } 860 861 while (!patternWorklist.empty()) { 862 auto *pattern = patternWorklist.pop_back_val(); 863 864 // Check to see if any of the generated operations are invalid. 865 if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { 866 auto action = target.getOpAction(op); 867 return !legalizerPatterns.count(op) && 868 (!action || action == LegalizationAction::Illegal); 869 })) 870 continue; 871 872 // Otherwise, if all of the generated operation are valid, this op is now 873 // legal so add all of the child patterns to the worklist. 874 legalizerPatterns[pattern->getRootKind()].push_back(pattern); 875 invalidPatterns[pattern->getRootKind()].erase(pattern); 876 877 // Add any invalid patterns of the parent operations to see if they have now 878 // become legal. 879 for (auto op : parentOps[pattern->getRootKind()]) 880 patternWorklist.set_union(invalidPatterns[op]); 881 } 882 } 883 884 void OperationLegalizer::computeLegalizationGraphBenefit() { 885 // The smallest pattern depth, when legalizing an operation. 886 DenseMap<OperationName, unsigned> minPatternDepth; 887 888 // Compute the minimum legalization depth for a given operation. 889 std::function<unsigned(OperationName)> computeDepth = [&](OperationName op) { 890 // Check for existing depth. 891 auto depthIt = minPatternDepth.find(op); 892 if (depthIt != minPatternDepth.end()) 893 return depthIt->second; 894 895 // If a mapping for this operation does not exist, then this operation 896 // is always legal. Return 0 as the depth for a directly legal operation. 897 auto opPatternsIt = legalizerPatterns.find(op); 898 if (opPatternsIt == legalizerPatterns.end()) 899 return 0u; 900 901 auto &minDepth = minPatternDepth[op]; 902 if (opPatternsIt->second.empty()) 903 return minDepth; 904 905 // Initialize the depth to the maximum value. 906 minDepth = std::numeric_limits<unsigned>::max(); 907 908 // Compute the depth for each pattern used to legalize this operation. 909 SmallVector<std::pair<RewritePattern *, unsigned>, 4> patternsByDepth; 910 patternsByDepth.reserve(opPatternsIt->second.size()); 911 for (RewritePattern *pattern : opPatternsIt->second) { 912 unsigned depth = 0; 913 for (auto generatedOp : pattern->getGeneratedOps()) 914 depth = std::max(depth, computeDepth(generatedOp) + 1); 915 patternsByDepth.emplace_back(pattern, depth); 916 917 // Update the min depth for this operation. 918 minDepth = std::min(minDepth, depth); 919 } 920 921 // If the operation only has one legalization pattern, there is no need to 922 // sort them. 923 if (patternsByDepth.size() == 1) 924 return minDepth; 925 926 // Sort the patterns by those likely to be the most beneficial. 927 llvm::array_pod_sort( 928 patternsByDepth.begin(), patternsByDepth.end(), 929 [](const std::pair<RewritePattern *, unsigned> *lhs, 930 const std::pair<RewritePattern *, unsigned> *rhs) { 931 // First sort by the smaller pattern legalization depth. 932 if (lhs->second != rhs->second) 933 return llvm::array_pod_sort_comparator<unsigned>(&lhs->second, 934 &rhs->second); 935 936 // Then sort by the larger pattern benefit. 937 auto lhsBenefit = lhs->first->getBenefit(); 938 auto rhsBenefit = rhs->first->getBenefit(); 939 return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit, 940 &lhsBenefit); 941 }); 942 943 // Update the legalization pattern to use the new sorted list. 944 opPatternsIt->second.clear(); 945 for (auto &patternIt : patternsByDepth) 946 opPatternsIt->second.push_back(patternIt.first); 947 948 return minDepth; 949 }; 950 951 // For each operation that is transitively legal, compute a cost for it. 952 for (auto &opIt : legalizerPatterns) 953 if (!minPatternDepth.count(opIt.first)) 954 computeDepth(opIt.first); 955 } 956 957 //===----------------------------------------------------------------------===// 958 // OperationConverter 959 //===----------------------------------------------------------------------===// 960 namespace { 961 enum OpConversionMode { 962 // In this mode, the conversion will ignore failed conversions to allow 963 // illegal operations to co-exist in the IR. 964 Partial, 965 966 // In this mode, all operations must be legal for the given target for the 967 // conversion to succeeed. 968 Full, 969 970 // In this mode, operations are analyzed for legality. No actual rewrites are 971 // applied to the operations on success. 972 Analysis, 973 }; 974 975 // This class converts operations to a given conversion target via a set of 976 // rewrite patterns. The conversion behaves differently depending on the 977 // conversion mode. 978 struct OperationConverter { 979 explicit OperationConverter(ConversionTarget &target, 980 const OwningRewritePatternList &patterns, 981 OpConversionMode mode, 982 DenseSet<Operation *> *legalizableOps = nullptr) 983 : opLegalizer(target, patterns), mode(mode), 984 legalizableOps(legalizableOps) {} 985 986 /// Converts the given operations to the conversion target. 987 LogicalResult convertOperations(ArrayRef<Operation *> ops, 988 TypeConverter *typeConverter); 989 990 private: 991 /// Converts an operation with the given rewriter. 992 LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); 993 994 /// Recursively collect all of the operations to convert from within 'region'. 995 LogicalResult computeConversionSet(Region ®ion, 996 std::vector<Operation *> &toConvert); 997 998 /// Converts the type signatures of the blocks nested within 'op'. 999 LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter, 1000 Operation *op); 1001 1002 /// The legalizer to use when converting operations. 1003 OperationLegalizer opLegalizer; 1004 1005 /// The conversion mode to use when legalizing operations. 1006 OpConversionMode mode; 1007 1008 /// A set of pre-existing operations that were found to be legalizable to the 1009 /// target. This field is only used when mode == OpConversionMode::Analysis. 1010 DenseSet<Operation *> *legalizableOps; 1011 }; 1012 } // end anonymous namespace 1013 1014 LogicalResult 1015 OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter, 1016 Operation *op) { 1017 // Check to see if type signatures need to be converted. 1018 if (!rewriter.getImpl().argConverter.typeConverter) 1019 return success(); 1020 1021 for (auto ®ion : op->getRegions()) { 1022 for (auto &block : region) 1023 if (failed(rewriter.getImpl().convertBlockSignature(&block))) 1024 return failure(); 1025 } 1026 return success(); 1027 } 1028 1029 LogicalResult 1030 OperationConverter::computeConversionSet(Region ®ion, 1031 std::vector<Operation *> &toConvert) { 1032 if (region.empty()) 1033 return success(); 1034 1035 // Traverse starting from the entry block. 1036 SmallVector<Block *, 16> worklist(1, ®ion.front()); 1037 DenseSet<Block *> visitedBlocks; 1038 visitedBlocks.insert(®ion.front()); 1039 while (!worklist.empty()) { 1040 auto *block = worklist.pop_back_val(); 1041 1042 // Compute the conversion set of each of the nested operations. 1043 for (auto &op : *block) { 1044 toConvert.emplace_back(&op); 1045 for (auto ®ion : op.getRegions()) 1046 computeConversionSet(region, toConvert); 1047 } 1048 1049 // Recurse to children that haven't been visited. 1050 for (Block *succ : block->getSuccessors()) 1051 if (visitedBlocks.insert(succ).second) 1052 worklist.push_back(succ); 1053 } 1054 1055 // Check that all blocks in the region were visited. 1056 if (llvm::any_of(llvm::drop_begin(region.getBlocks(), 1), 1057 [&](Block &block) { return !visitedBlocks.count(&block); })) 1058 return emitError(region.getLoc(), "unreachable blocks were not converted"); 1059 return success(); 1060 } 1061 1062 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, 1063 Operation *op) { 1064 // Legalize the given operation. 1065 if (failed(opLegalizer.legalize(op, rewriter))) { 1066 // Handle the case of a failed conversion for each of the different modes. 1067 /// Full conversions expect all operations to be converted. 1068 if (mode == OpConversionMode::Full) 1069 return op->emitError() 1070 << "failed to legalize operation '" << op->getName() << "'"; 1071 /// Partial conversions allow conversions to fail iff the operation was not 1072 /// explicitly marked as illegal. 1073 if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op)) 1074 return op->emitError() 1075 << "failed to legalize operation '" << op->getName() 1076 << "' that was explicitly marked illegal"; 1077 } else { 1078 /// Analysis conversions don't fail if any operations fail to legalize, 1079 /// they are only interested in the operations that were successfully 1080 /// legalized. 1081 if (mode == OpConversionMode::Analysis) 1082 legalizableOps->insert(op); 1083 1084 // If legalization succeeded, convert the types any of the blocks within 1085 // this operation. 1086 if (failed(convertBlockSignatures(rewriter, op))) 1087 return failure(); 1088 } 1089 return success(); 1090 } 1091 1092 LogicalResult 1093 OperationConverter::convertOperations(ArrayRef<Operation *> ops, 1094 TypeConverter *typeConverter) { 1095 if (ops.empty()) 1096 return success(); 1097 1098 /// Compute the set of operations and blocks to convert. 1099 std::vector<Operation *> toConvert; 1100 for (auto *op : ops) { 1101 toConvert.emplace_back(op); 1102 for (auto ®ion : op->getRegions()) 1103 if (failed(computeConversionSet(region, toConvert))) 1104 return failure(); 1105 } 1106 1107 // Convert each operation and discard rewrites on failure. 1108 ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter); 1109 for (auto *op : toConvert) 1110 if (failed(convert(rewriter, op))) 1111 return rewriter.getImpl().discardRewrites(), failure(); 1112 1113 // Otherwise, the body conversion succeeded. Apply rewrites if this is not an 1114 // analysis conversion. 1115 if (mode == OpConversionMode::Analysis) 1116 rewriter.getImpl().discardRewrites(); 1117 else 1118 rewriter.getImpl().applyRewrites(); 1119 return success(); 1120 } 1121 1122 //===----------------------------------------------------------------------===// 1123 // Type Conversion 1124 //===----------------------------------------------------------------------===// 1125 1126 /// Remap an input of the original signature with a new set of types. The 1127 /// new types are appended to the new signature conversion. 1128 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, 1129 ArrayRef<Type> types) { 1130 assert(!types.empty() && "expected valid types"); 1131 remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); 1132 addInputs(types); 1133 } 1134 1135 /// Append new input types to the signature conversion, this should only be 1136 /// used if the new types are not intended to remap an existing input. 1137 void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { 1138 assert(!types.empty() && 1139 "1->0 type remappings don't need to be added explicitly"); 1140 argTypes.append(types.begin(), types.end()); 1141 } 1142 1143 /// Remap an input of the original signature with a range of types in the 1144 /// new signature. 1145 void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 1146 unsigned newInputNo, 1147 unsigned newInputCount) { 1148 assert(!remappedInputs[origInputNo] && "input has already been remapped"); 1149 assert(newInputCount != 0 && "expected valid input count"); 1150 remappedInputs[origInputNo] = InputMapping{newInputNo, newInputCount}; 1151 } 1152 1153 /// This hooks allows for converting a type. 1154 LogicalResult TypeConverter::convertType(Type t, 1155 SmallVectorImpl<Type> &results) { 1156 if (auto newT = convertType(t)) { 1157 results.push_back(newT); 1158 return success(); 1159 } 1160 return failure(); 1161 } 1162 1163 /// Convert the given set of types, filling 'results' as necessary. This 1164 /// returns failure if the conversion of any of the types fails, success 1165 /// otherwise. 1166 LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types, 1167 SmallVectorImpl<Type> &results) { 1168 for (auto type : types) 1169 if (failed(convertType(type, results))) 1170 return failure(); 1171 return success(); 1172 } 1173 1174 /// Return true if the given type is legal for this type converter, i.e. the 1175 /// type converts to itself. 1176 bool TypeConverter::isLegal(Type type) { 1177 SmallVector<Type, 1> results; 1178 return succeeded(convertType(type, results)) && results.size() == 1 && 1179 results.front() == type; 1180 } 1181 1182 /// Return true if the inputs and outputs of the given function type are 1183 /// legal. 1184 bool TypeConverter::isSignatureLegal(FunctionType funcType) { 1185 return llvm::all_of( 1186 llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()), 1187 [this](Type type) { return isLegal(type); }); 1188 } 1189 1190 /// This hook allows for converting a specific argument of a signature. 1191 LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, 1192 SignatureConversion &result) { 1193 // Try to convert the given input type. 1194 SmallVector<Type, 1> convertedTypes; 1195 if (failed(convertType(type, convertedTypes))) 1196 return failure(); 1197 1198 // If this argument is being dropped, there is nothing left to do. 1199 if (convertedTypes.empty()) 1200 return success(); 1201 1202 // Otherwise, add the new inputs. 1203 result.addInputs(inputNo, convertedTypes); 1204 return success(); 1205 } 1206 1207 /// Create a default conversion pattern that rewrites the type signature of a 1208 /// FuncOp. 1209 namespace { 1210 struct FuncOpSignatureConversion : public ConversionPattern { 1211 FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) 1212 : ConversionPattern(FuncOp::getOperationName(), 1, ctx), 1213 converter(converter) {} 1214 1215 /// Hook for derived classes to implement combined matching and rewriting. 1216 PatternMatchResult 1217 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 1218 ConversionPatternRewriter &rewriter) const override { 1219 auto funcOp = cast<FuncOp>(op); 1220 FunctionType type = funcOp.getType(); 1221 1222 // Convert the original function arguments. 1223 TypeConverter::SignatureConversion result(type.getNumInputs()); 1224 for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) 1225 if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) 1226 return matchFailure(); 1227 1228 // Convert the original function results. 1229 SmallVector<Type, 1> convertedResults; 1230 if (failed(converter.convertTypes(type.getResults(), convertedResults))) 1231 return matchFailure(); 1232 1233 // Create a new function with an updated signature. 1234 auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); 1235 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1236 newFuncOp.end()); 1237 newFuncOp.setType(FunctionType::get(result.getConvertedTypes(), 1238 convertedResults, funcOp.getContext())); 1239 1240 // Tell the rewriter to convert the region signature. 1241 rewriter.applySignatureConversion(&newFuncOp.getBody(), result); 1242 rewriter.replaceOp(op, llvm::None); 1243 return matchSuccess(); 1244 } 1245 1246 /// The type converter to use when rewriting the signature. 1247 TypeConverter &converter; 1248 }; 1249 } // end anonymous namespace 1250 1251 void mlir::populateFuncOpTypeConversionPattern( 1252 OwningRewritePatternList &patterns, MLIRContext *ctx, 1253 TypeConverter &converter) { 1254 patterns.insert<FuncOpSignatureConversion>(ctx, converter); 1255 } 1256 1257 /// This function converts the type signature of the given block, by invoking 1258 /// 'convertSignatureArg' for each argument. This function should return a valid 1259 /// conversion for the signature on success, None otherwise. 1260 auto TypeConverter::convertBlockSignature(Block *block) 1261 -> llvm::Optional<SignatureConversion> { 1262 SignatureConversion conversion(block->getNumArguments()); 1263 for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) 1264 if (failed(convertSignatureArg(i, block->getArgument(i)->getType(), 1265 conversion))) 1266 return llvm::None; 1267 return conversion; 1268 } 1269 1270 //===----------------------------------------------------------------------===// 1271 // ConversionTarget 1272 //===----------------------------------------------------------------------===// 1273 1274 /// Register a legality action for the given operation. 1275 void ConversionTarget::setOpAction(OperationName op, 1276 LegalizationAction action) { 1277 legalOperations[op] = action; 1278 } 1279 1280 /// Register a legality action for the given dialects. 1281 void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, 1282 LegalizationAction action) { 1283 for (StringRef dialect : dialectNames) 1284 legalDialects[dialect] = action; 1285 } 1286 1287 /// Get the legality action for the given operation. 1288 auto ConversionTarget::getOpAction(OperationName op) const 1289 -> llvm::Optional<LegalizationAction> { 1290 // Check for an action for this specific operation. 1291 auto it = legalOperations.find(op); 1292 if (it != legalOperations.end()) 1293 return it->second; 1294 // Otherwise, default to checking for an action on the parent dialect. 1295 auto dialectIt = legalDialects.find(op.getDialect()); 1296 if (dialectIt != legalDialects.end()) 1297 return dialectIt->second; 1298 return llvm::None; 1299 } 1300 1301 /// Return if the given operation instance is legal on this target. 1302 bool ConversionTarget::isLegal(Operation *op) const { 1303 auto action = getOpAction(op->getName()); 1304 1305 // Handle dynamic legality. 1306 if (action == LegalizationAction::Dynamic) { 1307 // Check for callbacks on the operation or dialect. 1308 auto opFn = opLegalityFns.find(op->getName()); 1309 if (opFn != opLegalityFns.end()) 1310 return opFn->second(op); 1311 auto dialectFn = dialectLegalityFns.find(op->getName().getDialect()); 1312 if (dialectFn != dialectLegalityFns.end()) 1313 return dialectFn->second(op); 1314 1315 // Otherwise, invoke the hook on the derived instance. 1316 return isDynamicallyLegal(op); 1317 } 1318 1319 // Otherwise, the operation is only legal if it was marked 'Legal'. 1320 return action == LegalizationAction::Legal; 1321 } 1322 1323 /// Set the dynamic legality callback for the given operation. 1324 void ConversionTarget::setLegalityCallback( 1325 OperationName name, const DynamicLegalityCallbackFn &callback) { 1326 assert(callback && "expected valid legality callback"); 1327 opLegalityFns[name] = callback; 1328 } 1329 1330 /// Set the dynamic legality callback for the given dialects. 1331 void ConversionTarget::setLegalityCallback( 1332 ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { 1333 assert(callback && "expected valid legality callback"); 1334 for (StringRef dialect : dialects) 1335 dialectLegalityFns[dialect] = callback; 1336 } 1337 1338 //===----------------------------------------------------------------------===// 1339 // Op Conversion Entry Points 1340 //===----------------------------------------------------------------------===// 1341 1342 /// Apply a partial conversion on the given operations, and all nested 1343 /// operations. This method converts as many operations to the target as 1344 /// possible, ignoring operations that failed to legalize. 1345 LogicalResult mlir::applyPartialConversion( 1346 ArrayRef<Operation *> ops, ConversionTarget &target, 1347 const OwningRewritePatternList &patterns, TypeConverter *converter) { 1348 OperationConverter opConverter(target, patterns, OpConversionMode::Partial); 1349 return opConverter.convertOperations(ops, converter); 1350 } 1351 LogicalResult 1352 mlir::applyPartialConversion(Operation *op, ConversionTarget &target, 1353 const OwningRewritePatternList &patterns, 1354 TypeConverter *converter) { 1355 return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, 1356 converter); 1357 } 1358 1359 /// Apply a complete conversion on the given operations, and all nested 1360 /// operations. This method will return failure if the conversion of any 1361 /// operation fails. 1362 LogicalResult 1363 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target, 1364 const OwningRewritePatternList &patterns, 1365 TypeConverter *converter) { 1366 OperationConverter opConverter(target, patterns, OpConversionMode::Full); 1367 return opConverter.convertOperations(ops, converter); 1368 } 1369 LogicalResult 1370 mlir::applyFullConversion(Operation *op, ConversionTarget &target, 1371 const OwningRewritePatternList &patterns, 1372 TypeConverter *converter) { 1373 return applyFullConversion(llvm::makeArrayRef(op), target, patterns, 1374 converter); 1375 } 1376 1377 /// Apply an analysis conversion on the given operations, and all nested 1378 /// operations. This method analyzes which operations would be successfully 1379 /// converted to the target if a conversion was applied. All operations that 1380 /// were found to be legalizable to the given 'target' are placed within the 1381 /// provided 'convertedOps' set; note that no actual rewrites are applied to the 1382 /// operations on success and only pre-existing operations are added to the set. 1383 LogicalResult mlir::applyAnalysisConversion( 1384 ArrayRef<Operation *> ops, ConversionTarget &target, 1385 const OwningRewritePatternList &patterns, 1386 DenseSet<Operation *> &convertedOps, TypeConverter *converter) { 1387 OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, 1388 &convertedOps); 1389 return opConverter.convertOperations(ops, converter); 1390 } 1391 LogicalResult 1392 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, 1393 const OwningRewritePatternList &patterns, 1394 DenseSet<Operation *> &convertedOps, 1395 TypeConverter *converter) { 1396 return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, 1397 convertedOps, converter); 1398 }