github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (about) 1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a the Linalg operations. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 23 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 24 #include "mlir/Dialect/Linalg/Utils/Utils.h" 25 #include "mlir/Dialect/LoopOps/LoopOps.h" 26 #include "mlir/EDSC/Helpers.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineMap.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/Function.h" 31 #include "mlir/IR/Module.h" 32 #include "mlir/IR/OpImplementation.h" 33 #include "mlir/IR/PatternMatch.h" 34 #include "mlir/IR/StandardTypes.h" 35 #include "mlir/Support/LLVM.h" 36 #include "mlir/Support/STLExtras.h" 37 #include "mlir/Transforms/FoldUtils.h" 38 39 #include "llvm/ADT/StringSet.h" 40 #include "llvm/Support/MathExtras.h" 41 #include "llvm/Support/raw_ostream.h" 42 43 using namespace mlir; 44 using namespace mlir::edsc; 45 using namespace mlir::edsc::intrinsics; 46 using namespace mlir::linalg; 47 48 namespace { 49 /// Fold constant dimensions into an alloc operation. 50 struct SimplifyDimOp : public OpRewritePattern<linalg::DimOp> { 51 using OpRewritePattern<linalg::DimOp>::OpRewritePattern; 52 53 PatternMatchResult matchAndRewrite(linalg::DimOp dimOp, 54 PatternRewriter &rewriter) const override; 55 }; 56 } // end namespace 57 58 PatternMatchResult 59 SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp, 60 PatternRewriter &rewriter) const { 61 auto *viewProducingOp = dimOp.view()->getDefiningOp(); 62 auto subView = dyn_cast_or_null<SubViewOp>(viewProducingOp); 63 auto slice = dyn_cast_or_null<SliceOp>(viewProducingOp); 64 auto view = dyn_cast_or_null<ViewOp>(viewProducingOp); 65 assert(subView || slice || view); 66 67 unsigned dim = dimOp.getIndex(); 68 Value *min, *max, *step; 69 if (view) { 70 // Cannot traverse block arguments, fail. 71 if (isa<BlockArgument>(view.getRange(dim))) 72 return matchFailure(); 73 // Record min, max, step for further processing. 74 auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp()); 75 std::tie(min, max, step) = 76 std::make_tuple(range.min(), range.max(), range.step()); 77 } else if (subView) { 78 // Record min, max, step for further processing. 79 auto range = subView.getRange(dim); 80 std::tie(min, max, step) = 81 std::make_tuple(range.min, range.max, range.step); 82 } else { 83 // Taking the dim of a slice must take a range (since other dims have been 84 // rank-reduced). 85 auto *rangeValue = slice.getRanges()[dim]; 86 // Cannot traverse block arguments, fail. 87 if (isa<BlockArgument>(rangeValue)) 88 return matchFailure(); 89 auto range = cast<RangeOp>(rangeValue->getDefiningOp()); 90 // Record min, max, step for further processing. 91 std::tie(min, max, step) = 92 std::make_tuple(range.min(), range.max(), range.step()); 93 } 94 95 // Only support constant steps of 1 atm. 96 auto constant = dyn_cast_or_null<ConstantIndexOp>(step->getDefiningOp()); 97 if (!constant || constant.getValue() != 1) 98 return matchFailure(); 99 100 // Circumvent affine constraints: 101 // emit an affine_apply when possible, otherwise emit a `subi`. 102 bool validAffineMin = isValidDim(min) || isValidSymbol(min) || 103 isa_and_nonnull<ConstantIndexOp>(min->getDefiningOp()); 104 bool validAffineMax = isValidDim(max) || isValidSymbol(max) || 105 isa_and_nonnull<ConstantIndexOp>(max->getDefiningOp()); 106 107 OpBuilder b(dimOp); 108 ScopedContext scope(b, dimOp.getLoc()); 109 // Emit `subi`. 110 if (!validAffineMin || !validAffineMax) { 111 rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()}); 112 return matchSuccess(); 113 } 114 115 // Emit affine_apply. 116 using edsc::op::operator-; 117 rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)}, 118 {dimOp.view()}); 119 return matchSuccess(); 120 } 121 122 ///////////////////// Operations defined with Tablegen ///////////////////////// 123 // For such operations that do not correspond to library calls (i.e. defined in 124 // LinalgOps.td), we define an overloaded `print` function and a 125 // parse`className` function. 126 127 //===----------------------------------------------------------------------===// 128 // BufferAllocOp 129 //===----------------------------------------------------------------------===// 130 131 static void print(OpAsmPrinter *p, BufferAllocOp op) { 132 *p << op.getOperationName() << " "; 133 if (!llvm::empty(op.size())) 134 *p << *op.getOperand(0); 135 if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) 136 p->printOptionalAttrDict(op.getAttrs()); 137 else 138 p->printOptionalAttrDict(op.getAttrs(), 139 BufferAllocOp::getAlignmentAttrName()); 140 *p << " : " << op.getBufferType(); 141 } 142 143 static ParseResult parseBufferAllocOp(OpAsmParser *parser, 144 OperationState *result) { 145 SmallVector<OpAsmParser::OperandType, 1> sizeInfo; 146 BufferType bufferType; 147 auto indexTy = parser->getBuilder().getIndexType(); 148 if (parser->parseOperandList(sizeInfo) || 149 parser->parseOptionalAttributeDict(result->attributes) || 150 parser->parseColonType(bufferType)) 151 return failure(); 152 if (sizeInfo.empty()) 153 return parser->addTypeToList(bufferType, result->types); 154 return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) || 155 parser->addTypeToList(bufferType, result->types)); 156 } 157 158 static LogicalResult verify(BufferAllocOp op) { 159 if (!op.getBufferType().hasConstantSize()) { 160 if (llvm::size(op.size()) != 1) 161 return op.emitOpError("expected one index operand"); 162 } else { // op.getBufferType().hasConstantSize() 163 if (!llvm::empty(op.size())) 164 return op.emitOpError("expected zero operand"); 165 if (op.getBufferType().getBufferSize().getValue() <= 0) 166 return op.emitOpError("expected nonnegative static buffer size"); 167 } 168 if (op.alignment().hasValue()) { 169 auto align = op.alignment().getValue(); 170 if (align.getSExtValue() < 0) 171 return op.emitOpError("expected positive alignment"); 172 if (!llvm::isPowerOf2_64(align.getZExtValue())) 173 return op.emitOpError("expected power of 2 alignment"); 174 } 175 if (!TensorType::isValidElementType(op.getElementType())) 176 return op.emitOpError("expected valid buffer element type"); 177 return success(); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // BufferDeallocOp 182 //===----------------------------------------------------------------------===// 183 184 static void print(OpAsmPrinter *p, BufferDeallocOp op) { 185 *p << op.getOperationName() << " " << *op.buffer(); 186 p->printOptionalAttrDict(op.getAttrs()); 187 *p << " : " << op.getBufferType(); 188 } 189 190 static ParseResult parseBufferDeallocOp(OpAsmParser *parser, 191 OperationState *result) { 192 OpAsmParser::OperandType bufferInfo; 193 BufferType bufferType; 194 if (parser->parseOperand(bufferInfo) || 195 parser->parseOptionalAttributeDict(result->attributes) || 196 parser->parseColonType(bufferType)) 197 return failure(); 198 return parser->resolveOperands(bufferInfo, bufferType, result->operands); 199 } 200 201 //===----------------------------------------------------------------------===// 202 // BufferSizeOp 203 //===----------------------------------------------------------------------===// 204 205 static void print(OpAsmPrinter *p, BufferSizeOp op) { 206 *p << op.getOperationName() << " " << *op.buffer(); 207 p->printOptionalAttrDict(op.getAttrs()); 208 *p << " : " << op.buffer()->getType(); 209 } 210 211 static ParseResult parseBufferSizeOp(OpAsmParser *parser, 212 OperationState *result) { 213 OpAsmParser::OperandType op; 214 Type type; 215 return failure(parser->parseOperand(op) || 216 parser->parseOptionalAttributeDict(result->attributes) || 217 parser->parseColonType(type) || 218 parser->resolveOperand(op, type, result->operands) || 219 parser->addTypeToList(parser->getBuilder().getIndexType(), 220 result->types)); 221 } 222 223 //===----------------------------------------------------------------------===// 224 // DimOp 225 //===----------------------------------------------------------------------===// 226 void mlir::linalg::DimOp::getCanonicalizationPatterns( 227 OwningRewritePatternList &results, MLIRContext *context) { 228 results.insert<SimplifyDimOp>(context); 229 } 230 231 static void print(OpAsmPrinter *p, linalg::DimOp op) { 232 *p << op.getOperationName() << " " << *op.getOperand() << ", " 233 << op.getIndex(); 234 p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); 235 *p << " : " << op.getOperand()->getType(); 236 } 237 238 static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) { 239 OpAsmParser::OperandType operandInfo; 240 IntegerAttr indexAttr; 241 Type type; 242 Type indexType = parser->getBuilder().getIndexType(); 243 return failure(parser->parseOperand(operandInfo) || parser->parseComma() || 244 parser->parseAttribute(indexAttr, indexType, "index", 245 result->attributes) || 246 parser->parseOptionalAttributeDict(result->attributes) || 247 parser->parseColonType(type) || 248 parser->resolveOperand(operandInfo, type, result->operands) || 249 parser->addTypeToList(indexType, result->types)); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // GenericOp 254 //===----------------------------------------------------------------------===// 255 256 static void print(OpAsmPrinter *p, GenericOp op) { 257 auto attrNames = op.linalgTraitAttrNames(); 258 llvm::StringSet<> linalgTraitAttrsSet; 259 linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end()); 260 SmallVector<NamedAttribute, 8> attrs; 261 for (auto attr : op.getAttrs()) { 262 if (linalgTraitAttrsSet.count(attr.first.strref()) > 0) 263 attrs.push_back(attr); 264 } 265 auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); 266 *p << op.getOperationName() << " " << dictAttr << " "; 267 p->printOperands(op.getOperands()); 268 if (!op.region().empty()) 269 p->printRegion(op.region()); 270 p->printOptionalAttrDict(op.getAttrs(), attrNames); 271 *p << ": "; 272 interleaveComma(op.getOperandTypes(), *p); 273 } 274 275 static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) { 276 SmallVector<OpAsmParser::OperandType, 8> operandsInfo, regionOperandsInfo; 277 DictionaryAttr dictAttr; 278 // Parse the core linalg traits that must check into a dictAttr. 279 // The name is unimportant as we will overwrite result->attributes. 280 // The core linalg traits must contain the information necessary to pass the 281 // verifier. 282 if (parser->parseAttribute(dictAttr, "_", result->attributes) || 283 parser->parseOperandList(operandsInfo)) 284 return failure(); 285 result->attributes.assign(dictAttr.getValue().begin(), 286 dictAttr.getValue().end()); 287 288 Region ®ion = *result->addRegion(); 289 SmallVector<Type, 8> operandTypes, regionTypes; 290 // Optional attributes may be added. 291 // Either Optional "fun" attribute or region must be specified. 292 if (!dictAttr.get("fun") && 293 parser->parseOptionalRegion(region, regionOperandsInfo, regionTypes)) 294 return failure(); 295 if (parser->parseOptionalAttributeDict(result->attributes) || 296 parser->parseColonTypeList(operandTypes)) 297 return failure(); 298 return parser->resolveOperands(operandsInfo, operandTypes, 299 parser->getCurrentLocation(), 300 result->operands); 301 } 302 303 static LogicalResult verify(GenericOp op) { 304 auto nInputViews = op.getNumInputs(); 305 auto nViews = op.getNumInputsAndOutputs(); 306 if (nViews != llvm::size(op.views())) 307 return op.emitError("op expected exactly ") << nViews << " view operands"; 308 309 auto ®ion = op.region(); 310 auto funOp = op.getFunction(); 311 auto funType = funOp ? funOp.getType() : FunctionType(); 312 if (!region.empty()) { 313 if (region.getBlocks().size() != 1) 314 return op.emitError("op expected region with 1 block"); 315 316 auto &block = region.getBlocks().front(); 317 if (block.getNumArguments() != nViews) 318 return op.emitError( 319 "op expected number of block arguments to match number of views"); 320 321 for (unsigned i = 0; i < nViews; ++i) { 322 auto viewType = op.getViewType(i); 323 if (viewType.getElementType() != block.getArgument(i)->getType()) 324 return op.emitError("op expected block argument ") 325 << i << " of the same type as elemental type of " 326 << ((i < nInputViews) ? "input " : "output ") 327 << "view: " << viewType; 328 } 329 } else { 330 if (!funOp || !funOp.getType()) 331 return op.emitError( 332 "op expected fun attribute to refer to a defined symbol"); 333 if (funType.getNumInputs() != nViews) 334 return op.emitError("op expected fun arguments to match number of views"); 335 if (funType.getNumResults() != op.getNumOutputs()) 336 return op.emitError( 337 "op expected fun results to match number of output views"); 338 } 339 340 auto nLoops = op.getNumLoops(); 341 SmallVector<AffineMap, 4> indexingMaps; 342 indexingMaps.reserve(op.indexing_maps().size()); 343 for (auto en : llvm::enumerate(op.indexing_maps())) { 344 auto idx = en.index(); 345 auto m = en.value().cast<AffineMapAttr>().getValue(); 346 indexingMaps.push_back(m); // Save reference to map for further checks. 347 auto view = (idx < nInputViews) ? op.getInputViewType(idx) 348 : op.getOutputViewType(idx - nInputViews); 349 350 if (m.getNumSymbols() != 0) 351 return op.emitError("op expected indexing_map #") 352 << idx << " to have no symbols"; 353 354 if (m.getNumDims() != nLoops) 355 return op.emitError("op expected indexing_map #") 356 << idx << " to have " << nLoops 357 << " dim(s) to match the number of loops"; 358 359 if (m.getNumResults() == 1 && view.getRank() == 0) { 360 auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>(); 361 if (!cst || cst.getValue() != 0) 362 return op.emitError("op expected indexing_map #") 363 << idx << " to be 0 to match 0-D view: " << view; 364 } 365 366 if (m.getNumResults() != view.getRank()) 367 return op.emitError("op expected indexing_map #") 368 << idx << " results to match view rank: " << view; 369 370 if (funType) { 371 if (funType.getInput(idx) != view.getElementType()) 372 return op.emitError("op expected fun argument ") 373 << idx 374 << " to match view element type: " << view.getElementType(); 375 376 if (idx >= nInputViews) 377 if (funType.getResult(idx - nInputViews) != view.getElementType()) 378 return op.emitError("op expected fun result ") 379 << idx << " to match output view element type: " 380 << view.getElementType(); 381 } 382 } 383 384 auto concatMap = concatAffineMaps(indexingMaps); 385 auto aggregateMap = inversePermutation(concatMap); 386 if (!aggregateMap) 387 return op.emitError("op expected the concatenation of maps in indexing_map " 388 "to be invertible"); 389 390 return success(); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // LoadOp 395 //===----------------------------------------------------------------------===// 396 397 static void print(OpAsmPrinter *p, linalg::LoadOp op) { 398 *p << op.getOperationName() << " " << *op.view() << '['; 399 p->printOperands(op.indices()); 400 *p << ']'; 401 p->printOptionalAttrDict(op.getAttrs()); 402 *p << " : " << op.getViewType(); 403 } 404 405 static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { 406 OpAsmParser::OperandType viewInfo; 407 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 408 ViewType type; 409 410 auto affineIntTy = parser->getBuilder().getIndexType(); 411 return failure( 412 parser->parseOperand(viewInfo) || 413 parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 414 parser->parseOptionalAttributeDict(result->attributes) || 415 parser->parseColonType(type) || 416 parser->resolveOperand(viewInfo, type, result->operands) || 417 parser->resolveOperands(indexInfo, affineIntTy, result->operands) || 418 parser->addTypeToList(type.getElementType(), result->types)); 419 } 420 421 static LogicalResult verify(linalg::LoadOp op) { 422 if (op.getRank() != llvm::size(op.indices())) 423 return op.emitOpError("expected ") 424 << op.getRank() << " indices, got " << llvm::size(op.indices()); 425 return success(); 426 } 427 428 //===----------------------------------------------------------------------===// 429 // RangeOp 430 //===----------------------------------------------------------------------===// 431 432 static void print(OpAsmPrinter *p, RangeOp op) { 433 *p << op.getOperationName() << " " << *op.min() << ":" << *op.max() << ":" 434 << *op.step(); 435 p->printOptionalAttrDict(op.getAttrs()); 436 *p << " : " << op.getResult()->getType(); 437 } 438 439 static ParseResult parseRangeOp(OpAsmParser *parser, OperationState *result) { 440 SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3); 441 RangeType type; 442 auto affineIntTy = parser->getBuilder().getIndexType(); 443 return failure( 444 parser->parseOperand(rangeInfo[0]) || parser->parseColon() || 445 parser->parseOperand(rangeInfo[1]) || parser->parseColon() || 446 parser->parseOperand(rangeInfo[2]) || 447 parser->parseOptionalAttributeDict(result->attributes) || 448 parser->parseColonType(type) || 449 parser->resolveOperands(rangeInfo, affineIntTy, result->operands) || 450 parser->addTypeToList(type, result->types)); 451 } 452 453 //===----------------------------------------------------------------------===// 454 // SliceOp 455 //===----------------------------------------------------------------------===// 456 457 void mlir::linalg::SliceOp::build(Builder *b, OperationState *result, 458 Value *base, ArrayRef<Value *> indexings) { 459 result->addOperands(base); 460 result->addOperands(indexings); 461 462 ViewType viewType = base->getType().cast<ViewType>(); 463 unsigned rank = viewType.getRank(); 464 for (auto *i : indexings) 465 if (!i->getType().isa<RangeType>()) 466 rank--; 467 Type elementType = viewType.getElementType(); 468 result->addTypes({ViewType::get(b->getContext(), elementType, rank)}); 469 } 470 471 static void print(OpAsmPrinter *p, SliceOp op) { 472 *p << SliceOp::getOperationName() << " " << *op.view() << "["; 473 p->printOperands(op.indexings()); 474 *p << "] "; 475 p->printOptionalAttrDict(op.getAttrs()); 476 *p << " : " << op.getBaseViewType(); 477 for (auto indexing : op.indexings()) { 478 *p << ", " << indexing->getType(); 479 } 480 *p << ", " << op.getType(); 481 } 482 483 static ParseResult parseSliceOp(OpAsmParser *parser, OperationState *result) { 484 OpAsmParser::OperandType baseInfo; 485 SmallVector<OpAsmParser::OperandType, 8> operands; 486 SmallVector<Type, 8> types; 487 if (parser->parseOperand(baseInfo) || 488 parser->parseOperandList(operands, OpAsmParser::Delimiter::Square) || 489 parser->parseOptionalAttributeDict(result->attributes) || 490 parser->parseColonTypeList(types)) 491 return failure(); 492 493 if (types.size() < 2) 494 return parser->emitError(parser->getCurrentLocation(), 495 "expected at least input and result view types"); 496 497 ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back(); 498 return failure( 499 parser->resolveOperand(baseInfo, types.front(), result->operands) || 500 (!operands.empty() && 501 parser->resolveOperands(operands, indexingTypes, 502 operands.front().location, result->operands)) || 503 parser->addTypeToList(types.back(), result->types)); 504 } 505 506 static LogicalResult verify(SliceOp op) { 507 unsigned rank = op.getBaseViewRank(); 508 if (rank != llvm::size(op.indexings())) 509 return op.emitOpError("expected ") 510 << op.getRank() << " indexings, got " << llvm::size(op.indexings()); 511 unsigned index = 0; 512 for (auto indexing : op.indexings()) { 513 if (indexing->getType().isa<IndexType>()) 514 --rank; 515 ++index; 516 } 517 if (op.getRank() != rank) 518 return op.emitOpError() << "expected rank of the view(" << op.getRank() 519 << ") to be the number of ranges(" << rank << ")"; 520 return success(); 521 } 522 523 //===----------------------------------------------------------------------===// 524 // StoreOp 525 //===----------------------------------------------------------------------===// 526 527 static void print(OpAsmPrinter *p, linalg::StoreOp op) { 528 *p << op.getOperationName() << " " << *op.value(); 529 *p << ", " << *op.view() << '['; 530 p->printOperands(op.indices()); 531 *p << ']'; 532 p->printOptionalAttrDict(op.getAttrs()); 533 *p << " : " << op.getViewType(); 534 } 535 536 static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { 537 OpAsmParser::OperandType storeValueInfo; 538 OpAsmParser::OperandType viewInfo; 539 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 540 ViewType viewType; 541 542 auto affineIntTy = parser->getBuilder().getIndexType(); 543 return failure( 544 parser->parseOperand(storeValueInfo) || parser->parseComma() || 545 parser->parseOperand(viewInfo) || 546 parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 547 parser->parseOptionalAttributeDict(result->attributes) || 548 parser->parseColonType(viewType) || 549 parser->resolveOperand(storeValueInfo, viewType.getElementType(), 550 result->operands) || 551 parser->resolveOperand(viewInfo, viewType, result->operands) || 552 parser->resolveOperands(indexInfo, affineIntTy, result->operands)); 553 } 554 555 static LogicalResult verify(linalg::StoreOp op) { 556 if (op.value()->getType() != op.getViewType().getElementType()) 557 return op.emitOpError("expected value type to match view element type"); 558 if (op.getRank() != llvm::size(op.indices())) 559 return op.emitOpError("expected ") 560 << op.getRank() << " indices, got " << llvm::size(op.indices()); 561 return success(); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // SubViewOp 566 //===----------------------------------------------------------------------===// 567 568 static void print(OpAsmPrinter *p, SubViewOp op) { 569 *p << op.getOperationName() << " " << *op.getOperand(0) << "["; 570 auto ranges = op.getRanges(); 571 interleaveComma(ranges, *p, [&p](const SubViewOp::Range &i) { 572 *p << *i.min << ", " << *i.max << ", " << *i.step; 573 }); 574 *p << "]"; 575 p->printOptionalAttrDict(op.getAttrs()); 576 *p << " : " << op.getViewType(); 577 } 578 579 static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) { 580 OpAsmParser::OperandType inputView, resultView; 581 Type viewType; 582 if (parser->parseOperand(inputView)) 583 return failure(); 584 585 SmallVector<OpAsmParser::OperandType, 12> ops; 586 // TODO(ntv) evolve parsing from 587 // linalg.subview %0[%1, %2, %3, %4, %5, %6] 588 // to something resembling 589 // linalg.subview %0[%1:%2:%3][%4:%5:%6] 590 if (parser->parseOperandList(ops, OpAsmParser::Delimiter::Square) || 591 parser->parseOptionalAttributeDict(result->attributes) || 592 parser->parseColonType(viewType)) 593 return failure(); 594 595 auto indexTy = parser->getBuilder().getIndexType(); 596 return failure( 597 parser->resolveOperand(inputView, viewType, result->operands) || 598 parser->resolveOperands(ops, indexTy, result->operands) || 599 parser->addTypeToList(viewType, result->types)); 600 } 601 602 //===----------------------------------------------------------------------===// 603 // TransposeOp 604 //===----------------------------------------------------------------------===// 605 void mlir::linalg::TransposeOp::build(Builder *b, OperationState *result, 606 Value *view, AffineMapAttr permutation, 607 ArrayRef<NamedAttribute> attrs) { 608 // TODO(ntv): once views have static dimensions, compute the permuted type. 609 build(b, result, view->getType(), view, attrs); 610 result->addAttribute(TransposeOp::getPermutationAttrName(), permutation); 611 } 612 613 static void print(OpAsmPrinter *p, TransposeOp op) { 614 *p << op.getOperationName() << " " << *op.view() << " " << op.permutation(); 615 p->printOptionalAttrDict(op.getAttrs(), 616 {TransposeOp::getPermutationAttrName()}); 617 *p << " : " << op.view()->getType(); 618 } 619 620 static ParseResult parseTransposeOp(OpAsmParser *parser, 621 OperationState *result) { 622 OpAsmParser::OperandType view; 623 AffineMapAttr permutation; 624 Type type; 625 return failure(parser->parseOperand(view) || 626 parser->parseAttribute(permutation, 627 TransposeOp::getPermutationAttrName(), 628 result->attributes) || 629 parser->parseOptionalAttributeDict(result->attributes) || 630 parser->parseColonType(type) || 631 parser->resolveOperand(view, type, result->operands) || 632 parser->addTypeToList(type, result->types)); 633 } 634 635 //===----------------------------------------------------------------------===// 636 // ViewOp 637 //===----------------------------------------------------------------------===// 638 void mlir::linalg::ViewOp::build(Builder *b, OperationState *result, 639 Value *buffer, ArrayRef<Value *> ranges, 640 Type resultType, 641 ArrayRef<NamedAttribute> attrs) { 642 if (!resultType) { 643 Type elementType = buffer->getType().cast<BufferType>().getElementType(); 644 resultType = ViewType::get(b->getContext(), elementType, ranges.size()); 645 } 646 build(b, result, resultType, buffer, ranges); 647 result->addAttributes(attrs); 648 } 649 650 static void print(OpAsmPrinter *p, ViewOp op) { 651 *p << op.getOperationName() << " " << *op.buffer() << "["; 652 interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; }); 653 *p << "] "; 654 p->printOptionalAttrDict(op.getAttrs()); 655 *p << " : " << op.buffer()->getType() << " -> " << op.getType(); 656 } 657 658 static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) { 659 OpAsmParser::OperandType bufferInfo; 660 SmallVector<OpAsmParser::OperandType, 8> rangesInfo; 661 Type bType, vType; 662 if (parser->parseOperand(bufferInfo) || 663 parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) || 664 parser->parseOptionalAttributeDict(result->attributes) || 665 parser->parseColon() || parser->parseType(bType) || 666 parser->parseArrow() || parser->parseType(vType)) { 667 return failure(); 668 } 669 670 ViewType viewType = vType.dyn_cast<ViewType>(); 671 if (!viewType) 672 return parser->emitError(parser->getNameLoc(), "expected view type"); 673 if (viewType.getRank() != rangesInfo.size()) 674 return parser->emitError(parser->getNameLoc(), "expected ") 675 << viewType.getRank() << " ranges"; 676 return failure( 677 parser->resolveOperand(bufferInfo, bType, result->operands) || 678 (!rangesInfo.empty() && 679 parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()), 680 result->operands)) || 681 parser->addTypeToList(viewType, result->types)); 682 } 683 684 //===----------------------------------------------------------------------===// 685 // YieldOp 686 //===----------------------------------------------------------------------===// 687 688 static void print(OpAsmPrinter *p, YieldOp op) { 689 *p << op.getOperationName(); 690 if (op.getNumOperands() > 0) { 691 *p << ' '; 692 p->printOperands(op.operand_begin(), op.operand_end()); 693 } 694 p->printOptionalAttrDict(op.getAttrs()); 695 if (op.getNumOperands() > 0) { 696 *p << " : "; 697 interleaveComma(op.getOperands(), *p, 698 [&](Value *e) { p->printType(e->getType()); }); 699 } 700 } 701 702 static ParseResult parseYieldOp(OpAsmParser *parser, OperationState *result) { 703 SmallVector<OpAsmParser::OperandType, 2> opInfo; 704 SmallVector<Type, 2> types; 705 llvm::SMLoc loc = parser->getCurrentLocation(); 706 return failure(parser->parseOperandList(opInfo) || 707 parser->parseOptionalAttributeDict(result->attributes) || 708 (!opInfo.empty() && parser->parseColonTypeList(types)) || 709 parser->resolveOperands(opInfo, types, loc, result->operands)); 710 } 711 712 static LogicalResult verify(YieldOp op) { 713 auto *parentOp = op.getParentOp(); 714 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) 715 return op.emitOpError("op expected single non-empty parent region"); 716 717 auto genericOp = dyn_cast<GenericOp>(parentOp); 718 if (!genericOp) 719 return op.emitOpError("op expected '") 720 << GenericOp::getOperationName() << "' parent op"; 721 722 // The operand number and types must match the view element types. 723 auto nOutputViews = genericOp.getNumOutputs(); 724 if (op.getNumOperands() != nOutputViews) 725 return op.emitOpError("op expected ") 726 << nOutputViews << " operand to match enclosing linalg.generic op"; 727 728 for (unsigned i = 0; i != nOutputViews; ++i) { 729 auto elementType = genericOp.getOutputViewType(i).getElementType(); 730 if (op.getOperand(i)->getType() != elementType) 731 return op.emitError("type of return operand ") 732 << i << " (" << op.getOperand(i)->getType() 733 << ") doesn't match view element type (" << elementType << ")"; 734 } 735 return success(); 736 } 737 738 /////// Operations corresponding to library calls defined with Tablegen //////// 739 // For such operations correspond to library calls (i.e. defined in 740 // LinalgLibraryOps.td), we define an overloaded `print` function and a 741 // parse`className` function. 742 743 // A LinalgLibraryOp prints as: 744 // 745 // ```{.mlir} 746 // concrete_op_name (ssa-inputs, ssa-outputs) : view-types 747 // ``` 748 // 749 // for example: 750 // 751 // ``` 752 // linalg.matmul(%0, %1, %2) : 753 // !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32> 754 // ``` 755 // 756 // Where %0, %1 and %2 are ssa-values of type ViewType. 757 static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) { 758 assert(op->getAbstractOperation() && "unregistered operation"); 759 *p << op->getName().getStringRef() << "("; 760 interleave( 761 op->getOperands().begin(), op->getOperands().end(), 762 [&](Value *v) { *p << *v; }, [&]() { *p << ", "; }); 763 *p << ")"; 764 p->printOptionalAttrDict(op->getAttrs()); 765 *p << " : "; 766 interleave( 767 op->getOperands().begin(), op->getOperands().end(), 768 [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; }); 769 } 770 771 static ParseResult parseLinalgLibraryOp(OpAsmParser *parser, 772 OperationState *result) { 773 SmallVector<OpAsmParser::OperandType, 3> ops; 774 SmallVector<Type, 3> types; 775 return failure(parser->parseOperandList(ops, OpAsmParser::Delimiter::Paren) || 776 parser->parseOptionalAttributeDict(result->attributes) || 777 parser->parseColonTypeList(types) || 778 parser->resolveOperands(ops, types, parser->getNameLoc(), 779 result->operands)); 780 } 781 782 static LogicalResult verify(FillOp op) { 783 auto viewType = op.getOutputViewType(0); 784 auto fillType = op.getValue()->getType(); 785 if (viewType.getElementType() != fillType) 786 return op.emitOpError("expects fill type to match view elemental type"); 787 return success(); 788 } 789 790 static LogicalResult verify(CopyOp op) { 791 auto outputViewType = op.getOutputViewType(0); 792 auto inputViewType = op.getInputViewType(0); 793 if (inputViewType.getElementType() != outputViewType.getElementType()) 794 return op.emitOpError("expects views of the same type"); 795 if (inputViewType.getRank() != outputViewType.getRank()) 796 return op.emitOpError("expects views of the same rank"); 797 auto rank = op.getNumParallelLoops(); 798 auto inputPermutationMap = op.inputPermutation(); 799 if (inputPermutationMap) { 800 if (inputPermutationMap->getNumInputs() != rank) 801 return op.emitOpError("expects optional input_permutation map of rank ") 802 << rank; 803 if (!inputPermutationMap->isPermutation()) 804 return op.emitOpError( 805 "expects optional input_permutation map to be a permutation"); 806 } 807 auto outputPermutationMap = op.outputPermutation(); 808 if (outputPermutationMap) { 809 if (outputPermutationMap->getNumInputs() != rank) 810 return op.emitOpError("expects optional output_permutation map of rank ") 811 << rank; 812 if (!outputPermutationMap->isPermutation()) 813 return op.emitOpError( 814 "expects optional output_permutation map to be a permutation"); 815 } 816 if (rank == 0 && inputPermutationMap) 817 return op.emitOpError("expected no input permutation when rank == 0"); 818 if (rank == 0 && outputPermutationMap) 819 return op.emitOpError("expected no output permutation when rank == 0"); 820 return success(); 821 } 822 823 static LogicalResult 824 verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) { 825 auto strideOrDilation = isStride ? "stride" : "dilation"; 826 if (attrs.size() != op.getNumWindowLoops()) 827 return op.emitOpError("expects num ") 828 << strideOrDilation 829 << "s equal to number of window dimensions: " << attrs.size() 830 << " vs " << op.getNumWindowLoops(); 831 return success(); 832 } 833 834 static LogicalResult verify(ConvOp op) { 835 auto oType = op.output()->getType().cast<ViewType>(); 836 auto fType = op.filter()->getType().cast<ViewType>(); 837 auto iType = op.input()->getType().cast<ViewType>(); 838 if (oType.getElementType() != iType.getElementType() || 839 oType.getElementType() != fType.getElementType()) 840 return op.emitOpError("expects view elemental types to match"); 841 if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank()) 842 return op.emitOpError("expects view ranks to match"); 843 if (auto strides = op.strides()) { 844 if (failed( 845 verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) 846 return failure(); 847 } 848 if (auto dilations = op.dilations()) { 849 if (failed(verifyStrideOrDilation(op, dilations->getValue(), 850 /*isStride=*/false))) 851 return failure(); 852 } 853 return success(); 854 } 855 856 llvm::raw_ostream &mlir::linalg::operator<<(llvm::raw_ostream &os, 857 SubViewOp::Range &range) { 858 return os << "range " << *range.min << ":" << *range.max << ":" 859 << *range.step; 860 } 861 862 namespace mlir { 863 namespace linalg { 864 865 #include "mlir/Dialect/Linalg/IR/LinalgLibraryOpInterfaces.cpp.inc" 866 867 #define GET_OP_CLASSES 868 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" 869 870 #define GET_OP_CLASSES 871 #include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc" 872 873 } // namespace linalg 874 } // namespace mlir 875 876 static AffineMap extractOrIdentityMap(llvm::Optional<AffineMap> maybeMap, 877 unsigned rank, MLIRContext *context) { 878 if (maybeMap) 879 return maybeMap.getValue(); 880 if (rank == 0) 881 return AffineMap(); 882 return AffineMap::getMultiDimIdentityMap(rank, context); 883 } 884 885 // Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num) 886 // and increments `curIdx` to `curIdx + num`. 887 static SmallVector<AffineExpr, 4> 888 makeAffineDimExprs(unsigned num, unsigned &curIdx, MLIRContext *context) { 889 SmallVector<AffineExpr, 4> res; 890 res.reserve(num); 891 for (unsigned i = 0; i < num; ++i) 892 res.push_back(getAffineDimExpr(curIdx++, context)); 893 return res; 894 } 895 896 static SmallVector<AffineExpr, 4> 897 weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> a, 898 ArrayRef<AffineExpr> b) { 899 assert(a.size() == b.size()); 900 SmallVector<AffineExpr, 4> res; 901 res.reserve(a.size()); 902 for (unsigned i = 0, e = a.size(); i < e; ++i) { 903 res.push_back(op.getStride(i) * a[i] + op.getDilation(i) * b[i]); 904 } 905 return res; 906 } 907 908 static SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a, 909 ArrayRef<AffineExpr> b) { 910 SmallVector<AffineExpr, 4> res; 911 res.reserve(a.size() + b.size()); 912 res.assign(a.begin(), a.end()); 913 res.append(b.begin(), b.end()); 914 return res; 915 } 916 917 // Note: both functions below would completely disappear with a simple tensor 918 // kernel language. 919 // 920 // Ideally this should all be Tablegen'd but there is no good story for 921 // AffineMap for now. 922 SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) { 923 MLIRContext *context = op->getContext(); 924 if (auto copyOp = dyn_cast<CopyOp>(op)) { 925 // I(input_perm(ivs)) -> O(output_perm(ivs)) 926 auto maybeInputMap = copyOp.inputPermutation(); 927 auto maybeOutputMap = copyOp.outputPermutation(); 928 unsigned inputRank = copyOp.getInputViewType(0).getRank(); 929 unsigned outputRank = copyOp.getOutputViewType(0).getRank(); 930 return SmallVector<AffineMap, 4>{ 931 extractOrIdentityMap(maybeInputMap, inputRank, context), 932 extractOrIdentityMap(maybeOutputMap, outputRank, context)}; 933 } 934 if (auto fillOp = dyn_cast<FillOp>(op)) { 935 // filling_value -> O(ivs) 936 unsigned rank = fillOp.getNumParallelLoops(); 937 return SmallVector<AffineMap, 4>{ 938 extractOrIdentityMap(llvm::None, rank, context)}; 939 } 940 auto i = getAffineDimExpr(0, context); 941 auto j = getAffineDimExpr(1, context); 942 auto k = getAffineDimExpr(2, context); 943 if (isa<DotOp>(op)) 944 // A(r_i) * B(r_i) -> C() 945 return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}), 946 AffineMap::get(1, 0, {i}), AffineMap()}; 947 if (isa<MatvecOp>(op)) 948 // A(i, r_j) * B(r_j) -> C(i) 949 return SmallVector<AffineMap, 4>{AffineMap::get(2, 0, {i, j}), 950 AffineMap::get(2, 0, {j}), 951 AffineMap::get(2, 0, {i})}; 952 if (isa<MatmulOp>(op)) 953 // A(i, r_k) * B(r_k, j) -> C(i, j) 954 return SmallVector<AffineMap, 4>{AffineMap::get(3, 0, {i, k}), 955 AffineMap::get(3, 0, {k, j}), 956 AffineMap::get(3, 0, {i, j})}; 957 if (auto convOp = dyn_cast<ConvOp>(op)) { 958 // F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) -> 959 // O(b, x0, ..., xN-1, k) 960 // for N equal to `nWindow`. 961 auto nWin = convOp.getNumWindowLoops(); 962 assert(nWin > 0 && "expected at least one window dimension"); 963 unsigned idx = 0; 964 // In the following, AffineDimExprs are indexed in loop order: 965 // [ b, xs, k, q, zs] 966 // parallels non-window reductions windows 967 // 968 // Parallel dims are exactly the dimensions indexing `output`: 969 // output[b, x[0], ..., x[N-1], k]; i.e. 970 // * batch dimensions (bs with #bs = 1 for now) 971 // * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks) 972 // * output filter dimensions (ks with #ks = 1 for now) 973 auto bs = makeAffineDimExprs(convOp.getNumBatchDimensions(), idx, context); 974 auto xs = makeAffineDimExprs(nWin, idx, context); 975 auto ks = makeAffineDimExprs(convOp.getNumOutputFeatureDimensions(), idx, 976 context); 977 // Non-window reduction dim: sum_{z[0], ..., z[N-1], q} 978 auto qs = 979 makeAffineDimExprs(convOp.getNumInputFeatureDimensions(), idx, context); 980 // Window reduction dims: sum_{z[0], ..., z[N-1], q} 981 auto zs = makeAffineDimExprs(nWin, idx, context); 982 // Construct the weighedSum expression. 983 auto ws = weightedConvInputIndex(convOp, xs, zs); 984 return SmallVector<AffineMap, 4>{ 985 // filter[z[0], ..., z[N-1], q, k] 986 AffineMap::get(idx, 0, concat(concat(zs, qs), ks)), 987 // input[b, 988 // x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1], 989 // q] 990 AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), 991 // output[b, x[0], ..., x[N-1], k] 992 AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; 993 } else if (auto genericOp = dyn_cast<GenericOp>(op)) { 994 SmallVector<AffineMap, 4> res; 995 unsigned nViews = genericOp.getNumInputsAndOutputs(); 996 res.reserve(nViews); 997 for (unsigned i = 0, e = nViews; i < e; ++i) { 998 res.push_back(genericOp.getIndexingMap(i)); 999 } 1000 return res; 1001 } 1002 llvm_unreachable("Missing loopToOperandRangesMaps for op"); 1003 } 1004 1005 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { 1006 if (auto view = t.dyn_cast<ViewType>()) { 1007 ss << "view"; 1008 for (unsigned i = 0, e = view.getRank(); i < e; ++i) 1009 ss << "x"; 1010 appendMangledType(ss, view.getElementType()); 1011 } else if (auto vec = t.dyn_cast<VectorType>()) { 1012 ss << "vector"; 1013 interleave( 1014 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); 1015 appendMangledType(ss, vec.getElementType()); 1016 } else if (t.isIntOrIndexOrFloat()) { 1017 ss << t; 1018 } else { 1019 llvm_unreachable("Invalid type for linalg library name mangling"); 1020 } 1021 } 1022 1023 std::string mlir::linalg::generateLibraryCallName(Operation *op) { 1024 assert(isa<LinalgOp>(op)); 1025 std::string name(op->getName().getStringRef().str()); 1026 name.reserve(128); 1027 std::replace(name.begin(), name.end(), '.', '_'); 1028 llvm::raw_string_ostream ss(name); 1029 ss << "_"; 1030 auto types = op->getOperandTypes(); 1031 interleave( 1032 types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, 1033 [&]() { ss << "_"; }); 1034 return ss.str(); 1035 }