github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/Operation.cpp (about) 1 //===- Operation.cpp - Operation support code -----------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/Operation.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/IR/OpDefinition.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/IR/StandardTypes.h" 28 #include <numeric> 29 using namespace mlir; 30 31 /// Form the OperationName for an op with the specified string. This either is 32 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 33 /// if not. 34 OperationName::OperationName(StringRef name, MLIRContext *context) { 35 if (auto *op = AbstractOperation::lookup(name, context)) 36 representation = op; 37 else 38 representation = Identifier::get(name, context); 39 } 40 41 /// Return the name of the dialect this operation is registered to. 42 StringRef OperationName::getDialect() const { 43 return getStringRef().split('.').first; 44 } 45 46 /// Return the name of this operation. This always succeeds. 47 StringRef OperationName::getStringRef() const { 48 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 49 return op->name; 50 return representation.get<Identifier>().strref(); 51 } 52 53 const AbstractOperation *OperationName::getAbstractOperation() const { 54 return representation.dyn_cast<const AbstractOperation *>(); 55 } 56 57 OperationName OperationName::getFromOpaquePointer(void *pointer) { 58 return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); 59 } 60 61 OpAsmParser::~OpAsmParser() {} 62 63 //===----------------------------------------------------------------------===// 64 // OpResult 65 //===----------------------------------------------------------------------===// 66 67 /// Return the result number of this result. 68 unsigned OpResult::getResultNumber() { 69 // Results are always stored consecutively, so use pointer subtraction to 70 // figure out what number this is. 71 return this - &getOwner()->getOpResults()[0]; 72 } 73 74 //===----------------------------------------------------------------------===// 75 // OpOperand 76 //===----------------------------------------------------------------------===// 77 78 // TODO: This namespace is only required because of a bug in GCC<7.0. 79 namespace mlir { 80 /// Return which operand this is in the operand list. 81 template <> unsigned OpOperand::getOperandNumber() { 82 return this - &getOwner()->getOpOperands()[0]; 83 } 84 } // end namespace mlir 85 86 //===----------------------------------------------------------------------===// 87 // BlockOperand 88 //===----------------------------------------------------------------------===// 89 90 // TODO: This namespace is only required because of a bug in GCC<7.0. 91 namespace mlir { 92 /// Return which operand this is in the operand list. 93 template <> unsigned BlockOperand::getOperandNumber() { 94 return this - &getOwner()->getBlockOperands()[0]; 95 } 96 } // end namespace mlir 97 98 //===----------------------------------------------------------------------===// 99 // Operation 100 //===----------------------------------------------------------------------===// 101 102 /// Create a new Operation with the specific fields. 103 Operation *Operation::create(Location location, OperationName name, 104 ArrayRef<Value *> operands, 105 ArrayRef<Type> resultTypes, 106 ArrayRef<NamedAttribute> attributes, 107 ArrayRef<Block *> successors, unsigned numRegions, 108 bool resizableOperandList) { 109 return create(location, name, operands, resultTypes, 110 NamedAttributeList(attributes), successors, numRegions, 111 resizableOperandList); 112 } 113 114 /// Create a new Operation from operation state. 115 Operation *Operation::create(const OperationState &state) { 116 unsigned numRegions = state.regions.size(); 117 Operation *op = create(state.location, state.name, state.operands, 118 state.types, state.attributes, state.successors, 119 numRegions, state.resizableOperandList); 120 for (unsigned i = 0; i < numRegions; ++i) 121 if (state.regions[i]) 122 op->getRegion(i).takeBody(*state.regions[i]); 123 return op; 124 } 125 126 /// Overload of create that takes an existing NamedAttributeList to avoid 127 /// unnecessarily uniquing a list of attributes. 128 Operation *Operation::create(Location location, OperationName name, 129 ArrayRef<Value *> operands, 130 ArrayRef<Type> resultTypes, 131 const NamedAttributeList &attributes, 132 ArrayRef<Block *> successors, unsigned numRegions, 133 bool resizableOperandList) { 134 unsigned numSuccessors = successors.size(); 135 136 // Input operands are nullptr-separated for each successor, the null operands 137 // aren't actually stored. 138 unsigned numOperands = operands.size() - numSuccessors; 139 140 // Compute the byte size for the operation and the operand storage. 141 auto byteSize = totalSizeToAlloc<OpResult, BlockOperand, unsigned, Region, 142 detail::OperandStorage>( 143 resultTypes.size(), numSuccessors, numSuccessors, numRegions, 144 /*detail::OperandStorage*/ 1); 145 byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize( 146 numOperands, resizableOperandList), 147 alignof(Operation)); 148 void *rawMem = malloc(byteSize); 149 150 // Create the new Operation. 151 auto op = ::new (rawMem) Operation(location, name, resultTypes.size(), 152 numSuccessors, numRegions, attributes); 153 154 assert((numSuccessors == 0 || !op->isKnownNonTerminator()) && 155 "unexpected successors in a non-terminator operation"); 156 157 // Initialize the regions. 158 for (unsigned i = 0; i != numRegions; ++i) 159 new (&op->getRegion(i)) Region(op); 160 161 // Initialize the results and operands. 162 new (&op->getOperandStorage()) 163 detail::OperandStorage(numOperands, resizableOperandList); 164 165 auto instResults = op->getOpResults(); 166 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) 167 new (&instResults[i]) OpResult(resultTypes[i], op); 168 169 auto opOperands = op->getOpOperands(); 170 171 // Initialize normal operands. 172 unsigned operandIt = 0, operandE = operands.size(); 173 unsigned nextOperand = 0; 174 for (; operandIt != operandE; ++operandIt) { 175 // Null operands are used as sentinels between successor operand lists. If 176 // we encounter one here, break and handle the successor operands lists 177 // separately below. 178 if (!operands[operandIt]) 179 break; 180 new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]); 181 } 182 183 unsigned currentSuccNum = 0; 184 if (operandIt == operandE) { 185 // Verify that the amount of sentinel operands is equivalent to the number 186 // of successors. 187 assert(currentSuccNum == numSuccessors); 188 return op; 189 } 190 191 assert(!op->isKnownNonTerminator() && 192 "Unexpected nullptr in operand list when creating non-terminator."); 193 auto instBlockOperands = op->getBlockOperands(); 194 unsigned *succOperandCountIt = op->getTrailingObjects<unsigned>(); 195 unsigned *succOperandCountE = succOperandCountIt + numSuccessors; 196 (void)succOperandCountE; 197 198 for (; operandIt != operandE; ++operandIt) { 199 // If we encounter a sentinel branch to the next operand update the count 200 // variable. 201 if (!operands[operandIt]) { 202 assert(currentSuccNum < numSuccessors); 203 204 // After the first iteration update the successor operand count 205 // variable. 206 if (currentSuccNum != 0) { 207 ++succOperandCountIt; 208 assert(succOperandCountIt != succOperandCountE && 209 "More sentinel operands than successors."); 210 } 211 212 new (&instBlockOperands[currentSuccNum]) 213 BlockOperand(op, successors[currentSuccNum]); 214 *succOperandCountIt = 0; 215 ++currentSuccNum; 216 continue; 217 } 218 new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]); 219 ++(*succOperandCountIt); 220 } 221 222 // Verify that the amount of sentinel operands is equivalent to the number of 223 // successors. 224 assert(currentSuccNum == numSuccessors); 225 226 return op; 227 } 228 229 Operation::Operation(Location location, OperationName name, unsigned numResults, 230 unsigned numSuccessors, unsigned numRegions, 231 const NamedAttributeList &attributes) 232 : location(location), numResults(numResults), numSuccs(numSuccessors), 233 numRegions(numRegions), name(name), attrs(attributes) {} 234 235 // Operations are deleted through the destroy() member because they are 236 // allocated via malloc. 237 Operation::~Operation() { 238 assert(block == nullptr && "operation destroyed but still in a block"); 239 240 // Explicitly run the destructors for the operands and results. 241 getOperandStorage().~OperandStorage(); 242 243 for (auto &result : getOpResults()) 244 result.~OpResult(); 245 246 // Explicitly run the destructors for the successors. 247 for (auto &successor : getBlockOperands()) 248 successor.~BlockOperand(); 249 250 // Explicitly destroy the regions. 251 for (auto ®ion : getRegions()) 252 region.~Region(); 253 } 254 255 /// Destroy this operation or one of its subclasses. 256 void Operation::destroy() { 257 this->~Operation(); 258 free(this); 259 } 260 261 /// Return the context this operation is associated with. 262 MLIRContext *Operation::getContext() { return location->getContext(); } 263 264 /// Return the dialact this operation is associated with, or nullptr if the 265 /// associated dialect is not registered. 266 Dialect *Operation::getDialect() { 267 if (auto *abstractOp = getAbstractOperation()) 268 return &abstractOp->dialect; 269 270 // If this operation hasn't been registered or doesn't have abstract 271 // operation, try looking up the dialect name in the context. 272 return getContext()->getRegisteredDialect(getName().getDialect()); 273 } 274 275 Region *Operation::getParentRegion() { 276 return block ? block->getParent() : nullptr; 277 } 278 279 Operation *Operation::getParentOp() { 280 return block ? block->getParentOp() : nullptr; 281 } 282 283 /// Replace any uses of 'from' with 'to' within this operation. 284 void Operation::replaceUsesOfWith(Value *from, Value *to) { 285 if (from == to) 286 return; 287 for (auto &operand : getOpOperands()) 288 if (operand.get() == from) 289 operand.set(to); 290 } 291 292 //===----------------------------------------------------------------------===// 293 // Other 294 //===----------------------------------------------------------------------===// 295 296 /// Emit an error about fatal conditions with this operation, reporting up to 297 /// any diagnostic handlers that may be listening. 298 InFlightDiagnostic Operation::emitError(const Twine &message) { 299 return mlir::emitError(getLoc(), message); 300 } 301 302 /// Emit a warning about this operation, reporting up to any diagnostic 303 /// handlers that may be listening. 304 InFlightDiagnostic Operation::emitWarning(const Twine &message) { 305 return mlir::emitWarning(getLoc(), message); 306 } 307 308 /// Emit a remark about this operation, reporting up to any diagnostic 309 /// handlers that may be listening. 310 InFlightDiagnostic Operation::emitRemark(const Twine &message) { 311 return mlir::emitRemark(getLoc(), message); 312 } 313 314 /// Given an operation 'other' that is within the same parent block, return 315 /// whether the current operation is before 'other' in the operation list 316 /// of the parent block. 317 /// Note: This function has an average complexity of O(1), but worst case may 318 /// take O(N) where N is the number of operations within the parent block. 319 bool Operation::isBeforeInBlock(Operation *other) { 320 assert(block && "Operations without parent blocks have no order."); 321 assert(other && other->block == block && 322 "Expected other operation to have the same parent block."); 323 // Recompute the parent ordering if necessary. 324 if (!block->isInstOrderValid()) 325 block->recomputeInstOrder(); 326 return orderIndex < other->orderIndex; 327 } 328 329 //===----------------------------------------------------------------------===// 330 // ilist_traits for Operation 331 //===----------------------------------------------------------------------===// 332 333 auto llvm::ilist_detail::SpecificNodeAccess< 334 typename llvm::ilist_detail::compute_node_options< 335 ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { 336 return NodeAccess::getNodePtr<OptionsT>(N); 337 } 338 339 auto llvm::ilist_detail::SpecificNodeAccess< 340 typename llvm::ilist_detail::compute_node_options< 341 ::mlir::Operation>::type>::getNodePtr(const_pointer N) 342 -> const node_type * { 343 return NodeAccess::getNodePtr<OptionsT>(N); 344 } 345 346 auto llvm::ilist_detail::SpecificNodeAccess< 347 typename llvm::ilist_detail::compute_node_options< 348 ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { 349 return NodeAccess::getValuePtr<OptionsT>(N); 350 } 351 352 auto llvm::ilist_detail::SpecificNodeAccess< 353 typename llvm::ilist_detail::compute_node_options< 354 ::mlir::Operation>::type>::getValuePtr(const node_type *N) 355 -> const_pointer { 356 return NodeAccess::getValuePtr<OptionsT>(N); 357 } 358 359 void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { 360 op->destroy(); 361 } 362 363 Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { 364 size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); 365 iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); 366 return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); 367 } 368 369 /// This is a trait method invoked when a operation is added to a block. We 370 /// keep the block pointer up to date. 371 void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) { 372 assert(!op->getBlock() && "already in a operation block!"); 373 op->block = getContainingBlock(); 374 375 // Invalidate the block ordering. 376 op->block->invalidateInstOrder(); 377 } 378 379 /// This is a trait method invoked when a operation is removed from a block. 380 /// We keep the block pointer up to date. 381 void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) { 382 assert(op->block && "not already in a operation block!"); 383 op->block = nullptr; 384 } 385 386 /// This is a trait method invoked when a operation is moved from one block 387 /// to another. We keep the block pointer up to date. 388 void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList( 389 ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) { 390 Block *curParent = getContainingBlock(); 391 392 // Invalidate the ordering of the parent block. 393 curParent->invalidateInstOrder(); 394 395 // If we are transferring operations within the same block, the block 396 // pointer doesn't need to be updated. 397 if (curParent == otherList.getContainingBlock()) 398 return; 399 400 // Update the 'block' member of each operation. 401 for (; first != last; ++first) 402 first->block = curParent; 403 } 404 405 /// Remove this operation (and its descendants) from its Block and delete 406 /// all of them. 407 void Operation::erase() { 408 if (auto *parent = getBlock()) 409 parent->getOperations().erase(this); 410 else 411 destroy(); 412 } 413 414 /// Unlink this operation from its current block and insert it right before 415 /// `existingInst` which may be in the same or another block in the same 416 /// function. 417 void Operation::moveBefore(Operation *existingInst) { 418 moveBefore(existingInst->getBlock(), existingInst->getIterator()); 419 } 420 421 /// Unlink this operation from its current basic block and insert it right 422 /// before `iterator` in the specified basic block. 423 void Operation::moveBefore(Block *block, 424 llvm::iplist<Operation>::iterator iterator) { 425 block->getOperations().splice(iterator, getBlock()->getOperations(), 426 getIterator()); 427 } 428 429 /// This drops all operand uses from this operation, which is an essential 430 /// step in breaking cyclic dependences between references when they are to 431 /// be deleted. 432 void Operation::dropAllReferences() { 433 for (auto &op : getOpOperands()) 434 op.drop(); 435 436 for (auto ®ion : getRegions()) 437 region.dropAllReferences(); 438 439 for (auto &dest : getBlockOperands()) 440 dest.drop(); 441 } 442 443 /// This drops all uses of any values defined by this operation or its nested 444 /// regions, wherever they are located. 445 void Operation::dropAllDefinedValueUses() { 446 for (auto &val : getOpResults()) 447 val.dropAllUses(); 448 449 for (auto ®ion : getRegions()) 450 for (auto &block : region) 451 block.dropAllDefinedValueUses(); 452 } 453 454 /// Return true if there are no users of any results of this operation. 455 bool Operation::use_empty() { 456 for (auto *result : getResults()) 457 if (!result->use_empty()) 458 return false; 459 return true; 460 } 461 462 void Operation::setSuccessor(Block *block, unsigned index) { 463 assert(index < getNumSuccessors()); 464 getBlockOperands()[index].set(block); 465 } 466 467 auto Operation::getNonSuccessorOperands() -> operand_range { 468 return {operand_iterator(this, 0), 469 operand_iterator(this, hasSuccessors() ? getSuccessorOperandIndex(0) 470 : getNumOperands())}; 471 } 472 473 /// Get the index of the first operand of the successor at the provided 474 /// index. 475 unsigned Operation::getSuccessorOperandIndex(unsigned index) { 476 assert(!isKnownNonTerminator() && "only terminators may have successors"); 477 assert(index < getNumSuccessors()); 478 479 // Count the number of operands for each of the successors after, and 480 // including, the one at 'index'. This is based upon the assumption that all 481 // non successor operands are placed at the beginning of the operand list. 482 auto *successorOpCountBegin = getTrailingObjects<unsigned>(); 483 unsigned postSuccessorOpCount = 484 std::accumulate(successorOpCountBegin + index, 485 successorOpCountBegin + getNumSuccessors(), 0u); 486 return getNumOperands() - postSuccessorOpCount; 487 } 488 489 auto Operation::getSuccessorOperands(unsigned index) -> operand_range { 490 unsigned succOperandIndex = getSuccessorOperandIndex(index); 491 return {operand_iterator(this, succOperandIndex), 492 operand_iterator(this, 493 succOperandIndex + getNumSuccessorOperands(index))}; 494 } 495 496 /// Attempt to fold this operation using the Op's registered foldHook. 497 LogicalResult Operation::fold(ArrayRef<Attribute> operands, 498 SmallVectorImpl<OpFoldResult> &results) { 499 // If we have a registered operation definition matching this one, use it to 500 // try to constant fold the operation. 501 auto *abstractOp = getAbstractOperation(); 502 if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results))) 503 return success(); 504 505 // Otherwise, fall back on the dialect hook to handle it. 506 Dialect *dialect = getDialect(); 507 if (!dialect) 508 return failure(); 509 510 SmallVector<Attribute, 8> constants; 511 if (failed(dialect->constantFoldHook(this, operands, constants))) 512 return failure(); 513 results.assign(constants.begin(), constants.end()); 514 return success(); 515 } 516 517 /// Emit an error with the op name prefixed, like "'dim' op " which is 518 /// convenient for verifiers. 519 InFlightDiagnostic Operation::emitOpError(const Twine &message) { 520 return emitError() << "'" << getName() << "' op " << message; 521 } 522 523 //===----------------------------------------------------------------------===// 524 // Operation Cloning 525 //===----------------------------------------------------------------------===// 526 527 /// Create a deep copy of this operation but keep the operation regions empty. 528 /// Operands are remapped using `mapper` (if present), and `mapper` is updated 529 /// to contain the results. 530 Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { 531 SmallVector<Value *, 8> operands; 532 SmallVector<Block *, 2> successors; 533 534 operands.reserve(getNumOperands() + getNumSuccessors()); 535 536 if (getNumSuccessors() == 0) { 537 // Non-branching operations can just add all the operands. 538 for (auto *opValue : getOperands()) 539 operands.push_back(mapper.lookupOrDefault(opValue)); 540 } else { 541 // We add the operands separated by nullptr's for each successor. 542 unsigned firstSuccOperand = 543 getNumSuccessors() ? getSuccessorOperandIndex(0) : getNumOperands(); 544 auto opOperands = getOpOperands(); 545 546 unsigned i = 0; 547 for (; i != firstSuccOperand; ++i) 548 operands.push_back(mapper.lookupOrDefault(opOperands[i].get())); 549 550 successors.reserve(getNumSuccessors()); 551 for (unsigned succ = 0, e = getNumSuccessors(); succ != e; ++succ) { 552 successors.push_back(mapper.lookupOrDefault(getSuccessor(succ))); 553 554 // Add sentinel to delineate successor operands. 555 operands.push_back(nullptr); 556 557 // Remap the successors operands. 558 for (auto *operand : getSuccessorOperands(succ)) 559 operands.push_back(mapper.lookupOrDefault(operand)); 560 } 561 } 562 563 SmallVector<Type, 8> resultTypes(getResultTypes()); 564 unsigned numRegions = getNumRegions(); 565 auto *newOp = 566 Operation::create(getLoc(), getName(), operands, resultTypes, attrs, 567 successors, numRegions, hasResizableOperandsList()); 568 569 // Remember the mapping of any results. 570 for (unsigned i = 0, e = getNumResults(); i != e; ++i) 571 mapper.map(getResult(i), newOp->getResult(i)); 572 573 return newOp; 574 } 575 576 Operation *Operation::cloneWithoutRegions() { 577 BlockAndValueMapping mapper; 578 return cloneWithoutRegions(mapper); 579 } 580 581 /// Create a deep copy of this operation, remapping any operands that use 582 /// values outside of the operation using the map that is provided (leaving 583 /// them alone if no entry is present). Replaces references to cloned 584 /// sub-operations to the corresponding operation that is copied, and adds 585 /// those mappings to the map. 586 Operation *Operation::clone(BlockAndValueMapping &mapper) { 587 auto *newOp = cloneWithoutRegions(mapper); 588 589 // Clone the regions. 590 for (unsigned i = 0; i != numRegions; ++i) 591 getRegion(i).cloneInto(&newOp->getRegion(i), mapper); 592 593 return newOp; 594 } 595 596 Operation *Operation::clone() { 597 BlockAndValueMapping mapper; 598 return clone(mapper); 599 } 600 601 //===----------------------------------------------------------------------===// 602 // OpState trait class. 603 //===----------------------------------------------------------------------===// 604 605 // The fallback for the parser is to reject the custom assembly form. 606 ParseResult OpState::parse(OpAsmParser *parser, OperationState *result) { 607 return parser->emitError(parser->getNameLoc(), "has no custom assembly form"); 608 } 609 610 // The fallback for the printer is to print in the generic assembly form. 611 void OpState::print(OpAsmPrinter *p) { p->printGenericOp(getOperation()); } 612 613 /// Emit an error about fatal conditions with this operation, reporting up to 614 /// any diagnostic handlers that may be listening. 615 InFlightDiagnostic OpState::emitError(const Twine &message) { 616 return getOperation()->emitError(message); 617 } 618 619 /// Emit an error with the op name prefixed, like "'dim' op " which is 620 /// convenient for verifiers. 621 InFlightDiagnostic OpState::emitOpError(const Twine &message) { 622 return getOperation()->emitOpError(message); 623 } 624 625 /// Emit a warning about this operation, reporting up to any diagnostic 626 /// handlers that may be listening. 627 InFlightDiagnostic OpState::emitWarning(const Twine &message) { 628 return getOperation()->emitWarning(message); 629 } 630 631 /// Emit a remark about this operation, reporting up to any diagnostic 632 /// handlers that may be listening. 633 InFlightDiagnostic OpState::emitRemark(const Twine &message) { 634 return getOperation()->emitRemark(message); 635 } 636 637 //===----------------------------------------------------------------------===// 638 // Op Trait implementations 639 //===----------------------------------------------------------------------===// 640 641 LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { 642 if (op->getNumOperands() != 0) 643 return op->emitOpError() << "requires zero operands"; 644 return success(); 645 } 646 647 LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) { 648 if (op->getNumOperands() != 1) 649 return op->emitOpError() << "requires a single operand"; 650 return success(); 651 } 652 653 LogicalResult OpTrait::impl::verifyNOperands(Operation *op, 654 unsigned numOperands) { 655 if (op->getNumOperands() != numOperands) { 656 return op->emitOpError() << "expected " << numOperands 657 << " operands, but found " << op->getNumOperands(); 658 } 659 return success(); 660 } 661 662 LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op, 663 unsigned numOperands) { 664 if (op->getNumOperands() < numOperands) 665 return op->emitOpError() 666 << "expected " << numOperands << " or more operands"; 667 return success(); 668 } 669 670 /// If this is a vector type, or a tensor type, return the scalar element type 671 /// that it is built around, otherwise return the type unmodified. 672 static Type getTensorOrVectorElementType(Type type) { 673 if (auto vec = type.dyn_cast<VectorType>()) 674 return vec.getElementType(); 675 676 // Look through tensor<vector<...>> to find the underlying element type. 677 if (auto tensor = type.dyn_cast<TensorType>()) 678 return getTensorOrVectorElementType(tensor.getElementType()); 679 return type; 680 } 681 682 LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) { 683 for (auto opType : op->getOperandTypes()) { 684 auto type = getTensorOrVectorElementType(opType); 685 if (!type.isIntOrIndex()) 686 return op->emitOpError() << "requires an integer or index type"; 687 } 688 return success(); 689 } 690 691 LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) { 692 for (auto opType : op->getOperandTypes()) { 693 auto type = getTensorOrVectorElementType(opType); 694 if (!type.isa<FloatType>()) 695 return op->emitOpError("requires a float type"); 696 } 697 return success(); 698 } 699 700 LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) { 701 // Zero or one operand always have the "same" type. 702 unsigned nOperands = op->getNumOperands(); 703 if (nOperands < 2) 704 return success(); 705 706 auto type = op->getOperand(0)->getType(); 707 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) 708 if (opType != type) 709 return op->emitOpError() << "requires all operands to have the same type"; 710 return success(); 711 } 712 713 LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { 714 if (op->getNumResults() != 0) 715 return op->emitOpError() << "requires zero results"; 716 return success(); 717 } 718 719 LogicalResult OpTrait::impl::verifyOneResult(Operation *op) { 720 if (op->getNumResults() != 1) 721 return op->emitOpError() << "requires one result"; 722 return success(); 723 } 724 725 LogicalResult OpTrait::impl::verifyNResults(Operation *op, 726 unsigned numOperands) { 727 if (op->getNumResults() != numOperands) 728 return op->emitOpError() << "expected " << numOperands << " results"; 729 return success(); 730 } 731 732 LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, 733 unsigned numOperands) { 734 if (op->getNumResults() < numOperands) 735 return op->emitOpError() 736 << "expected " << numOperands << " or more results"; 737 return success(); 738 } 739 740 /// Returns success if the given two types have the same shape. That is, 741 /// they are both scalars (not shaped), or they are both shaped types and at 742 /// least one is unranked or they have the same shape. The element type does not 743 /// matter. 744 static LogicalResult verifyShapeMatch(Type type1, Type type2) { 745 auto sType1 = type1.dyn_cast<ShapedType>(); 746 auto sType2 = type2.dyn_cast<ShapedType>(); 747 748 // Either both or neither type should be shaped. 749 if (!sType1) 750 return success(!sType2); 751 if (!sType2) 752 return failure(); 753 754 if (!sType1.hasRank() || !sType2.hasRank()) 755 return success(); 756 757 return success(sType1.getShape() == sType2.getShape()); 758 } 759 760 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) { 761 if (op->getNumOperands() == 0) 762 return failure(); 763 764 auto type = op->getOperand(0)->getType(); 765 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 766 if (failed(verifyShapeMatch(opType, type))) 767 return op->emitOpError() << "requires the same shape for all operands"; 768 } 769 return success(); 770 } 771 772 LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { 773 if (op->getNumOperands() == 0 || op->getNumResults() == 0) 774 return failure(); 775 776 auto type = op->getOperand(0)->getType(); 777 for (auto resultType : op->getResultTypes()) { 778 if (failed(verifyShapeMatch(resultType, type))) 779 return op->emitOpError() 780 << "requires the same shape for all operands and results"; 781 } 782 for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) { 783 if (failed(verifyShapeMatch(opType, type))) 784 return op->emitOpError() 785 << "requires the same shape for all operands and results"; 786 } 787 return success(); 788 } 789 790 LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { 791 if (op->getNumOperands() == 0) 792 return failure(); 793 794 auto type = op->getOperand(0)->getType().dyn_cast<ShapedType>(); 795 if (!type) 796 return op->emitOpError("requires shaped type results"); 797 auto elementType = type.getElementType(); 798 799 for (auto operandType : llvm::drop_begin(op->getOperandTypes(), 1)) { 800 auto shapedType = operandType.dyn_cast<ShapedType>(); 801 if (!shapedType) 802 return op->emitOpError("requires shaped type operands"); 803 if (shapedType.getElementType() != elementType) 804 return op->emitOpError("requires the same element type for all operands"); 805 } 806 807 return success(); 808 } 809 810 LogicalResult 811 OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { 812 if (op->getNumOperands() == 0 || op->getNumResults() == 0) 813 return failure(); 814 815 auto type = op->getResult(0)->getType().dyn_cast<ShapedType>(); 816 if (!type) 817 return op->emitOpError("requires shaped type results"); 818 auto elementType = type.getElementType(); 819 820 // Verify result element type matches first result's element type. 821 for (auto result : drop_begin(op->getResults(), 1)) { 822 auto resultType = result->getType().dyn_cast<ShapedType>(); 823 if (!resultType) 824 return op->emitOpError("requires shaped type results"); 825 if (resultType.getElementType() != elementType) 826 return op->emitOpError( 827 "requires the same element type for all operands and results"); 828 } 829 830 // Verify operand's element type matches first result's element type. 831 for (auto operand : op->getOperands()) { 832 auto operandType = operand->getType().dyn_cast<ShapedType>(); 833 if (!operandType) 834 return op->emitOpError("requires shaped type operands"); 835 if (operandType.getElementType() != elementType) 836 return op->emitOpError( 837 "requires the same element type for all operands and results"); 838 } 839 840 return success(); 841 } 842 843 LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) { 844 if (op->getNumOperands() == 0 || op->getNumResults() == 0) 845 return failure(); 846 847 auto type = op->getResult(0)->getType(); 848 for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) { 849 if (resultType != type) 850 return op->emitOpError() 851 << "requires the same type for all operands and results"; 852 } 853 for (auto opType : op->getOperandTypes()) { 854 if (opType != type) 855 return op->emitOpError() 856 << "requires the same type for all operands and results"; 857 } 858 return success(); 859 } 860 861 static LogicalResult verifyBBArguments(Operation::operand_range operands, 862 Block *destBB, Operation *op) { 863 unsigned operandCount = std::distance(operands.begin(), operands.end()); 864 if (operandCount != destBB->getNumArguments()) 865 return op->emitError() << "branch has " << operandCount 866 << " operands, but target block has " 867 << destBB->getNumArguments(); 868 869 auto operandIt = operands.begin(); 870 for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) { 871 if ((*operandIt)->getType() != destBB->getArgument(i)->getType()) 872 return op->emitError() << "type mismatch in bb argument #" << i; 873 } 874 875 return success(); 876 } 877 878 static LogicalResult verifyTerminatorSuccessors(Operation *op) { 879 auto *parent = op->getParentRegion(); 880 881 // Verify that the operands lines up with the BB arguments in the successor. 882 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { 883 auto *succ = op->getSuccessor(i); 884 if (succ->getParent() != parent) 885 return op->emitError("reference to block defined in another region"); 886 if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op))) 887 return failure(); 888 } 889 return success(); 890 } 891 892 LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) { 893 Block *block = op->getBlock(); 894 // Verify that the operation is at the end of the respective parent block. 895 if (!block || &block->back() != op) 896 return op->emitOpError("must be the last operation in the parent block"); 897 898 // Verify the state of the successor blocks. 899 if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op))) 900 return failure(); 901 return success(); 902 } 903 904 LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) { 905 for (auto resultType : op->getResultTypes()) { 906 auto elementType = getTensorOrVectorElementType(resultType); 907 bool isBoolType = elementType.isInteger(1); 908 if (!isBoolType) 909 return op->emitOpError() << "requires a bool result type"; 910 } 911 912 return success(); 913 } 914 915 LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) { 916 for (auto resultType : op->getResultTypes()) 917 if (!getTensorOrVectorElementType(resultType).isa<FloatType>()) 918 return op->emitOpError() << "requires a floating point type"; 919 920 return success(); 921 } 922 923 LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) { 924 for (auto resultType : op->getResultTypes()) 925 if (!getTensorOrVectorElementType(resultType).isIntOrIndex()) 926 return op->emitOpError() << "requires an integer or index type"; 927 return success(); 928 } 929 930 //===----------------------------------------------------------------------===// 931 // BinaryOp implementation 932 //===----------------------------------------------------------------------===// 933 934 // These functions are out-of-line implementations of the methods in BinaryOp, 935 // which avoids them being template instantiated/duplicated. 936 937 void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs, 938 Value *rhs) { 939 assert(lhs->getType() == rhs->getType()); 940 result->addOperands({lhs, rhs}); 941 result->types.push_back(lhs->getType()); 942 } 943 944 ParseResult impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { 945 SmallVector<OpAsmParser::OperandType, 2> ops; 946 Type type; 947 return failure(parser->parseOperandList(ops, 2) || 948 parser->parseOptionalAttributeDict(result->attributes) || 949 parser->parseColonType(type) || 950 parser->resolveOperands(ops, type, result->operands) || 951 parser->addTypeToList(type, result->types)); 952 } 953 954 void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) { 955 assert(op->getNumOperands() == 2 && "binary op should have two operands"); 956 assert(op->getNumResults() == 1 && "binary op should have one result"); 957 958 // If not all the operand and result types are the same, just use the 959 // generic assembly form to avoid omitting information in printing. 960 auto resultType = op->getResult(0)->getType(); 961 if (op->getOperand(0)->getType() != resultType || 962 op->getOperand(1)->getType() != resultType) { 963 p->printGenericOp(op); 964 return; 965 } 966 967 *p << op->getName() << ' ' << *op->getOperand(0) << ", " 968 << *op->getOperand(1); 969 p->printOptionalAttrDict(op->getAttrs()); 970 // Now we can output only one type for all operands and the result. 971 *p << " : " << op->getResult(0)->getType(); 972 } 973 974 //===----------------------------------------------------------------------===// 975 // CastOp implementation 976 //===----------------------------------------------------------------------===// 977 978 void impl::buildCastOp(Builder *builder, OperationState *result, Value *source, 979 Type destType) { 980 result->addOperands(source); 981 result->addTypes(destType); 982 } 983 984 ParseResult impl::parseCastOp(OpAsmParser *parser, OperationState *result) { 985 OpAsmParser::OperandType srcInfo; 986 Type srcType, dstType; 987 return failure(parser->parseOperand(srcInfo) || 988 parser->parseOptionalAttributeDict(result->attributes) || 989 parser->parseColonType(srcType) || 990 parser->resolveOperand(srcInfo, srcType, result->operands) || 991 parser->parseKeywordType("to", dstType) || 992 parser->addTypeToList(dstType, result->types)); 993 } 994 995 void impl::printCastOp(Operation *op, OpAsmPrinter *p) { 996 *p << op->getName() << ' ' << *op->getOperand(0); 997 p->printOptionalAttrDict(op->getAttrs()); 998 *p << " : " << op->getOperand(0)->getType() << " to " 999 << op->getResult(0)->getType(); 1000 } 1001 1002 Value *impl::foldCastOp(Operation *op) { 1003 // Identity cast 1004 if (op->getOperand(0)->getType() == op->getResult(0)->getType()) 1005 return op->getOperand(0); 1006 return nullptr; 1007 } 1008 1009 //===----------------------------------------------------------------------===// 1010 // CastOp implementation 1011 //===----------------------------------------------------------------------===// 1012 1013 /// Insert an operation, generated by `buildTerminatorOp`, at the end of the 1014 /// region's only block if it does not have a terminator already. If the region 1015 /// is empty, insert a new block first. `buildTerminatorOp` should return the 1016 /// terminator operation to insert. 1017 void impl::ensureRegionTerminator( 1018 Region ®ion, Location loc, 1019 llvm::function_ref<Operation *()> buildTerminatorOp) { 1020 if (region.empty()) 1021 region.push_back(new Block); 1022 1023 Block &block = region.back(); 1024 if (!block.empty() && block.back().isKnownTerminator()) 1025 return; 1026 1027 block.push_back(buildTerminatorOp()); 1028 }