github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp (about)

     1  //===- Ops.cpp - Standard MLIR Operations ---------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/StandardOps/Ops.h"
    19  
    20  #include "mlir/IR/AffineExpr.h"
    21  #include "mlir/IR/AffineMap.h"
    22  #include "mlir/IR/Builders.h"
    23  #include "mlir/IR/Function.h"
    24  #include "mlir/IR/Matchers.h"
    25  #include "mlir/IR/Module.h"
    26  #include "mlir/IR/OpImplementation.h"
    27  #include "mlir/IR/PatternMatch.h"
    28  #include "mlir/IR/StandardTypes.h"
    29  #include "mlir/IR/Value.h"
    30  #include "mlir/Support/MathExtras.h"
    31  #include "mlir/Support/STLExtras.h"
    32  #include "llvm/ADT/StringSwitch.h"
    33  #include "llvm/Support/FormatVariadic.h"
    34  #include "llvm/Support/raw_ostream.h"
    35  using namespace mlir;
    36  
    37  //===----------------------------------------------------------------------===//
    38  // StandardOpsDialect Interfaces
    39  //===----------------------------------------------------------------------===//
    40  namespace {
    41  struct StdOpAsmInterface : public OpAsmDialectInterface {
    42    using OpAsmDialectInterface::OpAsmDialectInterface;
    43  
    44    /// Get a special name to use when printing the given operation. The desired
    45    /// name should be streamed into 'os'.
    46    void getOpResultName(Operation *op, raw_ostream &os) const final {
    47      if (ConstantOp constant = dyn_cast<ConstantOp>(op))
    48        return getConstantOpResultName(constant, os);
    49    }
    50  
    51    /// Get a special name to use when printing the given constant.
    52    static void getConstantOpResultName(ConstantOp op, raw_ostream &os) {
    53      Type type = op.getType();
    54      Attribute value = op.getValue();
    55      if (auto intCst = value.dyn_cast<IntegerAttr>()) {
    56        if (type.isIndex()) {
    57          os << 'c' << intCst.getInt();
    58        } else if (type.cast<IntegerType>().isInteger(1)) {
    59          // i1 constants get special names.
    60          os << (intCst.getInt() ? "true" : "false");
    61        } else {
    62          os << 'c' << intCst.getInt() << '_' << type;
    63        }
    64      } else if (type.isa<FunctionType>()) {
    65        os << 'f';
    66      } else {
    67        os << "cst";
    68      }
    69    }
    70  };
    71  } // end anonymous namespace
    72  
    73  //===----------------------------------------------------------------------===//
    74  // StandardOpsDialect
    75  //===----------------------------------------------------------------------===//
    76  
    77  /// A custom binary operation printer that omits the "std." prefix from the
    78  /// operation names.
    79  static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) {
    80    assert(op->getNumOperands() == 2 && "binary op should have two operands");
    81    assert(op->getNumResults() == 1 && "binary op should have one result");
    82  
    83    // If not all the operand and result types are the same, just use the
    84    // generic assembly form to avoid omitting information in printing.
    85    auto resultType = op->getResult(0)->getType();
    86    if (op->getOperand(0)->getType() != resultType ||
    87        op->getOperand(1)->getType() != resultType) {
    88      p->printGenericOp(op);
    89      return;
    90    }
    91  
    92    *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
    93       << *op->getOperand(0) << ", " << *op->getOperand(1);
    94    p->printOptionalAttrDict(op->getAttrs());
    95  
    96    // Now we can output only one type for all operands and the result.
    97    *p << " : " << op->getResult(0)->getType();
    98  }
    99  
   100  /// A custom cast operation printer that omits the "std." prefix from the
   101  /// operation names.
   102  static void printStandardCastOp(Operation *op, OpAsmPrinter *p) {
   103    *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
   104       << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
   105       << op->getResult(0)->getType();
   106  }
   107  
   108  /// A custom cast operation verifier.
   109  template <typename T> static LogicalResult verifyCastOp(T op) {
   110    auto opType = op.getOperand()->getType();
   111    auto resType = op.getType();
   112    if (!T::areCastCompatible(opType, resType))
   113      return op.emitError("operand type ") << opType << " and result type "
   114                                           << resType << " are cast incompatible";
   115  
   116    return success();
   117  }
   118  
   119  StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
   120      : Dialect(getDialectNamespace(), context) {
   121    addOperations<DmaStartOp, DmaWaitOp,
   122  #define GET_OP_LIST
   123  #include "mlir/Dialect/StandardOps/Ops.cpp.inc"
   124                  >();
   125    addInterfaces<StdOpAsmInterface>();
   126  }
   127  
   128  void mlir::printDimAndSymbolList(Operation::operand_iterator begin,
   129                                   Operation::operand_iterator end,
   130                                   unsigned numDims, OpAsmPrinter *p) {
   131    *p << '(';
   132    p->printOperands(begin, begin + numDims);
   133    *p << ')';
   134  
   135    if (begin + numDims != end) {
   136      *p << '[';
   137      p->printOperands(begin + numDims, end);
   138      *p << ']';
   139    }
   140  }
   141  
   142  // Parses dimension and symbol list, and sets 'numDims' to the number of
   143  // dimension operands parsed.
   144  // Returns 'false' on success and 'true' on error.
   145  ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser,
   146                                          SmallVector<Value *, 4> &operands,
   147                                          unsigned &numDims) {
   148    SmallVector<OpAsmParser::OperandType, 8> opInfos;
   149    if (parser->parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
   150      return failure();
   151    // Store number of dimensions for validation by caller.
   152    numDims = opInfos.size();
   153  
   154    // Parse the optional symbol operands.
   155    auto affineIntTy = parser->getBuilder().getIndexType();
   156    if (parser->parseOperandList(opInfos,
   157                                 OpAsmParser::Delimiter::OptionalSquare) ||
   158        parser->resolveOperands(opInfos, affineIntTy, operands))
   159      return failure();
   160    return success();
   161  }
   162  
   163  /// Matches a ConstantIndexOp.
   164  /// TODO: This should probably just be a general matcher that uses m_Constant
   165  /// and checks the operation for an index type.
   166  static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
   167    return detail::op_matcher<ConstantIndexOp>();
   168  }
   169  
   170  //===----------------------------------------------------------------------===//
   171  // Common canonicalization pattern support logic
   172  //===----------------------------------------------------------------------===//
   173  
   174  namespace {
   175  /// This is a common class used for patterns of the form
   176  /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
   177  /// into the root operation directly.
   178  struct MemRefCastFolder : public RewritePattern {
   179    /// The rootOpName is the name of the root operation to match against.
   180    MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
   181        : RewritePattern(rootOpName, 1, context) {}
   182  
   183    PatternMatchResult match(Operation *op) const override {
   184      for (auto *operand : op->getOperands())
   185        if (matchPattern(operand, m_Op<MemRefCastOp>()))
   186          return matchSuccess();
   187  
   188      return matchFailure();
   189    }
   190  
   191    void rewrite(Operation *op, PatternRewriter &rewriter) const override {
   192      for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
   193        if (auto *memref = op->getOperand(i)->getDefiningOp())
   194          if (auto cast = dyn_cast<MemRefCastOp>(memref))
   195            op->setOperand(i, cast.getOperand());
   196      rewriter.updatedRootInPlace(op);
   197    }
   198  };
   199  
   200  /// Performs const folding `calculate` with element-wise behavior on the two
   201  /// attributes in `operands` and returns the result if possible.
   202  template <class AttrElementT,
   203            class ElementValueT = typename AttrElementT::ValueType,
   204            class CalculationT =
   205                std::function<ElementValueT(ElementValueT, ElementValueT)>>
   206  Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
   207                              const CalculationT &calculate) {
   208    assert(operands.size() == 2 && "binary op takes two operands");
   209  
   210    if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
   211      auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
   212      if (!rhs || lhs.getType() != rhs.getType())
   213        return {};
   214  
   215      return AttrElementT::get(lhs.getType(),
   216                               calculate(lhs.getValue(), rhs.getValue()));
   217    } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
   218      auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>();
   219      if (!rhs || lhs.getType() != rhs.getType())
   220        return {};
   221  
   222      auto elementResult = constFoldBinaryOp<AttrElementT>(
   223          {lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
   224      if (!elementResult)
   225        return {};
   226  
   227      return DenseElementsAttr::get(lhs.getType(), elementResult);
   228    }
   229    return {};
   230  }
   231  } // end anonymous namespace.
   232  
   233  //===----------------------------------------------------------------------===//
   234  // AddFOp
   235  //===----------------------------------------------------------------------===//
   236  
   237  OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) {
   238    return constFoldBinaryOp<FloatAttr>(
   239        operands, [](APFloat a, APFloat b) { return a + b; });
   240  }
   241  
   242  //===----------------------------------------------------------------------===//
   243  // AddIOp
   244  //===----------------------------------------------------------------------===//
   245  
   246  OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
   247    /// addi(x, 0) -> x
   248    if (matchPattern(rhs(), m_Zero()))
   249      return lhs();
   250  
   251    return constFoldBinaryOp<IntegerAttr>(operands,
   252                                          [](APInt a, APInt b) { return a + b; });
   253  }
   254  
   255  //===----------------------------------------------------------------------===//
   256  // AllocOp
   257  //===----------------------------------------------------------------------===//
   258  
   259  static void print(OpAsmPrinter *p, AllocOp op) {
   260    *p << "alloc";
   261  
   262    // Print dynamic dimension operands.
   263    MemRefType type = op.getType();
   264    printDimAndSymbolList(op.operand_begin(), op.operand_end(),
   265                          type.getNumDynamicDims(), p);
   266    p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
   267    *p << " : " << type;
   268  }
   269  
   270  static ParseResult parseAllocOp(OpAsmParser *parser, OperationState *result) {
   271    MemRefType type;
   272  
   273    // Parse the dimension operands and optional symbol operands, followed by a
   274    // memref type.
   275    unsigned numDimOperands;
   276    if (parseDimAndSymbolList(parser, result->operands, numDimOperands) ||
   277        parser->parseOptionalAttributeDict(result->attributes) ||
   278        parser->parseColonType(type))
   279      return failure();
   280  
   281    // Check numDynamicDims against number of question marks in memref type.
   282    // Note: this check remains here (instead of in verify()), because the
   283    // partition between dim operands and symbol operands is lost after parsing.
   284    // Verification still checks that the total number of operands matches
   285    // the number of symbols in the affine map, plus the number of dynamic
   286    // dimensions in the memref.
   287    if (numDimOperands != type.getNumDynamicDims())
   288      return parser->emitError(parser->getNameLoc())
   289             << "dimension operand count does not equal memref dynamic dimension "
   290                "count";
   291    result->types.push_back(type);
   292    return success();
   293  }
   294  
   295  static LogicalResult verify(AllocOp op) {
   296    auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>();
   297    if (!memRefType)
   298      return op.emitOpError("result must be a memref");
   299  
   300    unsigned numSymbols = 0;
   301    if (!memRefType.getAffineMaps().empty()) {
   302      // Store number of symbols used in affine map (used in subsequent check).
   303      AffineMap affineMap = memRefType.getAffineMaps()[0];
   304      numSymbols = affineMap.getNumSymbols();
   305    }
   306  
   307    // Check that the total number of operands matches the number of symbols in
   308    // the affine map, plus the number of dynamic dimensions specified in the
   309    // memref type.
   310    unsigned numDynamicDims = memRefType.getNumDynamicDims();
   311    if (op.getOperation()->getNumOperands() != numDynamicDims + numSymbols)
   312      return op.emitOpError(
   313          "operand count does not equal dimension plus symbol operand count");
   314  
   315    // Verify that all operands are of type Index.
   316    for (auto operandType : op.getOperandTypes())
   317      if (!operandType.isIndex())
   318        return op.emitOpError("requires operands to be of type Index");
   319    return success();
   320  }
   321  
   322  namespace {
   323  /// Fold constant dimensions into an alloc operation.
   324  struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
   325    using OpRewritePattern<AllocOp>::OpRewritePattern;
   326  
   327    PatternMatchResult matchAndRewrite(AllocOp alloc,
   328                                       PatternRewriter &rewriter) const override {
   329      // Check to see if any dimensions operands are constants.  If so, we can
   330      // substitute and drop them.
   331      if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
   332            return matchPattern(operand, m_ConstantIndex());
   333          }))
   334        return matchFailure();
   335  
   336      auto memrefType = alloc.getType();
   337  
   338      // Ok, we have one or more constant operands.  Collect the non-constant ones
   339      // and keep track of the resultant memref type to build.
   340      SmallVector<int64_t, 4> newShapeConstants;
   341      newShapeConstants.reserve(memrefType.getRank());
   342      SmallVector<Value *, 4> newOperands;
   343      SmallVector<Value *, 4> droppedOperands;
   344  
   345      unsigned dynamicDimPos = 0;
   346      for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
   347        int64_t dimSize = memrefType.getDimSize(dim);
   348        // If this is already static dimension, keep it.
   349        if (dimSize != -1) {
   350          newShapeConstants.push_back(dimSize);
   351          continue;
   352        }
   353        auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
   354        if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
   355          // Dynamic shape dimension will be folded.
   356          newShapeConstants.push_back(constantIndexOp.getValue());
   357          // Record to check for zero uses later below.
   358          droppedOperands.push_back(constantIndexOp);
   359        } else {
   360          // Dynamic shape dimension not folded; copy operand from old memref.
   361          newShapeConstants.push_back(-1);
   362          newOperands.push_back(alloc.getOperand(dynamicDimPos));
   363        }
   364        dynamicDimPos++;
   365      }
   366  
   367      // Create new memref type (which will have fewer dynamic dimensions).
   368      auto newMemRefType = MemRefType::get(
   369          newShapeConstants, memrefType.getElementType(),
   370          memrefType.getAffineMaps(), memrefType.getMemorySpace());
   371      assert(static_cast<int64_t>(newOperands.size()) ==
   372             newMemRefType.getNumDynamicDims());
   373  
   374      // Create and insert the alloc op for the new memref.
   375      auto newAlloc =
   376          rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
   377      // Insert a cast so we have the same type as the old alloc.
   378      auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
   379                                                      alloc.getType());
   380  
   381      rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
   382      return matchSuccess();
   383    }
   384  };
   385  
   386  /// Fold alloc operations with no uses. Alloc has side effects on the heap,
   387  /// but can still be deleted if it has zero uses.
   388  struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
   389    using OpRewritePattern<AllocOp>::OpRewritePattern;
   390  
   391    PatternMatchResult matchAndRewrite(AllocOp alloc,
   392                                       PatternRewriter &rewriter) const override {
   393      // Check if the alloc'ed value has any uses.
   394      if (!alloc.use_empty())
   395        return matchFailure();
   396  
   397      // If it doesn't, we can eliminate it.
   398      alloc.erase();
   399      return matchSuccess();
   400    }
   401  };
   402  } // end anonymous namespace.
   403  
   404  void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   405                                            MLIRContext *context) {
   406    results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context);
   407  }
   408  
   409  //===----------------------------------------------------------------------===//
   410  // BranchOp
   411  //===----------------------------------------------------------------------===//
   412  
   413  static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) {
   414    Block *dest;
   415    SmallVector<Value *, 4> destOperands;
   416    if (parser->parseSuccessorAndUseList(dest, destOperands))
   417      return failure();
   418    result->addSuccessor(dest, destOperands);
   419    return success();
   420  }
   421  
   422  static void print(OpAsmPrinter *p, BranchOp op) {
   423    *p << "br ";
   424    p->printSuccessorAndUseList(op.getOperation(), 0);
   425  }
   426  
   427  Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); }
   428  
   429  void BranchOp::setDest(Block *block) {
   430    return getOperation()->setSuccessor(block, 0);
   431  }
   432  
   433  void BranchOp::eraseOperand(unsigned index) {
   434    getOperation()->eraseSuccessorOperand(0, index);
   435  }
   436  
   437  //===----------------------------------------------------------------------===//
   438  // CallOp
   439  //===----------------------------------------------------------------------===//
   440  
   441  static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
   442    SymbolRefAttr calleeAttr;
   443    FunctionType calleeType;
   444    SmallVector<OpAsmParser::OperandType, 4> operands;
   445    auto calleeLoc = parser->getNameLoc();
   446    if (parser->parseAttribute(calleeAttr, "callee", result->attributes) ||
   447        parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
   448        parser->parseOptionalAttributeDict(result->attributes) ||
   449        parser->parseColonType(calleeType) ||
   450        parser->addTypesToList(calleeType.getResults(), result->types) ||
   451        parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
   452                                result->operands))
   453      return failure();
   454  
   455    return success();
   456  }
   457  
   458  static void print(OpAsmPrinter *p, CallOp op) {
   459    *p << "call " << op.getAttr("callee") << '(';
   460    p->printOperands(op.getOperands());
   461    *p << ')';
   462    p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
   463    *p << " : ";
   464    p->printType(op.getCalleeType());
   465  }
   466  
   467  static LogicalResult verify(CallOp op) {
   468    // Check that the callee attribute was specified.
   469    auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
   470    if (!fnAttr)
   471      return op.emitOpError("requires a 'callee' symbol reference attribute");
   472    auto fn =
   473        op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
   474    if (!fn)
   475      return op.emitOpError() << "'" << fnAttr.getValue()
   476                              << "' does not reference a valid function";
   477  
   478    // Verify that the operand and result types match the callee.
   479    auto fnType = fn.getType();
   480    if (fnType.getNumInputs() != op.getNumOperands())
   481      return op.emitOpError("incorrect number of operands for callee");
   482  
   483    for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
   484      if (op.getOperand(i)->getType() != fnType.getInput(i))
   485        return op.emitOpError("operand type mismatch");
   486  
   487    if (fnType.getNumResults() != op.getNumResults())
   488      return op.emitOpError("incorrect number of results for callee");
   489  
   490    for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
   491      if (op.getResult(i)->getType() != fnType.getResult(i))
   492        return op.emitOpError("result type mismatch");
   493  
   494    return success();
   495  }
   496  
   497  FunctionType CallOp::getCalleeType() {
   498    SmallVector<Type, 4> resultTypes(getResultTypes());
   499    SmallVector<Type, 8> argTypes(getOperandTypes());
   500    return FunctionType::get(argTypes, resultTypes, getContext());
   501  }
   502  
   503  //===----------------------------------------------------------------------===//
   504  // CallIndirectOp
   505  //===----------------------------------------------------------------------===//
   506  namespace {
   507  /// Fold indirect calls that have a constant function as the callee operand.
   508  struct SimplifyIndirectCallWithKnownCallee
   509      : public OpRewritePattern<CallIndirectOp> {
   510    using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
   511  
   512    PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
   513                                       PatternRewriter &rewriter) const override {
   514      // Check that the callee is a constant callee.
   515      SymbolRefAttr calledFn;
   516      if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
   517        return matchFailure();
   518  
   519      // Replace with a direct call.
   520      SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
   521      SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
   522      rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
   523                                          callResults, callOperands);
   524      return matchSuccess();
   525    }
   526  };
   527  } // end anonymous namespace.
   528  
   529  static ParseResult parseCallIndirectOp(OpAsmParser *parser,
   530                                         OperationState *result) {
   531    FunctionType calleeType;
   532    OpAsmParser::OperandType callee;
   533    llvm::SMLoc operandsLoc;
   534    SmallVector<OpAsmParser::OperandType, 4> operands;
   535    return failure(
   536        parser->parseOperand(callee) ||
   537        parser->getCurrentLocation(&operandsLoc) ||
   538        parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
   539        parser->parseOptionalAttributeDict(result->attributes) ||
   540        parser->parseColonType(calleeType) ||
   541        parser->resolveOperand(callee, calleeType, result->operands) ||
   542        parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
   543                                result->operands) ||
   544        parser->addTypesToList(calleeType.getResults(), result->types));
   545  }
   546  
   547  static void print(OpAsmPrinter *p, CallIndirectOp op) {
   548    *p << "call_indirect ";
   549    p->printOperand(op.getCallee());
   550    *p << '(';
   551    p->printOperands(op.getArgOperands());
   552    *p << ')';
   553    p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
   554    *p << " : " << op.getCallee()->getType();
   555  }
   556  
   557  static LogicalResult verify(CallIndirectOp op) {
   558    // The callee must be a function.
   559    auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
   560    if (!fnType)
   561      return op.emitOpError("callee must have function type");
   562  
   563    // Verify that the operand and result types match the callee.
   564    if (fnType.getNumInputs() != op.getNumOperands() - 1)
   565      return op.emitOpError("incorrect number of operands for callee");
   566  
   567    for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
   568      if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
   569        return op.emitOpError("operand type mismatch");
   570  
   571    if (fnType.getNumResults() != op.getNumResults())
   572      return op.emitOpError("incorrect number of results for callee");
   573  
   574    for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
   575      if (op.getResult(i)->getType() != fnType.getResult(i))
   576        return op.emitOpError("result type mismatch");
   577  
   578    return success();
   579  }
   580  
   581  void CallIndirectOp::getCanonicalizationPatterns(
   582      OwningRewritePatternList &results, MLIRContext *context) {
   583    results.insert<SimplifyIndirectCallWithKnownCallee>(context);
   584  }
   585  
   586  //===----------------------------------------------------------------------===//
   587  // General helpers for comparison ops
   588  //===----------------------------------------------------------------------===//
   589  
   590  // Return the type of the same shape (scalar, vector or tensor) containing i1.
   591  static Type getCheckedI1SameShape(Builder *build, Type type) {
   592    auto i1Type = build->getI1Type();
   593    if (type.isIntOrIndexOrFloat())
   594      return i1Type;
   595    if (auto tensorType = type.dyn_cast<RankedTensorType>())
   596      return build->getTensorType(tensorType.getShape(), i1Type);
   597    if (type.isa<UnrankedTensorType>())
   598      return build->getTensorType(i1Type);
   599    if (auto vectorType = type.dyn_cast<VectorType>())
   600      return build->getVectorType(vectorType.getShape(), i1Type);
   601    return Type();
   602  }
   603  
   604  static Type getI1SameShape(Builder *build, Type type) {
   605    Type res = getCheckedI1SameShape(build, type);
   606    assert(res && "expected type with valid i1 shape");
   607    return res;
   608  }
   609  
   610  //===----------------------------------------------------------------------===//
   611  // CmpIOp
   612  //===----------------------------------------------------------------------===//
   613  
   614  // Returns an array of mnemonics for CmpIPredicates indexed by values thereof.
   615  static inline const char *const *getCmpIPredicateNames() {
   616    static const char *predicateNames[]{
   617        /*EQ*/ "eq",
   618        /*NE*/ "ne",
   619        /*SLT*/ "slt",
   620        /*SLE*/ "sle",
   621        /*SGT*/ "sgt",
   622        /*SGE*/ "sge",
   623        /*ULT*/ "ult",
   624        /*ULE*/ "ule",
   625        /*UGT*/ "ugt",
   626        /*UGE*/ "uge",
   627    };
   628    static_assert(std::extent<decltype(predicateNames)>::value ==
   629                      (size_t)CmpIPredicate::NumPredicates,
   630                  "wrong number of predicate names");
   631    return predicateNames;
   632  }
   633  
   634  // Returns a value of the predicate corresponding to the given mnemonic.
   635  // Returns NumPredicates (one-past-end) if there is no such mnemonic.
   636  CmpIPredicate CmpIOp::getPredicateByName(StringRef name) {
   637    return llvm::StringSwitch<CmpIPredicate>(name)
   638        .Case("eq", CmpIPredicate::EQ)
   639        .Case("ne", CmpIPredicate::NE)
   640        .Case("slt", CmpIPredicate::SLT)
   641        .Case("sle", CmpIPredicate::SLE)
   642        .Case("sgt", CmpIPredicate::SGT)
   643        .Case("sge", CmpIPredicate::SGE)
   644        .Case("ult", CmpIPredicate::ULT)
   645        .Case("ule", CmpIPredicate::ULE)
   646        .Case("ugt", CmpIPredicate::UGT)
   647        .Case("uge", CmpIPredicate::UGE)
   648        .Default(CmpIPredicate::NumPredicates);
   649  }
   650  
   651  static void buildCmpIOp(Builder *build, OperationState *result,
   652                          CmpIPredicate predicate, Value *lhs, Value *rhs) {
   653    result->addOperands({lhs, rhs});
   654    result->types.push_back(getI1SameShape(build, lhs->getType()));
   655    result->addAttribute(
   656        CmpIOp::getPredicateAttrName(),
   657        build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
   658  }
   659  
   660  static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) {
   661    SmallVector<OpAsmParser::OperandType, 2> ops;
   662    SmallVector<NamedAttribute, 4> attrs;
   663    Attribute predicateNameAttr;
   664    Type type;
   665    if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(),
   666                               attrs) ||
   667        parser->parseComma() || parser->parseOperandList(ops, 2) ||
   668        parser->parseOptionalAttributeDict(attrs) ||
   669        parser->parseColonType(type) ||
   670        parser->resolveOperands(ops, type, result->operands))
   671      return failure();
   672  
   673    if (!predicateNameAttr.isa<StringAttr>())
   674      return parser->emitError(parser->getNameLoc(),
   675                               "expected string comparison predicate attribute");
   676  
   677    // Rewrite string attribute to an enum value.
   678    StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
   679    auto predicate = CmpIOp::getPredicateByName(predicateName);
   680    if (predicate == CmpIPredicate::NumPredicates)
   681      return parser->emitError(parser->getNameLoc())
   682             << "unknown comparison predicate \"" << predicateName << "\"";
   683  
   684    auto builder = parser->getBuilder();
   685    Type i1Type = getCheckedI1SameShape(&builder, type);
   686    if (!i1Type)
   687      return parser->emitError(parser->getNameLoc(),
   688                               "expected type with valid i1 shape");
   689  
   690    attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
   691    result->attributes = attrs;
   692  
   693    result->addTypes({i1Type});
   694    return success();
   695  }
   696  
   697  static void print(OpAsmPrinter *p, CmpIOp op) {
   698    *p << "cmpi ";
   699  
   700    auto predicateValue =
   701        op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt();
   702    assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
   703           predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
   704           "unknown predicate index");
   705    Builder b(op.getContext());
   706    auto predicateStringAttr =
   707        b.getStringAttr(getCmpIPredicateNames()[predicateValue]);
   708    p->printAttribute(predicateStringAttr);
   709  
   710    *p << ", ";
   711    p->printOperand(op.lhs());
   712    *p << ", ";
   713    p->printOperand(op.rhs());
   714    p->printOptionalAttrDict(op.getAttrs(),
   715                             /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
   716    *p << " : " << op.lhs()->getType();
   717  }
   718  
   719  static LogicalResult verify(CmpIOp op) {
   720    auto predicateAttr =
   721        op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName());
   722    if (!predicateAttr)
   723      return op.emitOpError("requires an integer attribute named 'predicate'");
   724    auto predicate = predicateAttr.getInt();
   725    if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
   726        predicate >= (int64_t)CmpIPredicate::NumPredicates)
   727      return op.emitOpError("'predicate' attribute value out of range");
   728  
   729    return success();
   730  }
   731  
   732  // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
   733  // comparison predicates.
   734  static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
   735                                const APInt &rhs) {
   736    switch (predicate) {
   737    case CmpIPredicate::EQ:
   738      return lhs.eq(rhs);
   739    case CmpIPredicate::NE:
   740      return lhs.ne(rhs);
   741    case CmpIPredicate::SLT:
   742      return lhs.slt(rhs);
   743    case CmpIPredicate::SLE:
   744      return lhs.sle(rhs);
   745    case CmpIPredicate::SGT:
   746      return lhs.sgt(rhs);
   747    case CmpIPredicate::SGE:
   748      return lhs.sge(rhs);
   749    case CmpIPredicate::ULT:
   750      return lhs.ult(rhs);
   751    case CmpIPredicate::ULE:
   752      return lhs.ule(rhs);
   753    case CmpIPredicate::UGT:
   754      return lhs.ugt(rhs);
   755    case CmpIPredicate::UGE:
   756      return lhs.uge(rhs);
   757    default:
   758      llvm_unreachable("unknown comparison predicate");
   759    }
   760  }
   761  
   762  // Constant folding hook for comparisons.
   763  OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
   764    assert(operands.size() == 2 && "cmpi takes two arguments");
   765  
   766    auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   767    auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
   768    if (!lhs || !rhs)
   769      return {};
   770  
   771    auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
   772    return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
   773  }
   774  
   775  //===----------------------------------------------------------------------===//
   776  // CmpFOp
   777  //===----------------------------------------------------------------------===//
   778  
   779  // Returns an array of mnemonics for CmpFPredicates indexed by values thereof.
   780  static inline const char *const *getCmpFPredicateNames() {
   781    static const char *predicateNames[] = {
   782        /*AlwaysFalse*/ "false",
   783        /*OEQ*/ "oeq",
   784        /*OGT*/ "ogt",
   785        /*OGE*/ "oge",
   786        /*OLT*/ "olt",
   787        /*OLE*/ "ole",
   788        /*ONE*/ "one",
   789        /*ORD*/ "ord",
   790        /*UEQ*/ "ueq",
   791        /*UGT*/ "ugt",
   792        /*UGE*/ "uge",
   793        /*ULT*/ "ult",
   794        /*ULE*/ "ule",
   795        /*UNE*/ "une",
   796        /*UNO*/ "uno",
   797        /*AlwaysTrue*/ "true",
   798    };
   799    static_assert(std::extent<decltype(predicateNames)>::value ==
   800                      (size_t)CmpFPredicate::NumPredicates,
   801                  "wrong number of predicate names");
   802    return predicateNames;
   803  }
   804  
   805  // Returns a value of the predicate corresponding to the given mnemonic.
   806  // Returns NumPredicates (one-past-end) if there is no such mnemonic.
   807  CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
   808    return llvm::StringSwitch<CmpFPredicate>(name)
   809        .Case("false", CmpFPredicate::AlwaysFalse)
   810        .Case("oeq", CmpFPredicate::OEQ)
   811        .Case("ogt", CmpFPredicate::OGT)
   812        .Case("oge", CmpFPredicate::OGE)
   813        .Case("olt", CmpFPredicate::OLT)
   814        .Case("ole", CmpFPredicate::OLE)
   815        .Case("one", CmpFPredicate::ONE)
   816        .Case("ord", CmpFPredicate::ORD)
   817        .Case("ueq", CmpFPredicate::UEQ)
   818        .Case("ugt", CmpFPredicate::UGT)
   819        .Case("uge", CmpFPredicate::UGE)
   820        .Case("ult", CmpFPredicate::ULT)
   821        .Case("ule", CmpFPredicate::ULE)
   822        .Case("une", CmpFPredicate::UNE)
   823        .Case("uno", CmpFPredicate::UNO)
   824        .Case("true", CmpFPredicate::AlwaysTrue)
   825        .Default(CmpFPredicate::NumPredicates);
   826  }
   827  
   828  static void buildCmpFOp(Builder *build, OperationState *result,
   829                          CmpFPredicate predicate, Value *lhs, Value *rhs) {
   830    result->addOperands({lhs, rhs});
   831    result->types.push_back(getI1SameShape(build, lhs->getType()));
   832    result->addAttribute(
   833        CmpFOp::getPredicateAttrName(),
   834        build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
   835  }
   836  
   837  static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) {
   838    SmallVector<OpAsmParser::OperandType, 2> ops;
   839    SmallVector<NamedAttribute, 4> attrs;
   840    Attribute predicateNameAttr;
   841    Type type;
   842    if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(),
   843                               attrs) ||
   844        parser->parseComma() || parser->parseOperandList(ops, 2) ||
   845        parser->parseOptionalAttributeDict(attrs) ||
   846        parser->parseColonType(type) ||
   847        parser->resolveOperands(ops, type, result->operands))
   848      return failure();
   849  
   850    if (!predicateNameAttr.isa<StringAttr>())
   851      return parser->emitError(parser->getNameLoc(),
   852                               "expected string comparison predicate attribute");
   853  
   854    // Rewrite string attribute to an enum value.
   855    StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue();
   856    auto predicate = CmpFOp::getPredicateByName(predicateName);
   857    if (predicate == CmpFPredicate::NumPredicates)
   858      return parser->emitError(parser->getNameLoc(),
   859                               "unknown comparison predicate \"" + predicateName +
   860                                   "\"");
   861  
   862    auto builder = parser->getBuilder();
   863    Type i1Type = getCheckedI1SameShape(&builder, type);
   864    if (!i1Type)
   865      return parser->emitError(parser->getNameLoc(),
   866                               "expected type with valid i1 shape");
   867  
   868    attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
   869    result->attributes = attrs;
   870  
   871    result->addTypes({i1Type});
   872    return success();
   873  }
   874  
   875  static void print(OpAsmPrinter *p, CmpFOp op) {
   876    *p << "cmpf ";
   877  
   878    auto predicateValue =
   879        op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt();
   880    assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) &&
   881           predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) &&
   882           "unknown predicate index");
   883    Builder b(op.getContext());
   884    auto predicateStringAttr =
   885        b.getStringAttr(getCmpFPredicateNames()[predicateValue]);
   886    p->printAttribute(predicateStringAttr);
   887  
   888    *p << ", ";
   889    p->printOperand(op.lhs());
   890    *p << ", ";
   891    p->printOperand(op.rhs());
   892    p->printOptionalAttrDict(op.getAttrs(),
   893                             /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
   894    *p << " : " << op.lhs()->getType();
   895  }
   896  
   897  static LogicalResult verify(CmpFOp op) {
   898    auto predicateAttr =
   899        op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName());
   900    if (!predicateAttr)
   901      return op.emitOpError("requires an integer attribute named 'predicate'");
   902    auto predicate = predicateAttr.getInt();
   903    if (predicate < (int64_t)CmpFPredicate::FirstValidValue ||
   904        predicate >= (int64_t)CmpFPredicate::NumPredicates)
   905      return op.emitOpError("'predicate' attribute value out of range");
   906  
   907    return success();
   908  }
   909  
   910  // Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
   911  // comparison predicates.
   912  static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
   913                                const APFloat &rhs) {
   914    auto cmpResult = lhs.compare(rhs);
   915    switch (predicate) {
   916    case CmpFPredicate::AlwaysFalse:
   917      return false;
   918    case CmpFPredicate::OEQ:
   919      return cmpResult == APFloat::cmpEqual;
   920    case CmpFPredicate::OGT:
   921      return cmpResult == APFloat::cmpGreaterThan;
   922    case CmpFPredicate::OGE:
   923      return cmpResult == APFloat::cmpGreaterThan ||
   924             cmpResult == APFloat::cmpEqual;
   925    case CmpFPredicate::OLT:
   926      return cmpResult == APFloat::cmpLessThan;
   927    case CmpFPredicate::OLE:
   928      return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
   929    case CmpFPredicate::ONE:
   930      return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
   931    case CmpFPredicate::ORD:
   932      return cmpResult != APFloat::cmpUnordered;
   933    case CmpFPredicate::UEQ:
   934      return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
   935    case CmpFPredicate::UGT:
   936      return cmpResult == APFloat::cmpUnordered ||
   937             cmpResult == APFloat::cmpGreaterThan;
   938    case CmpFPredicate::UGE:
   939      return cmpResult == APFloat::cmpUnordered ||
   940             cmpResult == APFloat::cmpGreaterThan ||
   941             cmpResult == APFloat::cmpEqual;
   942    case CmpFPredicate::ULT:
   943      return cmpResult == APFloat::cmpUnordered ||
   944             cmpResult == APFloat::cmpLessThan;
   945    case CmpFPredicate::ULE:
   946      return cmpResult == APFloat::cmpUnordered ||
   947             cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
   948    case CmpFPredicate::UNE:
   949      return cmpResult != APFloat::cmpEqual;
   950    case CmpFPredicate::UNO:
   951      return cmpResult == APFloat::cmpUnordered;
   952    case CmpFPredicate::AlwaysTrue:
   953      return true;
   954    default:
   955      llvm_unreachable("unknown comparison predicate");
   956    }
   957  }
   958  
   959  // Constant folding hook for comparisons.
   960  OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
   961    assert(operands.size() == 2 && "cmpf takes two arguments");
   962  
   963    auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
   964    auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
   965    if (!lhs || !rhs ||
   966        // TODO(b/122019992) Implement and test constant folding for nan/inf when
   967        // it is possible to have constant nan/inf
   968        !lhs.getValue().isFinite() || !rhs.getValue().isFinite())
   969      return {};
   970  
   971    auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
   972    return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
   973  }
   974  
   975  //===----------------------------------------------------------------------===//
   976  // CondBranchOp
   977  //===----------------------------------------------------------------------===//
   978  
   979  namespace {
   980  /// cond_br true, ^bb1, ^bb2 -> br ^bb1
   981  /// cond_br false, ^bb1, ^bb2 -> br ^bb2
   982  ///
   983  struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
   984    using OpRewritePattern<CondBranchOp>::OpRewritePattern;
   985  
   986    PatternMatchResult matchAndRewrite(CondBranchOp condbr,
   987                                       PatternRewriter &rewriter) const override {
   988      // Check that the condition is a constant.
   989      if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
   990        return matchFailure();
   991  
   992      Block *foldedDest;
   993      SmallVector<Value *, 4> branchArgs;
   994  
   995      // If the condition is known to evaluate to false we fold to a branch to the
   996      // false destination. Otherwise, we fold to a branch to the true
   997      // destination.
   998      if (matchPattern(condbr.getCondition(), m_Zero())) {
   999        foldedDest = condbr.getFalseDest();
  1000        branchArgs.assign(condbr.false_operand_begin(),
  1001                          condbr.false_operand_end());
  1002      } else {
  1003        foldedDest = condbr.getTrueDest();
  1004        branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
  1005      }
  1006  
  1007      rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
  1008      return matchSuccess();
  1009    }
  1010  };
  1011  } // end anonymous namespace.
  1012  
  1013  static ParseResult parseCondBranchOp(OpAsmParser *parser,
  1014                                       OperationState *result) {
  1015    SmallVector<Value *, 4> destOperands;
  1016    Block *dest;
  1017    OpAsmParser::OperandType condInfo;
  1018  
  1019    // Parse the condition.
  1020    Type int1Ty = parser->getBuilder().getI1Type();
  1021    if (parser->parseOperand(condInfo) || parser->parseComma() ||
  1022        parser->resolveOperand(condInfo, int1Ty, result->operands)) {
  1023      return parser->emitError(parser->getNameLoc(),
  1024                               "expected condition type was boolean (i1)");
  1025    }
  1026  
  1027    // Parse the true successor.
  1028    if (parser->parseSuccessorAndUseList(dest, destOperands))
  1029      return failure();
  1030    result->addSuccessor(dest, destOperands);
  1031  
  1032    // Parse the false successor.
  1033    destOperands.clear();
  1034    if (parser->parseComma() ||
  1035        parser->parseSuccessorAndUseList(dest, destOperands))
  1036      return failure();
  1037    result->addSuccessor(dest, destOperands);
  1038  
  1039    return success();
  1040  }
  1041  
  1042  static void print(OpAsmPrinter *p, CondBranchOp op) {
  1043    *p << "cond_br ";
  1044    p->printOperand(op.getCondition());
  1045    *p << ", ";
  1046    p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex);
  1047    *p << ", ";
  1048    p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex);
  1049  }
  1050  
  1051  void CondBranchOp::getCanonicalizationPatterns(
  1052      OwningRewritePatternList &results, MLIRContext *context) {
  1053    results.insert<SimplifyConstCondBranchPred>(context);
  1054  }
  1055  
  1056  //===----------------------------------------------------------------------===//
  1057  // Constant*Op
  1058  //===----------------------------------------------------------------------===//
  1059  
  1060  static void print(OpAsmPrinter *p, ConstantOp &op) {
  1061    *p << "constant ";
  1062    p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
  1063  
  1064    if (op.getAttrs().size() > 1)
  1065      *p << ' ';
  1066    p->printAttribute(op.getValue());
  1067  
  1068    // If the value is a symbol reference, print a trailing type.
  1069    if (op.getValue().isa<SymbolRefAttr>())
  1070      *p << " : " << op.getType();
  1071  }
  1072  
  1073  static ParseResult parseConstantOp(OpAsmParser *parser,
  1074                                     OperationState *result) {
  1075    Attribute valueAttr;
  1076    if (parser->parseOptionalAttributeDict(result->attributes) ||
  1077        parser->parseAttribute(valueAttr, "value", result->attributes))
  1078      return failure();
  1079  
  1080    // If the attribute is a symbol reference, then we expect a trailing type.
  1081    Type type;
  1082    if (!valueAttr.isa<SymbolRefAttr>())
  1083      type = valueAttr.getType();
  1084    else if (parser->parseColonType(type))
  1085      return failure();
  1086  
  1087    // Add the attribute type to the list.
  1088    return parser->addTypeToList(type, result->types);
  1089  }
  1090  
  1091  /// The constant op requires an attribute, and furthermore requires that it
  1092  /// matches the return type.
  1093  static LogicalResult verify(ConstantOp &op) {
  1094    auto value = op.getValue();
  1095    if (!value)
  1096      return op.emitOpError("requires a 'value' attribute");
  1097  
  1098    auto type = op.getType();
  1099    if (!value.getType().isa<NoneType>() && type != value.getType())
  1100      return op.emitOpError() << "requires attribute's type (" << value.getType()
  1101                              << ") to match op's return type (" << type << ")";
  1102  
  1103    if (type.isa<IndexType>() || value.isa<BoolAttr>())
  1104      return success();
  1105  
  1106    if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
  1107      // If the type has a known bitwidth we verify that the value can be
  1108      // represented with the given bitwidth.
  1109      auto bitwidth = type.cast<IntegerType>().getWidth();
  1110      auto intVal = intAttr.getValue();
  1111      if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth))
  1112        return op.emitOpError("requires 'value' to be an integer within the "
  1113                              "range of the integer result type");
  1114      return success();
  1115    }
  1116  
  1117    if (type.isa<FloatType>()) {
  1118      if (!value.isa<FloatAttr>())
  1119        return op.emitOpError("requires 'value' to be a floating point constant");
  1120      return success();
  1121    }
  1122  
  1123    if (type.isa<ShapedType>()) {
  1124      if (!value.isa<ElementsAttr>())
  1125        return op.emitOpError("requires 'value' to be a shaped constant");
  1126      return success();
  1127    }
  1128  
  1129    if (type.isa<FunctionType>()) {
  1130      auto fnAttr = value.dyn_cast<SymbolRefAttr>();
  1131      if (!fnAttr)
  1132        return op.emitOpError("requires 'value' to be a function reference");
  1133  
  1134      // Try to find the referenced function.
  1135      auto fn =
  1136          op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
  1137      if (!fn)
  1138        return op.emitOpError("reference to undefined function 'bar'");
  1139  
  1140      // Check that the referenced function has the correct type.
  1141      if (fn.getType() != type)
  1142        return op.emitOpError("reference to function with mismatched type");
  1143  
  1144      return success();
  1145    }
  1146  
  1147    if (type.isa<NoneType>() && value.isa<UnitAttr>())
  1148      return success();
  1149  
  1150    return op.emitOpError("unsupported 'value' attribute: ") << value;
  1151  }
  1152  
  1153  OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
  1154    assert(operands.empty() && "constant has no operands");
  1155    return getValue();
  1156  }
  1157  
  1158  /// Returns true if a constant operation can be built with the given value and
  1159  /// result type.
  1160  bool ConstantOp::isBuildableWith(Attribute value, Type type) {
  1161    // SymbolRefAttr can only be used with a function type.
  1162    if (value.isa<SymbolRefAttr>())
  1163      return type.isa<FunctionType>();
  1164    // Otherwise, the attribute must have the same type as 'type'.
  1165    if (value.getType() != type)
  1166      return false;
  1167    // Finally, check that the attribute kind is handled.
  1168    return value.isa<BoolAttr>() || value.isa<IntegerAttr>() ||
  1169           value.isa<FloatAttr>() || value.isa<ElementsAttr>() ||
  1170           value.isa<UnitAttr>();
  1171  }
  1172  
  1173  void ConstantFloatOp::build(Builder *builder, OperationState *result,
  1174                              const APFloat &value, FloatType type) {
  1175    ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value));
  1176  }
  1177  
  1178  bool ConstantFloatOp::classof(Operation *op) {
  1179    return ConstantOp::classof(op) &&
  1180           op->getResult(0)->getType().isa<FloatType>();
  1181  }
  1182  
  1183  /// ConstantIntOp only matches values whose result type is an IntegerType.
  1184  bool ConstantIntOp::classof(Operation *op) {
  1185    return ConstantOp::classof(op) &&
  1186           op->getResult(0)->getType().isa<IntegerType>();
  1187  }
  1188  
  1189  void ConstantIntOp::build(Builder *builder, OperationState *result,
  1190                            int64_t value, unsigned width) {
  1191    Type type = builder->getIntegerType(width);
  1192    ConstantOp::build(builder, result, type,
  1193                      builder->getIntegerAttr(type, value));
  1194  }
  1195  
  1196  /// Build a constant int op producing an integer with the specified type,
  1197  /// which must be an integer type.
  1198  void ConstantIntOp::build(Builder *builder, OperationState *result,
  1199                            int64_t value, Type type) {
  1200    assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
  1201    ConstantOp::build(builder, result, type,
  1202                      builder->getIntegerAttr(type, value));
  1203  }
  1204  
  1205  /// ConstantIndexOp only matches values whose result type is Index.
  1206  bool ConstantIndexOp::classof(Operation *op) {
  1207    return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex();
  1208  }
  1209  
  1210  void ConstantIndexOp::build(Builder *builder, OperationState *result,
  1211                              int64_t value) {
  1212    Type type = builder->getIndexType();
  1213    ConstantOp::build(builder, result, type,
  1214                      builder->getIntegerAttr(type, value));
  1215  }
  1216  
  1217  //===----------------------------------------------------------------------===//
  1218  // DeallocOp
  1219  //===----------------------------------------------------------------------===//
  1220  namespace {
  1221  /// Fold Dealloc operations that are deallocating an AllocOp that is only used
  1222  /// by other Dealloc operations.
  1223  struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
  1224    using OpRewritePattern<DeallocOp>::OpRewritePattern;
  1225  
  1226    PatternMatchResult matchAndRewrite(DeallocOp dealloc,
  1227                                       PatternRewriter &rewriter) const override {
  1228      // Check that the memref operand's defining operation is an AllocOp.
  1229      Value *memref = dealloc.memref();
  1230      if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
  1231        return matchFailure();
  1232  
  1233      // Check that all of the uses of the AllocOp are other DeallocOps.
  1234      for (auto *user : memref->getUsers())
  1235        if (!isa<DeallocOp>(user))
  1236          return matchFailure();
  1237  
  1238      // Erase the dealloc operation.
  1239      rewriter.replaceOp(dealloc, llvm::None);
  1240      return matchSuccess();
  1241    }
  1242  };
  1243  } // end anonymous namespace.
  1244  
  1245  static void print(OpAsmPrinter *p, DeallocOp op) {
  1246    *p << "dealloc " << *op.memref() << " : " << op.memref()->getType();
  1247  }
  1248  
  1249  static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) {
  1250    OpAsmParser::OperandType memrefInfo;
  1251    MemRefType type;
  1252  
  1253    return failure(parser->parseOperand(memrefInfo) ||
  1254                   parser->parseColonType(type) ||
  1255                   parser->resolveOperand(memrefInfo, type, result->operands));
  1256  }
  1257  
  1258  static LogicalResult verify(DeallocOp op) {
  1259    if (!op.memref()->getType().isa<MemRefType>())
  1260      return op.emitOpError("operand must be a memref");
  1261    return success();
  1262  }
  1263  
  1264  void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
  1265                                              MLIRContext *context) {
  1266    /// dealloc(memrefcast) -> dealloc
  1267    results.insert<MemRefCastFolder>(getOperationName(), context);
  1268    results.insert<SimplifyDeadDealloc>(context);
  1269  }
  1270  
  1271  //===----------------------------------------------------------------------===//
  1272  // DimOp
  1273  //===----------------------------------------------------------------------===//
  1274  
  1275  static void print(OpAsmPrinter *p, DimOp op) {
  1276    *p << "dim " << *op.getOperand() << ", " << op.getIndex();
  1277    p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
  1278    *p << " : " << op.getOperand()->getType();
  1279  }
  1280  
  1281  static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
  1282    OpAsmParser::OperandType operandInfo;
  1283    IntegerAttr indexAttr;
  1284    Type type;
  1285    Type indexType = parser->getBuilder().getIndexType();
  1286  
  1287    return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
  1288                   parser->parseAttribute(indexAttr, indexType, "index",
  1289                                          result->attributes) ||
  1290                   parser->parseOptionalAttributeDict(result->attributes) ||
  1291                   parser->parseColonType(type) ||
  1292                   parser->resolveOperand(operandInfo, type, result->operands) ||
  1293                   parser->addTypeToList(indexType, result->types));
  1294  }
  1295  
  1296  static LogicalResult verify(DimOp op) {
  1297    // Check that we have an integer index operand.
  1298    auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
  1299    if (!indexAttr)
  1300      return op.emitOpError("requires an integer attribute named 'index'");
  1301    int64_t index = indexAttr.getValue().getSExtValue();
  1302  
  1303    auto type = op.getOperand()->getType();
  1304    if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
  1305      if (index >= tensorType.getRank())
  1306        return op.emitOpError("index is out of range");
  1307    } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
  1308      if (index >= memrefType.getRank())
  1309        return op.emitOpError("index is out of range");
  1310  
  1311    } else if (type.isa<UnrankedTensorType>()) {
  1312      // ok, assumed to be in-range.
  1313    } else {
  1314      return op.emitOpError("requires an operand with tensor or memref type");
  1315    }
  1316  
  1317    return success();
  1318  }
  1319  
  1320  OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
  1321    // Constant fold dim when the size along the index referred to is a constant.
  1322    auto opType = getOperand()->getType();
  1323    int64_t indexSize = -1;
  1324    if (auto tensorType = opType.dyn_cast<RankedTensorType>())
  1325      indexSize = tensorType.getShape()[getIndex()];
  1326    else if (auto memrefType = opType.dyn_cast<MemRefType>())
  1327      indexSize = memrefType.getShape()[getIndex()];
  1328  
  1329    if (indexSize >= 0)
  1330      return IntegerAttr::get(IndexType::get(getContext()), indexSize);
  1331  
  1332    return {};
  1333  }
  1334  
  1335  //===----------------------------------------------------------------------===//
  1336  // DivISOp
  1337  //===----------------------------------------------------------------------===//
  1338  
  1339  OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) {
  1340    assert(operands.size() == 2 && "binary operation takes two operands");
  1341  
  1342    auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
  1343    auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
  1344    if (!lhs || !rhs)
  1345      return {};
  1346  
  1347    // Don't fold if it requires division by zero.
  1348    if (rhs.getValue().isNullValue())
  1349      return {};
  1350  
  1351    // Don't fold if it would overflow.
  1352    bool overflow;
  1353    auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow);
  1354    return overflow ? IntegerAttr() : IntegerAttr::get(lhs.getType(), result);
  1355  }
  1356  
  1357  //===----------------------------------------------------------------------===//
  1358  // DivIUOp
  1359  //===----------------------------------------------------------------------===//
  1360  
  1361  OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) {
  1362    assert(operands.size() == 2 && "binary operation takes two operands");
  1363  
  1364    auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
  1365    auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
  1366    if (!lhs || !rhs)
  1367      return {};
  1368  
  1369    // Don't fold if it requires division by zero.
  1370    auto rhsValue = rhs.getValue();
  1371    if (rhsValue.isNullValue())
  1372      return {};
  1373  
  1374    return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhsValue));
  1375  }
  1376  
  1377  // ---------------------------------------------------------------------------
  1378  // DmaStartOp
  1379  // ---------------------------------------------------------------------------
  1380  
  1381  void DmaStartOp::build(Builder *builder, OperationState *result,
  1382                         Value *srcMemRef, ArrayRef<Value *> srcIndices,
  1383                         Value *destMemRef, ArrayRef<Value *> destIndices,
  1384                         Value *numElements, Value *tagMemRef,
  1385                         ArrayRef<Value *> tagIndices, Value *stride,
  1386                         Value *elementsPerStride) {
  1387    result->addOperands(srcMemRef);
  1388    result->addOperands(srcIndices);
  1389    result->addOperands(destMemRef);
  1390    result->addOperands(destIndices);
  1391    result->addOperands({numElements, tagMemRef});
  1392    result->addOperands(tagIndices);
  1393    if (stride)
  1394      result->addOperands({stride, elementsPerStride});
  1395  }
  1396  
  1397  void DmaStartOp::print(OpAsmPrinter *p) {
  1398    *p << "dma_start " << *getSrcMemRef() << '[';
  1399    p->printOperands(getSrcIndices());
  1400    *p << "], " << *getDstMemRef() << '[';
  1401    p->printOperands(getDstIndices());
  1402    *p << "], " << *getNumElements();
  1403    *p << ", " << *getTagMemRef() << '[';
  1404    p->printOperands(getTagIndices());
  1405    *p << ']';
  1406    if (isStrided()) {
  1407      *p << ", " << *getStride();
  1408      *p << ", " << *getNumElementsPerStride();
  1409    }
  1410    p->printOptionalAttrDict(getAttrs());
  1411    *p << " : " << getSrcMemRef()->getType();
  1412    *p << ", " << getDstMemRef()->getType();
  1413    *p << ", " << getTagMemRef()->getType();
  1414  }
  1415  
  1416  // Parse DmaStartOp.
  1417  // Ex:
  1418  //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
  1419  //                       %tag[%index], %stride, %num_elt_per_stride :
  1420  //                     : memref<3076 x f32, 0>,
  1421  //                       memref<1024 x f32, 2>,
  1422  //                       memref<1 x i32>
  1423  //
  1424  ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
  1425    OpAsmParser::OperandType srcMemRefInfo;
  1426    SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
  1427    OpAsmParser::OperandType dstMemRefInfo;
  1428    SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
  1429    OpAsmParser::OperandType numElementsInfo;
  1430    OpAsmParser::OperandType tagMemrefInfo;
  1431    SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
  1432    SmallVector<OpAsmParser::OperandType, 2> strideInfo;
  1433  
  1434    SmallVector<Type, 3> types;
  1435    auto indexType = parser->getBuilder().getIndexType();
  1436  
  1437    // Parse and resolve the following list of operands:
  1438    // *) source memref followed by its indices (in square brackets).
  1439    // *) destination memref followed by its indices (in square brackets).
  1440    // *) dma size in KiB.
  1441    if (parser->parseOperand(srcMemRefInfo) ||
  1442        parser->parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
  1443        parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
  1444        parser->parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
  1445        parser->parseComma() || parser->parseOperand(numElementsInfo) ||
  1446        parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
  1447        parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
  1448      return failure();
  1449  
  1450    // Parse optional stride and elements per stride.
  1451    if (parser->parseTrailingOperandList(strideInfo))
  1452      return failure();
  1453  
  1454    bool isStrided = strideInfo.size() == 2;
  1455    if (!strideInfo.empty() && !isStrided) {
  1456      return parser->emitError(parser->getNameLoc(),
  1457                               "expected two stride related operands");
  1458    }
  1459  
  1460    if (parser->parseColonTypeList(types))
  1461      return failure();
  1462    if (types.size() != 3)
  1463      return parser->emitError(parser->getNameLoc(), "fewer/more types expected");
  1464  
  1465    if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
  1466        parser->resolveOperands(srcIndexInfos, indexType, result->operands) ||
  1467        parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
  1468        parser->resolveOperands(dstIndexInfos, indexType, result->operands) ||
  1469        // size should be an index.
  1470        parser->resolveOperand(numElementsInfo, indexType, result->operands) ||
  1471        parser->resolveOperand(tagMemrefInfo, types[2], result->operands) ||
  1472        // tag indices should be index.
  1473        parser->resolveOperands(tagIndexInfos, indexType, result->operands))
  1474      return failure();
  1475  
  1476    auto memrefType0 = types[0].dyn_cast<MemRefType>();
  1477    if (!memrefType0)
  1478      return parser->emitError(parser->getNameLoc(),
  1479                               "expected source to be of memref type");
  1480  
  1481    auto memrefType1 = types[1].dyn_cast<MemRefType>();
  1482    if (!memrefType1)
  1483      return parser->emitError(parser->getNameLoc(),
  1484                               "expected destination to be of memref type");
  1485  
  1486    auto memrefType2 = types[2].dyn_cast<MemRefType>();
  1487    if (!memrefType2)
  1488      return parser->emitError(parser->getNameLoc(),
  1489                               "expected tag to be of memref type");
  1490  
  1491    if (isStrided) {
  1492      if (parser->resolveOperands(strideInfo, indexType, result->operands))
  1493        return failure();
  1494    }
  1495  
  1496    // Check that source/destination index list size matches associated rank.
  1497    if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() ||
  1498        static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank())
  1499      return parser->emitError(parser->getNameLoc(),
  1500                               "memref rank not equal to indices count");
  1501    if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank())
  1502      return parser->emitError(parser->getNameLoc(),
  1503                               "tag memref rank not equal to indices count");
  1504  
  1505    return success();
  1506  }
  1507  
  1508  LogicalResult DmaStartOp::verify() {
  1509    // DMAs from different memory spaces supported.
  1510    if (getSrcMemorySpace() == getDstMemorySpace())
  1511      return emitOpError("DMA should be between different memory spaces");
  1512  
  1513    if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
  1514                                getDstMemRefRank() + 3 + 1 &&
  1515        getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
  1516                                getDstMemRefRank() + 3 + 1 + 2) {
  1517      return emitOpError("incorrect number of operands");
  1518    }
  1519    return success();
  1520  }
  1521  
  1522  void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
  1523                                               MLIRContext *context) {
  1524    /// dma_start(memrefcast) -> dma_start
  1525    results.insert<MemRefCastFolder>(getOperationName(), context);
  1526  }
  1527  
  1528  // ---------------------------------------------------------------------------
  1529  // DmaWaitOp
  1530  // ---------------------------------------------------------------------------
  1531  
  1532  void DmaWaitOp::build(Builder *builder, OperationState *result,
  1533                        Value *tagMemRef, ArrayRef<Value *> tagIndices,
  1534                        Value *numElements) {
  1535    result->addOperands(tagMemRef);
  1536    result->addOperands(tagIndices);
  1537    result->addOperands(numElements);
  1538  }
  1539  
  1540  void DmaWaitOp::print(OpAsmPrinter *p) {
  1541    *p << "dma_wait ";
  1542    p->printOperand(getTagMemRef());
  1543    *p << '[';
  1544    p->printOperands(getTagIndices());
  1545    *p << "], ";
  1546    p->printOperand(getNumElements());
  1547    p->printOptionalAttrDict(getAttrs());
  1548    *p << " : " << getTagMemRef()->getType();
  1549  }
  1550  
  1551  // Parse DmaWaitOp.
  1552  // Eg:
  1553  //   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
  1554  //
  1555  ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
  1556    OpAsmParser::OperandType tagMemrefInfo;
  1557    SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
  1558    Type type;
  1559    auto indexType = parser->getBuilder().getIndexType();
  1560    OpAsmParser::OperandType numElementsInfo;
  1561  
  1562    // Parse tag memref, its indices, and dma size.
  1563    if (parser->parseOperand(tagMemrefInfo) ||
  1564        parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
  1565        parser->parseComma() || parser->parseOperand(numElementsInfo) ||
  1566        parser->parseColonType(type) ||
  1567        parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
  1568        parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
  1569        parser->resolveOperand(numElementsInfo, indexType, result->operands))
  1570      return failure();
  1571  
  1572    auto memrefType = type.dyn_cast<MemRefType>();
  1573    if (!memrefType)
  1574      return parser->emitError(parser->getNameLoc(),
  1575                               "expected tag to be of memref type");
  1576  
  1577    if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank())
  1578      return parser->emitError(parser->getNameLoc(),
  1579                               "tag memref rank not equal to indices count");
  1580  
  1581    return success();
  1582  }
  1583  
  1584  void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
  1585                                              MLIRContext *context) {
  1586    /// dma_wait(memrefcast) -> dma_wait
  1587    results.insert<MemRefCastFolder>(getOperationName(), context);
  1588  }
  1589  
  1590  //===----------------------------------------------------------------------===//
  1591  // ExtractElementOp
  1592  //===----------------------------------------------------------------------===//
  1593  
  1594  static void print(OpAsmPrinter *p, ExtractElementOp op) {
  1595    *p << "extract_element " << *op.getAggregate() << '[';
  1596    p->printOperands(op.getIndices());
  1597    *p << ']';
  1598    p->printOptionalAttrDict(op.getAttrs());
  1599    *p << " : " << op.getAggregate()->getType();
  1600  }
  1601  
  1602  static ParseResult parseExtractElementOp(OpAsmParser *parser,
  1603                                           OperationState *result) {
  1604    OpAsmParser::OperandType aggregateInfo;
  1605    SmallVector<OpAsmParser::OperandType, 4> indexInfo;
  1606    ShapedType type;
  1607  
  1608    auto affineIntTy = parser->getBuilder().getIndexType();
  1609    return failure(
  1610        parser->parseOperand(aggregateInfo) ||
  1611        parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
  1612        parser->parseOptionalAttributeDict(result->attributes) ||
  1613        parser->parseColonType(type) ||
  1614        parser->resolveOperand(aggregateInfo, type, result->operands) ||
  1615        parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
  1616        parser->addTypeToList(type.getElementType(), result->types));
  1617  }
  1618  
  1619  static LogicalResult verify(ExtractElementOp op) {
  1620    auto aggregateType = op.getAggregate()->getType().cast<ShapedType>();
  1621  
  1622    // This should be possible with tablegen type constraints
  1623    if (op.getType() != aggregateType.getElementType())
  1624      return op.emitOpError("result type must match element type of aggregate");
  1625  
  1626    // Verify the # indices match if we have a ranked type.
  1627    if (aggregateType.hasRank() &&
  1628        aggregateType.getRank() != op.getNumOperands() - 1)
  1629      return op.emitOpError("incorrect number of indices for extract_element");
  1630  
  1631    return success();
  1632  }
  1633  
  1634  OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
  1635    assert(!operands.empty() && "extract_element takes atleast one operand");
  1636  
  1637    // The aggregate operand must be a known constant.
  1638    Attribute aggregate = operands.front();
  1639    if (!aggregate)
  1640      return {};
  1641  
  1642    // If this is a splat elements attribute, simply return the value. All of the
  1643    // elements of a splat attribute are the same.
  1644    if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
  1645      return splatAggregate.getSplatValue();
  1646  
  1647    // Otherwise, collect the constant indices into the aggregate.
  1648    SmallVector<uint64_t, 8> indices;
  1649    for (Attribute indice : llvm::drop_begin(operands, 1)) {
  1650      if (!indice || !indice.isa<IntegerAttr>())
  1651        return {};
  1652      indices.push_back(indice.cast<IntegerAttr>().getInt());
  1653    }
  1654  
  1655    // If this is an elements attribute, query the value at the given indices.
  1656    auto elementsAttr = aggregate.dyn_cast<ElementsAttr>();
  1657    if (elementsAttr && elementsAttr.isValidIndex(indices))
  1658      return elementsAttr.getValue(indices);
  1659    return {};
  1660  }
  1661  
  1662  //===----------------------------------------------------------------------===//
  1663  // IndexCastOp
  1664  //===----------------------------------------------------------------------===//
  1665  
  1666  // Index cast is applicable from index to integer and backwards.
  1667  bool IndexCastOp::areCastCompatible(Type a, Type b) {
  1668    return (a.isIndex() && b.isa<IntegerType>()) ||
  1669           (a.isa<IntegerType>() && b.isIndex());
  1670  }
  1671  
  1672  //===----------------------------------------------------------------------===//
  1673  // LoadOp
  1674  //===----------------------------------------------------------------------===//
  1675  
  1676  static void print(OpAsmPrinter *p, LoadOp op) {
  1677    *p << "load " << *op.getMemRef() << '[';
  1678    p->printOperands(op.getIndices());
  1679    *p << ']';
  1680    p->printOptionalAttrDict(op.getAttrs());
  1681    *p << " : " << op.getMemRefType();
  1682  }
  1683  
  1684  static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
  1685    OpAsmParser::OperandType memrefInfo;
  1686    SmallVector<OpAsmParser::OperandType, 4> indexInfo;
  1687    MemRefType type;
  1688  
  1689    auto affineIntTy = parser->getBuilder().getIndexType();
  1690    return failure(
  1691        parser->parseOperand(memrefInfo) ||
  1692        parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
  1693        parser->parseOptionalAttributeDict(result->attributes) ||
  1694        parser->parseColonType(type) ||
  1695        parser->resolveOperand(memrefInfo, type, result->operands) ||
  1696        parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
  1697        parser->addTypeToList(type.getElementType(), result->types));
  1698  }
  1699  
  1700  static LogicalResult verify(LoadOp op) {
  1701    if (op.getType() != op.getMemRefType().getElementType())
  1702      return op.emitOpError("result type must match element type of memref");
  1703  
  1704    if (op.getMemRefType().getRank() != op.getNumOperands() - 1)
  1705      return op.emitOpError("incorrect number of indices for load");
  1706  
  1707    for (auto *idx : op.getIndices())
  1708      if (!idx->getType().isIndex())
  1709        return op.emitOpError("index to load must have 'index' type");
  1710  
  1711    // TODO: Verify we have the right number of indices.
  1712  
  1713    // TODO: in Function verify that the indices are parameters, IV's, or the
  1714    // result of an affine.apply.
  1715    return success();
  1716  }
  1717  
  1718  void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
  1719                                           MLIRContext *context) {
  1720    /// load(memrefcast) -> load
  1721    results.insert<MemRefCastFolder>(getOperationName(), context);
  1722  }
  1723  
  1724  //===----------------------------------------------------------------------===//
  1725  // MemRefCastOp
  1726  //===----------------------------------------------------------------------===//
  1727  
  1728  bool MemRefCastOp::areCastCompatible(Type a, Type b) {
  1729    auto aT = a.dyn_cast<MemRefType>();
  1730    auto bT = b.dyn_cast<MemRefType>();
  1731  
  1732    if (!aT || !bT)
  1733      return false;
  1734    if (aT.getElementType() != bT.getElementType())
  1735      return false;
  1736    if (aT.getAffineMaps() != bT.getAffineMaps())
  1737      return false;
  1738    if (aT.getMemorySpace() != bT.getMemorySpace())
  1739      return false;
  1740  
  1741    // They must have the same rank, and any specified dimensions must match.
  1742    if (aT.getRank() != bT.getRank())
  1743      return false;
  1744  
  1745    for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
  1746      int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
  1747      if (aDim != -1 && bDim != -1 && aDim != bDim)
  1748        return false;
  1749    }
  1750  
  1751    return true;
  1752  }
  1753  
  1754  OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
  1755    return impl::foldCastOp(*this);
  1756  }
  1757  
  1758  //===----------------------------------------------------------------------===//
  1759  // MulFOp
  1760  //===----------------------------------------------------------------------===//
  1761  
  1762  OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) {
  1763    return constFoldBinaryOp<FloatAttr>(
  1764        operands, [](APFloat a, APFloat b) { return a * b; });
  1765  }
  1766  
  1767  //===----------------------------------------------------------------------===//
  1768  // MulIOp
  1769  //===----------------------------------------------------------------------===//
  1770  
  1771  OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
  1772    /// muli(x, 0) -> 0
  1773    if (matchPattern(rhs(), m_Zero()))
  1774      return rhs();
  1775    /// muli(x, 1) -> x
  1776    if (matchPattern(rhs(), m_One()))
  1777      return getOperand(0);
  1778  
  1779    // TODO: Handle the overflow case.
  1780    return constFoldBinaryOp<IntegerAttr>(operands,
  1781                                          [](APInt a, APInt b) { return a * b; });
  1782  }
  1783  
  1784  //===----------------------------------------------------------------------===//
  1785  // RankOp
  1786  //===----------------------------------------------------------------------===//
  1787  
  1788  static void print(OpAsmPrinter *p, RankOp op) {
  1789    *p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
  1790  }
  1791  
  1792  static ParseResult parseRankOp(OpAsmParser *parser, OperationState *result) {
  1793    OpAsmParser::OperandType operandInfo;
  1794    Type type;
  1795    Type indexType = parser->getBuilder().getIndexType();
  1796    return failure(parser->parseOperand(operandInfo) ||
  1797                   parser->parseColonType(type) ||
  1798                   parser->resolveOperand(operandInfo, type, result->operands) ||
  1799                   parser->addTypeToList(indexType, result->types));
  1800  }
  1801  
  1802  OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
  1803    // Constant fold rank when the rank of the tensor is known.
  1804    auto type = getOperand()->getType();
  1805    if (auto tensorType = type.dyn_cast<RankedTensorType>())
  1806      return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
  1807    return IntegerAttr();
  1808  }
  1809  
  1810  //===----------------------------------------------------------------------===//
  1811  // RemISOp
  1812  //===----------------------------------------------------------------------===//
  1813  
  1814  OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) {
  1815    assert(operands.size() == 2 && "remis takes two operands");
  1816  
  1817    auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
  1818    if (!rhs)
  1819      return {};
  1820    auto rhsValue = rhs.getValue();
  1821  
  1822    // x % 1 = 0
  1823    if (rhsValue.isOneValue())
  1824      return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
  1825  
  1826    // Don't fold if it requires division by zero.
  1827    if (rhsValue.isNullValue())
  1828      return {};
  1829  
  1830    auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
  1831    if (!lhs)
  1832      return {};
  1833    return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
  1834  }
  1835  
  1836  //===----------------------------------------------------------------------===//
  1837  // RemIUOp
  1838  //===----------------------------------------------------------------------===//
  1839  
  1840  OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
  1841    assert(operands.size() == 2 && "remiu takes two operands");
  1842  
  1843    auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
  1844    if (!rhs)
  1845      return {};
  1846    auto rhsValue = rhs.getValue();
  1847  
  1848    // x % 1 = 0
  1849    if (rhsValue.isOneValue())
  1850      return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
  1851  
  1852    // Don't fold if it requires division by zero.
  1853    if (rhsValue.isNullValue())
  1854      return {};
  1855  
  1856    auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
  1857    if (!lhs)
  1858      return {};
  1859    return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
  1860  }
  1861  
  1862  //===----------------------------------------------------------------------===//
  1863  // ReturnOp
  1864  //===----------------------------------------------------------------------===//
  1865  
  1866  static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
  1867    SmallVector<OpAsmParser::OperandType, 2> opInfo;
  1868    SmallVector<Type, 2> types;
  1869    llvm::SMLoc loc = parser->getCurrentLocation();
  1870    return failure(parser->parseOperandList(opInfo) ||
  1871                   (!opInfo.empty() && parser->parseColonTypeList(types)) ||
  1872                   parser->resolveOperands(opInfo, types, loc, result->operands));
  1873  }
  1874  
  1875  static void print(OpAsmPrinter *p, ReturnOp op) {
  1876    *p << "return";
  1877    if (op.getNumOperands() != 0) {
  1878      *p << ' ';
  1879      p->printOperands(op.getOperands());
  1880      *p << " : ";
  1881      interleaveComma(op.getOperandTypes(), *p);
  1882    }
  1883  }
  1884  
  1885  static LogicalResult verify(ReturnOp op) {
  1886    auto function = cast<FuncOp>(op.getParentOp());
  1887  
  1888    // The operand number and types must match the function signature.
  1889    const auto &results = function.getType().getResults();
  1890    if (op.getNumOperands() != results.size())
  1891      return op.emitOpError("has ")
  1892             << op.getNumOperands()
  1893             << " operands, but enclosing function returns " << results.size();
  1894  
  1895    for (unsigned i = 0, e = results.size(); i != e; ++i)
  1896      if (op.getOperand(i)->getType() != results[i])
  1897        return op.emitError()
  1898               << "type of return operand " << i << " ("
  1899               << op.getOperand(i)->getType()
  1900               << ") doesn't match function result type (" << results[i] << ")";
  1901  
  1902    return success();
  1903  }
  1904  
  1905  //===----------------------------------------------------------------------===//
  1906  // SIToFPOp
  1907  //===----------------------------------------------------------------------===//
  1908  
  1909  // sitofp is applicable from integer types to float types.
  1910  bool SIToFPOp::areCastCompatible(Type a, Type b) {
  1911    return a.isa<IntegerType>() && b.isa<FloatType>();
  1912  }
  1913  
  1914  //===----------------------------------------------------------------------===//
  1915  // SelectOp
  1916  //===----------------------------------------------------------------------===//
  1917  
  1918  static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
  1919    SmallVector<OpAsmParser::OperandType, 3> ops;
  1920    SmallVector<NamedAttribute, 4> attrs;
  1921    Type type;
  1922    if (parser->parseOperandList(ops, 3) ||
  1923        parser->parseOptionalAttributeDict(result->attributes) ||
  1924        parser->parseColonType(type))
  1925      return failure();
  1926  
  1927    auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type);
  1928    if (!i1Type)
  1929      return parser->emitError(parser->getNameLoc(),
  1930                               "expected type with valid i1 shape");
  1931  
  1932    SmallVector<Type, 3> types = {i1Type, type, type};
  1933    return failure(parser->resolveOperands(ops, types, parser->getNameLoc(),
  1934                                           result->operands) ||
  1935                   parser->addTypeToList(type, result->types));
  1936  }
  1937  
  1938  static void print(OpAsmPrinter *p, SelectOp op) {
  1939    *p << "select ";
  1940    p->printOperands(op.getOperands());
  1941    *p << " : " << op.getTrueValue()->getType();
  1942    p->printOptionalAttrDict(op.getAttrs());
  1943  }
  1944  
  1945  static LogicalResult verify(SelectOp op) {
  1946    auto trueType = op.getTrueValue()->getType();
  1947    auto falseType = op.getFalseValue()->getType();
  1948  
  1949    if (trueType != falseType)
  1950      return op.emitOpError(
  1951          "requires 'true' and 'false' arguments to be of the same type");
  1952  
  1953    return success();
  1954  }
  1955  
  1956  OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
  1957    auto *condition = getCondition();
  1958  
  1959    // select true, %0, %1 => %0
  1960    if (matchPattern(condition, m_One()))
  1961      return getTrueValue();
  1962  
  1963    // select false, %0, %1 => %1
  1964    if (matchPattern(condition, m_Zero()))
  1965      return getFalseValue();
  1966    return nullptr;
  1967  }
  1968  
  1969  //===----------------------------------------------------------------------===//
  1970  // StoreOp
  1971  //===----------------------------------------------------------------------===//
  1972  
  1973  static void print(OpAsmPrinter *p, StoreOp op) {
  1974    *p << "store " << *op.getValueToStore();
  1975    *p << ", " << *op.getMemRef() << '[';
  1976    p->printOperands(op.getIndices());
  1977    *p << ']';
  1978    p->printOptionalAttrDict(op.getAttrs());
  1979    *p << " : " << op.getMemRefType();
  1980  }
  1981  
  1982  static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
  1983    OpAsmParser::OperandType storeValueInfo;
  1984    OpAsmParser::OperandType memrefInfo;
  1985    SmallVector<OpAsmParser::OperandType, 4> indexInfo;
  1986    MemRefType memrefType;
  1987  
  1988    auto affineIntTy = parser->getBuilder().getIndexType();
  1989    return failure(
  1990        parser->parseOperand(storeValueInfo) || parser->parseComma() ||
  1991        parser->parseOperand(memrefInfo) ||
  1992        parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
  1993        parser->parseOptionalAttributeDict(result->attributes) ||
  1994        parser->parseColonType(memrefType) ||
  1995        parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
  1996                               result->operands) ||
  1997        parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
  1998        parser->resolveOperands(indexInfo, affineIntTy, result->operands));
  1999  }
  2000  
  2001  static LogicalResult verify(StoreOp op) {
  2002    // First operand must have same type as memref element type.
  2003    if (op.getValueToStore()->getType() != op.getMemRefType().getElementType())
  2004      return op.emitOpError(
  2005          "first operand must have same type memref element type");
  2006  
  2007    if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
  2008      return op.emitOpError("store index operand count not equal to memref rank");
  2009  
  2010    for (auto *idx : op.getIndices())
  2011      if (!idx->getType().isIndex())
  2012        return op.emitOpError("index to load must have 'index' type");
  2013  
  2014    // TODO: Verify we have the right number of indices.
  2015  
  2016    // TODO: in Function verify that the indices are parameters, IV's, or the
  2017    // result of an affine.apply.
  2018    return success();
  2019  }
  2020  
  2021  void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
  2022                                            MLIRContext *context) {
  2023    /// store(memrefcast) -> store
  2024    results.insert<MemRefCastFolder>(getOperationName(), context);
  2025  }
  2026  
  2027  //===----------------------------------------------------------------------===//
  2028  // SubFOp
  2029  //===----------------------------------------------------------------------===//
  2030  
  2031  OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) {
  2032    return constFoldBinaryOp<FloatAttr>(
  2033        operands, [](APFloat a, APFloat b) { return a - b; });
  2034  }
  2035  
  2036  //===----------------------------------------------------------------------===//
  2037  // SubIOp
  2038  //===----------------------------------------------------------------------===//
  2039  
  2040  OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
  2041    // subi(x,x) -> 0
  2042    if (getOperand(0) == getOperand(1))
  2043      return Builder(getContext()).getZeroAttr(getType());
  2044  
  2045    return constFoldBinaryOp<IntegerAttr>(operands,
  2046                                          [](APInt a, APInt b) { return a - b; });
  2047  }
  2048  
  2049  //===----------------------------------------------------------------------===//
  2050  // AndOp
  2051  //===----------------------------------------------------------------------===//
  2052  
  2053  OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
  2054    /// and(x, 0) -> 0
  2055    if (matchPattern(rhs(), m_Zero()))
  2056      return rhs();
  2057    /// and(x,x) -> x
  2058    if (lhs() == rhs())
  2059      return rhs();
  2060  
  2061    return constFoldBinaryOp<IntegerAttr>(operands,
  2062                                          [](APInt a, APInt b) { return a & b; });
  2063  }
  2064  
  2065  //===----------------------------------------------------------------------===//
  2066  // OrOp
  2067  //===----------------------------------------------------------------------===//
  2068  
  2069  OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
  2070    /// or(x, 0) -> x
  2071    if (matchPattern(rhs(), m_Zero()))
  2072      return lhs();
  2073    /// or(x,x) -> x
  2074    if (lhs() == rhs())
  2075      return rhs();
  2076  
  2077    return constFoldBinaryOp<IntegerAttr>(operands,
  2078                                          [](APInt a, APInt b) { return a | b; });
  2079  }
  2080  
  2081  //===----------------------------------------------------------------------===//
  2082  // XOrOp
  2083  //===----------------------------------------------------------------------===//
  2084  
  2085  OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
  2086    /// xor(x, 0) -> x
  2087    if (matchPattern(rhs(), m_Zero()))
  2088      return lhs();
  2089    /// xor(x,x) -> 0
  2090    if (lhs() == rhs())
  2091      return Builder(getContext()).getZeroAttr(getType());
  2092  
  2093    return constFoldBinaryOp<IntegerAttr>(operands,
  2094                                          [](APInt a, APInt b) { return a ^ b; });
  2095  }
  2096  
  2097  //===----------------------------------------------------------------------===//
  2098  // TensorCastOp
  2099  //===----------------------------------------------------------------------===//
  2100  
  2101  bool TensorCastOp::areCastCompatible(Type a, Type b) {
  2102    auto aT = a.dyn_cast<TensorType>();
  2103    auto bT = b.dyn_cast<TensorType>();
  2104    if (!aT || !bT)
  2105      return false;
  2106  
  2107    if (aT.getElementType() != bT.getElementType())
  2108      return false;
  2109  
  2110    // If the either are unranked, then the cast is valid.
  2111    auto aRType = aT.dyn_cast<RankedTensorType>();
  2112    auto bRType = bT.dyn_cast<RankedTensorType>();
  2113    if (!aRType || !bRType)
  2114      return true;
  2115  
  2116    // If they are both ranked, they have to have the same rank, and any specified
  2117    // dimensions must match.
  2118    if (aRType.getRank() != bRType.getRank())
  2119      return false;
  2120  
  2121    for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) {
  2122      int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i);
  2123      if (aDim != -1 && bDim != -1 && aDim != bDim)
  2124        return false;
  2125    }
  2126  
  2127    return true;
  2128  }
  2129  
  2130  OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
  2131    return impl::foldCastOp(*this);
  2132  }
  2133  
  2134  //===----------------------------------------------------------------------===//
  2135  // Helpers for Tensor[Load|Store]Op
  2136  //===----------------------------------------------------------------------===//
  2137  
  2138  static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
  2139    if (auto memref = type.dyn_cast<MemRefType>())
  2140      return b.getTensorType(memref.getShape(), memref.getElementType());
  2141    return b.getNoneType();
  2142  }
  2143  
  2144  //===----------------------------------------------------------------------===//
  2145  // TensorLoadOp
  2146  //===----------------------------------------------------------------------===//
  2147  
  2148  static void print(OpAsmPrinter *p, TensorLoadOp op) {
  2149    *p << "tensor_load " << *op.getOperand();
  2150    p->printOptionalAttrDict(op.getAttrs());
  2151    *p << " : " << op.getOperand()->getType();
  2152  }
  2153  
  2154  static ParseResult parseTensorLoadOp(OpAsmParser *parser,
  2155                                       OperationState *result) {
  2156    OpAsmParser::OperandType op;
  2157    Type type;
  2158    return failure(parser->parseOperand(op) ||
  2159                   parser->parseOptionalAttributeDict(result->attributes) ||
  2160                   parser->parseColonType(type) ||
  2161                   parser->resolveOperand(op, type, result->operands) ||
  2162                   parser->addTypeToList(
  2163                       getTensorTypeFromMemRefType(parser->getBuilder(), type),
  2164                       result->types));
  2165  }
  2166  
  2167  //===----------------------------------------------------------------------===//
  2168  // TensorStoreOp
  2169  //===----------------------------------------------------------------------===//
  2170  
  2171  static void print(OpAsmPrinter *p, TensorStoreOp op) {
  2172    *p << "tensor_store " << *op.tensor() << ", " << *op.memref();
  2173    p->printOptionalAttrDict(op.getAttrs());
  2174    *p << " : " << op.memref()->getType();
  2175  }
  2176  
  2177  static ParseResult parseTensorStoreOp(OpAsmParser *parser,
  2178                                        OperationState *result) {
  2179    SmallVector<OpAsmParser::OperandType, 2> ops;
  2180    Type type;
  2181    llvm::SMLoc loc = parser->getCurrentLocation();
  2182    return failure(
  2183        parser->parseOperandList(ops, /*requiredOperandCount=*/2) ||
  2184        parser->parseOptionalAttributeDict(result->attributes) ||
  2185        parser->parseColonType(type) ||
  2186        parser->resolveOperands(
  2187            ops, {getTensorTypeFromMemRefType(parser->getBuilder(), type), type},
  2188            loc, result->operands));
  2189  }
  2190  
  2191  //===----------------------------------------------------------------------===//
  2192  // TableGen'd op method definitions
  2193  //===----------------------------------------------------------------------===//
  2194  
  2195  #define GET_OP_CLASSES
  2196  #include "mlir/Dialect/StandardOps/Ops.cpp.inc"