github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/Operation.cpp (about)

     1  //===- Operation.cpp - Operation support code -----------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/IR/Operation.h"
    19  #include "mlir/IR/BlockAndValueMapping.h"
    20  #include "mlir/IR/Diagnostics.h"
    21  #include "mlir/IR/Dialect.h"
    22  #include "mlir/IR/Function.h"
    23  #include "mlir/IR/MLIRContext.h"
    24  #include "mlir/IR/OpDefinition.h"
    25  #include "mlir/IR/OpImplementation.h"
    26  #include "mlir/IR/PatternMatch.h"
    27  #include "mlir/IR/StandardTypes.h"
    28  #include <numeric>
    29  using namespace mlir;
    30  
    31  /// Form the OperationName for an op with the specified string.  This either is
    32  /// a reference to an AbstractOperation if one is known, or a uniqued Identifier
    33  /// if not.
    34  OperationName::OperationName(StringRef name, MLIRContext *context) {
    35    if (auto *op = AbstractOperation::lookup(name, context))
    36      representation = op;
    37    else
    38      representation = Identifier::get(name, context);
    39  }
    40  
    41  /// Return the name of the dialect this operation is registered to.
    42  StringRef OperationName::getDialect() const {
    43    return getStringRef().split('.').first;
    44  }
    45  
    46  /// Return the name of this operation.  This always succeeds.
    47  StringRef OperationName::getStringRef() const {
    48    if (auto *op = representation.dyn_cast<const AbstractOperation *>())
    49      return op->name;
    50    return representation.get<Identifier>().strref();
    51  }
    52  
    53  const AbstractOperation *OperationName::getAbstractOperation() const {
    54    return representation.dyn_cast<const AbstractOperation *>();
    55  }
    56  
    57  OperationName OperationName::getFromOpaquePointer(void *pointer) {
    58    return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
    59  }
    60  
    61  OpAsmParser::~OpAsmParser() {}
    62  
    63  //===----------------------------------------------------------------------===//
    64  // OpResult
    65  //===----------------------------------------------------------------------===//
    66  
    67  /// Return the result number of this result.
    68  unsigned OpResult::getResultNumber() {
    69    // Results are always stored consecutively, so use pointer subtraction to
    70    // figure out what number this is.
    71    return this - &getOwner()->getOpResults()[0];
    72  }
    73  
    74  //===----------------------------------------------------------------------===//
    75  // OpOperand
    76  //===----------------------------------------------------------------------===//
    77  
    78  // TODO: This namespace is only required because of a bug in GCC<7.0.
    79  namespace mlir {
    80  /// Return which operand this is in the operand list.
    81  template <> unsigned OpOperand::getOperandNumber() {
    82    return this - &getOwner()->getOpOperands()[0];
    83  }
    84  } // end namespace mlir
    85  
    86  //===----------------------------------------------------------------------===//
    87  // BlockOperand
    88  //===----------------------------------------------------------------------===//
    89  
    90  // TODO: This namespace is only required because of a bug in GCC<7.0.
    91  namespace mlir {
    92  /// Return which operand this is in the operand list.
    93  template <> unsigned BlockOperand::getOperandNumber() {
    94    return this - &getOwner()->getBlockOperands()[0];
    95  }
    96  } // end namespace mlir
    97  
    98  //===----------------------------------------------------------------------===//
    99  // Operation
   100  //===----------------------------------------------------------------------===//
   101  
   102  /// Create a new Operation with the specific fields.
   103  Operation *Operation::create(Location location, OperationName name,
   104                               ArrayRef<Value *> operands,
   105                               ArrayRef<Type> resultTypes,
   106                               ArrayRef<NamedAttribute> attributes,
   107                               ArrayRef<Block *> successors, unsigned numRegions,
   108                               bool resizableOperandList) {
   109    return create(location, name, operands, resultTypes,
   110                  NamedAttributeList(attributes), successors, numRegions,
   111                  resizableOperandList);
   112  }
   113  
   114  /// Create a new Operation from operation state.
   115  Operation *Operation::create(const OperationState &state) {
   116    unsigned numRegions = state.regions.size();
   117    Operation *op = create(state.location, state.name, state.operands,
   118                           state.types, state.attributes, state.successors,
   119                           numRegions, state.resizableOperandList);
   120    for (unsigned i = 0; i < numRegions; ++i)
   121      if (state.regions[i])
   122        op->getRegion(i).takeBody(*state.regions[i]);
   123    return op;
   124  }
   125  
   126  /// Overload of create that takes an existing NamedAttributeList to avoid
   127  /// unnecessarily uniquing a list of attributes.
   128  Operation *Operation::create(Location location, OperationName name,
   129                               ArrayRef<Value *> operands,
   130                               ArrayRef<Type> resultTypes,
   131                               const NamedAttributeList &attributes,
   132                               ArrayRef<Block *> successors, unsigned numRegions,
   133                               bool resizableOperandList) {
   134    unsigned numSuccessors = successors.size();
   135  
   136    // Input operands are nullptr-separated for each successor, the null operands
   137    // aren't actually stored.
   138    unsigned numOperands = operands.size() - numSuccessors;
   139  
   140    // Compute the byte size for the operation and the operand storage.
   141    auto byteSize = totalSizeToAlloc<OpResult, BlockOperand, unsigned, Region,
   142                                     detail::OperandStorage>(
   143        resultTypes.size(), numSuccessors, numSuccessors, numRegions,
   144        /*detail::OperandStorage*/ 1);
   145    byteSize += llvm::alignTo(detail::OperandStorage::additionalAllocSize(
   146                                  numOperands, resizableOperandList),
   147                              alignof(Operation));
   148    void *rawMem = malloc(byteSize);
   149  
   150    // Create the new Operation.
   151    auto op = ::new (rawMem) Operation(location, name, resultTypes.size(),
   152                                       numSuccessors, numRegions, attributes);
   153  
   154    assert((numSuccessors == 0 || !op->isKnownNonTerminator()) &&
   155           "unexpected successors in a non-terminator operation");
   156  
   157    // Initialize the regions.
   158    for (unsigned i = 0; i != numRegions; ++i)
   159      new (&op->getRegion(i)) Region(op);
   160  
   161    // Initialize the results and operands.
   162    new (&op->getOperandStorage())
   163        detail::OperandStorage(numOperands, resizableOperandList);
   164  
   165    auto instResults = op->getOpResults();
   166    for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
   167      new (&instResults[i]) OpResult(resultTypes[i], op);
   168  
   169    auto opOperands = op->getOpOperands();
   170  
   171    // Initialize normal operands.
   172    unsigned operandIt = 0, operandE = operands.size();
   173    unsigned nextOperand = 0;
   174    for (; operandIt != operandE; ++operandIt) {
   175      // Null operands are used as sentinels between successor operand lists. If
   176      // we encounter one here, break and handle the successor operands lists
   177      // separately below.
   178      if (!operands[operandIt])
   179        break;
   180      new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]);
   181    }
   182  
   183    unsigned currentSuccNum = 0;
   184    if (operandIt == operandE) {
   185      // Verify that the amount of sentinel operands is equivalent to the number
   186      // of successors.
   187      assert(currentSuccNum == numSuccessors);
   188      return op;
   189    }
   190  
   191    assert(!op->isKnownNonTerminator() &&
   192           "Unexpected nullptr in operand list when creating non-terminator.");
   193    auto instBlockOperands = op->getBlockOperands();
   194    unsigned *succOperandCountIt = op->getTrailingObjects<unsigned>();
   195    unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
   196    (void)succOperandCountE;
   197  
   198    for (; operandIt != operandE; ++operandIt) {
   199      // If we encounter a sentinel branch to the next operand update the count
   200      // variable.
   201      if (!operands[operandIt]) {
   202        assert(currentSuccNum < numSuccessors);
   203  
   204        // After the first iteration update the successor operand count
   205        // variable.
   206        if (currentSuccNum != 0) {
   207          ++succOperandCountIt;
   208          assert(succOperandCountIt != succOperandCountE &&
   209                 "More sentinel operands than successors.");
   210        }
   211  
   212        new (&instBlockOperands[currentSuccNum])
   213            BlockOperand(op, successors[currentSuccNum]);
   214        *succOperandCountIt = 0;
   215        ++currentSuccNum;
   216        continue;
   217      }
   218      new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]);
   219      ++(*succOperandCountIt);
   220    }
   221  
   222    // Verify that the amount of sentinel operands is equivalent to the number of
   223    // successors.
   224    assert(currentSuccNum == numSuccessors);
   225  
   226    return op;
   227  }
   228  
   229  Operation::Operation(Location location, OperationName name, unsigned numResults,
   230                       unsigned numSuccessors, unsigned numRegions,
   231                       const NamedAttributeList &attributes)
   232      : location(location), numResults(numResults), numSuccs(numSuccessors),
   233        numRegions(numRegions), name(name), attrs(attributes) {}
   234  
   235  // Operations are deleted through the destroy() member because they are
   236  // allocated via malloc.
   237  Operation::~Operation() {
   238    assert(block == nullptr && "operation destroyed but still in a block");
   239  
   240    // Explicitly run the destructors for the operands and results.
   241    getOperandStorage().~OperandStorage();
   242  
   243    for (auto &result : getOpResults())
   244      result.~OpResult();
   245  
   246    // Explicitly run the destructors for the successors.
   247    for (auto &successor : getBlockOperands())
   248      successor.~BlockOperand();
   249  
   250    // Explicitly destroy the regions.
   251    for (auto &region : getRegions())
   252      region.~Region();
   253  }
   254  
   255  /// Destroy this operation or one of its subclasses.
   256  void Operation::destroy() {
   257    this->~Operation();
   258    free(this);
   259  }
   260  
   261  /// Return the context this operation is associated with.
   262  MLIRContext *Operation::getContext() { return location->getContext(); }
   263  
   264  /// Return the dialact this operation is associated with, or nullptr if the
   265  /// associated dialect is not registered.
   266  Dialect *Operation::getDialect() {
   267    if (auto *abstractOp = getAbstractOperation())
   268      return &abstractOp->dialect;
   269  
   270    // If this operation hasn't been registered or doesn't have abstract
   271    // operation, try looking up the dialect name in the context.
   272    return getContext()->getRegisteredDialect(getName().getDialect());
   273  }
   274  
   275  Region *Operation::getParentRegion() {
   276    return block ? block->getParent() : nullptr;
   277  }
   278  
   279  Operation *Operation::getParentOp() {
   280    return block ? block->getParentOp() : nullptr;
   281  }
   282  
   283  /// Replace any uses of 'from' with 'to' within this operation.
   284  void Operation::replaceUsesOfWith(Value *from, Value *to) {
   285    if (from == to)
   286      return;
   287    for (auto &operand : getOpOperands())
   288      if (operand.get() == from)
   289        operand.set(to);
   290  }
   291  
   292  //===----------------------------------------------------------------------===//
   293  // Other
   294  //===----------------------------------------------------------------------===//
   295  
   296  /// Emit an error about fatal conditions with this operation, reporting up to
   297  /// any diagnostic handlers that may be listening.
   298  InFlightDiagnostic Operation::emitError(const Twine &message) {
   299    return mlir::emitError(getLoc(), message);
   300  }
   301  
   302  /// Emit a warning about this operation, reporting up to any diagnostic
   303  /// handlers that may be listening.
   304  InFlightDiagnostic Operation::emitWarning(const Twine &message) {
   305    return mlir::emitWarning(getLoc(), message);
   306  }
   307  
   308  /// Emit a remark about this operation, reporting up to any diagnostic
   309  /// handlers that may be listening.
   310  InFlightDiagnostic Operation::emitRemark(const Twine &message) {
   311    return mlir::emitRemark(getLoc(), message);
   312  }
   313  
   314  /// Given an operation 'other' that is within the same parent block, return
   315  /// whether the current operation is before 'other' in the operation list
   316  /// of the parent block.
   317  /// Note: This function has an average complexity of O(1), but worst case may
   318  /// take O(N) where N is the number of operations within the parent block.
   319  bool Operation::isBeforeInBlock(Operation *other) {
   320    assert(block && "Operations without parent blocks have no order.");
   321    assert(other && other->block == block &&
   322           "Expected other operation to have the same parent block.");
   323    // Recompute the parent ordering if necessary.
   324    if (!block->isInstOrderValid())
   325      block->recomputeInstOrder();
   326    return orderIndex < other->orderIndex;
   327  }
   328  
   329  //===----------------------------------------------------------------------===//
   330  // ilist_traits for Operation
   331  //===----------------------------------------------------------------------===//
   332  
   333  auto llvm::ilist_detail::SpecificNodeAccess<
   334      typename llvm::ilist_detail::compute_node_options<
   335          ::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * {
   336    return NodeAccess::getNodePtr<OptionsT>(N);
   337  }
   338  
   339  auto llvm::ilist_detail::SpecificNodeAccess<
   340      typename llvm::ilist_detail::compute_node_options<
   341          ::mlir::Operation>::type>::getNodePtr(const_pointer N)
   342      -> const node_type * {
   343    return NodeAccess::getNodePtr<OptionsT>(N);
   344  }
   345  
   346  auto llvm::ilist_detail::SpecificNodeAccess<
   347      typename llvm::ilist_detail::compute_node_options<
   348          ::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer {
   349    return NodeAccess::getValuePtr<OptionsT>(N);
   350  }
   351  
   352  auto llvm::ilist_detail::SpecificNodeAccess<
   353      typename llvm::ilist_detail::compute_node_options<
   354          ::mlir::Operation>::type>::getValuePtr(const node_type *N)
   355      -> const_pointer {
   356    return NodeAccess::getValuePtr<OptionsT>(N);
   357  }
   358  
   359  void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
   360    op->destroy();
   361  }
   362  
   363  Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
   364    size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
   365    iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this));
   366    return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
   367  }
   368  
   369  /// This is a trait method invoked when a operation is added to a block.  We
   370  /// keep the block pointer up to date.
   371  void llvm::ilist_traits<::mlir::Operation>::addNodeToList(Operation *op) {
   372    assert(!op->getBlock() && "already in a operation block!");
   373    op->block = getContainingBlock();
   374  
   375    // Invalidate the block ordering.
   376    op->block->invalidateInstOrder();
   377  }
   378  
   379  /// This is a trait method invoked when a operation is removed from a block.
   380  /// We keep the block pointer up to date.
   381  void llvm::ilist_traits<::mlir::Operation>::removeNodeFromList(Operation *op) {
   382    assert(op->block && "not already in a operation block!");
   383    op->block = nullptr;
   384  }
   385  
   386  /// This is a trait method invoked when a operation is moved from one block
   387  /// to another.  We keep the block pointer up to date.
   388  void llvm::ilist_traits<::mlir::Operation>::transferNodesFromList(
   389      ilist_traits<Operation> &otherList, op_iterator first, op_iterator last) {
   390    Block *curParent = getContainingBlock();
   391  
   392    // Invalidate the ordering of the parent block.
   393    curParent->invalidateInstOrder();
   394  
   395    // If we are transferring operations within the same block, the block
   396    // pointer doesn't need to be updated.
   397    if (curParent == otherList.getContainingBlock())
   398      return;
   399  
   400    // Update the 'block' member of each operation.
   401    for (; first != last; ++first)
   402      first->block = curParent;
   403  }
   404  
   405  /// Remove this operation (and its descendants) from its Block and delete
   406  /// all of them.
   407  void Operation::erase() {
   408    if (auto *parent = getBlock())
   409      parent->getOperations().erase(this);
   410    else
   411      destroy();
   412  }
   413  
   414  /// Unlink this operation from its current block and insert it right before
   415  /// `existingInst` which may be in the same or another block in the same
   416  /// function.
   417  void Operation::moveBefore(Operation *existingInst) {
   418    moveBefore(existingInst->getBlock(), existingInst->getIterator());
   419  }
   420  
   421  /// Unlink this operation from its current basic block and insert it right
   422  /// before `iterator` in the specified basic block.
   423  void Operation::moveBefore(Block *block,
   424                             llvm::iplist<Operation>::iterator iterator) {
   425    block->getOperations().splice(iterator, getBlock()->getOperations(),
   426                                  getIterator());
   427  }
   428  
   429  /// This drops all operand uses from this operation, which is an essential
   430  /// step in breaking cyclic dependences between references when they are to
   431  /// be deleted.
   432  void Operation::dropAllReferences() {
   433    for (auto &op : getOpOperands())
   434      op.drop();
   435  
   436    for (auto &region : getRegions())
   437      region.dropAllReferences();
   438  
   439    for (auto &dest : getBlockOperands())
   440      dest.drop();
   441  }
   442  
   443  /// This drops all uses of any values defined by this operation or its nested
   444  /// regions, wherever they are located.
   445  void Operation::dropAllDefinedValueUses() {
   446    for (auto &val : getOpResults())
   447      val.dropAllUses();
   448  
   449    for (auto &region : getRegions())
   450      for (auto &block : region)
   451        block.dropAllDefinedValueUses();
   452  }
   453  
   454  /// Return true if there are no users of any results of this operation.
   455  bool Operation::use_empty() {
   456    for (auto *result : getResults())
   457      if (!result->use_empty())
   458        return false;
   459    return true;
   460  }
   461  
   462  void Operation::setSuccessor(Block *block, unsigned index) {
   463    assert(index < getNumSuccessors());
   464    getBlockOperands()[index].set(block);
   465  }
   466  
   467  auto Operation::getNonSuccessorOperands() -> operand_range {
   468    return {operand_iterator(this, 0),
   469            operand_iterator(this, hasSuccessors() ? getSuccessorOperandIndex(0)
   470                                                   : getNumOperands())};
   471  }
   472  
   473  /// Get the index of the first operand of the successor at the provided
   474  /// index.
   475  unsigned Operation::getSuccessorOperandIndex(unsigned index) {
   476    assert(!isKnownNonTerminator() && "only terminators may have successors");
   477    assert(index < getNumSuccessors());
   478  
   479    // Count the number of operands for each of the successors after, and
   480    // including, the one at 'index'. This is based upon the assumption that all
   481    // non successor operands are placed at the beginning of the operand list.
   482    auto *successorOpCountBegin = getTrailingObjects<unsigned>();
   483    unsigned postSuccessorOpCount =
   484        std::accumulate(successorOpCountBegin + index,
   485                        successorOpCountBegin + getNumSuccessors(), 0u);
   486    return getNumOperands() - postSuccessorOpCount;
   487  }
   488  
   489  auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
   490    unsigned succOperandIndex = getSuccessorOperandIndex(index);
   491    return {operand_iterator(this, succOperandIndex),
   492            operand_iterator(this,
   493                             succOperandIndex + getNumSuccessorOperands(index))};
   494  }
   495  
   496  /// Attempt to fold this operation using the Op's registered foldHook.
   497  LogicalResult Operation::fold(ArrayRef<Attribute> operands,
   498                                SmallVectorImpl<OpFoldResult> &results) {
   499    // If we have a registered operation definition matching this one, use it to
   500    // try to constant fold the operation.
   501    auto *abstractOp = getAbstractOperation();
   502    if (abstractOp && succeeded(abstractOp->foldHook(this, operands, results)))
   503      return success();
   504  
   505    // Otherwise, fall back on the dialect hook to handle it.
   506    Dialect *dialect = getDialect();
   507    if (!dialect)
   508      return failure();
   509  
   510    SmallVector<Attribute, 8> constants;
   511    if (failed(dialect->constantFoldHook(this, operands, constants)))
   512      return failure();
   513    results.assign(constants.begin(), constants.end());
   514    return success();
   515  }
   516  
   517  /// Emit an error with the op name prefixed, like "'dim' op " which is
   518  /// convenient for verifiers.
   519  InFlightDiagnostic Operation::emitOpError(const Twine &message) {
   520    return emitError() << "'" << getName() << "' op " << message;
   521  }
   522  
   523  //===----------------------------------------------------------------------===//
   524  // Operation Cloning
   525  //===----------------------------------------------------------------------===//
   526  
   527  /// Create a deep copy of this operation but keep the operation regions empty.
   528  /// Operands are remapped using `mapper` (if present), and `mapper` is updated
   529  /// to contain the results.
   530  Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
   531    SmallVector<Value *, 8> operands;
   532    SmallVector<Block *, 2> successors;
   533  
   534    operands.reserve(getNumOperands() + getNumSuccessors());
   535  
   536    if (getNumSuccessors() == 0) {
   537      // Non-branching operations can just add all the operands.
   538      for (auto *opValue : getOperands())
   539        operands.push_back(mapper.lookupOrDefault(opValue));
   540    } else {
   541      // We add the operands separated by nullptr's for each successor.
   542      unsigned firstSuccOperand =
   543          getNumSuccessors() ? getSuccessorOperandIndex(0) : getNumOperands();
   544      auto opOperands = getOpOperands();
   545  
   546      unsigned i = 0;
   547      for (; i != firstSuccOperand; ++i)
   548        operands.push_back(mapper.lookupOrDefault(opOperands[i].get()));
   549  
   550      successors.reserve(getNumSuccessors());
   551      for (unsigned succ = 0, e = getNumSuccessors(); succ != e; ++succ) {
   552        successors.push_back(mapper.lookupOrDefault(getSuccessor(succ)));
   553  
   554        // Add sentinel to delineate successor operands.
   555        operands.push_back(nullptr);
   556  
   557        // Remap the successors operands.
   558        for (auto *operand : getSuccessorOperands(succ))
   559          operands.push_back(mapper.lookupOrDefault(operand));
   560      }
   561    }
   562  
   563    SmallVector<Type, 8> resultTypes(getResultTypes());
   564    unsigned numRegions = getNumRegions();
   565    auto *newOp =
   566        Operation::create(getLoc(), getName(), operands, resultTypes, attrs,
   567                          successors, numRegions, hasResizableOperandsList());
   568  
   569    // Remember the mapping of any results.
   570    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
   571      mapper.map(getResult(i), newOp->getResult(i));
   572  
   573    return newOp;
   574  }
   575  
   576  Operation *Operation::cloneWithoutRegions() {
   577    BlockAndValueMapping mapper;
   578    return cloneWithoutRegions(mapper);
   579  }
   580  
   581  /// Create a deep copy of this operation, remapping any operands that use
   582  /// values outside of the operation using the map that is provided (leaving
   583  /// them alone if no entry is present).  Replaces references to cloned
   584  /// sub-operations to the corresponding operation that is copied, and adds
   585  /// those mappings to the map.
   586  Operation *Operation::clone(BlockAndValueMapping &mapper) {
   587    auto *newOp = cloneWithoutRegions(mapper);
   588  
   589    // Clone the regions.
   590    for (unsigned i = 0; i != numRegions; ++i)
   591      getRegion(i).cloneInto(&newOp->getRegion(i), mapper);
   592  
   593    return newOp;
   594  }
   595  
   596  Operation *Operation::clone() {
   597    BlockAndValueMapping mapper;
   598    return clone(mapper);
   599  }
   600  
   601  //===----------------------------------------------------------------------===//
   602  // OpState trait class.
   603  //===----------------------------------------------------------------------===//
   604  
   605  // The fallback for the parser is to reject the custom assembly form.
   606  ParseResult OpState::parse(OpAsmParser *parser, OperationState *result) {
   607    return parser->emitError(parser->getNameLoc(), "has no custom assembly form");
   608  }
   609  
   610  // The fallback for the printer is to print in the generic assembly form.
   611  void OpState::print(OpAsmPrinter *p) { p->printGenericOp(getOperation()); }
   612  
   613  /// Emit an error about fatal conditions with this operation, reporting up to
   614  /// any diagnostic handlers that may be listening.
   615  InFlightDiagnostic OpState::emitError(const Twine &message) {
   616    return getOperation()->emitError(message);
   617  }
   618  
   619  /// Emit an error with the op name prefixed, like "'dim' op " which is
   620  /// convenient for verifiers.
   621  InFlightDiagnostic OpState::emitOpError(const Twine &message) {
   622    return getOperation()->emitOpError(message);
   623  }
   624  
   625  /// Emit a warning about this operation, reporting up to any diagnostic
   626  /// handlers that may be listening.
   627  InFlightDiagnostic OpState::emitWarning(const Twine &message) {
   628    return getOperation()->emitWarning(message);
   629  }
   630  
   631  /// Emit a remark about this operation, reporting up to any diagnostic
   632  /// handlers that may be listening.
   633  InFlightDiagnostic OpState::emitRemark(const Twine &message) {
   634    return getOperation()->emitRemark(message);
   635  }
   636  
   637  //===----------------------------------------------------------------------===//
   638  // Op Trait implementations
   639  //===----------------------------------------------------------------------===//
   640  
   641  LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) {
   642    if (op->getNumOperands() != 0)
   643      return op->emitOpError() << "requires zero operands";
   644    return success();
   645  }
   646  
   647  LogicalResult OpTrait::impl::verifyOneOperand(Operation *op) {
   648    if (op->getNumOperands() != 1)
   649      return op->emitOpError() << "requires a single operand";
   650    return success();
   651  }
   652  
   653  LogicalResult OpTrait::impl::verifyNOperands(Operation *op,
   654                                               unsigned numOperands) {
   655    if (op->getNumOperands() != numOperands) {
   656      return op->emitOpError() << "expected " << numOperands
   657                               << " operands, but found " << op->getNumOperands();
   658    }
   659    return success();
   660  }
   661  
   662  LogicalResult OpTrait::impl::verifyAtLeastNOperands(Operation *op,
   663                                                      unsigned numOperands) {
   664    if (op->getNumOperands() < numOperands)
   665      return op->emitOpError()
   666             << "expected " << numOperands << " or more operands";
   667    return success();
   668  }
   669  
   670  /// If this is a vector type, or a tensor type, return the scalar element type
   671  /// that it is built around, otherwise return the type unmodified.
   672  static Type getTensorOrVectorElementType(Type type) {
   673    if (auto vec = type.dyn_cast<VectorType>())
   674      return vec.getElementType();
   675  
   676    // Look through tensor<vector<...>> to find the underlying element type.
   677    if (auto tensor = type.dyn_cast<TensorType>())
   678      return getTensorOrVectorElementType(tensor.getElementType());
   679    return type;
   680  }
   681  
   682  LogicalResult OpTrait::impl::verifyOperandsAreIntegerLike(Operation *op) {
   683    for (auto opType : op->getOperandTypes()) {
   684      auto type = getTensorOrVectorElementType(opType);
   685      if (!type.isIntOrIndex())
   686        return op->emitOpError() << "requires an integer or index type";
   687    }
   688    return success();
   689  }
   690  
   691  LogicalResult OpTrait::impl::verifyOperandsAreFloatLike(Operation *op) {
   692    for (auto opType : op->getOperandTypes()) {
   693      auto type = getTensorOrVectorElementType(opType);
   694      if (!type.isa<FloatType>())
   695        return op->emitOpError("requires a float type");
   696    }
   697    return success();
   698  }
   699  
   700  LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
   701    // Zero or one operand always have the "same" type.
   702    unsigned nOperands = op->getNumOperands();
   703    if (nOperands < 2)
   704      return success();
   705  
   706    auto type = op->getOperand(0)->getType();
   707    for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
   708      if (opType != type)
   709        return op->emitOpError() << "requires all operands to have the same type";
   710    return success();
   711  }
   712  
   713  LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) {
   714    if (op->getNumResults() != 0)
   715      return op->emitOpError() << "requires zero results";
   716    return success();
   717  }
   718  
   719  LogicalResult OpTrait::impl::verifyOneResult(Operation *op) {
   720    if (op->getNumResults() != 1)
   721      return op->emitOpError() << "requires one result";
   722    return success();
   723  }
   724  
   725  LogicalResult OpTrait::impl::verifyNResults(Operation *op,
   726                                              unsigned numOperands) {
   727    if (op->getNumResults() != numOperands)
   728      return op->emitOpError() << "expected " << numOperands << " results";
   729    return success();
   730  }
   731  
   732  LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
   733                                                     unsigned numOperands) {
   734    if (op->getNumResults() < numOperands)
   735      return op->emitOpError()
   736             << "expected " << numOperands << " or more results";
   737    return success();
   738  }
   739  
   740  /// Returns success if the given two types have the same shape. That is,
   741  /// they are both scalars (not shaped), or they are both shaped types and at
   742  /// least one is unranked or they have the same shape. The element type does not
   743  /// matter.
   744  static LogicalResult verifyShapeMatch(Type type1, Type type2) {
   745    auto sType1 = type1.dyn_cast<ShapedType>();
   746    auto sType2 = type2.dyn_cast<ShapedType>();
   747  
   748    // Either both or neither type should be shaped.
   749    if (!sType1)
   750      return success(!sType2);
   751    if (!sType2)
   752      return failure();
   753  
   754    if (!sType1.hasRank() || !sType2.hasRank())
   755      return success();
   756  
   757    return success(sType1.getShape() == sType2.getShape());
   758  }
   759  
   760  LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
   761    if (op->getNumOperands() == 0)
   762      return failure();
   763  
   764    auto type = op->getOperand(0)->getType();
   765    for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
   766      if (failed(verifyShapeMatch(opType, type)))
   767        return op->emitOpError() << "requires the same shape for all operands";
   768    }
   769    return success();
   770  }
   771  
   772  LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
   773    if (op->getNumOperands() == 0 || op->getNumResults() == 0)
   774      return failure();
   775  
   776    auto type = op->getOperand(0)->getType();
   777    for (auto resultType : op->getResultTypes()) {
   778      if (failed(verifyShapeMatch(resultType, type)))
   779        return op->emitOpError()
   780               << "requires the same shape for all operands and results";
   781    }
   782    for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
   783      if (failed(verifyShapeMatch(opType, type)))
   784        return op->emitOpError()
   785               << "requires the same shape for all operands and results";
   786    }
   787    return success();
   788  }
   789  
   790  LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
   791    if (op->getNumOperands() == 0)
   792      return failure();
   793  
   794    auto type = op->getOperand(0)->getType().dyn_cast<ShapedType>();
   795    if (!type)
   796      return op->emitOpError("requires shaped type results");
   797    auto elementType = type.getElementType();
   798  
   799    for (auto operandType : llvm::drop_begin(op->getOperandTypes(), 1)) {
   800      auto shapedType = operandType.dyn_cast<ShapedType>();
   801      if (!shapedType)
   802        return op->emitOpError("requires shaped type operands");
   803      if (shapedType.getElementType() != elementType)
   804        return op->emitOpError("requires the same element type for all operands");
   805    }
   806  
   807    return success();
   808  }
   809  
   810  LogicalResult
   811  OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
   812    if (op->getNumOperands() == 0 || op->getNumResults() == 0)
   813      return failure();
   814  
   815    auto type = op->getResult(0)->getType().dyn_cast<ShapedType>();
   816    if (!type)
   817      return op->emitOpError("requires shaped type results");
   818    auto elementType = type.getElementType();
   819  
   820    // Verify result element type matches first result's element type.
   821    for (auto result : drop_begin(op->getResults(), 1)) {
   822      auto resultType = result->getType().dyn_cast<ShapedType>();
   823      if (!resultType)
   824        return op->emitOpError("requires shaped type results");
   825      if (resultType.getElementType() != elementType)
   826        return op->emitOpError(
   827            "requires the same element type for all operands and results");
   828    }
   829  
   830    // Verify operand's element type matches first result's element type.
   831    for (auto operand : op->getOperands()) {
   832      auto operandType = operand->getType().dyn_cast<ShapedType>();
   833      if (!operandType)
   834        return op->emitOpError("requires shaped type operands");
   835      if (operandType.getElementType() != elementType)
   836        return op->emitOpError(
   837            "requires the same element type for all operands and results");
   838    }
   839  
   840    return success();
   841  }
   842  
   843  LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
   844    if (op->getNumOperands() == 0 || op->getNumResults() == 0)
   845      return failure();
   846  
   847    auto type = op->getResult(0)->getType();
   848    for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
   849      if (resultType != type)
   850        return op->emitOpError()
   851               << "requires the same type for all operands and results";
   852    }
   853    for (auto opType : op->getOperandTypes()) {
   854      if (opType != type)
   855        return op->emitOpError()
   856               << "requires the same type for all operands and results";
   857    }
   858    return success();
   859  }
   860  
   861  static LogicalResult verifyBBArguments(Operation::operand_range operands,
   862                                         Block *destBB, Operation *op) {
   863    unsigned operandCount = std::distance(operands.begin(), operands.end());
   864    if (operandCount != destBB->getNumArguments())
   865      return op->emitError() << "branch has " << operandCount
   866                             << " operands, but target block has "
   867                             << destBB->getNumArguments();
   868  
   869    auto operandIt = operands.begin();
   870    for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
   871      if ((*operandIt)->getType() != destBB->getArgument(i)->getType())
   872        return op->emitError() << "type mismatch in bb argument #" << i;
   873    }
   874  
   875    return success();
   876  }
   877  
   878  static LogicalResult verifyTerminatorSuccessors(Operation *op) {
   879    auto *parent = op->getParentRegion();
   880  
   881    // Verify that the operands lines up with the BB arguments in the successor.
   882    for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
   883      auto *succ = op->getSuccessor(i);
   884      if (succ->getParent() != parent)
   885        return op->emitError("reference to block defined in another region");
   886      if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op)))
   887        return failure();
   888    }
   889    return success();
   890  }
   891  
   892  LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
   893    Block *block = op->getBlock();
   894    // Verify that the operation is at the end of the respective parent block.
   895    if (!block || &block->back() != op)
   896      return op->emitOpError("must be the last operation in the parent block");
   897  
   898    // Verify the state of the successor blocks.
   899    if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op)))
   900      return failure();
   901    return success();
   902  }
   903  
   904  LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
   905    for (auto resultType : op->getResultTypes()) {
   906      auto elementType = getTensorOrVectorElementType(resultType);
   907      bool isBoolType = elementType.isInteger(1);
   908      if (!isBoolType)
   909        return op->emitOpError() << "requires a bool result type";
   910    }
   911  
   912    return success();
   913  }
   914  
   915  LogicalResult OpTrait::impl::verifyResultsAreFloatLike(Operation *op) {
   916    for (auto resultType : op->getResultTypes())
   917      if (!getTensorOrVectorElementType(resultType).isa<FloatType>())
   918        return op->emitOpError() << "requires a floating point type";
   919  
   920    return success();
   921  }
   922  
   923  LogicalResult OpTrait::impl::verifyResultsAreIntegerLike(Operation *op) {
   924    for (auto resultType : op->getResultTypes())
   925      if (!getTensorOrVectorElementType(resultType).isIntOrIndex())
   926        return op->emitOpError() << "requires an integer or index type";
   927    return success();
   928  }
   929  
   930  //===----------------------------------------------------------------------===//
   931  // BinaryOp implementation
   932  //===----------------------------------------------------------------------===//
   933  
   934  // These functions are out-of-line implementations of the methods in BinaryOp,
   935  // which avoids them being template instantiated/duplicated.
   936  
   937  void impl::buildBinaryOp(Builder *builder, OperationState *result, Value *lhs,
   938                           Value *rhs) {
   939    assert(lhs->getType() == rhs->getType());
   940    result->addOperands({lhs, rhs});
   941    result->types.push_back(lhs->getType());
   942  }
   943  
   944  ParseResult impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
   945    SmallVector<OpAsmParser::OperandType, 2> ops;
   946    Type type;
   947    return failure(parser->parseOperandList(ops, 2) ||
   948                   parser->parseOptionalAttributeDict(result->attributes) ||
   949                   parser->parseColonType(type) ||
   950                   parser->resolveOperands(ops, type, result->operands) ||
   951                   parser->addTypeToList(type, result->types));
   952  }
   953  
   954  void impl::printBinaryOp(Operation *op, OpAsmPrinter *p) {
   955    assert(op->getNumOperands() == 2 && "binary op should have two operands");
   956    assert(op->getNumResults() == 1 && "binary op should have one result");
   957  
   958    // If not all the operand and result types are the same, just use the
   959    // generic assembly form to avoid omitting information in printing.
   960    auto resultType = op->getResult(0)->getType();
   961    if (op->getOperand(0)->getType() != resultType ||
   962        op->getOperand(1)->getType() != resultType) {
   963      p->printGenericOp(op);
   964      return;
   965    }
   966  
   967    *p << op->getName() << ' ' << *op->getOperand(0) << ", "
   968       << *op->getOperand(1);
   969    p->printOptionalAttrDict(op->getAttrs());
   970    // Now we can output only one type for all operands and the result.
   971    *p << " : " << op->getResult(0)->getType();
   972  }
   973  
   974  //===----------------------------------------------------------------------===//
   975  // CastOp implementation
   976  //===----------------------------------------------------------------------===//
   977  
   978  void impl::buildCastOp(Builder *builder, OperationState *result, Value *source,
   979                         Type destType) {
   980    result->addOperands(source);
   981    result->addTypes(destType);
   982  }
   983  
   984  ParseResult impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
   985    OpAsmParser::OperandType srcInfo;
   986    Type srcType, dstType;
   987    return failure(parser->parseOperand(srcInfo) ||
   988                   parser->parseOptionalAttributeDict(result->attributes) ||
   989                   parser->parseColonType(srcType) ||
   990                   parser->resolveOperand(srcInfo, srcType, result->operands) ||
   991                   parser->parseKeywordType("to", dstType) ||
   992                   parser->addTypeToList(dstType, result->types));
   993  }
   994  
   995  void impl::printCastOp(Operation *op, OpAsmPrinter *p) {
   996    *p << op->getName() << ' ' << *op->getOperand(0);
   997    p->printOptionalAttrDict(op->getAttrs());
   998    *p << " : " << op->getOperand(0)->getType() << " to "
   999       << op->getResult(0)->getType();
  1000  }
  1001  
  1002  Value *impl::foldCastOp(Operation *op) {
  1003    // Identity cast
  1004    if (op->getOperand(0)->getType() == op->getResult(0)->getType())
  1005      return op->getOperand(0);
  1006    return nullptr;
  1007  }
  1008  
  1009  //===----------------------------------------------------------------------===//
  1010  // CastOp implementation
  1011  //===----------------------------------------------------------------------===//
  1012  
  1013  /// Insert an operation, generated by `buildTerminatorOp`, at the end of the
  1014  /// region's only block if it does not have a terminator already. If the region
  1015  /// is empty, insert a new block first. `buildTerminatorOp` should return the
  1016  /// terminator operation to insert.
  1017  void impl::ensureRegionTerminator(
  1018      Region &region, Location loc,
  1019      llvm::function_ref<Operation *()> buildTerminatorOp) {
  1020    if (region.empty())
  1021      region.push_back(new Block);
  1022  
  1023    Block &block = region.back();
  1024    if (!block.empty() && block.back().isKnownTerminator())
  1025      return;
  1026  
  1027    block.push_back(buildTerminatorOp());
  1028  }