github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp (about) 1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/AffineOps/AffineOps.h" 19 #include "mlir/Dialect/StandardOps/Ops.h" 20 #include "mlir/IR/Block.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/IR/IntegerSet.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallBitVector.h" 29 #include "llvm/Support/Debug.h" 30 using namespace mlir; 31 using llvm::dbgs; 32 33 #define DEBUG_TYPE "affine-analysis" 34 35 //===----------------------------------------------------------------------===// 36 // AffineOpsDialect 37 //===----------------------------------------------------------------------===// 38 39 AffineOpsDialect::AffineOpsDialect(MLIRContext *context) 40 : Dialect(getDialectNamespace(), context) { 41 addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp, 42 AffineStoreOp, 43 #define GET_OP_LIST 44 #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc" 45 >(); 46 } 47 48 /// A utility function to check if a given region is attached to a function. 49 static bool isFunctionRegion(Region *region) { 50 return llvm::isa<FuncOp>(region->getParentOp()); 51 } 52 53 /// A utility function to check if a value is defined at the top level of a 54 /// function. A value defined at the top level is always a valid symbol. 55 bool mlir::isTopLevelSymbol(Value *value) { 56 if (auto *arg = dyn_cast<BlockArgument>(value)) 57 return isFunctionRegion(arg->getOwner()->getParent()); 58 return isFunctionRegion(value->getDefiningOp()->getParentRegion()); 59 } 60 61 // Value can be used as a dimension id if it is valid as a symbol, or 62 // it is an induction variable, or it is a result of affine apply operation 63 // with dimension id arguments. 64 bool mlir::isValidDim(Value *value) { 65 // The value must be an index type. 66 if (!value->getType().isIndex()) 67 return false; 68 69 if (auto *op = value->getDefiningOp()) { 70 // Top level operation or constant operation is ok. 71 if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op)) 72 return true; 73 // Affine apply operation is ok if all of its operands are ok. 74 if (auto applyOp = dyn_cast<AffineApplyOp>(op)) 75 return applyOp.isValidDim(); 76 // The dim op is okay if its operand memref/tensor is defined at the top 77 // level. 78 if (auto dimOp = dyn_cast<DimOp>(op)) 79 return isTopLevelSymbol(dimOp.getOperand()); 80 return false; 81 } 82 // This value is a block argument (which also includes 'affine.for' loop IVs). 83 return true; 84 } 85 86 // Value can be used as a symbol if it is a constant, or it is defined at 87 // the top level, or it is a result of affine apply operation with symbol 88 // arguments. 89 bool mlir::isValidSymbol(Value *value) { 90 // The value must be an index type. 91 if (!value->getType().isIndex()) 92 return false; 93 94 if (auto *op = value->getDefiningOp()) { 95 // Top level operation or constant operation is ok. 96 if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op)) 97 return true; 98 // Affine apply operation is ok if all of its operands are ok. 99 if (auto applyOp = dyn_cast<AffineApplyOp>(op)) 100 return applyOp.isValidSymbol(); 101 // The dim op is okay if its operand memref/tensor is defined at the top 102 // level. 103 if (auto dimOp = dyn_cast<DimOp>(op)) 104 return isTopLevelSymbol(dimOp.getOperand()); 105 return false; 106 } 107 // Otherwise, check that the value is a top level symbol. 108 return isTopLevelSymbol(value); 109 } 110 111 // Returns true if 'value' is a valid index to an affine operation (e.g. 112 // affine.load, affine.store, affine.dma_start, affine.dma_wait). 113 // Returns false otherwise. 114 static bool isValidAffineIndexOperand(Value *value) { 115 return isValidDim(value) || isValidSymbol(value); 116 } 117 118 /// Utility function to verify that a set of operands are valid dimension and 119 /// symbol identifiers. The operands should be layed out such that the dimension 120 /// operands are before the symbol operands. This function returns failure if 121 /// there was an invalid operand. An operation is provided to emit any necessary 122 /// errors. 123 template <typename OpTy> 124 static LogicalResult 125 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, 126 unsigned numDims) { 127 unsigned opIt = 0; 128 for (auto *operand : operands) { 129 if (opIt++ < numDims) { 130 if (!isValidDim(operand)) 131 return op.emitOpError("operand cannot be used as a dimension id"); 132 } else if (!isValidSymbol(operand)) { 133 return op.emitOpError("operand cannot be used as a symbol"); 134 } 135 } 136 return success(); 137 } 138 139 //===----------------------------------------------------------------------===// 140 // AffineApplyOp 141 //===----------------------------------------------------------------------===// 142 143 void AffineApplyOp::build(Builder *builder, OperationState *result, 144 AffineMap map, ArrayRef<Value *> operands) { 145 result->addOperands(operands); 146 result->types.append(map.getNumResults(), builder->getIndexType()); 147 result->addAttribute("map", builder->getAffineMapAttr(map)); 148 } 149 150 ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { 151 auto &builder = parser->getBuilder(); 152 auto affineIntTy = builder.getIndexType(); 153 154 AffineMapAttr mapAttr; 155 unsigned numDims; 156 if (parser->parseAttribute(mapAttr, "map", result->attributes) || 157 parseDimAndSymbolList(parser, result->operands, numDims) || 158 parser->parseOptionalAttributeDict(result->attributes)) 159 return failure(); 160 auto map = mapAttr.getValue(); 161 162 if (map.getNumDims() != numDims || 163 numDims + map.getNumSymbols() != result->operands.size()) { 164 return parser->emitError(parser->getNameLoc(), 165 "dimension or symbol index mismatch"); 166 } 167 168 result->types.append(map.getNumResults(), affineIntTy); 169 return success(); 170 } 171 172 void AffineApplyOp::print(OpAsmPrinter *p) { 173 *p << "affine.apply " << getAttr("map"); 174 printDimAndSymbolList(operand_begin(), operand_end(), 175 getAffineMap().getNumDims(), p); 176 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"}); 177 } 178 179 LogicalResult AffineApplyOp::verify() { 180 // Check that affine map attribute was specified. 181 auto affineMapAttr = getAttrOfType<AffineMapAttr>("map"); 182 if (!affineMapAttr) 183 return emitOpError("requires an affine map"); 184 185 // Check input and output dimensions match. 186 auto map = affineMapAttr.getValue(); 187 188 // Verify that operand count matches affine map dimension and symbol count. 189 if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) 190 return emitOpError( 191 "operand count and affine map dimension and symbol count must match"); 192 193 // Verify that all operands are of `index` type. 194 for (Type t : getOperandTypes()) { 195 if (!t.isIndex()) 196 return emitOpError("operands must be of type 'index'"); 197 } 198 199 if (!getResult()->getType().isIndex()) 200 return emitOpError("result must be of type 'index'"); 201 202 // Verify that the operands are valid dimension and symbol identifiers. 203 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), 204 map.getNumDims()))) 205 return failure(); 206 207 // Verify that the map only produces one result. 208 if (map.getNumResults() != 1) 209 return emitOpError("mapping must produce one value"); 210 211 return success(); 212 } 213 214 // The result of the affine apply operation can be used as a dimension id if it 215 // is a CFG value or if it is an Value, and all the operands are valid 216 // dimension ids. 217 bool AffineApplyOp::isValidDim() { 218 return llvm::all_of(getOperands(), 219 [](Value *op) { return mlir::isValidDim(op); }); 220 } 221 222 // The result of the affine apply operation can be used as a symbol if it is 223 // a CFG value or if it is an Value, and all the operands are symbols. 224 bool AffineApplyOp::isValidSymbol() { 225 return llvm::all_of(getOperands(), 226 [](Value *op) { return mlir::isValidSymbol(op); }); 227 } 228 229 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) { 230 auto map = getAffineMap(); 231 232 // Fold dims and symbols to existing values. 233 auto expr = map.getResult(0); 234 if (auto dim = expr.dyn_cast<AffineDimExpr>()) 235 return getOperand(dim.getPosition()); 236 if (auto sym = expr.dyn_cast<AffineSymbolExpr>()) 237 return getOperand(map.getNumDims() + sym.getPosition()); 238 239 // Otherwise, default to folding the map. 240 SmallVector<Attribute, 1> result; 241 if (failed(map.constantFold(operands, result))) 242 return {}; 243 return result[0]; 244 } 245 246 namespace { 247 /// An `AffineApplyNormalizer` is a helper class that is not visible to the user 248 /// and supports renumbering operands of AffineApplyOp. This acts as a 249 /// reindexing map of Value* to positional dims or symbols and allows 250 /// simplifications such as: 251 /// 252 /// ```mlir 253 /// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) 254 /// ``` 255 /// 256 /// into: 257 /// 258 /// ```mlir 259 /// %1 = affine.apply () -> (0) 260 /// ``` 261 struct AffineApplyNormalizer { 262 AffineApplyNormalizer(AffineMap map, ArrayRef<Value *> operands); 263 264 /// Returns the AffineMap resulting from normalization. 265 AffineMap getAffineMap() { return affineMap; } 266 267 SmallVector<Value *, 8> getOperands() { 268 SmallVector<Value *, 8> res(reorderedDims); 269 res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); 270 return res; 271 } 272 273 private: 274 /// Helper function to insert `v` into the coordinate system of the current 275 /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding 276 /// renumbered position. 277 AffineDimExpr renumberOneDim(Value *v); 278 279 /// Given an `other` normalizer, this rewrites `other.affineMap` in the 280 /// coordinate system of the current AffineApplyNormalizer. 281 /// Returns the rewritten AffineMap and updates the dims and symbols of 282 /// `this`. 283 AffineMap renumber(const AffineApplyNormalizer &other); 284 285 /// Maps of Value* to position in `affineMap`. 286 DenseMap<Value *, unsigned> dimValueToPosition; 287 288 /// Ordered dims and symbols matching positional dims and symbols in 289 /// `affineMap`. 290 SmallVector<Value *, 8> reorderedDims; 291 SmallVector<Value *, 8> concatenatedSymbols; 292 293 AffineMap affineMap; 294 295 /// Used with RAII to control the depth at which AffineApply are composed 296 /// recursively. Only accepts depth 1 for now to allow a behavior where a 297 /// newly composed AffineApplyOp does not increase the length of the chain of 298 /// AffineApplyOps. Full composition is implemented iteratively on top of 299 /// this behavior. 300 static unsigned &affineApplyDepth() { 301 static thread_local unsigned depth = 0; 302 return depth; 303 } 304 static constexpr unsigned kMaxAffineApplyDepth = 1; 305 306 AffineApplyNormalizer() { affineApplyDepth()++; } 307 308 public: 309 ~AffineApplyNormalizer() { affineApplyDepth()--; } 310 }; 311 } // end anonymous namespace. 312 313 AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { 314 DenseMap<Value *, unsigned>::iterator iterPos; 315 bool inserted = false; 316 std::tie(iterPos, inserted) = 317 dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); 318 if (inserted) { 319 reorderedDims.push_back(v); 320 } 321 return getAffineDimExpr(iterPos->second, v->getContext()) 322 .cast<AffineDimExpr>(); 323 } 324 325 AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { 326 SmallVector<AffineExpr, 8> dimRemapping; 327 for (auto *v : other.reorderedDims) { 328 auto kvp = other.dimValueToPosition.find(v); 329 if (dimRemapping.size() <= kvp->second) 330 dimRemapping.resize(kvp->second + 1); 331 dimRemapping[kvp->second] = renumberOneDim(kvp->first); 332 } 333 unsigned numSymbols = concatenatedSymbols.size(); 334 unsigned numOtherSymbols = other.concatenatedSymbols.size(); 335 SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols); 336 for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { 337 symRemapping[idx] = 338 getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); 339 } 340 concatenatedSymbols.insert(concatenatedSymbols.end(), 341 other.concatenatedSymbols.begin(), 342 other.concatenatedSymbols.end()); 343 auto map = other.affineMap; 344 return map.replaceDimsAndSymbols(dimRemapping, symRemapping, 345 dimRemapping.size(), symRemapping.size()); 346 } 347 348 // Gather the positions of the operands that are produced by an AffineApplyOp. 349 static llvm::SetVector<unsigned> 350 indicesFromAffineApplyOp(ArrayRef<Value *> operands) { 351 llvm::SetVector<unsigned> res; 352 for (auto en : llvm::enumerate(operands)) 353 if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp())) 354 res.insert(en.index()); 355 return res; 356 } 357 358 // Support the special case of a symbol coming from an AffineApplyOp that needs 359 // to be composed into the current AffineApplyOp. 360 // This case is handled by rewriting all such symbols into dims for the purpose 361 // of allowing mathematical AffineMap composition. 362 // Returns an AffineMap where symbols that come from an AffineApplyOp have been 363 // rewritten as dims and are ordered after the original dims. 364 // TODO(andydavis,ntv): This promotion makes AffineMap lose track of which 365 // symbols are represented as dims. This loss is static but can still be 366 // recovered dynamically (with `isValidSymbol`). Still this is annoying for the 367 // semi-affine map case. A dynamic canonicalization of all dims that are valid 368 // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even 369 // results in better simplifications and foldings. But we should evaluate 370 // whether this behavior is what we really want after using more. 371 static AffineMap promoteComposedSymbolsAsDims(AffineMap map, 372 ArrayRef<Value *> symbols) { 373 if (symbols.empty()) { 374 return map; 375 } 376 377 // Sanity check on symbols. 378 for (auto *sym : symbols) { 379 assert(isValidSymbol(sym) && "Expected only valid symbols"); 380 (void)sym; 381 } 382 383 // Extract the symbol positions that come from an AffineApplyOp and 384 // needs to be rewritten as dims. 385 auto symPositions = indicesFromAffineApplyOp(symbols); 386 if (symPositions.empty()) { 387 return map; 388 } 389 390 // Create the new map by replacing each symbol at pos by the next new dim. 391 unsigned numDims = map.getNumDims(); 392 unsigned numSymbols = map.getNumSymbols(); 393 unsigned numNewDims = 0; 394 unsigned numNewSymbols = 0; 395 SmallVector<AffineExpr, 8> symReplacements(numSymbols); 396 for (unsigned i = 0; i < numSymbols; ++i) { 397 symReplacements[i] = 398 symPositions.count(i) > 0 399 ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) 400 : getAffineSymbolExpr(numNewSymbols++, map.getContext()); 401 } 402 assert(numSymbols >= numNewDims); 403 AffineMap newMap = map.replaceDimsAndSymbols( 404 {}, symReplacements, numDims + numNewDims, numNewSymbols); 405 406 return newMap; 407 } 408 409 /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to 410 /// keep a correspondence between the mathematical `map` and the `operands` of 411 /// a given AffineApplyOp. This correspondence is maintained by iterating over 412 /// the operands and forming an `auxiliaryMap` that can be composed 413 /// mathematically with `map`. To keep this correspondence in cases where 414 /// symbols are produced by affine.apply operations, we perform a local rewrite 415 /// of symbols as dims. 416 /// 417 /// Rationale for locally rewriting symbols as dims: 418 /// ================================================ 419 /// The mathematical composition of AffineMap must always concatenate symbols 420 /// because it does not have enough information to do otherwise. For example, 421 /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce 422 /// `(d0)[s0, s1] -> (d0 + s0 + s1)`. 423 /// 424 /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when 425 /// applied to the same mlir::Value* for both s0 and s1. 426 /// As a consequence mathematical composition of AffineMap always concatenates 427 /// symbols. 428 /// 429 /// When AffineMaps are used in AffineApplyOp however, they may specify 430 /// composition via symbols, which is ambiguous mathematically. This corner case 431 /// is handled by locally rewriting such symbols that come from AffineApplyOp 432 /// into dims and composing through dims. 433 /// TODO(andydavis, ntv): Composition via symbols comes at a significant code 434 /// complexity. Alternatively we should investigate whether we want to 435 /// explicitly disallow symbols coming from affine.apply and instead force the 436 /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 437 /// extra API calls for such uses, which haven't popped up until now) and the 438 /// benefit potentially big: simpler and more maintainable code for a 439 /// non-trivial, recursive, procedure. 440 AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, 441 ArrayRef<Value *> operands) 442 : AffineApplyNormalizer() { 443 static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); 444 assert(map.getNumInputs() == operands.size() && 445 "number of operands does not match the number of map inputs"); 446 447 LLVM_DEBUG(map.print(dbgs() << "\nInput map: ")); 448 449 // Promote symbols that come from an AffineApplyOp to dims by rewriting the 450 // map to always refer to: 451 // (dims, symbols coming from AffineApplyOp, other symbols). 452 // The order of operands can remain unchanged. 453 // This is a simplification that relies on 2 ordering properties: 454 // 1. rewritten symbols always appear after the original dims in the map; 455 // 2. operands are traversed in order and either dispatched to: 456 // a. auxiliaryExprs (dims and symbols rewritten as dims); 457 // b. concatenatedSymbols (all other symbols) 458 // This allows operand order to remain unchanged. 459 unsigned numDimsBeforeRewrite = map.getNumDims(); 460 map = promoteComposedSymbolsAsDims(map, 461 operands.take_back(map.getNumSymbols())); 462 463 LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: ")); 464 465 SmallVector<AffineExpr, 8> auxiliaryExprs; 466 bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); 467 // We fully spell out the 2 cases below. In this particular instance a little 468 // code duplication greatly improves readability. 469 // Note that the first branch would disappear if we only supported full 470 // composition (i.e. infinite kMaxAffineApplyDepth). 471 if (!furtherCompose) { 472 // 1. Only dispatch dims or symbols. 473 for (auto en : llvm::enumerate(operands)) { 474 auto *t = en.value(); 475 assert(t->getType().isIndex()); 476 bool isDim = (en.index() < map.getNumDims()); 477 if (isDim) { 478 // a. The mathematical composition of AffineMap composes dims. 479 auxiliaryExprs.push_back(renumberOneDim(t)); 480 } else { 481 // b. The mathematical composition of AffineMap concatenates symbols. 482 // We do the same for symbol operands. 483 concatenatedSymbols.push_back(t); 484 } 485 } 486 } else { 487 assert(numDimsBeforeRewrite <= operands.size()); 488 // 2. Compose AffineApplyOps and dispatch dims or symbols. 489 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 490 auto *t = operands[i]; 491 auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp()); 492 if (affineApply) { 493 // a. Compose affine.apply operations. 494 LLVM_DEBUG(affineApply.getOperation()->print( 495 dbgs() << "\nCompose AffineApplyOp recursively: ")); 496 AffineMap affineApplyMap = affineApply.getAffineMap(); 497 SmallVector<Value *, 8> affineApplyOperands( 498 affineApply.getOperands().begin(), affineApply.getOperands().end()); 499 AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); 500 501 LLVM_DEBUG(normalizer.affineMap.print( 502 dbgs() << "\nRenumber into current normalizer: ")); 503 504 auto renumberedMap = renumber(normalizer); 505 506 LLVM_DEBUG( 507 renumberedMap.print(dbgs() << "\nRecursive composition yields: ")); 508 509 auxiliaryExprs.push_back(renumberedMap.getResult(0)); 510 } else { 511 if (i < numDimsBeforeRewrite) { 512 // b. The mathematical composition of AffineMap composes dims. 513 auxiliaryExprs.push_back(renumberOneDim(t)); 514 } else { 515 // c. The mathematical composition of AffineMap concatenates symbols. 516 // We do the same for symbol operands. 517 concatenatedSymbols.push_back(t); 518 } 519 } 520 } 521 } 522 523 // Early exit if `map` is already composed. 524 if (auxiliaryExprs.empty()) { 525 affineMap = map; 526 return; 527 } 528 529 assert(concatenatedSymbols.size() >= map.getNumSymbols() && 530 "Unexpected number of concatenated symbols"); 531 auto numDims = dimValueToPosition.size(); 532 auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); 533 auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs); 534 535 LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); 536 LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); 537 LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: ")); 538 539 // TODO(andydavis,ntv): Disabling simplification results in major speed gains. 540 // Another option is to cache the results as it is expected a lot of redundant 541 // work is performed in practice. 542 affineMap = simplifyAffineMap(map.compose(auxiliaryMap)); 543 544 LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); 545 LLVM_DEBUG(dbgs() << "\n"); 546 } 547 548 /// Implements `map` and `operands` composition and simplification to support 549 /// `makeComposedAffineApply`. This can be called to achieve the same effects 550 /// on `map` and `operands` without creating an AffineApplyOp that needs to be 551 /// immediately deleted. 552 static void composeAffineMapAndOperands(AffineMap *map, 553 SmallVectorImpl<Value *> *operands) { 554 AffineApplyNormalizer normalizer(*map, *operands); 555 auto normalizedMap = normalizer.getAffineMap(); 556 auto normalizedOperands = normalizer.getOperands(); 557 canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); 558 *map = normalizedMap; 559 *operands = normalizedOperands; 560 assert(*map); 561 } 562 563 void mlir::fullyComposeAffineMapAndOperands( 564 AffineMap *map, SmallVectorImpl<Value *> *operands) { 565 while (llvm::any_of(*operands, [](Value *v) { 566 return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp()); 567 })) { 568 composeAffineMapAndOperands(map, operands); 569 } 570 } 571 572 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, 573 AffineMap map, 574 ArrayRef<Value *> operands) { 575 AffineMap normalizedMap = map; 576 SmallVector<Value *, 8> normalizedOperands(operands.begin(), operands.end()); 577 composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); 578 assert(normalizedMap); 579 return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands); 580 } 581 582 // A symbol may appear as a dim in affine.apply operations. This function 583 // canonicalizes dims that are valid symbols into actual symbols. 584 static void 585 canonicalizePromotedSymbols(AffineMap *map, 586 llvm::SmallVectorImpl<Value *> *operands) { 587 if (!map || operands->empty()) 588 return; 589 590 assert(map->getNumInputs() == operands->size() && 591 "map inputs must match number of operands"); 592 593 auto *context = map->getContext(); 594 SmallVector<Value *, 8> resultOperands; 595 resultOperands.reserve(operands->size()); 596 SmallVector<Value *, 8> remappedSymbols; 597 remappedSymbols.reserve(operands->size()); 598 unsigned nextDim = 0; 599 unsigned nextSym = 0; 600 unsigned oldNumSyms = map->getNumSymbols(); 601 SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims()); 602 for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) { 603 if (i < map->getNumDims()) { 604 if (isValidSymbol((*operands)[i])) { 605 // This is a valid symbol that appears as a dim, canonicalize it. 606 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); 607 remappedSymbols.push_back((*operands)[i]); 608 } else { 609 dimRemapping[i] = getAffineDimExpr(nextDim++, context); 610 resultOperands.push_back((*operands)[i]); 611 } 612 } else { 613 resultOperands.push_back((*operands)[i]); 614 } 615 } 616 617 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); 618 *operands = resultOperands; 619 *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim, 620 oldNumSyms + nextSym); 621 622 assert(map->getNumInputs() == operands->size() && 623 "map inputs must match number of operands"); 624 } 625 626 void mlir::canonicalizeMapAndOperands( 627 AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) { 628 if (!map || operands->empty()) 629 return; 630 631 assert(map->getNumInputs() == operands->size() && 632 "map inputs must match number of operands"); 633 634 canonicalizePromotedSymbols(map, operands); 635 636 // Check to see what dims are used. 637 llvm::SmallBitVector usedDims(map->getNumDims()); 638 llvm::SmallBitVector usedSyms(map->getNumSymbols()); 639 map->walkExprs([&](AffineExpr expr) { 640 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) 641 usedDims[dimExpr.getPosition()] = true; 642 else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) 643 usedSyms[symExpr.getPosition()] = true; 644 }); 645 646 auto *context = map->getContext(); 647 648 SmallVector<Value *, 8> resultOperands; 649 resultOperands.reserve(operands->size()); 650 651 llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims; 652 SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims()); 653 unsigned nextDim = 0; 654 for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) { 655 if (usedDims[i]) { 656 // Remap dim positions for duplicate operands. 657 auto it = seenDims.find((*operands)[i]); 658 if (it == seenDims.end()) { 659 dimRemapping[i] = getAffineDimExpr(nextDim++, context); 660 resultOperands.push_back((*operands)[i]); 661 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); 662 } else { 663 dimRemapping[i] = it->second; 664 } 665 } 666 } 667 llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols; 668 SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols()); 669 unsigned nextSym = 0; 670 for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) { 671 if (!usedSyms[i]) 672 continue; 673 // Handle constant operands (only needed for symbolic operands since 674 // constant operands in dimensional positions would have already been 675 // promoted to symbolic positions above). 676 IntegerAttr operandCst; 677 if (matchPattern((*operands)[i + map->getNumDims()], 678 m_Constant(&operandCst))) { 679 symRemapping[i] = 680 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); 681 continue; 682 } 683 // Remap symbol positions for duplicate operands. 684 auto it = seenSymbols.find((*operands)[i + map->getNumDims()]); 685 if (it == seenSymbols.end()) { 686 symRemapping[i] = getAffineSymbolExpr(nextSym++, context); 687 resultOperands.push_back((*operands)[i + map->getNumDims()]); 688 seenSymbols.insert( 689 std::make_pair((*operands)[i + map->getNumDims()], symRemapping[i])); 690 } else { 691 symRemapping[i] = it->second; 692 } 693 } 694 *map = 695 map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); 696 *operands = resultOperands; 697 } 698 699 namespace { 700 /// Simplify AffineApply operations. 701 /// 702 struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> { 703 using OpRewritePattern<AffineApplyOp>::OpRewritePattern; 704 705 PatternMatchResult matchAndRewrite(AffineApplyOp apply, 706 PatternRewriter &rewriter) const override { 707 auto map = apply.getAffineMap(); 708 709 AffineMap oldMap = map; 710 SmallVector<Value *, 8> resultOperands(apply.getOperands()); 711 composeAffineMapAndOperands(&map, &resultOperands); 712 if (map == oldMap) 713 return matchFailure(); 714 715 rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands); 716 return matchSuccess(); 717 } 718 }; 719 } // end anonymous namespace. 720 721 void AffineApplyOp::getCanonicalizationPatterns( 722 OwningRewritePatternList &results, MLIRContext *context) { 723 results.insert<SimplifyAffineApply>(context); 724 } 725 726 //===----------------------------------------------------------------------===// 727 // Common canonicalization pattern support logic 728 //===----------------------------------------------------------------------===// 729 730 namespace { 731 /// This is a common class used for patterns of the form 732 /// "someop(memrefcast) -> someop". It folds the source of any memref_cast 733 /// into the root operation directly. 734 struct MemRefCastFolder : public RewritePattern { 735 /// The rootOpName is the name of the root operation to match against. 736 MemRefCastFolder(StringRef rootOpName, MLIRContext *context) 737 : RewritePattern(rootOpName, 1, context) {} 738 739 PatternMatchResult match(Operation *op) const override { 740 for (auto *operand : op->getOperands()) 741 if (matchPattern(operand, m_Op<MemRefCastOp>())) 742 return matchSuccess(); 743 744 return matchFailure(); 745 } 746 747 void rewrite(Operation *op, PatternRewriter &rewriter) const override { 748 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 749 if (auto *memref = op->getOperand(i)->getDefiningOp()) 750 if (auto cast = dyn_cast<MemRefCastOp>(memref)) 751 op->setOperand(i, cast.getOperand()); 752 rewriter.updatedRootInPlace(op); 753 } 754 }; 755 756 } // end anonymous namespace. 757 758 //===----------------------------------------------------------------------===// 759 // AffineDmaStartOp 760 //===----------------------------------------------------------------------===// 761 762 // TODO(b/133776335) Check that map operands are loop IVs or symbols. 763 void AffineDmaStartOp::build(Builder *builder, OperationState *result, 764 Value *srcMemRef, AffineMap srcMap, 765 ArrayRef<Value *> srcIndices, Value *destMemRef, 766 AffineMap dstMap, ArrayRef<Value *> destIndices, 767 Value *tagMemRef, AffineMap tagMap, 768 ArrayRef<Value *> tagIndices, Value *numElements, 769 Value *stride, Value *elementsPerStride) { 770 result->addOperands(srcMemRef); 771 result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap)); 772 result->addOperands(srcIndices); 773 result->addOperands(destMemRef); 774 result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap)); 775 result->addOperands(destIndices); 776 result->addOperands(tagMemRef); 777 result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); 778 result->addOperands(tagIndices); 779 result->addOperands(numElements); 780 if (stride) { 781 result->addOperands({stride, elementsPerStride}); 782 } 783 } 784 785 void AffineDmaStartOp::print(OpAsmPrinter *p) { 786 *p << "affine.dma_start " << *getSrcMemRef() << '['; 787 SmallVector<Value *, 8> operands(getSrcIndices()); 788 p->printAffineMapOfSSAIds(getSrcMapAttr(), operands); 789 *p << "], " << *getDstMemRef() << '['; 790 operands.assign(getDstIndices().begin(), getDstIndices().end()); 791 p->printAffineMapOfSSAIds(getDstMapAttr(), operands); 792 *p << "], " << *getTagMemRef() << '['; 793 operands.assign(getTagIndices().begin(), getTagIndices().end()); 794 p->printAffineMapOfSSAIds(getTagMapAttr(), operands); 795 *p << "], " << *getNumElements(); 796 if (isStrided()) { 797 *p << ", " << *getStride(); 798 *p << ", " << *getNumElementsPerStride(); 799 } 800 *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " 801 << getTagMemRefType(); 802 } 803 804 // Parse AffineDmaStartOp. 805 // Ex: 806 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, 807 // %stride, %num_elt_per_stride 808 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> 809 // 810 ParseResult AffineDmaStartOp::parse(OpAsmParser *parser, 811 OperationState *result) { 812 OpAsmParser::OperandType srcMemRefInfo; 813 AffineMapAttr srcMapAttr; 814 SmallVector<OpAsmParser::OperandType, 4> srcMapOperands; 815 OpAsmParser::OperandType dstMemRefInfo; 816 AffineMapAttr dstMapAttr; 817 SmallVector<OpAsmParser::OperandType, 4> dstMapOperands; 818 OpAsmParser::OperandType tagMemRefInfo; 819 AffineMapAttr tagMapAttr; 820 SmallVector<OpAsmParser::OperandType, 4> tagMapOperands; 821 OpAsmParser::OperandType numElementsInfo; 822 SmallVector<OpAsmParser::OperandType, 2> strideInfo; 823 824 SmallVector<Type, 3> types; 825 auto indexType = parser->getBuilder().getIndexType(); 826 827 // Parse and resolve the following list of operands: 828 // *) dst memref followed by its affine maps operands (in square brackets). 829 // *) src memref followed by its affine map operands (in square brackets). 830 // *) tag memref followed by its affine map operands (in square brackets). 831 // *) number of elements transferred by DMA operation. 832 if (parser->parseOperand(srcMemRefInfo) || 833 parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, 834 getSrcMapAttrName(), result->attributes) || 835 parser->parseComma() || parser->parseOperand(dstMemRefInfo) || 836 parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, 837 getDstMapAttrName(), result->attributes) || 838 parser->parseComma() || parser->parseOperand(tagMemRefInfo) || 839 parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, 840 getTagMapAttrName(), result->attributes) || 841 parser->parseComma() || parser->parseOperand(numElementsInfo)) 842 return failure(); 843 844 // Parse optional stride and elements per stride. 845 if (parser->parseTrailingOperandList(strideInfo)) { 846 return failure(); 847 } 848 if (!strideInfo.empty() && strideInfo.size() != 2) { 849 return parser->emitError(parser->getNameLoc(), 850 "expected two stride related operands"); 851 } 852 bool isStrided = strideInfo.size() == 2; 853 854 if (parser->parseColonTypeList(types)) 855 return failure(); 856 857 if (types.size() != 3) 858 return parser->emitError(parser->getNameLoc(), "expected three types"); 859 860 if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || 861 parser->resolveOperands(srcMapOperands, indexType, result->operands) || 862 parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || 863 parser->resolveOperands(dstMapOperands, indexType, result->operands) || 864 parser->resolveOperand(tagMemRefInfo, types[2], result->operands) || 865 parser->resolveOperands(tagMapOperands, indexType, result->operands) || 866 parser->resolveOperand(numElementsInfo, indexType, result->operands)) 867 return failure(); 868 869 if (isStrided) { 870 if (parser->resolveOperands(strideInfo, indexType, result->operands)) 871 return failure(); 872 } 873 874 // Check that src/dst/tag operand counts match their map.numInputs. 875 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || 876 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || 877 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) 878 return parser->emitError(parser->getNameLoc(), 879 "memref operand count not equal to map.numInputs"); 880 return success(); 881 } 882 883 LogicalResult AffineDmaStartOp::verify() { 884 if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>()) 885 return emitOpError("expected DMA source to be of memref type"); 886 if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>()) 887 return emitOpError("expected DMA destination to be of memref type"); 888 if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>()) 889 return emitOpError("expected DMA tag to be of memref type"); 890 891 // DMAs from different memory spaces supported. 892 if (getSrcMemorySpace() == getDstMemorySpace()) { 893 return emitOpError("DMA should be between different memory spaces"); 894 } 895 unsigned numInputsAllMaps = getSrcMap().getNumInputs() + 896 getDstMap().getNumInputs() + 897 getTagMap().getNumInputs(); 898 if (getNumOperands() != numInputsAllMaps + 3 + 1 && 899 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { 900 return emitOpError("incorrect number of operands"); 901 } 902 903 for (auto *idx : getSrcIndices()) { 904 if (!idx->getType().isIndex()) 905 return emitOpError("src index to dma_start must have 'index' type"); 906 if (!isValidAffineIndexOperand(idx)) 907 return emitOpError("src index must be a dimension or symbol identifier"); 908 } 909 for (auto *idx : getDstIndices()) { 910 if (!idx->getType().isIndex()) 911 return emitOpError("dst index to dma_start must have 'index' type"); 912 if (!isValidAffineIndexOperand(idx)) 913 return emitOpError("dst index must be a dimension or symbol identifier"); 914 } 915 for (auto *idx : getTagIndices()) { 916 if (!idx->getType().isIndex()) 917 return emitOpError("tag index to dma_start must have 'index' type"); 918 if (!isValidAffineIndexOperand(idx)) 919 return emitOpError("tag index must be a dimension or symbol identifier"); 920 } 921 return success(); 922 } 923 924 void AffineDmaStartOp::getCanonicalizationPatterns( 925 OwningRewritePatternList &results, MLIRContext *context) { 926 /// dma_start(memrefcast) -> dma_start 927 results.insert<MemRefCastFolder>(getOperationName(), context); 928 } 929 930 //===----------------------------------------------------------------------===// 931 // AffineDmaWaitOp 932 //===----------------------------------------------------------------------===// 933 934 // TODO(b/133776335) Check that map operands are loop IVs or symbols. 935 void AffineDmaWaitOp::build(Builder *builder, OperationState *result, 936 Value *tagMemRef, AffineMap tagMap, 937 ArrayRef<Value *> tagIndices, Value *numElements) { 938 result->addOperands(tagMemRef); 939 result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); 940 result->addOperands(tagIndices); 941 result->addOperands(numElements); 942 } 943 944 void AffineDmaWaitOp::print(OpAsmPrinter *p) { 945 *p << "affine.dma_wait " << *getTagMemRef() << '['; 946 SmallVector<Value *, 2> operands(getTagIndices()); 947 p->printAffineMapOfSSAIds(getTagMapAttr(), operands); 948 *p << "], "; 949 p->printOperand(getNumElements()); 950 *p << " : " << getTagMemRef()->getType(); 951 } 952 953 // Parse AffineDmaWaitOp. 954 // Eg: 955 // affine.dma_wait %tag[%index], %num_elements 956 // : memref<1 x i32, (d0) -> (d0), 4> 957 // 958 ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser, 959 OperationState *result) { 960 OpAsmParser::OperandType tagMemRefInfo; 961 AffineMapAttr tagMapAttr; 962 SmallVector<OpAsmParser::OperandType, 2> tagMapOperands; 963 Type type; 964 auto indexType = parser->getBuilder().getIndexType(); 965 OpAsmParser::OperandType numElementsInfo; 966 967 // Parse tag memref, its map operands, and dma size. 968 if (parser->parseOperand(tagMemRefInfo) || 969 parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, 970 getTagMapAttrName(), result->attributes) || 971 parser->parseComma() || parser->parseOperand(numElementsInfo) || 972 parser->parseColonType(type) || 973 parser->resolveOperand(tagMemRefInfo, type, result->operands) || 974 parser->resolveOperands(tagMapOperands, indexType, result->operands) || 975 parser->resolveOperand(numElementsInfo, indexType, result->operands)) 976 return failure(); 977 978 if (!type.isa<MemRefType>()) 979 return parser->emitError(parser->getNameLoc(), 980 "expected tag to be of memref type"); 981 982 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) 983 return parser->emitError(parser->getNameLoc(), 984 "tag memref operand count != to map.numInputs"); 985 return success(); 986 } 987 988 LogicalResult AffineDmaWaitOp::verify() { 989 if (!getOperand(0)->getType().isa<MemRefType>()) 990 return emitOpError("expected DMA tag to be of memref type"); 991 for (auto *idx : getTagIndices()) { 992 if (!idx->getType().isIndex()) 993 return emitOpError("index to dma_wait must have 'index' type"); 994 if (!isValidAffineIndexOperand(idx)) 995 return emitOpError("index must be a dimension or symbol identifier"); 996 } 997 return success(); 998 } 999 1000 void AffineDmaWaitOp::getCanonicalizationPatterns( 1001 OwningRewritePatternList &results, MLIRContext *context) { 1002 /// dma_wait(memrefcast) -> dma_wait 1003 results.insert<MemRefCastFolder>(getOperationName(), context); 1004 } 1005 1006 //===----------------------------------------------------------------------===// 1007 // AffineForOp 1008 //===----------------------------------------------------------------------===// 1009 1010 void AffineForOp::build(Builder *builder, OperationState *result, 1011 ArrayRef<Value *> lbOperands, AffineMap lbMap, 1012 ArrayRef<Value *> ubOperands, AffineMap ubMap, 1013 int64_t step) { 1014 assert(((!lbMap && lbOperands.empty()) || 1015 lbOperands.size() == lbMap.getNumInputs()) && 1016 "lower bound operand count does not match the affine map"); 1017 assert(((!ubMap && ubOperands.empty()) || 1018 ubOperands.size() == ubMap.getNumInputs()) && 1019 "upper bound operand count does not match the affine map"); 1020 assert(step > 0 && "step has to be a positive integer constant"); 1021 1022 // Add an attribute for the step. 1023 result->addAttribute(getStepAttrName(), 1024 builder->getIntegerAttr(builder->getIndexType(), step)); 1025 1026 // Add the lower bound. 1027 result->addAttribute(getLowerBoundAttrName(), 1028 builder->getAffineMapAttr(lbMap)); 1029 result->addOperands(lbOperands); 1030 1031 // Add the upper bound. 1032 result->addAttribute(getUpperBoundAttrName(), 1033 builder->getAffineMapAttr(ubMap)); 1034 result->addOperands(ubOperands); 1035 1036 // Create a region and a block for the body. The argument of the region is 1037 // the loop induction variable. 1038 Region *bodyRegion = result->addRegion(); 1039 Block *body = new Block(); 1040 body->addArgument(IndexType::get(builder->getContext())); 1041 bodyRegion->push_back(body); 1042 ensureTerminator(*bodyRegion, *builder, result->location); 1043 1044 // Set the operands list as resizable so that we can freely modify the bounds. 1045 result->setOperandListToResizable(); 1046 } 1047 1048 void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, 1049 int64_t ub, int64_t step) { 1050 auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); 1051 auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); 1052 return build(builder, result, {}, lbMap, {}, ubMap, step); 1053 } 1054 1055 static LogicalResult verify(AffineForOp op) { 1056 // Check that the body defines as single block argument for the induction 1057 // variable. 1058 auto *body = op.getBody(); 1059 if (body->getNumArguments() != 1 || 1060 !body->getArgument(0)->getType().isIndex()) 1061 return op.emitOpError( 1062 "expected body to have a single index argument for the " 1063 "induction variable"); 1064 1065 // Verify that there are enough operands for the bounds. 1066 AffineMap lowerBoundMap = op.getLowerBoundMap(), 1067 upperBoundMap = op.getUpperBoundMap(); 1068 if (op.getNumOperands() != 1069 (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs())) 1070 return op.emitOpError( 1071 "operand count must match with affine map dimension and symbol count"); 1072 1073 // Verify that the bound operands are valid dimension/symbols. 1074 /// Lower bound. 1075 if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), 1076 op.getLowerBoundMap().getNumDims()))) 1077 return failure(); 1078 /// Upper bound. 1079 if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), 1080 op.getUpperBoundMap().getNumDims()))) 1081 return failure(); 1082 return success(); 1083 } 1084 1085 /// Parse a for operation loop bounds. 1086 static ParseResult parseBound(bool isLower, OperationState *result, 1087 OpAsmParser *p) { 1088 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if 1089 // the map has multiple results. 1090 bool failedToParsedMinMax = 1091 failed(p->parseOptionalKeyword(isLower ? "max" : "min")); 1092 1093 auto &builder = p->getBuilder(); 1094 auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() 1095 : AffineForOp::getUpperBoundAttrName(); 1096 1097 // Parse ssa-id as identity map. 1098 SmallVector<OpAsmParser::OperandType, 1> boundOpInfos; 1099 if (p->parseOperandList(boundOpInfos)) 1100 return failure(); 1101 1102 if (!boundOpInfos.empty()) { 1103 // Check that only one operand was parsed. 1104 if (boundOpInfos.size() > 1) 1105 return p->emitError(p->getNameLoc(), 1106 "expected only one loop bound operand"); 1107 1108 // TODO: improve error message when SSA value is not an affine integer. 1109 // Currently it is 'use of value ... expects different type than prior uses' 1110 if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), 1111 result->operands)) 1112 return failure(); 1113 1114 // Create an identity map using symbol id. This representation is optimized 1115 // for storage. Analysis passes may expand it into a multi-dimensional map 1116 // if desired. 1117 AffineMap map = builder.getSymbolIdentityMap(); 1118 result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); 1119 return success(); 1120 } 1121 1122 // Get the attribute location. 1123 llvm::SMLoc attrLoc = p->getCurrentLocation(); 1124 1125 Attribute boundAttr; 1126 if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, 1127 result->attributes)) 1128 return failure(); 1129 1130 // Parse full form - affine map followed by dim and symbol list. 1131 if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { 1132 unsigned currentNumOperands = result->operands.size(); 1133 unsigned numDims; 1134 if (parseDimAndSymbolList(p, result->operands, numDims)) 1135 return failure(); 1136 1137 auto map = affineMapAttr.getValue(); 1138 if (map.getNumDims() != numDims) 1139 return p->emitError( 1140 p->getNameLoc(), 1141 "dim operand count and integer set dim count must match"); 1142 1143 unsigned numDimAndSymbolOperands = 1144 result->operands.size() - currentNumOperands; 1145 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) 1146 return p->emitError( 1147 p->getNameLoc(), 1148 "symbol operand count and integer set symbol count must match"); 1149 1150 // If the map has multiple results, make sure that we parsed the min/max 1151 // prefix. 1152 if (map.getNumResults() > 1 && failedToParsedMinMax) { 1153 if (isLower) { 1154 return p->emitError(attrLoc, "lower loop bound affine map with " 1155 "multiple results requires 'max' prefix"); 1156 } 1157 return p->emitError(attrLoc, "upper loop bound affine map with multiple " 1158 "results requires 'min' prefix"); 1159 } 1160 return success(); 1161 } 1162 1163 // Parse custom assembly form. 1164 if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { 1165 result->attributes.pop_back(); 1166 result->addAttribute( 1167 boundAttrName, builder.getAffineMapAttr( 1168 builder.getConstantAffineMap(integerAttr.getInt()))); 1169 return success(); 1170 } 1171 1172 return p->emitError( 1173 p->getNameLoc(), 1174 "expected valid affine map representation for loop bounds"); 1175 } 1176 1177 ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) { 1178 auto &builder = parser->getBuilder(); 1179 OpAsmParser::OperandType inductionVariable; 1180 // Parse the induction variable followed by '='. 1181 if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) 1182 return failure(); 1183 1184 // Parse loop bounds. 1185 if (parseBound(/*isLower=*/true, result, parser) || 1186 parser->parseKeyword("to", " between bounds") || 1187 parseBound(/*isLower=*/false, result, parser)) 1188 return failure(); 1189 1190 // Parse the optional loop step, we default to 1 if one is not present. 1191 if (parser->parseOptionalKeyword("step")) { 1192 result->addAttribute( 1193 AffineForOp::getStepAttrName(), 1194 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); 1195 } else { 1196 llvm::SMLoc stepLoc = parser->getCurrentLocation(); 1197 IntegerAttr stepAttr; 1198 if (parser->parseAttribute(stepAttr, builder.getIndexType(), 1199 AffineForOp::getStepAttrName().data(), 1200 result->attributes)) 1201 return failure(); 1202 1203 if (stepAttr.getValue().getSExtValue() < 0) 1204 return parser->emitError( 1205 stepLoc, 1206 "expected step to be representable as a positive signed integer"); 1207 } 1208 1209 // Parse the body region. 1210 Region *body = result->addRegion(); 1211 if (parser->parseRegion(*body, inductionVariable, builder.getIndexType())) 1212 return failure(); 1213 1214 AffineForOp::ensureTerminator(*body, builder, result->location); 1215 1216 // Parse the optional attribute list. 1217 if (parser->parseOptionalAttributeDict(result->attributes)) 1218 return failure(); 1219 1220 // Set the operands list as resizable so that we can freely modify the bounds. 1221 result->setOperandListToResizable(); 1222 return success(); 1223 } 1224 1225 static void printBound(AffineMapAttr boundMap, 1226 Operation::operand_range boundOperands, 1227 const char *prefix, OpAsmPrinter *p) { 1228 AffineMap map = boundMap.getValue(); 1229 1230 // Check if this bound should be printed using custom assembly form. 1231 // The decision to restrict printing custom assembly form to trivial cases 1232 // comes from the will to roundtrip MLIR binary -> text -> binary in a 1233 // lossless way. 1234 // Therefore, custom assembly form parsing and printing is only supported for 1235 // zero-operand constant maps and single symbol operand identity maps. 1236 if (map.getNumResults() == 1) { 1237 AffineExpr expr = map.getResult(0); 1238 1239 // Print constant bound. 1240 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { 1241 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { 1242 *p << constExpr.getValue(); 1243 return; 1244 } 1245 } 1246 1247 // Print bound that consists of a single SSA symbol if the map is over a 1248 // single symbol. 1249 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { 1250 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { 1251 p->printOperand(*boundOperands.begin()); 1252 return; 1253 } 1254 } 1255 } else { 1256 // Map has multiple results. Print 'min' or 'max' prefix. 1257 *p << prefix << ' '; 1258 } 1259 1260 // Print the map and its operands. 1261 *p << boundMap; 1262 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), 1263 map.getNumDims(), p); 1264 } 1265 1266 void print(OpAsmPrinter *p, AffineForOp op) { 1267 *p << "affine.for "; 1268 p->printOperand(op.getBody()->getArgument(0)); 1269 *p << " = "; 1270 printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); 1271 *p << " to "; 1272 printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); 1273 1274 if (op.getStep() != 1) 1275 *p << " step " << op.getStep(); 1276 p->printRegion(op.region(), 1277 /*printEntryBlockArgs=*/false, 1278 /*printBlockTerminators=*/false); 1279 p->printOptionalAttrDict(op.getAttrs(), 1280 /*elidedAttrs=*/{op.getLowerBoundAttrName(), 1281 op.getUpperBoundAttrName(), 1282 op.getStepAttrName()}); 1283 } 1284 1285 namespace { 1286 /// This is a pattern to fold trivially empty loops. 1287 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> { 1288 using OpRewritePattern<AffineForOp>::OpRewritePattern; 1289 1290 PatternMatchResult matchAndRewrite(AffineForOp forOp, 1291 PatternRewriter &rewriter) const override { 1292 // Check that the body only contains a terminator. 1293 auto *body = forOp.getBody(); 1294 if (std::next(body->begin()) != body->end()) 1295 return matchFailure(); 1296 rewriter.replaceOp(forOp, llvm::None); 1297 return matchSuccess(); 1298 } 1299 }; 1300 1301 /// This is a pattern to fold constant loop bounds. 1302 struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> { 1303 using OpRewritePattern<AffineForOp>::OpRewritePattern; 1304 1305 PatternMatchResult matchAndRewrite(AffineForOp forOp, 1306 PatternRewriter &rewriter) const override { 1307 auto foldLowerOrUpperBound = [&forOp](bool lower) { 1308 // Check to see if each of the operands is the result of a constant. If 1309 // so, get the value. If not, ignore it. 1310 SmallVector<Attribute, 8> operandConstants; 1311 auto boundOperands = 1312 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); 1313 for (auto *operand : boundOperands) { 1314 Attribute operandCst; 1315 matchPattern(operand, m_Constant(&operandCst)); 1316 operandConstants.push_back(operandCst); 1317 } 1318 1319 AffineMap boundMap = 1320 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); 1321 assert(boundMap.getNumResults() >= 1 && 1322 "bound maps should have at least one result"); 1323 SmallVector<Attribute, 4> foldedResults; 1324 if (failed(boundMap.constantFold(operandConstants, foldedResults))) 1325 return failure(); 1326 1327 // Compute the max or min as applicable over the results. 1328 assert(!foldedResults.empty() && 1329 "bounds should have at least one result"); 1330 auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue(); 1331 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { 1332 auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue(); 1333 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) 1334 : llvm::APIntOps::smin(maxOrMin, foldedResult); 1335 } 1336 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) 1337 : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); 1338 return success(); 1339 }; 1340 1341 // Try to fold the lower bound. 1342 bool folded = false; 1343 if (!forOp.hasConstantLowerBound()) 1344 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); 1345 1346 // Try to fold the upper bound. 1347 if (!forOp.hasConstantUpperBound()) 1348 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); 1349 1350 // If any of the bounds were folded we return success. 1351 if (!folded) 1352 return matchFailure(); 1353 rewriter.updatedRootInPlace(forOp); 1354 return matchSuccess(); 1355 } 1356 }; 1357 } // end anonymous namespace 1358 1359 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 1360 MLIRContext *context) { 1361 results.insert<AffineForEmptyLoopFolder, AffineForLoopBoundFolder>(context); 1362 } 1363 1364 AffineBound AffineForOp::getLowerBound() { 1365 auto lbMap = getLowerBoundMap(); 1366 return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); 1367 } 1368 1369 AffineBound AffineForOp::getUpperBound() { 1370 auto lbMap = getLowerBoundMap(); 1371 auto ubMap = getUpperBoundMap(); 1372 return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(), 1373 ubMap); 1374 } 1375 1376 void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { 1377 assert(lbOperands.size() == map.getNumInputs()); 1378 assert(map.getNumResults() >= 1 && "bound map has at least one result"); 1379 1380 SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end()); 1381 1382 auto ubOperands = getUpperBoundOperands(); 1383 newOperands.append(ubOperands.begin(), ubOperands.end()); 1384 getOperation()->setOperands(newOperands); 1385 1386 setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); 1387 } 1388 1389 void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { 1390 assert(ubOperands.size() == map.getNumInputs()); 1391 assert(map.getNumResults() >= 1 && "bound map has at least one result"); 1392 1393 SmallVector<Value *, 4> newOperands(getLowerBoundOperands()); 1394 newOperands.append(ubOperands.begin(), ubOperands.end()); 1395 getOperation()->setOperands(newOperands); 1396 1397 setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); 1398 } 1399 1400 void AffineForOp::setLowerBoundMap(AffineMap map) { 1401 auto lbMap = getLowerBoundMap(); 1402 assert(lbMap.getNumDims() == map.getNumDims() && 1403 lbMap.getNumSymbols() == map.getNumSymbols()); 1404 assert(map.getNumResults() >= 1 && "bound map has at least one result"); 1405 (void)lbMap; 1406 setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); 1407 } 1408 1409 void AffineForOp::setUpperBoundMap(AffineMap map) { 1410 auto ubMap = getUpperBoundMap(); 1411 assert(ubMap.getNumDims() == map.getNumDims() && 1412 ubMap.getNumSymbols() == map.getNumSymbols()); 1413 assert(map.getNumResults() >= 1 && "bound map has at least one result"); 1414 (void)ubMap; 1415 setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); 1416 } 1417 1418 bool AffineForOp::hasConstantLowerBound() { 1419 return getLowerBoundMap().isSingleConstant(); 1420 } 1421 1422 bool AffineForOp::hasConstantUpperBound() { 1423 return getUpperBoundMap().isSingleConstant(); 1424 } 1425 1426 int64_t AffineForOp::getConstantLowerBound() { 1427 return getLowerBoundMap().getSingleConstantResult(); 1428 } 1429 1430 int64_t AffineForOp::getConstantUpperBound() { 1431 return getUpperBoundMap().getSingleConstantResult(); 1432 } 1433 1434 void AffineForOp::setConstantLowerBound(int64_t value) { 1435 setLowerBound({}, AffineMap::getConstantMap(value, getContext())); 1436 } 1437 1438 void AffineForOp::setConstantUpperBound(int64_t value) { 1439 setUpperBound({}, AffineMap::getConstantMap(value, getContext())); 1440 } 1441 1442 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { 1443 return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; 1444 } 1445 1446 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { 1447 return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; 1448 } 1449 1450 bool AffineForOp::matchingBoundOperandList() { 1451 auto lbMap = getLowerBoundMap(); 1452 auto ubMap = getUpperBoundMap(); 1453 if (lbMap.getNumDims() != ubMap.getNumDims() || 1454 lbMap.getNumSymbols() != ubMap.getNumSymbols()) 1455 return false; 1456 1457 unsigned numOperands = lbMap.getNumInputs(); 1458 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { 1459 // Compare Value *'s. 1460 if (getOperand(i) != getOperand(numOperands + i)) 1461 return false; 1462 } 1463 return true; 1464 } 1465 1466 /// Returns if the provided value is the induction variable of a AffineForOp. 1467 bool mlir::isForInductionVar(Value *val) { 1468 return getForInductionVarOwner(val) != AffineForOp(); 1469 } 1470 1471 /// Returns the loop parent of an induction variable. If the provided value is 1472 /// not an induction variable, then return nullptr. 1473 AffineForOp mlir::getForInductionVarOwner(Value *val) { 1474 auto *ivArg = dyn_cast<BlockArgument>(val); 1475 if (!ivArg || !ivArg->getOwner()) 1476 return AffineForOp(); 1477 auto *containingInst = ivArg->getOwner()->getParent()->getParentOp(); 1478 return dyn_cast<AffineForOp>(containingInst); 1479 } 1480 1481 /// Extracts the induction variables from a list of AffineForOps and returns 1482 /// them. 1483 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts, 1484 SmallVectorImpl<Value *> *ivs) { 1485 ivs->reserve(forInsts.size()); 1486 for (auto forInst : forInsts) 1487 ivs->push_back(forInst.getInductionVar()); 1488 } 1489 1490 //===----------------------------------------------------------------------===// 1491 // AffineIfOp 1492 //===----------------------------------------------------------------------===// 1493 1494 static LogicalResult verify(AffineIfOp op) { 1495 // Verify that we have a condition attribute. 1496 auto conditionAttr = 1497 op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); 1498 if (!conditionAttr) 1499 return op.emitOpError( 1500 "requires an integer set attribute named 'condition'"); 1501 1502 // Verify that there are enough operands for the condition. 1503 IntegerSet condition = conditionAttr.getValue(); 1504 if (op.getNumOperands() != condition.getNumOperands()) 1505 return op.emitOpError( 1506 "operand count and condition integer set dimension and " 1507 "symbol count must match"); 1508 1509 // Verify that the operands are valid dimension/symbols. 1510 if (failed(verifyDimAndSymbolIdentifiers( 1511 op, op.getOperation()->getNonSuccessorOperands(), 1512 condition.getNumDims()))) 1513 return failure(); 1514 1515 // Verify that the entry of each child region does not have arguments. 1516 for (auto ®ion : op.getOperation()->getRegions()) { 1517 for (auto &b : region) 1518 if (b.getNumArguments() != 0) 1519 return op.emitOpError( 1520 "requires that child entry blocks have no arguments"); 1521 } 1522 return success(); 1523 } 1524 1525 ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) { 1526 // Parse the condition attribute set. 1527 IntegerSetAttr conditionAttr; 1528 unsigned numDims; 1529 if (parser->parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), 1530 result->attributes) || 1531 parseDimAndSymbolList(parser, result->operands, numDims)) 1532 return failure(); 1533 1534 // Verify the condition operands. 1535 auto set = conditionAttr.getValue(); 1536 if (set.getNumDims() != numDims) 1537 return parser->emitError( 1538 parser->getNameLoc(), 1539 "dim operand count and integer set dim count must match"); 1540 if (numDims + set.getNumSymbols() != result->operands.size()) 1541 return parser->emitError( 1542 parser->getNameLoc(), 1543 "symbol operand count and integer set symbol count must match"); 1544 1545 // Create the regions for 'then' and 'else'. The latter must be created even 1546 // if it remains empty for the validity of the operation. 1547 result->regions.reserve(2); 1548 Region *thenRegion = result->addRegion(); 1549 Region *elseRegion = result->addRegion(); 1550 1551 // Parse the 'then' region. 1552 if (parser->parseRegion(*thenRegion, {}, {})) 1553 return failure(); 1554 AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(), 1555 result->location); 1556 1557 // If we find an 'else' keyword then parse the 'else' region. 1558 if (!parser->parseOptionalKeyword("else")) { 1559 if (parser->parseRegion(*elseRegion, {}, {})) 1560 return failure(); 1561 AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(), 1562 result->location); 1563 } 1564 1565 // Parse the optional attribute list. 1566 if (parser->parseOptionalAttributeDict(result->attributes)) 1567 return failure(); 1568 1569 return success(); 1570 } 1571 1572 void print(OpAsmPrinter *p, AffineIfOp op) { 1573 auto conditionAttr = 1574 op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); 1575 *p << "affine.if " << conditionAttr; 1576 printDimAndSymbolList(op.operand_begin(), op.operand_end(), 1577 conditionAttr.getValue().getNumDims(), p); 1578 p->printRegion(op.thenRegion(), 1579 /*printEntryBlockArgs=*/false, 1580 /*printBlockTerminators=*/false); 1581 1582 // Print the 'else' regions if it has any blocks. 1583 auto &elseRegion = op.elseRegion(); 1584 if (!elseRegion.empty()) { 1585 *p << " else"; 1586 p->printRegion(elseRegion, 1587 /*printEntryBlockArgs=*/false, 1588 /*printBlockTerminators=*/false); 1589 } 1590 1591 // Print the attribute list. 1592 p->printOptionalAttrDict(op.getAttrs(), 1593 /*elidedAttrs=*/op.getConditionAttrName()); 1594 } 1595 1596 IntegerSet AffineIfOp::getIntegerSet() { 1597 return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue(); 1598 } 1599 void AffineIfOp::setIntegerSet(IntegerSet newSet) { 1600 setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); 1601 } 1602 1603 //===----------------------------------------------------------------------===// 1604 // AffineLoadOp 1605 //===----------------------------------------------------------------------===// 1606 1607 void AffineLoadOp::build(Builder *builder, OperationState *result, 1608 AffineMap map, ArrayRef<Value *> operands) { 1609 result->addOperands(operands); 1610 if (map) 1611 result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); 1612 auto memrefType = operands[0]->getType().cast<MemRefType>(); 1613 result->types.push_back(memrefType.getElementType()); 1614 } 1615 1616 void AffineLoadOp::build(Builder *builder, OperationState *result, 1617 Value *memref, ArrayRef<Value *> indices) { 1618 result->addOperands(memref); 1619 result->addOperands(indices); 1620 auto memrefType = memref->getType().cast<MemRefType>(); 1621 auto rank = memrefType.getRank(); 1622 // Create identity map for memrefs with at least one dimension or () -> () 1623 // for zero-dimensional memrefs. 1624 auto map = rank ? builder->getMultiDimIdentityMap(rank) 1625 : builder->getEmptyAffineMap(); 1626 result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); 1627 result->types.push_back(memrefType.getElementType()); 1628 } 1629 1630 ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { 1631 auto &builder = parser->getBuilder(); 1632 auto affineIntTy = builder.getIndexType(); 1633 1634 MemRefType type; 1635 OpAsmParser::OperandType memrefInfo; 1636 AffineMapAttr mapAttr; 1637 SmallVector<OpAsmParser::OperandType, 1> mapOperands; 1638 return failure( 1639 parser->parseOperand(memrefInfo) || 1640 parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), 1641 result->attributes) || 1642 parser->parseOptionalAttributeDict(result->attributes) || 1643 parser->parseColonType(type) || 1644 parser->resolveOperand(memrefInfo, type, result->operands) || 1645 parser->resolveOperands(mapOperands, affineIntTy, result->operands) || 1646 parser->addTypeToList(type.getElementType(), result->types)); 1647 } 1648 1649 void AffineLoadOp::print(OpAsmPrinter *p) { 1650 *p << "affine.load " << *getMemRef() << '['; 1651 AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); 1652 if (mapAttr) { 1653 SmallVector<Value *, 2> operands(getIndices()); 1654 p->printAffineMapOfSSAIds(mapAttr, operands); 1655 } 1656 *p << ']'; 1657 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); 1658 *p << " : " << getMemRefType(); 1659 } 1660 1661 LogicalResult AffineLoadOp::verify() { 1662 if (getType() != getMemRefType().getElementType()) 1663 return emitOpError("result type must match element type of memref"); 1664 1665 auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); 1666 if (mapAttr) { 1667 AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue(); 1668 if (map.getNumResults() != getMemRefType().getRank()) 1669 return emitOpError("affine.load affine map num results must equal" 1670 " memref rank"); 1671 if (map.getNumInputs() != getNumOperands() - 1) 1672 return emitOpError("expects as many subscripts as affine map inputs"); 1673 } else { 1674 if (getMemRefType().getRank() != getNumOperands() - 1) 1675 return emitOpError( 1676 "expects the number of subscripts to be equal to memref rank"); 1677 } 1678 1679 for (auto *idx : getIndices()) { 1680 if (!idx->getType().isIndex()) 1681 return emitOpError("index to load must have 'index' type"); 1682 if (!isValidAffineIndexOperand(idx)) 1683 return emitOpError("index must be a dimension or symbol identifier"); 1684 } 1685 return success(); 1686 } 1687 1688 void AffineLoadOp::getCanonicalizationPatterns( 1689 OwningRewritePatternList &results, MLIRContext *context) { 1690 /// load(memrefcast) -> load 1691 results.insert<MemRefCastFolder>(getOperationName(), context); 1692 } 1693 1694 //===----------------------------------------------------------------------===// 1695 // AffineStoreOp 1696 //===----------------------------------------------------------------------===// 1697 1698 void AffineStoreOp::build(Builder *builder, OperationState *result, 1699 Value *valueToStore, AffineMap map, 1700 ArrayRef<Value *> operands) { 1701 result->addOperands(valueToStore); 1702 result->addOperands(operands); 1703 if (map) 1704 result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); 1705 } 1706 1707 void AffineStoreOp::build(Builder *builder, OperationState *result, 1708 Value *valueToStore, Value *memref, 1709 ArrayRef<Value *> operands) { 1710 result->addOperands(valueToStore); 1711 result->addOperands(memref); 1712 result->addOperands(operands); 1713 auto memrefType = memref->getType().cast<MemRefType>(); 1714 auto rank = memrefType.getRank(); 1715 // Create identity map for memrefs with at least one dimension or () -> () 1716 // for zero-dimensional memrefs. 1717 auto map = rank ? builder->getMultiDimIdentityMap(rank) 1718 : builder->getEmptyAffineMap(); 1719 result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); 1720 } 1721 1722 ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { 1723 auto affineIntTy = parser->getBuilder().getIndexType(); 1724 1725 MemRefType type; 1726 OpAsmParser::OperandType storeValueInfo; 1727 OpAsmParser::OperandType memrefInfo; 1728 AffineMapAttr mapAttr; 1729 SmallVector<OpAsmParser::OperandType, 1> mapOperands; 1730 return failure( 1731 parser->parseOperand(storeValueInfo) || parser->parseComma() || 1732 parser->parseOperand(memrefInfo) || 1733 parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), 1734 result->attributes) || 1735 parser->parseOptionalAttributeDict(result->attributes) || 1736 parser->parseColonType(type) || 1737 parser->resolveOperand(storeValueInfo, type.getElementType(), 1738 result->operands) || 1739 parser->resolveOperand(memrefInfo, type, result->operands) || 1740 parser->resolveOperands(mapOperands, affineIntTy, result->operands)); 1741 } 1742 1743 void AffineStoreOp::print(OpAsmPrinter *p) { 1744 *p << "affine.store " << *getValueToStore(); 1745 *p << ", " << *getMemRef() << '['; 1746 AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); 1747 if (mapAttr) { 1748 SmallVector<Value *, 2> operands(getIndices()); 1749 p->printAffineMapOfSSAIds(mapAttr, operands); 1750 } 1751 *p << ']'; 1752 p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); 1753 *p << " : " << getMemRefType(); 1754 } 1755 1756 LogicalResult AffineStoreOp::verify() { 1757 // First operand must have same type as memref element type. 1758 if (getValueToStore()->getType() != getMemRefType().getElementType()) 1759 return emitOpError("first operand must have same type memref element type"); 1760 1761 auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); 1762 if (mapAttr) { 1763 AffineMap map = mapAttr.getValue(); 1764 if (map.getNumResults() != getMemRefType().getRank()) 1765 return emitOpError("affine.store affine map num results must equal" 1766 " memref rank"); 1767 if (map.getNumInputs() != getNumOperands() - 2) 1768 return emitOpError("expects as many subscripts as affine map inputs"); 1769 } else { 1770 if (getMemRefType().getRank() != getNumOperands() - 2) 1771 return emitOpError( 1772 "expects the number of subscripts to be equal to memref rank"); 1773 } 1774 1775 for (auto *idx : getIndices()) { 1776 if (!idx->getType().isIndex()) 1777 return emitOpError("index to store must have 'index' type"); 1778 if (!isValidAffineIndexOperand(idx)) 1779 return emitOpError("index must be a dimension or symbol identifier"); 1780 } 1781 return success(); 1782 } 1783 1784 void AffineStoreOp::getCanonicalizationPatterns( 1785 OwningRewritePatternList &results, MLIRContext *context) { 1786 /// load(memrefcast) -> load 1787 results.insert<MemRefCastFolder>(getOperationName(), context); 1788 } 1789 1790 #define GET_OP_CLASSES 1791 #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"