github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/StandardOps/Ops.cpp (about) 1 //===- Ops.cpp - Standard MLIR Operations ---------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/StandardOps/Ops.h" 19 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/Builders.h" 23 #include "mlir/IR/Function.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/PatternMatch.h" 28 #include "mlir/IR/StandardTypes.h" 29 #include "mlir/IR/Value.h" 30 #include "mlir/Support/MathExtras.h" 31 #include "mlir/Support/STLExtras.h" 32 #include "llvm/ADT/StringSwitch.h" 33 #include "llvm/Support/FormatVariadic.h" 34 #include "llvm/Support/raw_ostream.h" 35 using namespace mlir; 36 37 //===----------------------------------------------------------------------===// 38 // StandardOpsDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 namespace { 41 struct StdOpAsmInterface : public OpAsmDialectInterface { 42 using OpAsmDialectInterface::OpAsmDialectInterface; 43 44 /// Get a special name to use when printing the given operation. The desired 45 /// name should be streamed into 'os'. 46 void getOpResultName(Operation *op, raw_ostream &os) const final { 47 if (ConstantOp constant = dyn_cast<ConstantOp>(op)) 48 return getConstantOpResultName(constant, os); 49 } 50 51 /// Get a special name to use when printing the given constant. 52 static void getConstantOpResultName(ConstantOp op, raw_ostream &os) { 53 Type type = op.getType(); 54 Attribute value = op.getValue(); 55 if (auto intCst = value.dyn_cast<IntegerAttr>()) { 56 if (type.isIndex()) { 57 os << 'c' << intCst.getInt(); 58 } else if (type.cast<IntegerType>().isInteger(1)) { 59 // i1 constants get special names. 60 os << (intCst.getInt() ? "true" : "false"); 61 } else { 62 os << 'c' << intCst.getInt() << '_' << type; 63 } 64 } else if (type.isa<FunctionType>()) { 65 os << 'f'; 66 } else { 67 os << "cst"; 68 } 69 } 70 }; 71 } // end anonymous namespace 72 73 //===----------------------------------------------------------------------===// 74 // StandardOpsDialect 75 //===----------------------------------------------------------------------===// 76 77 /// A custom binary operation printer that omits the "std." prefix from the 78 /// operation names. 79 static void printStandardBinaryOp(Operation *op, OpAsmPrinter *p) { 80 assert(op->getNumOperands() == 2 && "binary op should have two operands"); 81 assert(op->getNumResults() == 1 && "binary op should have one result"); 82 83 // If not all the operand and result types are the same, just use the 84 // generic assembly form to avoid omitting information in printing. 85 auto resultType = op->getResult(0)->getType(); 86 if (op->getOperand(0)->getType() != resultType || 87 op->getOperand(1)->getType() != resultType) { 88 p->printGenericOp(op); 89 return; 90 } 91 92 *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' 93 << *op->getOperand(0) << ", " << *op->getOperand(1); 94 p->printOptionalAttrDict(op->getAttrs()); 95 96 // Now we can output only one type for all operands and the result. 97 *p << " : " << op->getResult(0)->getType(); 98 } 99 100 /// A custom cast operation printer that omits the "std." prefix from the 101 /// operation names. 102 static void printStandardCastOp(Operation *op, OpAsmPrinter *p) { 103 *p << op->getName().getStringRef().drop_front(strlen("std.")) << ' ' 104 << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to " 105 << op->getResult(0)->getType(); 106 } 107 108 /// A custom cast operation verifier. 109 template <typename T> static LogicalResult verifyCastOp(T op) { 110 auto opType = op.getOperand()->getType(); 111 auto resType = op.getType(); 112 if (!T::areCastCompatible(opType, resType)) 113 return op.emitError("operand type ") << opType << " and result type " 114 << resType << " are cast incompatible"; 115 116 return success(); 117 } 118 119 StandardOpsDialect::StandardOpsDialect(MLIRContext *context) 120 : Dialect(getDialectNamespace(), context) { 121 addOperations<DmaStartOp, DmaWaitOp, 122 #define GET_OP_LIST 123 #include "mlir/Dialect/StandardOps/Ops.cpp.inc" 124 >(); 125 addInterfaces<StdOpAsmInterface>(); 126 } 127 128 void mlir::printDimAndSymbolList(Operation::operand_iterator begin, 129 Operation::operand_iterator end, 130 unsigned numDims, OpAsmPrinter *p) { 131 *p << '('; 132 p->printOperands(begin, begin + numDims); 133 *p << ')'; 134 135 if (begin + numDims != end) { 136 *p << '['; 137 p->printOperands(begin + numDims, end); 138 *p << ']'; 139 } 140 } 141 142 // Parses dimension and symbol list, and sets 'numDims' to the number of 143 // dimension operands parsed. 144 // Returns 'false' on success and 'true' on error. 145 ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser, 146 SmallVector<Value *, 4> &operands, 147 unsigned &numDims) { 148 SmallVector<OpAsmParser::OperandType, 8> opInfos; 149 if (parser->parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) 150 return failure(); 151 // Store number of dimensions for validation by caller. 152 numDims = opInfos.size(); 153 154 // Parse the optional symbol operands. 155 auto affineIntTy = parser->getBuilder().getIndexType(); 156 if (parser->parseOperandList(opInfos, 157 OpAsmParser::Delimiter::OptionalSquare) || 158 parser->resolveOperands(opInfos, affineIntTy, operands)) 159 return failure(); 160 return success(); 161 } 162 163 /// Matches a ConstantIndexOp. 164 /// TODO: This should probably just be a general matcher that uses m_Constant 165 /// and checks the operation for an index type. 166 static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() { 167 return detail::op_matcher<ConstantIndexOp>(); 168 } 169 170 //===----------------------------------------------------------------------===// 171 // Common canonicalization pattern support logic 172 //===----------------------------------------------------------------------===// 173 174 namespace { 175 /// This is a common class used for patterns of the form 176 /// "someop(memrefcast) -> someop". It folds the source of any memref_cast 177 /// into the root operation directly. 178 struct MemRefCastFolder : public RewritePattern { 179 /// The rootOpName is the name of the root operation to match against. 180 MemRefCastFolder(StringRef rootOpName, MLIRContext *context) 181 : RewritePattern(rootOpName, 1, context) {} 182 183 PatternMatchResult match(Operation *op) const override { 184 for (auto *operand : op->getOperands()) 185 if (matchPattern(operand, m_Op<MemRefCastOp>())) 186 return matchSuccess(); 187 188 return matchFailure(); 189 } 190 191 void rewrite(Operation *op, PatternRewriter &rewriter) const override { 192 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 193 if (auto *memref = op->getOperand(i)->getDefiningOp()) 194 if (auto cast = dyn_cast<MemRefCastOp>(memref)) 195 op->setOperand(i, cast.getOperand()); 196 rewriter.updatedRootInPlace(op); 197 } 198 }; 199 200 /// Performs const folding `calculate` with element-wise behavior on the two 201 /// attributes in `operands` and returns the result if possible. 202 template <class AttrElementT, 203 class ElementValueT = typename AttrElementT::ValueType, 204 class CalculationT = 205 std::function<ElementValueT(ElementValueT, ElementValueT)>> 206 Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, 207 const CalculationT &calculate) { 208 assert(operands.size() == 2 && "binary op takes two operands"); 209 210 if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) { 211 auto rhs = operands[1].dyn_cast_or_null<AttrElementT>(); 212 if (!rhs || lhs.getType() != rhs.getType()) 213 return {}; 214 215 return AttrElementT::get(lhs.getType(), 216 calculate(lhs.getValue(), rhs.getValue())); 217 } else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) { 218 auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>(); 219 if (!rhs || lhs.getType() != rhs.getType()) 220 return {}; 221 222 auto elementResult = constFoldBinaryOp<AttrElementT>( 223 {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); 224 if (!elementResult) 225 return {}; 226 227 return DenseElementsAttr::get(lhs.getType(), elementResult); 228 } 229 return {}; 230 } 231 } // end anonymous namespace. 232 233 //===----------------------------------------------------------------------===// 234 // AddFOp 235 //===----------------------------------------------------------------------===// 236 237 OpFoldResult AddFOp::fold(ArrayRef<Attribute> operands) { 238 return constFoldBinaryOp<FloatAttr>( 239 operands, [](APFloat a, APFloat b) { return a + b; }); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // AddIOp 244 //===----------------------------------------------------------------------===// 245 246 OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) { 247 /// addi(x, 0) -> x 248 if (matchPattern(rhs(), m_Zero())) 249 return lhs(); 250 251 return constFoldBinaryOp<IntegerAttr>(operands, 252 [](APInt a, APInt b) { return a + b; }); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // AllocOp 257 //===----------------------------------------------------------------------===// 258 259 static void print(OpAsmPrinter *p, AllocOp op) { 260 *p << "alloc"; 261 262 // Print dynamic dimension operands. 263 MemRefType type = op.getType(); 264 printDimAndSymbolList(op.operand_begin(), op.operand_end(), 265 type.getNumDynamicDims(), p); 266 p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); 267 *p << " : " << type; 268 } 269 270 static ParseResult parseAllocOp(OpAsmParser *parser, OperationState *result) { 271 MemRefType type; 272 273 // Parse the dimension operands and optional symbol operands, followed by a 274 // memref type. 275 unsigned numDimOperands; 276 if (parseDimAndSymbolList(parser, result->operands, numDimOperands) || 277 parser->parseOptionalAttributeDict(result->attributes) || 278 parser->parseColonType(type)) 279 return failure(); 280 281 // Check numDynamicDims against number of question marks in memref type. 282 // Note: this check remains here (instead of in verify()), because the 283 // partition between dim operands and symbol operands is lost after parsing. 284 // Verification still checks that the total number of operands matches 285 // the number of symbols in the affine map, plus the number of dynamic 286 // dimensions in the memref. 287 if (numDimOperands != type.getNumDynamicDims()) 288 return parser->emitError(parser->getNameLoc()) 289 << "dimension operand count does not equal memref dynamic dimension " 290 "count"; 291 result->types.push_back(type); 292 return success(); 293 } 294 295 static LogicalResult verify(AllocOp op) { 296 auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>(); 297 if (!memRefType) 298 return op.emitOpError("result must be a memref"); 299 300 unsigned numSymbols = 0; 301 if (!memRefType.getAffineMaps().empty()) { 302 // Store number of symbols used in affine map (used in subsequent check). 303 AffineMap affineMap = memRefType.getAffineMaps()[0]; 304 numSymbols = affineMap.getNumSymbols(); 305 } 306 307 // Check that the total number of operands matches the number of symbols in 308 // the affine map, plus the number of dynamic dimensions specified in the 309 // memref type. 310 unsigned numDynamicDims = memRefType.getNumDynamicDims(); 311 if (op.getOperation()->getNumOperands() != numDynamicDims + numSymbols) 312 return op.emitOpError( 313 "operand count does not equal dimension plus symbol operand count"); 314 315 // Verify that all operands are of type Index. 316 for (auto operandType : op.getOperandTypes()) 317 if (!operandType.isIndex()) 318 return op.emitOpError("requires operands to be of type Index"); 319 return success(); 320 } 321 322 namespace { 323 /// Fold constant dimensions into an alloc operation. 324 struct SimplifyAllocConst : public OpRewritePattern<AllocOp> { 325 using OpRewritePattern<AllocOp>::OpRewritePattern; 326 327 PatternMatchResult matchAndRewrite(AllocOp alloc, 328 PatternRewriter &rewriter) const override { 329 // Check to see if any dimensions operands are constants. If so, we can 330 // substitute and drop them. 331 if (llvm::none_of(alloc.getOperands(), [](Value *operand) { 332 return matchPattern(operand, m_ConstantIndex()); 333 })) 334 return matchFailure(); 335 336 auto memrefType = alloc.getType(); 337 338 // Ok, we have one or more constant operands. Collect the non-constant ones 339 // and keep track of the resultant memref type to build. 340 SmallVector<int64_t, 4> newShapeConstants; 341 newShapeConstants.reserve(memrefType.getRank()); 342 SmallVector<Value *, 4> newOperands; 343 SmallVector<Value *, 4> droppedOperands; 344 345 unsigned dynamicDimPos = 0; 346 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { 347 int64_t dimSize = memrefType.getDimSize(dim); 348 // If this is already static dimension, keep it. 349 if (dimSize != -1) { 350 newShapeConstants.push_back(dimSize); 351 continue; 352 } 353 auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); 354 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) { 355 // Dynamic shape dimension will be folded. 356 newShapeConstants.push_back(constantIndexOp.getValue()); 357 // Record to check for zero uses later below. 358 droppedOperands.push_back(constantIndexOp); 359 } else { 360 // Dynamic shape dimension not folded; copy operand from old memref. 361 newShapeConstants.push_back(-1); 362 newOperands.push_back(alloc.getOperand(dynamicDimPos)); 363 } 364 dynamicDimPos++; 365 } 366 367 // Create new memref type (which will have fewer dynamic dimensions). 368 auto newMemRefType = MemRefType::get( 369 newShapeConstants, memrefType.getElementType(), 370 memrefType.getAffineMaps(), memrefType.getMemorySpace()); 371 assert(static_cast<int64_t>(newOperands.size()) == 372 newMemRefType.getNumDynamicDims()); 373 374 // Create and insert the alloc op for the new memref. 375 auto newAlloc = 376 rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands); 377 // Insert a cast so we have the same type as the old alloc. 378 auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc, 379 alloc.getType()); 380 381 rewriter.replaceOp(alloc, {resultCast}, droppedOperands); 382 return matchSuccess(); 383 } 384 }; 385 386 /// Fold alloc operations with no uses. Alloc has side effects on the heap, 387 /// but can still be deleted if it has zero uses. 388 struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> { 389 using OpRewritePattern<AllocOp>::OpRewritePattern; 390 391 PatternMatchResult matchAndRewrite(AllocOp alloc, 392 PatternRewriter &rewriter) const override { 393 // Check if the alloc'ed value has any uses. 394 if (!alloc.use_empty()) 395 return matchFailure(); 396 397 // If it doesn't, we can eliminate it. 398 alloc.erase(); 399 return matchSuccess(); 400 } 401 }; 402 } // end anonymous namespace. 403 404 void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 405 MLIRContext *context) { 406 results.insert<SimplifyAllocConst, SimplifyDeadAlloc>(context); 407 } 408 409 //===----------------------------------------------------------------------===// 410 // BranchOp 411 //===----------------------------------------------------------------------===// 412 413 static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *result) { 414 Block *dest; 415 SmallVector<Value *, 4> destOperands; 416 if (parser->parseSuccessorAndUseList(dest, destOperands)) 417 return failure(); 418 result->addSuccessor(dest, destOperands); 419 return success(); 420 } 421 422 static void print(OpAsmPrinter *p, BranchOp op) { 423 *p << "br "; 424 p->printSuccessorAndUseList(op.getOperation(), 0); 425 } 426 427 Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } 428 429 void BranchOp::setDest(Block *block) { 430 return getOperation()->setSuccessor(block, 0); 431 } 432 433 void BranchOp::eraseOperand(unsigned index) { 434 getOperation()->eraseSuccessorOperand(0, index); 435 } 436 437 //===----------------------------------------------------------------------===// 438 // CallOp 439 //===----------------------------------------------------------------------===// 440 441 static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { 442 SymbolRefAttr calleeAttr; 443 FunctionType calleeType; 444 SmallVector<OpAsmParser::OperandType, 4> operands; 445 auto calleeLoc = parser->getNameLoc(); 446 if (parser->parseAttribute(calleeAttr, "callee", result->attributes) || 447 parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 448 parser->parseOptionalAttributeDict(result->attributes) || 449 parser->parseColonType(calleeType) || 450 parser->addTypesToList(calleeType.getResults(), result->types) || 451 parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc, 452 result->operands)) 453 return failure(); 454 455 return success(); 456 } 457 458 static void print(OpAsmPrinter *p, CallOp op) { 459 *p << "call " << op.getAttr("callee") << '('; 460 p->printOperands(op.getOperands()); 461 *p << ')'; 462 p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); 463 *p << " : "; 464 p->printType(op.getCalleeType()); 465 } 466 467 static LogicalResult verify(CallOp op) { 468 // Check that the callee attribute was specified. 469 auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee"); 470 if (!fnAttr) 471 return op.emitOpError("requires a 'callee' symbol reference attribute"); 472 auto fn = 473 op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); 474 if (!fn) 475 return op.emitOpError() << "'" << fnAttr.getValue() 476 << "' does not reference a valid function"; 477 478 // Verify that the operand and result types match the callee. 479 auto fnType = fn.getType(); 480 if (fnType.getNumInputs() != op.getNumOperands()) 481 return op.emitOpError("incorrect number of operands for callee"); 482 483 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) 484 if (op.getOperand(i)->getType() != fnType.getInput(i)) 485 return op.emitOpError("operand type mismatch"); 486 487 if (fnType.getNumResults() != op.getNumResults()) 488 return op.emitOpError("incorrect number of results for callee"); 489 490 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) 491 if (op.getResult(i)->getType() != fnType.getResult(i)) 492 return op.emitOpError("result type mismatch"); 493 494 return success(); 495 } 496 497 FunctionType CallOp::getCalleeType() { 498 SmallVector<Type, 4> resultTypes(getResultTypes()); 499 SmallVector<Type, 8> argTypes(getOperandTypes()); 500 return FunctionType::get(argTypes, resultTypes, getContext()); 501 } 502 503 //===----------------------------------------------------------------------===// 504 // CallIndirectOp 505 //===----------------------------------------------------------------------===// 506 namespace { 507 /// Fold indirect calls that have a constant function as the callee operand. 508 struct SimplifyIndirectCallWithKnownCallee 509 : public OpRewritePattern<CallIndirectOp> { 510 using OpRewritePattern<CallIndirectOp>::OpRewritePattern; 511 512 PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, 513 PatternRewriter &rewriter) const override { 514 // Check that the callee is a constant callee. 515 SymbolRefAttr calledFn; 516 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) 517 return matchFailure(); 518 519 // Replace with a direct call. 520 SmallVector<Type, 8> callResults(indirectCall.getResultTypes()); 521 SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands()); 522 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(), 523 callResults, callOperands); 524 return matchSuccess(); 525 } 526 }; 527 } // end anonymous namespace. 528 529 static ParseResult parseCallIndirectOp(OpAsmParser *parser, 530 OperationState *result) { 531 FunctionType calleeType; 532 OpAsmParser::OperandType callee; 533 llvm::SMLoc operandsLoc; 534 SmallVector<OpAsmParser::OperandType, 4> operands; 535 return failure( 536 parser->parseOperand(callee) || 537 parser->getCurrentLocation(&operandsLoc) || 538 parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 539 parser->parseOptionalAttributeDict(result->attributes) || 540 parser->parseColonType(calleeType) || 541 parser->resolveOperand(callee, calleeType, result->operands) || 542 parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc, 543 result->operands) || 544 parser->addTypesToList(calleeType.getResults(), result->types)); 545 } 546 547 static void print(OpAsmPrinter *p, CallIndirectOp op) { 548 *p << "call_indirect "; 549 p->printOperand(op.getCallee()); 550 *p << '('; 551 p->printOperands(op.getArgOperands()); 552 *p << ')'; 553 p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); 554 *p << " : " << op.getCallee()->getType(); 555 } 556 557 static LogicalResult verify(CallIndirectOp op) { 558 // The callee must be a function. 559 auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>(); 560 if (!fnType) 561 return op.emitOpError("callee must have function type"); 562 563 // Verify that the operand and result types match the callee. 564 if (fnType.getNumInputs() != op.getNumOperands() - 1) 565 return op.emitOpError("incorrect number of operands for callee"); 566 567 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) 568 if (op.getOperand(i + 1)->getType() != fnType.getInput(i)) 569 return op.emitOpError("operand type mismatch"); 570 571 if (fnType.getNumResults() != op.getNumResults()) 572 return op.emitOpError("incorrect number of results for callee"); 573 574 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) 575 if (op.getResult(i)->getType() != fnType.getResult(i)) 576 return op.emitOpError("result type mismatch"); 577 578 return success(); 579 } 580 581 void CallIndirectOp::getCanonicalizationPatterns( 582 OwningRewritePatternList &results, MLIRContext *context) { 583 results.insert<SimplifyIndirectCallWithKnownCallee>(context); 584 } 585 586 //===----------------------------------------------------------------------===// 587 // General helpers for comparison ops 588 //===----------------------------------------------------------------------===// 589 590 // Return the type of the same shape (scalar, vector or tensor) containing i1. 591 static Type getCheckedI1SameShape(Builder *build, Type type) { 592 auto i1Type = build->getI1Type(); 593 if (type.isIntOrIndexOrFloat()) 594 return i1Type; 595 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 596 return build->getTensorType(tensorType.getShape(), i1Type); 597 if (type.isa<UnrankedTensorType>()) 598 return build->getTensorType(i1Type); 599 if (auto vectorType = type.dyn_cast<VectorType>()) 600 return build->getVectorType(vectorType.getShape(), i1Type); 601 return Type(); 602 } 603 604 static Type getI1SameShape(Builder *build, Type type) { 605 Type res = getCheckedI1SameShape(build, type); 606 assert(res && "expected type with valid i1 shape"); 607 return res; 608 } 609 610 //===----------------------------------------------------------------------===// 611 // CmpIOp 612 //===----------------------------------------------------------------------===// 613 614 // Returns an array of mnemonics for CmpIPredicates indexed by values thereof. 615 static inline const char *const *getCmpIPredicateNames() { 616 static const char *predicateNames[]{ 617 /*EQ*/ "eq", 618 /*NE*/ "ne", 619 /*SLT*/ "slt", 620 /*SLE*/ "sle", 621 /*SGT*/ "sgt", 622 /*SGE*/ "sge", 623 /*ULT*/ "ult", 624 /*ULE*/ "ule", 625 /*UGT*/ "ugt", 626 /*UGE*/ "uge", 627 }; 628 static_assert(std::extent<decltype(predicateNames)>::value == 629 (size_t)CmpIPredicate::NumPredicates, 630 "wrong number of predicate names"); 631 return predicateNames; 632 } 633 634 // Returns a value of the predicate corresponding to the given mnemonic. 635 // Returns NumPredicates (one-past-end) if there is no such mnemonic. 636 CmpIPredicate CmpIOp::getPredicateByName(StringRef name) { 637 return llvm::StringSwitch<CmpIPredicate>(name) 638 .Case("eq", CmpIPredicate::EQ) 639 .Case("ne", CmpIPredicate::NE) 640 .Case("slt", CmpIPredicate::SLT) 641 .Case("sle", CmpIPredicate::SLE) 642 .Case("sgt", CmpIPredicate::SGT) 643 .Case("sge", CmpIPredicate::SGE) 644 .Case("ult", CmpIPredicate::ULT) 645 .Case("ule", CmpIPredicate::ULE) 646 .Case("ugt", CmpIPredicate::UGT) 647 .Case("uge", CmpIPredicate::UGE) 648 .Default(CmpIPredicate::NumPredicates); 649 } 650 651 static void buildCmpIOp(Builder *build, OperationState *result, 652 CmpIPredicate predicate, Value *lhs, Value *rhs) { 653 result->addOperands({lhs, rhs}); 654 result->types.push_back(getI1SameShape(build, lhs->getType())); 655 result->addAttribute( 656 CmpIOp::getPredicateAttrName(), 657 build->getI64IntegerAttr(static_cast<int64_t>(predicate))); 658 } 659 660 static ParseResult parseCmpIOp(OpAsmParser *parser, OperationState *result) { 661 SmallVector<OpAsmParser::OperandType, 2> ops; 662 SmallVector<NamedAttribute, 4> attrs; 663 Attribute predicateNameAttr; 664 Type type; 665 if (parser->parseAttribute(predicateNameAttr, CmpIOp::getPredicateAttrName(), 666 attrs) || 667 parser->parseComma() || parser->parseOperandList(ops, 2) || 668 parser->parseOptionalAttributeDict(attrs) || 669 parser->parseColonType(type) || 670 parser->resolveOperands(ops, type, result->operands)) 671 return failure(); 672 673 if (!predicateNameAttr.isa<StringAttr>()) 674 return parser->emitError(parser->getNameLoc(), 675 "expected string comparison predicate attribute"); 676 677 // Rewrite string attribute to an enum value. 678 StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); 679 auto predicate = CmpIOp::getPredicateByName(predicateName); 680 if (predicate == CmpIPredicate::NumPredicates) 681 return parser->emitError(parser->getNameLoc()) 682 << "unknown comparison predicate \"" << predicateName << "\""; 683 684 auto builder = parser->getBuilder(); 685 Type i1Type = getCheckedI1SameShape(&builder, type); 686 if (!i1Type) 687 return parser->emitError(parser->getNameLoc(), 688 "expected type with valid i1 shape"); 689 690 attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); 691 result->attributes = attrs; 692 693 result->addTypes({i1Type}); 694 return success(); 695 } 696 697 static void print(OpAsmPrinter *p, CmpIOp op) { 698 *p << "cmpi "; 699 700 auto predicateValue = 701 op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()).getInt(); 702 assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) && 703 predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) && 704 "unknown predicate index"); 705 Builder b(op.getContext()); 706 auto predicateStringAttr = 707 b.getStringAttr(getCmpIPredicateNames()[predicateValue]); 708 p->printAttribute(predicateStringAttr); 709 710 *p << ", "; 711 p->printOperand(op.lhs()); 712 *p << ", "; 713 p->printOperand(op.rhs()); 714 p->printOptionalAttrDict(op.getAttrs(), 715 /*elidedAttrs=*/{CmpIOp::getPredicateAttrName()}); 716 *p << " : " << op.lhs()->getType(); 717 } 718 719 static LogicalResult verify(CmpIOp op) { 720 auto predicateAttr = 721 op.getAttrOfType<IntegerAttr>(CmpIOp::getPredicateAttrName()); 722 if (!predicateAttr) 723 return op.emitOpError("requires an integer attribute named 'predicate'"); 724 auto predicate = predicateAttr.getInt(); 725 if (predicate < (int64_t)CmpIPredicate::FirstValidValue || 726 predicate >= (int64_t)CmpIPredicate::NumPredicates) 727 return op.emitOpError("'predicate' attribute value out of range"); 728 729 return success(); 730 } 731 732 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer 733 // comparison predicates. 734 static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, 735 const APInt &rhs) { 736 switch (predicate) { 737 case CmpIPredicate::EQ: 738 return lhs.eq(rhs); 739 case CmpIPredicate::NE: 740 return lhs.ne(rhs); 741 case CmpIPredicate::SLT: 742 return lhs.slt(rhs); 743 case CmpIPredicate::SLE: 744 return lhs.sle(rhs); 745 case CmpIPredicate::SGT: 746 return lhs.sgt(rhs); 747 case CmpIPredicate::SGE: 748 return lhs.sge(rhs); 749 case CmpIPredicate::ULT: 750 return lhs.ult(rhs); 751 case CmpIPredicate::ULE: 752 return lhs.ule(rhs); 753 case CmpIPredicate::UGT: 754 return lhs.ugt(rhs); 755 case CmpIPredicate::UGE: 756 return lhs.uge(rhs); 757 default: 758 llvm_unreachable("unknown comparison predicate"); 759 } 760 } 761 762 // Constant folding hook for comparisons. 763 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) { 764 assert(operands.size() == 2 && "cmpi takes two arguments"); 765 766 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 767 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 768 if (!lhs || !rhs) 769 return {}; 770 771 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 772 return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); 773 } 774 775 //===----------------------------------------------------------------------===// 776 // CmpFOp 777 //===----------------------------------------------------------------------===// 778 779 // Returns an array of mnemonics for CmpFPredicates indexed by values thereof. 780 static inline const char *const *getCmpFPredicateNames() { 781 static const char *predicateNames[] = { 782 /*AlwaysFalse*/ "false", 783 /*OEQ*/ "oeq", 784 /*OGT*/ "ogt", 785 /*OGE*/ "oge", 786 /*OLT*/ "olt", 787 /*OLE*/ "ole", 788 /*ONE*/ "one", 789 /*ORD*/ "ord", 790 /*UEQ*/ "ueq", 791 /*UGT*/ "ugt", 792 /*UGE*/ "uge", 793 /*ULT*/ "ult", 794 /*ULE*/ "ule", 795 /*UNE*/ "une", 796 /*UNO*/ "uno", 797 /*AlwaysTrue*/ "true", 798 }; 799 static_assert(std::extent<decltype(predicateNames)>::value == 800 (size_t)CmpFPredicate::NumPredicates, 801 "wrong number of predicate names"); 802 return predicateNames; 803 } 804 805 // Returns a value of the predicate corresponding to the given mnemonic. 806 // Returns NumPredicates (one-past-end) if there is no such mnemonic. 807 CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { 808 return llvm::StringSwitch<CmpFPredicate>(name) 809 .Case("false", CmpFPredicate::AlwaysFalse) 810 .Case("oeq", CmpFPredicate::OEQ) 811 .Case("ogt", CmpFPredicate::OGT) 812 .Case("oge", CmpFPredicate::OGE) 813 .Case("olt", CmpFPredicate::OLT) 814 .Case("ole", CmpFPredicate::OLE) 815 .Case("one", CmpFPredicate::ONE) 816 .Case("ord", CmpFPredicate::ORD) 817 .Case("ueq", CmpFPredicate::UEQ) 818 .Case("ugt", CmpFPredicate::UGT) 819 .Case("uge", CmpFPredicate::UGE) 820 .Case("ult", CmpFPredicate::ULT) 821 .Case("ule", CmpFPredicate::ULE) 822 .Case("une", CmpFPredicate::UNE) 823 .Case("uno", CmpFPredicate::UNO) 824 .Case("true", CmpFPredicate::AlwaysTrue) 825 .Default(CmpFPredicate::NumPredicates); 826 } 827 828 static void buildCmpFOp(Builder *build, OperationState *result, 829 CmpFPredicate predicate, Value *lhs, Value *rhs) { 830 result->addOperands({lhs, rhs}); 831 result->types.push_back(getI1SameShape(build, lhs->getType())); 832 result->addAttribute( 833 CmpFOp::getPredicateAttrName(), 834 build->getI64IntegerAttr(static_cast<int64_t>(predicate))); 835 } 836 837 static ParseResult parseCmpFOp(OpAsmParser *parser, OperationState *result) { 838 SmallVector<OpAsmParser::OperandType, 2> ops; 839 SmallVector<NamedAttribute, 4> attrs; 840 Attribute predicateNameAttr; 841 Type type; 842 if (parser->parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), 843 attrs) || 844 parser->parseComma() || parser->parseOperandList(ops, 2) || 845 parser->parseOptionalAttributeDict(attrs) || 846 parser->parseColonType(type) || 847 parser->resolveOperands(ops, type, result->operands)) 848 return failure(); 849 850 if (!predicateNameAttr.isa<StringAttr>()) 851 return parser->emitError(parser->getNameLoc(), 852 "expected string comparison predicate attribute"); 853 854 // Rewrite string attribute to an enum value. 855 StringRef predicateName = predicateNameAttr.cast<StringAttr>().getValue(); 856 auto predicate = CmpFOp::getPredicateByName(predicateName); 857 if (predicate == CmpFPredicate::NumPredicates) 858 return parser->emitError(parser->getNameLoc(), 859 "unknown comparison predicate \"" + predicateName + 860 "\""); 861 862 auto builder = parser->getBuilder(); 863 Type i1Type = getCheckedI1SameShape(&builder, type); 864 if (!i1Type) 865 return parser->emitError(parser->getNameLoc(), 866 "expected type with valid i1 shape"); 867 868 attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); 869 result->attributes = attrs; 870 871 result->addTypes({i1Type}); 872 return success(); 873 } 874 875 static void print(OpAsmPrinter *p, CmpFOp op) { 876 *p << "cmpf "; 877 878 auto predicateValue = 879 op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()).getInt(); 880 assert(predicateValue >= static_cast<int>(CmpFPredicate::FirstValidValue) && 881 predicateValue < static_cast<int>(CmpFPredicate::NumPredicates) && 882 "unknown predicate index"); 883 Builder b(op.getContext()); 884 auto predicateStringAttr = 885 b.getStringAttr(getCmpFPredicateNames()[predicateValue]); 886 p->printAttribute(predicateStringAttr); 887 888 *p << ", "; 889 p->printOperand(op.lhs()); 890 *p << ", "; 891 p->printOperand(op.rhs()); 892 p->printOptionalAttrDict(op.getAttrs(), 893 /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); 894 *p << " : " << op.lhs()->getType(); 895 } 896 897 static LogicalResult verify(CmpFOp op) { 898 auto predicateAttr = 899 op.getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()); 900 if (!predicateAttr) 901 return op.emitOpError("requires an integer attribute named 'predicate'"); 902 auto predicate = predicateAttr.getInt(); 903 if (predicate < (int64_t)CmpFPredicate::FirstValidValue || 904 predicate >= (int64_t)CmpFPredicate::NumPredicates) 905 return op.emitOpError("'predicate' attribute value out of range"); 906 907 return success(); 908 } 909 910 // Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point 911 // comparison predicates. 912 static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, 913 const APFloat &rhs) { 914 auto cmpResult = lhs.compare(rhs); 915 switch (predicate) { 916 case CmpFPredicate::AlwaysFalse: 917 return false; 918 case CmpFPredicate::OEQ: 919 return cmpResult == APFloat::cmpEqual; 920 case CmpFPredicate::OGT: 921 return cmpResult == APFloat::cmpGreaterThan; 922 case CmpFPredicate::OGE: 923 return cmpResult == APFloat::cmpGreaterThan || 924 cmpResult == APFloat::cmpEqual; 925 case CmpFPredicate::OLT: 926 return cmpResult == APFloat::cmpLessThan; 927 case CmpFPredicate::OLE: 928 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 929 case CmpFPredicate::ONE: 930 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; 931 case CmpFPredicate::ORD: 932 return cmpResult != APFloat::cmpUnordered; 933 case CmpFPredicate::UEQ: 934 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; 935 case CmpFPredicate::UGT: 936 return cmpResult == APFloat::cmpUnordered || 937 cmpResult == APFloat::cmpGreaterThan; 938 case CmpFPredicate::UGE: 939 return cmpResult == APFloat::cmpUnordered || 940 cmpResult == APFloat::cmpGreaterThan || 941 cmpResult == APFloat::cmpEqual; 942 case CmpFPredicate::ULT: 943 return cmpResult == APFloat::cmpUnordered || 944 cmpResult == APFloat::cmpLessThan; 945 case CmpFPredicate::ULE: 946 return cmpResult == APFloat::cmpUnordered || 947 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; 948 case CmpFPredicate::UNE: 949 return cmpResult != APFloat::cmpEqual; 950 case CmpFPredicate::UNO: 951 return cmpResult == APFloat::cmpUnordered; 952 case CmpFPredicate::AlwaysTrue: 953 return true; 954 default: 955 llvm_unreachable("unknown comparison predicate"); 956 } 957 } 958 959 // Constant folding hook for comparisons. 960 OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) { 961 assert(operands.size() == 2 && "cmpf takes two arguments"); 962 963 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>(); 964 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>(); 965 if (!lhs || !rhs || 966 // TODO(b/122019992) Implement and test constant folding for nan/inf when 967 // it is possible to have constant nan/inf 968 !lhs.getValue().isFinite() || !rhs.getValue().isFinite()) 969 return {}; 970 971 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); 972 return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); 973 } 974 975 //===----------------------------------------------------------------------===// 976 // CondBranchOp 977 //===----------------------------------------------------------------------===// 978 979 namespace { 980 /// cond_br true, ^bb1, ^bb2 -> br ^bb1 981 /// cond_br false, ^bb1, ^bb2 -> br ^bb2 982 /// 983 struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { 984 using OpRewritePattern<CondBranchOp>::OpRewritePattern; 985 986 PatternMatchResult matchAndRewrite(CondBranchOp condbr, 987 PatternRewriter &rewriter) const override { 988 // Check that the condition is a constant. 989 if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>())) 990 return matchFailure(); 991 992 Block *foldedDest; 993 SmallVector<Value *, 4> branchArgs; 994 995 // If the condition is known to evaluate to false we fold to a branch to the 996 // false destination. Otherwise, we fold to a branch to the true 997 // destination. 998 if (matchPattern(condbr.getCondition(), m_Zero())) { 999 foldedDest = condbr.getFalseDest(); 1000 branchArgs.assign(condbr.false_operand_begin(), 1001 condbr.false_operand_end()); 1002 } else { 1003 foldedDest = condbr.getTrueDest(); 1004 branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); 1005 } 1006 1007 rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs); 1008 return matchSuccess(); 1009 } 1010 }; 1011 } // end anonymous namespace. 1012 1013 static ParseResult parseCondBranchOp(OpAsmParser *parser, 1014 OperationState *result) { 1015 SmallVector<Value *, 4> destOperands; 1016 Block *dest; 1017 OpAsmParser::OperandType condInfo; 1018 1019 // Parse the condition. 1020 Type int1Ty = parser->getBuilder().getI1Type(); 1021 if (parser->parseOperand(condInfo) || parser->parseComma() || 1022 parser->resolveOperand(condInfo, int1Ty, result->operands)) { 1023 return parser->emitError(parser->getNameLoc(), 1024 "expected condition type was boolean (i1)"); 1025 } 1026 1027 // Parse the true successor. 1028 if (parser->parseSuccessorAndUseList(dest, destOperands)) 1029 return failure(); 1030 result->addSuccessor(dest, destOperands); 1031 1032 // Parse the false successor. 1033 destOperands.clear(); 1034 if (parser->parseComma() || 1035 parser->parseSuccessorAndUseList(dest, destOperands)) 1036 return failure(); 1037 result->addSuccessor(dest, destOperands); 1038 1039 return success(); 1040 } 1041 1042 static void print(OpAsmPrinter *p, CondBranchOp op) { 1043 *p << "cond_br "; 1044 p->printOperand(op.getCondition()); 1045 *p << ", "; 1046 p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); 1047 *p << ", "; 1048 p->printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); 1049 } 1050 1051 void CondBranchOp::getCanonicalizationPatterns( 1052 OwningRewritePatternList &results, MLIRContext *context) { 1053 results.insert<SimplifyConstCondBranchPred>(context); 1054 } 1055 1056 //===----------------------------------------------------------------------===// 1057 // Constant*Op 1058 //===----------------------------------------------------------------------===// 1059 1060 static void print(OpAsmPrinter *p, ConstantOp &op) { 1061 *p << "constant "; 1062 p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); 1063 1064 if (op.getAttrs().size() > 1) 1065 *p << ' '; 1066 p->printAttribute(op.getValue()); 1067 1068 // If the value is a symbol reference, print a trailing type. 1069 if (op.getValue().isa<SymbolRefAttr>()) 1070 *p << " : " << op.getType(); 1071 } 1072 1073 static ParseResult parseConstantOp(OpAsmParser *parser, 1074 OperationState *result) { 1075 Attribute valueAttr; 1076 if (parser->parseOptionalAttributeDict(result->attributes) || 1077 parser->parseAttribute(valueAttr, "value", result->attributes)) 1078 return failure(); 1079 1080 // If the attribute is a symbol reference, then we expect a trailing type. 1081 Type type; 1082 if (!valueAttr.isa<SymbolRefAttr>()) 1083 type = valueAttr.getType(); 1084 else if (parser->parseColonType(type)) 1085 return failure(); 1086 1087 // Add the attribute type to the list. 1088 return parser->addTypeToList(type, result->types); 1089 } 1090 1091 /// The constant op requires an attribute, and furthermore requires that it 1092 /// matches the return type. 1093 static LogicalResult verify(ConstantOp &op) { 1094 auto value = op.getValue(); 1095 if (!value) 1096 return op.emitOpError("requires a 'value' attribute"); 1097 1098 auto type = op.getType(); 1099 if (!value.getType().isa<NoneType>() && type != value.getType()) 1100 return op.emitOpError() << "requires attribute's type (" << value.getType() 1101 << ") to match op's return type (" << type << ")"; 1102 1103 if (type.isa<IndexType>() || value.isa<BoolAttr>()) 1104 return success(); 1105 1106 if (auto intAttr = value.dyn_cast<IntegerAttr>()) { 1107 // If the type has a known bitwidth we verify that the value can be 1108 // represented with the given bitwidth. 1109 auto bitwidth = type.cast<IntegerType>().getWidth(); 1110 auto intVal = intAttr.getValue(); 1111 if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) 1112 return op.emitOpError("requires 'value' to be an integer within the " 1113 "range of the integer result type"); 1114 return success(); 1115 } 1116 1117 if (type.isa<FloatType>()) { 1118 if (!value.isa<FloatAttr>()) 1119 return op.emitOpError("requires 'value' to be a floating point constant"); 1120 return success(); 1121 } 1122 1123 if (type.isa<ShapedType>()) { 1124 if (!value.isa<ElementsAttr>()) 1125 return op.emitOpError("requires 'value' to be a shaped constant"); 1126 return success(); 1127 } 1128 1129 if (type.isa<FunctionType>()) { 1130 auto fnAttr = value.dyn_cast<SymbolRefAttr>(); 1131 if (!fnAttr) 1132 return op.emitOpError("requires 'value' to be a function reference"); 1133 1134 // Try to find the referenced function. 1135 auto fn = 1136 op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue()); 1137 if (!fn) 1138 return op.emitOpError("reference to undefined function 'bar'"); 1139 1140 // Check that the referenced function has the correct type. 1141 if (fn.getType() != type) 1142 return op.emitOpError("reference to function with mismatched type"); 1143 1144 return success(); 1145 } 1146 1147 if (type.isa<NoneType>() && value.isa<UnitAttr>()) 1148 return success(); 1149 1150 return op.emitOpError("unsupported 'value' attribute: ") << value; 1151 } 1152 1153 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { 1154 assert(operands.empty() && "constant has no operands"); 1155 return getValue(); 1156 } 1157 1158 /// Returns true if a constant operation can be built with the given value and 1159 /// result type. 1160 bool ConstantOp::isBuildableWith(Attribute value, Type type) { 1161 // SymbolRefAttr can only be used with a function type. 1162 if (value.isa<SymbolRefAttr>()) 1163 return type.isa<FunctionType>(); 1164 // Otherwise, the attribute must have the same type as 'type'. 1165 if (value.getType() != type) 1166 return false; 1167 // Finally, check that the attribute kind is handled. 1168 return value.isa<BoolAttr>() || value.isa<IntegerAttr>() || 1169 value.isa<FloatAttr>() || value.isa<ElementsAttr>() || 1170 value.isa<UnitAttr>(); 1171 } 1172 1173 void ConstantFloatOp::build(Builder *builder, OperationState *result, 1174 const APFloat &value, FloatType type) { 1175 ConstantOp::build(builder, result, type, builder->getFloatAttr(type, value)); 1176 } 1177 1178 bool ConstantFloatOp::classof(Operation *op) { 1179 return ConstantOp::classof(op) && 1180 op->getResult(0)->getType().isa<FloatType>(); 1181 } 1182 1183 /// ConstantIntOp only matches values whose result type is an IntegerType. 1184 bool ConstantIntOp::classof(Operation *op) { 1185 return ConstantOp::classof(op) && 1186 op->getResult(0)->getType().isa<IntegerType>(); 1187 } 1188 1189 void ConstantIntOp::build(Builder *builder, OperationState *result, 1190 int64_t value, unsigned width) { 1191 Type type = builder->getIntegerType(width); 1192 ConstantOp::build(builder, result, type, 1193 builder->getIntegerAttr(type, value)); 1194 } 1195 1196 /// Build a constant int op producing an integer with the specified type, 1197 /// which must be an integer type. 1198 void ConstantIntOp::build(Builder *builder, OperationState *result, 1199 int64_t value, Type type) { 1200 assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type"); 1201 ConstantOp::build(builder, result, type, 1202 builder->getIntegerAttr(type, value)); 1203 } 1204 1205 /// ConstantIndexOp only matches values whose result type is Index. 1206 bool ConstantIndexOp::classof(Operation *op) { 1207 return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex(); 1208 } 1209 1210 void ConstantIndexOp::build(Builder *builder, OperationState *result, 1211 int64_t value) { 1212 Type type = builder->getIndexType(); 1213 ConstantOp::build(builder, result, type, 1214 builder->getIntegerAttr(type, value)); 1215 } 1216 1217 //===----------------------------------------------------------------------===// 1218 // DeallocOp 1219 //===----------------------------------------------------------------------===// 1220 namespace { 1221 /// Fold Dealloc operations that are deallocating an AllocOp that is only used 1222 /// by other Dealloc operations. 1223 struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> { 1224 using OpRewritePattern<DeallocOp>::OpRewritePattern; 1225 1226 PatternMatchResult matchAndRewrite(DeallocOp dealloc, 1227 PatternRewriter &rewriter) const override { 1228 // Check that the memref operand's defining operation is an AllocOp. 1229 Value *memref = dealloc.memref(); 1230 if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp())) 1231 return matchFailure(); 1232 1233 // Check that all of the uses of the AllocOp are other DeallocOps. 1234 for (auto *user : memref->getUsers()) 1235 if (!isa<DeallocOp>(user)) 1236 return matchFailure(); 1237 1238 // Erase the dealloc operation. 1239 rewriter.replaceOp(dealloc, llvm::None); 1240 return matchSuccess(); 1241 } 1242 }; 1243 } // end anonymous namespace. 1244 1245 static void print(OpAsmPrinter *p, DeallocOp op) { 1246 *p << "dealloc " << *op.memref() << " : " << op.memref()->getType(); 1247 } 1248 1249 static ParseResult parseDeallocOp(OpAsmParser *parser, OperationState *result) { 1250 OpAsmParser::OperandType memrefInfo; 1251 MemRefType type; 1252 1253 return failure(parser->parseOperand(memrefInfo) || 1254 parser->parseColonType(type) || 1255 parser->resolveOperand(memrefInfo, type, result->operands)); 1256 } 1257 1258 static LogicalResult verify(DeallocOp op) { 1259 if (!op.memref()->getType().isa<MemRefType>()) 1260 return op.emitOpError("operand must be a memref"); 1261 return success(); 1262 } 1263 1264 void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 1265 MLIRContext *context) { 1266 /// dealloc(memrefcast) -> dealloc 1267 results.insert<MemRefCastFolder>(getOperationName(), context); 1268 results.insert<SimplifyDeadDealloc>(context); 1269 } 1270 1271 //===----------------------------------------------------------------------===// 1272 // DimOp 1273 //===----------------------------------------------------------------------===// 1274 1275 static void print(OpAsmPrinter *p, DimOp op) { 1276 *p << "dim " << *op.getOperand() << ", " << op.getIndex(); 1277 p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); 1278 *p << " : " << op.getOperand()->getType(); 1279 } 1280 1281 static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) { 1282 OpAsmParser::OperandType operandInfo; 1283 IntegerAttr indexAttr; 1284 Type type; 1285 Type indexType = parser->getBuilder().getIndexType(); 1286 1287 return failure(parser->parseOperand(operandInfo) || parser->parseComma() || 1288 parser->parseAttribute(indexAttr, indexType, "index", 1289 result->attributes) || 1290 parser->parseOptionalAttributeDict(result->attributes) || 1291 parser->parseColonType(type) || 1292 parser->resolveOperand(operandInfo, type, result->operands) || 1293 parser->addTypeToList(indexType, result->types)); 1294 } 1295 1296 static LogicalResult verify(DimOp op) { 1297 // Check that we have an integer index operand. 1298 auto indexAttr = op.getAttrOfType<IntegerAttr>("index"); 1299 if (!indexAttr) 1300 return op.emitOpError("requires an integer attribute named 'index'"); 1301 int64_t index = indexAttr.getValue().getSExtValue(); 1302 1303 auto type = op.getOperand()->getType(); 1304 if (auto tensorType = type.dyn_cast<RankedTensorType>()) { 1305 if (index >= tensorType.getRank()) 1306 return op.emitOpError("index is out of range"); 1307 } else if (auto memrefType = type.dyn_cast<MemRefType>()) { 1308 if (index >= memrefType.getRank()) 1309 return op.emitOpError("index is out of range"); 1310 1311 } else if (type.isa<UnrankedTensorType>()) { 1312 // ok, assumed to be in-range. 1313 } else { 1314 return op.emitOpError("requires an operand with tensor or memref type"); 1315 } 1316 1317 return success(); 1318 } 1319 1320 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) { 1321 // Constant fold dim when the size along the index referred to is a constant. 1322 auto opType = getOperand()->getType(); 1323 int64_t indexSize = -1; 1324 if (auto tensorType = opType.dyn_cast<RankedTensorType>()) 1325 indexSize = tensorType.getShape()[getIndex()]; 1326 else if (auto memrefType = opType.dyn_cast<MemRefType>()) 1327 indexSize = memrefType.getShape()[getIndex()]; 1328 1329 if (indexSize >= 0) 1330 return IntegerAttr::get(IndexType::get(getContext()), indexSize); 1331 1332 return {}; 1333 } 1334 1335 //===----------------------------------------------------------------------===// 1336 // DivISOp 1337 //===----------------------------------------------------------------------===// 1338 1339 OpFoldResult DivISOp::fold(ArrayRef<Attribute> operands) { 1340 assert(operands.size() == 2 && "binary operation takes two operands"); 1341 1342 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1343 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1344 if (!lhs || !rhs) 1345 return {}; 1346 1347 // Don't fold if it requires division by zero. 1348 if (rhs.getValue().isNullValue()) 1349 return {}; 1350 1351 // Don't fold if it would overflow. 1352 bool overflow; 1353 auto result = lhs.getValue().sdiv_ov(rhs.getValue(), overflow); 1354 return overflow ? IntegerAttr() : IntegerAttr::get(lhs.getType(), result); 1355 } 1356 1357 //===----------------------------------------------------------------------===// 1358 // DivIUOp 1359 //===----------------------------------------------------------------------===// 1360 1361 OpFoldResult DivIUOp::fold(ArrayRef<Attribute> operands) { 1362 assert(operands.size() == 2 && "binary operation takes two operands"); 1363 1364 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1365 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1366 if (!lhs || !rhs) 1367 return {}; 1368 1369 // Don't fold if it requires division by zero. 1370 auto rhsValue = rhs.getValue(); 1371 if (rhsValue.isNullValue()) 1372 return {}; 1373 1374 return IntegerAttr::get(lhs.getType(), lhs.getValue().udiv(rhsValue)); 1375 } 1376 1377 // --------------------------------------------------------------------------- 1378 // DmaStartOp 1379 // --------------------------------------------------------------------------- 1380 1381 void DmaStartOp::build(Builder *builder, OperationState *result, 1382 Value *srcMemRef, ArrayRef<Value *> srcIndices, 1383 Value *destMemRef, ArrayRef<Value *> destIndices, 1384 Value *numElements, Value *tagMemRef, 1385 ArrayRef<Value *> tagIndices, Value *stride, 1386 Value *elementsPerStride) { 1387 result->addOperands(srcMemRef); 1388 result->addOperands(srcIndices); 1389 result->addOperands(destMemRef); 1390 result->addOperands(destIndices); 1391 result->addOperands({numElements, tagMemRef}); 1392 result->addOperands(tagIndices); 1393 if (stride) 1394 result->addOperands({stride, elementsPerStride}); 1395 } 1396 1397 void DmaStartOp::print(OpAsmPrinter *p) { 1398 *p << "dma_start " << *getSrcMemRef() << '['; 1399 p->printOperands(getSrcIndices()); 1400 *p << "], " << *getDstMemRef() << '['; 1401 p->printOperands(getDstIndices()); 1402 *p << "], " << *getNumElements(); 1403 *p << ", " << *getTagMemRef() << '['; 1404 p->printOperands(getTagIndices()); 1405 *p << ']'; 1406 if (isStrided()) { 1407 *p << ", " << *getStride(); 1408 *p << ", " << *getNumElementsPerStride(); 1409 } 1410 p->printOptionalAttrDict(getAttrs()); 1411 *p << " : " << getSrcMemRef()->getType(); 1412 *p << ", " << getDstMemRef()->getType(); 1413 *p << ", " << getTagMemRef()->getType(); 1414 } 1415 1416 // Parse DmaStartOp. 1417 // Ex: 1418 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, 1419 // %tag[%index], %stride, %num_elt_per_stride : 1420 // : memref<3076 x f32, 0>, 1421 // memref<1024 x f32, 2>, 1422 // memref<1 x i32> 1423 // 1424 ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) { 1425 OpAsmParser::OperandType srcMemRefInfo; 1426 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos; 1427 OpAsmParser::OperandType dstMemRefInfo; 1428 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos; 1429 OpAsmParser::OperandType numElementsInfo; 1430 OpAsmParser::OperandType tagMemrefInfo; 1431 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos; 1432 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 1433 1434 SmallVector<Type, 3> types; 1435 auto indexType = parser->getBuilder().getIndexType(); 1436 1437 // Parse and resolve the following list of operands: 1438 // *) source memref followed by its indices (in square brackets). 1439 // *) destination memref followed by its indices (in square brackets). 1440 // *) dma size in KiB. 1441 if (parser->parseOperand(srcMemRefInfo) || 1442 parser->parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || 1443 parser->parseComma() || parser->parseOperand(dstMemRefInfo) || 1444 parser->parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || 1445 parser->parseComma() || parser->parseOperand(numElementsInfo) || 1446 parser->parseComma() || parser->parseOperand(tagMemrefInfo) || 1447 parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) 1448 return failure(); 1449 1450 // Parse optional stride and elements per stride. 1451 if (parser->parseTrailingOperandList(strideInfo)) 1452 return failure(); 1453 1454 bool isStrided = strideInfo.size() == 2; 1455 if (!strideInfo.empty() && !isStrided) { 1456 return parser->emitError(parser->getNameLoc(), 1457 "expected two stride related operands"); 1458 } 1459 1460 if (parser->parseColonTypeList(types)) 1461 return failure(); 1462 if (types.size() != 3) 1463 return parser->emitError(parser->getNameLoc(), "fewer/more types expected"); 1464 1465 if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || 1466 parser->resolveOperands(srcIndexInfos, indexType, result->operands) || 1467 parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || 1468 parser->resolveOperands(dstIndexInfos, indexType, result->operands) || 1469 // size should be an index. 1470 parser->resolveOperand(numElementsInfo, indexType, result->operands) || 1471 parser->resolveOperand(tagMemrefInfo, types[2], result->operands) || 1472 // tag indices should be index. 1473 parser->resolveOperands(tagIndexInfos, indexType, result->operands)) 1474 return failure(); 1475 1476 auto memrefType0 = types[0].dyn_cast<MemRefType>(); 1477 if (!memrefType0) 1478 return parser->emitError(parser->getNameLoc(), 1479 "expected source to be of memref type"); 1480 1481 auto memrefType1 = types[1].dyn_cast<MemRefType>(); 1482 if (!memrefType1) 1483 return parser->emitError(parser->getNameLoc(), 1484 "expected destination to be of memref type"); 1485 1486 auto memrefType2 = types[2].dyn_cast<MemRefType>(); 1487 if (!memrefType2) 1488 return parser->emitError(parser->getNameLoc(), 1489 "expected tag to be of memref type"); 1490 1491 if (isStrided) { 1492 if (parser->resolveOperands(strideInfo, indexType, result->operands)) 1493 return failure(); 1494 } 1495 1496 // Check that source/destination index list size matches associated rank. 1497 if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() || 1498 static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank()) 1499 return parser->emitError(parser->getNameLoc(), 1500 "memref rank not equal to indices count"); 1501 if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank()) 1502 return parser->emitError(parser->getNameLoc(), 1503 "tag memref rank not equal to indices count"); 1504 1505 return success(); 1506 } 1507 1508 LogicalResult DmaStartOp::verify() { 1509 // DMAs from different memory spaces supported. 1510 if (getSrcMemorySpace() == getDstMemorySpace()) 1511 return emitOpError("DMA should be between different memory spaces"); 1512 1513 if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + 1514 getDstMemRefRank() + 3 + 1 && 1515 getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() + 1516 getDstMemRefRank() + 3 + 1 + 2) { 1517 return emitOpError("incorrect number of operands"); 1518 } 1519 return success(); 1520 } 1521 1522 void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 1523 MLIRContext *context) { 1524 /// dma_start(memrefcast) -> dma_start 1525 results.insert<MemRefCastFolder>(getOperationName(), context); 1526 } 1527 1528 // --------------------------------------------------------------------------- 1529 // DmaWaitOp 1530 // --------------------------------------------------------------------------- 1531 1532 void DmaWaitOp::build(Builder *builder, OperationState *result, 1533 Value *tagMemRef, ArrayRef<Value *> tagIndices, 1534 Value *numElements) { 1535 result->addOperands(tagMemRef); 1536 result->addOperands(tagIndices); 1537 result->addOperands(numElements); 1538 } 1539 1540 void DmaWaitOp::print(OpAsmPrinter *p) { 1541 *p << "dma_wait "; 1542 p->printOperand(getTagMemRef()); 1543 *p << '['; 1544 p->printOperands(getTagIndices()); 1545 *p << "], "; 1546 p->printOperand(getNumElements()); 1547 p->printOptionalAttrDict(getAttrs()); 1548 *p << " : " << getTagMemRef()->getType(); 1549 } 1550 1551 // Parse DmaWaitOp. 1552 // Eg: 1553 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> 1554 // 1555 ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) { 1556 OpAsmParser::OperandType tagMemrefInfo; 1557 SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos; 1558 Type type; 1559 auto indexType = parser->getBuilder().getIndexType(); 1560 OpAsmParser::OperandType numElementsInfo; 1561 1562 // Parse tag memref, its indices, and dma size. 1563 if (parser->parseOperand(tagMemrefInfo) || 1564 parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || 1565 parser->parseComma() || parser->parseOperand(numElementsInfo) || 1566 parser->parseColonType(type) || 1567 parser->resolveOperand(tagMemrefInfo, type, result->operands) || 1568 parser->resolveOperands(tagIndexInfos, indexType, result->operands) || 1569 parser->resolveOperand(numElementsInfo, indexType, result->operands)) 1570 return failure(); 1571 1572 auto memrefType = type.dyn_cast<MemRefType>(); 1573 if (!memrefType) 1574 return parser->emitError(parser->getNameLoc(), 1575 "expected tag to be of memref type"); 1576 1577 if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank()) 1578 return parser->emitError(parser->getNameLoc(), 1579 "tag memref rank not equal to indices count"); 1580 1581 return success(); 1582 } 1583 1584 void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 1585 MLIRContext *context) { 1586 /// dma_wait(memrefcast) -> dma_wait 1587 results.insert<MemRefCastFolder>(getOperationName(), context); 1588 } 1589 1590 //===----------------------------------------------------------------------===// 1591 // ExtractElementOp 1592 //===----------------------------------------------------------------------===// 1593 1594 static void print(OpAsmPrinter *p, ExtractElementOp op) { 1595 *p << "extract_element " << *op.getAggregate() << '['; 1596 p->printOperands(op.getIndices()); 1597 *p << ']'; 1598 p->printOptionalAttrDict(op.getAttrs()); 1599 *p << " : " << op.getAggregate()->getType(); 1600 } 1601 1602 static ParseResult parseExtractElementOp(OpAsmParser *parser, 1603 OperationState *result) { 1604 OpAsmParser::OperandType aggregateInfo; 1605 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1606 ShapedType type; 1607 1608 auto affineIntTy = parser->getBuilder().getIndexType(); 1609 return failure( 1610 parser->parseOperand(aggregateInfo) || 1611 parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1612 parser->parseOptionalAttributeDict(result->attributes) || 1613 parser->parseColonType(type) || 1614 parser->resolveOperand(aggregateInfo, type, result->operands) || 1615 parser->resolveOperands(indexInfo, affineIntTy, result->operands) || 1616 parser->addTypeToList(type.getElementType(), result->types)); 1617 } 1618 1619 static LogicalResult verify(ExtractElementOp op) { 1620 auto aggregateType = op.getAggregate()->getType().cast<ShapedType>(); 1621 1622 // This should be possible with tablegen type constraints 1623 if (op.getType() != aggregateType.getElementType()) 1624 return op.emitOpError("result type must match element type of aggregate"); 1625 1626 // Verify the # indices match if we have a ranked type. 1627 if (aggregateType.hasRank() && 1628 aggregateType.getRank() != op.getNumOperands() - 1) 1629 return op.emitOpError("incorrect number of indices for extract_element"); 1630 1631 return success(); 1632 } 1633 1634 OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) { 1635 assert(!operands.empty() && "extract_element takes atleast one operand"); 1636 1637 // The aggregate operand must be a known constant. 1638 Attribute aggregate = operands.front(); 1639 if (!aggregate) 1640 return {}; 1641 1642 // If this is a splat elements attribute, simply return the value. All of the 1643 // elements of a splat attribute are the same. 1644 if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>()) 1645 return splatAggregate.getSplatValue(); 1646 1647 // Otherwise, collect the constant indices into the aggregate. 1648 SmallVector<uint64_t, 8> indices; 1649 for (Attribute indice : llvm::drop_begin(operands, 1)) { 1650 if (!indice || !indice.isa<IntegerAttr>()) 1651 return {}; 1652 indices.push_back(indice.cast<IntegerAttr>().getInt()); 1653 } 1654 1655 // If this is an elements attribute, query the value at the given indices. 1656 auto elementsAttr = aggregate.dyn_cast<ElementsAttr>(); 1657 if (elementsAttr && elementsAttr.isValidIndex(indices)) 1658 return elementsAttr.getValue(indices); 1659 return {}; 1660 } 1661 1662 //===----------------------------------------------------------------------===// 1663 // IndexCastOp 1664 //===----------------------------------------------------------------------===// 1665 1666 // Index cast is applicable from index to integer and backwards. 1667 bool IndexCastOp::areCastCompatible(Type a, Type b) { 1668 return (a.isIndex() && b.isa<IntegerType>()) || 1669 (a.isa<IntegerType>() && b.isIndex()); 1670 } 1671 1672 //===----------------------------------------------------------------------===// 1673 // LoadOp 1674 //===----------------------------------------------------------------------===// 1675 1676 static void print(OpAsmPrinter *p, LoadOp op) { 1677 *p << "load " << *op.getMemRef() << '['; 1678 p->printOperands(op.getIndices()); 1679 *p << ']'; 1680 p->printOptionalAttrDict(op.getAttrs()); 1681 *p << " : " << op.getMemRefType(); 1682 } 1683 1684 static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { 1685 OpAsmParser::OperandType memrefInfo; 1686 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1687 MemRefType type; 1688 1689 auto affineIntTy = parser->getBuilder().getIndexType(); 1690 return failure( 1691 parser->parseOperand(memrefInfo) || 1692 parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1693 parser->parseOptionalAttributeDict(result->attributes) || 1694 parser->parseColonType(type) || 1695 parser->resolveOperand(memrefInfo, type, result->operands) || 1696 parser->resolveOperands(indexInfo, affineIntTy, result->operands) || 1697 parser->addTypeToList(type.getElementType(), result->types)); 1698 } 1699 1700 static LogicalResult verify(LoadOp op) { 1701 if (op.getType() != op.getMemRefType().getElementType()) 1702 return op.emitOpError("result type must match element type of memref"); 1703 1704 if (op.getMemRefType().getRank() != op.getNumOperands() - 1) 1705 return op.emitOpError("incorrect number of indices for load"); 1706 1707 for (auto *idx : op.getIndices()) 1708 if (!idx->getType().isIndex()) 1709 return op.emitOpError("index to load must have 'index' type"); 1710 1711 // TODO: Verify we have the right number of indices. 1712 1713 // TODO: in Function verify that the indices are parameters, IV's, or the 1714 // result of an affine.apply. 1715 return success(); 1716 } 1717 1718 void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 1719 MLIRContext *context) { 1720 /// load(memrefcast) -> load 1721 results.insert<MemRefCastFolder>(getOperationName(), context); 1722 } 1723 1724 //===----------------------------------------------------------------------===// 1725 // MemRefCastOp 1726 //===----------------------------------------------------------------------===// 1727 1728 bool MemRefCastOp::areCastCompatible(Type a, Type b) { 1729 auto aT = a.dyn_cast<MemRefType>(); 1730 auto bT = b.dyn_cast<MemRefType>(); 1731 1732 if (!aT || !bT) 1733 return false; 1734 if (aT.getElementType() != bT.getElementType()) 1735 return false; 1736 if (aT.getAffineMaps() != bT.getAffineMaps()) 1737 return false; 1738 if (aT.getMemorySpace() != bT.getMemorySpace()) 1739 return false; 1740 1741 // They must have the same rank, and any specified dimensions must match. 1742 if (aT.getRank() != bT.getRank()) 1743 return false; 1744 1745 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { 1746 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); 1747 if (aDim != -1 && bDim != -1 && aDim != bDim) 1748 return false; 1749 } 1750 1751 return true; 1752 } 1753 1754 OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) { 1755 return impl::foldCastOp(*this); 1756 } 1757 1758 //===----------------------------------------------------------------------===// 1759 // MulFOp 1760 //===----------------------------------------------------------------------===// 1761 1762 OpFoldResult MulFOp::fold(ArrayRef<Attribute> operands) { 1763 return constFoldBinaryOp<FloatAttr>( 1764 operands, [](APFloat a, APFloat b) { return a * b; }); 1765 } 1766 1767 //===----------------------------------------------------------------------===// 1768 // MulIOp 1769 //===----------------------------------------------------------------------===// 1770 1771 OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) { 1772 /// muli(x, 0) -> 0 1773 if (matchPattern(rhs(), m_Zero())) 1774 return rhs(); 1775 /// muli(x, 1) -> x 1776 if (matchPattern(rhs(), m_One())) 1777 return getOperand(0); 1778 1779 // TODO: Handle the overflow case. 1780 return constFoldBinaryOp<IntegerAttr>(operands, 1781 [](APInt a, APInt b) { return a * b; }); 1782 } 1783 1784 //===----------------------------------------------------------------------===// 1785 // RankOp 1786 //===----------------------------------------------------------------------===// 1787 1788 static void print(OpAsmPrinter *p, RankOp op) { 1789 *p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType(); 1790 } 1791 1792 static ParseResult parseRankOp(OpAsmParser *parser, OperationState *result) { 1793 OpAsmParser::OperandType operandInfo; 1794 Type type; 1795 Type indexType = parser->getBuilder().getIndexType(); 1796 return failure(parser->parseOperand(operandInfo) || 1797 parser->parseColonType(type) || 1798 parser->resolveOperand(operandInfo, type, result->operands) || 1799 parser->addTypeToList(indexType, result->types)); 1800 } 1801 1802 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { 1803 // Constant fold rank when the rank of the tensor is known. 1804 auto type = getOperand()->getType(); 1805 if (auto tensorType = type.dyn_cast<RankedTensorType>()) 1806 return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank()); 1807 return IntegerAttr(); 1808 } 1809 1810 //===----------------------------------------------------------------------===// 1811 // RemISOp 1812 //===----------------------------------------------------------------------===// 1813 1814 OpFoldResult RemISOp::fold(ArrayRef<Attribute> operands) { 1815 assert(operands.size() == 2 && "remis takes two operands"); 1816 1817 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1818 if (!rhs) 1819 return {}; 1820 auto rhsValue = rhs.getValue(); 1821 1822 // x % 1 = 0 1823 if (rhsValue.isOneValue()) 1824 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 1825 1826 // Don't fold if it requires division by zero. 1827 if (rhsValue.isNullValue()) 1828 return {}; 1829 1830 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1831 if (!lhs) 1832 return {}; 1833 return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); 1834 } 1835 1836 //===----------------------------------------------------------------------===// 1837 // RemIUOp 1838 //===----------------------------------------------------------------------===// 1839 1840 OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) { 1841 assert(operands.size() == 2 && "remiu takes two operands"); 1842 1843 auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>(); 1844 if (!rhs) 1845 return {}; 1846 auto rhsValue = rhs.getValue(); 1847 1848 // x % 1 = 0 1849 if (rhsValue.isOneValue()) 1850 return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); 1851 1852 // Don't fold if it requires division by zero. 1853 if (rhsValue.isNullValue()) 1854 return {}; 1855 1856 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>(); 1857 if (!lhs) 1858 return {}; 1859 return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); 1860 } 1861 1862 //===----------------------------------------------------------------------===// 1863 // ReturnOp 1864 //===----------------------------------------------------------------------===// 1865 1866 static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { 1867 SmallVector<OpAsmParser::OperandType, 2> opInfo; 1868 SmallVector<Type, 2> types; 1869 llvm::SMLoc loc = parser->getCurrentLocation(); 1870 return failure(parser->parseOperandList(opInfo) || 1871 (!opInfo.empty() && parser->parseColonTypeList(types)) || 1872 parser->resolveOperands(opInfo, types, loc, result->operands)); 1873 } 1874 1875 static void print(OpAsmPrinter *p, ReturnOp op) { 1876 *p << "return"; 1877 if (op.getNumOperands() != 0) { 1878 *p << ' '; 1879 p->printOperands(op.getOperands()); 1880 *p << " : "; 1881 interleaveComma(op.getOperandTypes(), *p); 1882 } 1883 } 1884 1885 static LogicalResult verify(ReturnOp op) { 1886 auto function = cast<FuncOp>(op.getParentOp()); 1887 1888 // The operand number and types must match the function signature. 1889 const auto &results = function.getType().getResults(); 1890 if (op.getNumOperands() != results.size()) 1891 return op.emitOpError("has ") 1892 << op.getNumOperands() 1893 << " operands, but enclosing function returns " << results.size(); 1894 1895 for (unsigned i = 0, e = results.size(); i != e; ++i) 1896 if (op.getOperand(i)->getType() != results[i]) 1897 return op.emitError() 1898 << "type of return operand " << i << " (" 1899 << op.getOperand(i)->getType() 1900 << ") doesn't match function result type (" << results[i] << ")"; 1901 1902 return success(); 1903 } 1904 1905 //===----------------------------------------------------------------------===// 1906 // SIToFPOp 1907 //===----------------------------------------------------------------------===// 1908 1909 // sitofp is applicable from integer types to float types. 1910 bool SIToFPOp::areCastCompatible(Type a, Type b) { 1911 return a.isa<IntegerType>() && b.isa<FloatType>(); 1912 } 1913 1914 //===----------------------------------------------------------------------===// 1915 // SelectOp 1916 //===----------------------------------------------------------------------===// 1917 1918 static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { 1919 SmallVector<OpAsmParser::OperandType, 3> ops; 1920 SmallVector<NamedAttribute, 4> attrs; 1921 Type type; 1922 if (parser->parseOperandList(ops, 3) || 1923 parser->parseOptionalAttributeDict(result->attributes) || 1924 parser->parseColonType(type)) 1925 return failure(); 1926 1927 auto i1Type = getCheckedI1SameShape(&parser->getBuilder(), type); 1928 if (!i1Type) 1929 return parser->emitError(parser->getNameLoc(), 1930 "expected type with valid i1 shape"); 1931 1932 SmallVector<Type, 3> types = {i1Type, type, type}; 1933 return failure(parser->resolveOperands(ops, types, parser->getNameLoc(), 1934 result->operands) || 1935 parser->addTypeToList(type, result->types)); 1936 } 1937 1938 static void print(OpAsmPrinter *p, SelectOp op) { 1939 *p << "select "; 1940 p->printOperands(op.getOperands()); 1941 *p << " : " << op.getTrueValue()->getType(); 1942 p->printOptionalAttrDict(op.getAttrs()); 1943 } 1944 1945 static LogicalResult verify(SelectOp op) { 1946 auto trueType = op.getTrueValue()->getType(); 1947 auto falseType = op.getFalseValue()->getType(); 1948 1949 if (trueType != falseType) 1950 return op.emitOpError( 1951 "requires 'true' and 'false' arguments to be of the same type"); 1952 1953 return success(); 1954 } 1955 1956 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) { 1957 auto *condition = getCondition(); 1958 1959 // select true, %0, %1 => %0 1960 if (matchPattern(condition, m_One())) 1961 return getTrueValue(); 1962 1963 // select false, %0, %1 => %1 1964 if (matchPattern(condition, m_Zero())) 1965 return getFalseValue(); 1966 return nullptr; 1967 } 1968 1969 //===----------------------------------------------------------------------===// 1970 // StoreOp 1971 //===----------------------------------------------------------------------===// 1972 1973 static void print(OpAsmPrinter *p, StoreOp op) { 1974 *p << "store " << *op.getValueToStore(); 1975 *p << ", " << *op.getMemRef() << '['; 1976 p->printOperands(op.getIndices()); 1977 *p << ']'; 1978 p->printOptionalAttrDict(op.getAttrs()); 1979 *p << " : " << op.getMemRefType(); 1980 } 1981 1982 static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { 1983 OpAsmParser::OperandType storeValueInfo; 1984 OpAsmParser::OperandType memrefInfo; 1985 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 1986 MemRefType memrefType; 1987 1988 auto affineIntTy = parser->getBuilder().getIndexType(); 1989 return failure( 1990 parser->parseOperand(storeValueInfo) || parser->parseComma() || 1991 parser->parseOperand(memrefInfo) || 1992 parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 1993 parser->parseOptionalAttributeDict(result->attributes) || 1994 parser->parseColonType(memrefType) || 1995 parser->resolveOperand(storeValueInfo, memrefType.getElementType(), 1996 result->operands) || 1997 parser->resolveOperand(memrefInfo, memrefType, result->operands) || 1998 parser->resolveOperands(indexInfo, affineIntTy, result->operands)); 1999 } 2000 2001 static LogicalResult verify(StoreOp op) { 2002 // First operand must have same type as memref element type. 2003 if (op.getValueToStore()->getType() != op.getMemRefType().getElementType()) 2004 return op.emitOpError( 2005 "first operand must have same type memref element type"); 2006 2007 if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) 2008 return op.emitOpError("store index operand count not equal to memref rank"); 2009 2010 for (auto *idx : op.getIndices()) 2011 if (!idx->getType().isIndex()) 2012 return op.emitOpError("index to load must have 'index' type"); 2013 2014 // TODO: Verify we have the right number of indices. 2015 2016 // TODO: in Function verify that the indices are parameters, IV's, or the 2017 // result of an affine.apply. 2018 return success(); 2019 } 2020 2021 void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 2022 MLIRContext *context) { 2023 /// store(memrefcast) -> store 2024 results.insert<MemRefCastFolder>(getOperationName(), context); 2025 } 2026 2027 //===----------------------------------------------------------------------===// 2028 // SubFOp 2029 //===----------------------------------------------------------------------===// 2030 2031 OpFoldResult SubFOp::fold(ArrayRef<Attribute> operands) { 2032 return constFoldBinaryOp<FloatAttr>( 2033 operands, [](APFloat a, APFloat b) { return a - b; }); 2034 } 2035 2036 //===----------------------------------------------------------------------===// 2037 // SubIOp 2038 //===----------------------------------------------------------------------===// 2039 2040 OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) { 2041 // subi(x,x) -> 0 2042 if (getOperand(0) == getOperand(1)) 2043 return Builder(getContext()).getZeroAttr(getType()); 2044 2045 return constFoldBinaryOp<IntegerAttr>(operands, 2046 [](APInt a, APInt b) { return a - b; }); 2047 } 2048 2049 //===----------------------------------------------------------------------===// 2050 // AndOp 2051 //===----------------------------------------------------------------------===// 2052 2053 OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) { 2054 /// and(x, 0) -> 0 2055 if (matchPattern(rhs(), m_Zero())) 2056 return rhs(); 2057 /// and(x,x) -> x 2058 if (lhs() == rhs()) 2059 return rhs(); 2060 2061 return constFoldBinaryOp<IntegerAttr>(operands, 2062 [](APInt a, APInt b) { return a & b; }); 2063 } 2064 2065 //===----------------------------------------------------------------------===// 2066 // OrOp 2067 //===----------------------------------------------------------------------===// 2068 2069 OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) { 2070 /// or(x, 0) -> x 2071 if (matchPattern(rhs(), m_Zero())) 2072 return lhs(); 2073 /// or(x,x) -> x 2074 if (lhs() == rhs()) 2075 return rhs(); 2076 2077 return constFoldBinaryOp<IntegerAttr>(operands, 2078 [](APInt a, APInt b) { return a | b; }); 2079 } 2080 2081 //===----------------------------------------------------------------------===// 2082 // XOrOp 2083 //===----------------------------------------------------------------------===// 2084 2085 OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) { 2086 /// xor(x, 0) -> x 2087 if (matchPattern(rhs(), m_Zero())) 2088 return lhs(); 2089 /// xor(x,x) -> 0 2090 if (lhs() == rhs()) 2091 return Builder(getContext()).getZeroAttr(getType()); 2092 2093 return constFoldBinaryOp<IntegerAttr>(operands, 2094 [](APInt a, APInt b) { return a ^ b; }); 2095 } 2096 2097 //===----------------------------------------------------------------------===// 2098 // TensorCastOp 2099 //===----------------------------------------------------------------------===// 2100 2101 bool TensorCastOp::areCastCompatible(Type a, Type b) { 2102 auto aT = a.dyn_cast<TensorType>(); 2103 auto bT = b.dyn_cast<TensorType>(); 2104 if (!aT || !bT) 2105 return false; 2106 2107 if (aT.getElementType() != bT.getElementType()) 2108 return false; 2109 2110 // If the either are unranked, then the cast is valid. 2111 auto aRType = aT.dyn_cast<RankedTensorType>(); 2112 auto bRType = bT.dyn_cast<RankedTensorType>(); 2113 if (!aRType || !bRType) 2114 return true; 2115 2116 // If they are both ranked, they have to have the same rank, and any specified 2117 // dimensions must match. 2118 if (aRType.getRank() != bRType.getRank()) 2119 return false; 2120 2121 for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) { 2122 int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i); 2123 if (aDim != -1 && bDim != -1 && aDim != bDim) 2124 return false; 2125 } 2126 2127 return true; 2128 } 2129 2130 OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) { 2131 return impl::foldCastOp(*this); 2132 } 2133 2134 //===----------------------------------------------------------------------===// 2135 // Helpers for Tensor[Load|Store]Op 2136 //===----------------------------------------------------------------------===// 2137 2138 static Type getTensorTypeFromMemRefType(Builder &b, Type type) { 2139 if (auto memref = type.dyn_cast<MemRefType>()) 2140 return b.getTensorType(memref.getShape(), memref.getElementType()); 2141 return b.getNoneType(); 2142 } 2143 2144 //===----------------------------------------------------------------------===// 2145 // TensorLoadOp 2146 //===----------------------------------------------------------------------===// 2147 2148 static void print(OpAsmPrinter *p, TensorLoadOp op) { 2149 *p << "tensor_load " << *op.getOperand(); 2150 p->printOptionalAttrDict(op.getAttrs()); 2151 *p << " : " << op.getOperand()->getType(); 2152 } 2153 2154 static ParseResult parseTensorLoadOp(OpAsmParser *parser, 2155 OperationState *result) { 2156 OpAsmParser::OperandType op; 2157 Type type; 2158 return failure(parser->parseOperand(op) || 2159 parser->parseOptionalAttributeDict(result->attributes) || 2160 parser->parseColonType(type) || 2161 parser->resolveOperand(op, type, result->operands) || 2162 parser->addTypeToList( 2163 getTensorTypeFromMemRefType(parser->getBuilder(), type), 2164 result->types)); 2165 } 2166 2167 //===----------------------------------------------------------------------===// 2168 // TensorStoreOp 2169 //===----------------------------------------------------------------------===// 2170 2171 static void print(OpAsmPrinter *p, TensorStoreOp op) { 2172 *p << "tensor_store " << *op.tensor() << ", " << *op.memref(); 2173 p->printOptionalAttrDict(op.getAttrs()); 2174 *p << " : " << op.memref()->getType(); 2175 } 2176 2177 static ParseResult parseTensorStoreOp(OpAsmParser *parser, 2178 OperationState *result) { 2179 SmallVector<OpAsmParser::OperandType, 2> ops; 2180 Type type; 2181 llvm::SMLoc loc = parser->getCurrentLocation(); 2182 return failure( 2183 parser->parseOperandList(ops, /*requiredOperandCount=*/2) || 2184 parser->parseOptionalAttributeDict(result->attributes) || 2185 parser->parseColonType(type) || 2186 parser->resolveOperands( 2187 ops, {getTensorTypeFromMemRefType(parser->getBuilder(), type), type}, 2188 loc, result->operands)); 2189 } 2190 2191 //===----------------------------------------------------------------------===// 2192 // TableGen'd op method definitions 2193 //===----------------------------------------------------------------------===// 2194 2195 #define GET_OP_CLASSES 2196 #include "mlir/Dialect/StandardOps/Ops.cpp.inc"