github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (about)

     1  //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements a the Linalg operations.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
    23  #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
    24  #include "mlir/Dialect/Linalg/Utils/Utils.h"
    25  #include "mlir/Dialect/LoopOps/LoopOps.h"
    26  #include "mlir/EDSC/Helpers.h"
    27  #include "mlir/IR/AffineExpr.h"
    28  #include "mlir/IR/AffineMap.h"
    29  #include "mlir/IR/Builders.h"
    30  #include "mlir/IR/Function.h"
    31  #include "mlir/IR/Module.h"
    32  #include "mlir/IR/OpImplementation.h"
    33  #include "mlir/IR/PatternMatch.h"
    34  #include "mlir/IR/StandardTypes.h"
    35  #include "mlir/Support/LLVM.h"
    36  #include "mlir/Support/STLExtras.h"
    37  #include "mlir/Transforms/FoldUtils.h"
    38  
    39  #include "llvm/ADT/StringSet.h"
    40  #include "llvm/Support/MathExtras.h"
    41  #include "llvm/Support/raw_ostream.h"
    42  
    43  using namespace mlir;
    44  using namespace mlir::edsc;
    45  using namespace mlir::edsc::intrinsics;
    46  using namespace mlir::linalg;
    47  
    48  namespace {
    49  /// Fold constant dimensions into an alloc operation.
    50  struct SimplifyDimOp : public OpRewritePattern<linalg::DimOp> {
    51    using OpRewritePattern<linalg::DimOp>::OpRewritePattern;
    52  
    53    PatternMatchResult matchAndRewrite(linalg::DimOp dimOp,
    54                                       PatternRewriter &rewriter) const override;
    55  };
    56  } // end namespace
    57  
    58  PatternMatchResult
    59  SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
    60                                 PatternRewriter &rewriter) const {
    61    auto *viewProducingOp = dimOp.view()->getDefiningOp();
    62    auto subView = dyn_cast_or_null<SubViewOp>(viewProducingOp);
    63    auto slice = dyn_cast_or_null<SliceOp>(viewProducingOp);
    64    auto view = dyn_cast_or_null<ViewOp>(viewProducingOp);
    65    assert(subView || slice || view);
    66  
    67    unsigned dim = dimOp.getIndex();
    68    Value *min, *max, *step;
    69    if (view) {
    70      // Cannot traverse block arguments, fail.
    71      if (isa<BlockArgument>(view.getRange(dim)))
    72        return matchFailure();
    73      // Record min, max, step for further processing.
    74      auto range = cast<RangeOp>(view.getRange(dim)->getDefiningOp());
    75      std::tie(min, max, step) =
    76          std::make_tuple(range.min(), range.max(), range.step());
    77    } else if (subView) {
    78      // Record min, max, step for further processing.
    79      auto range = subView.getRange(dim);
    80      std::tie(min, max, step) =
    81          std::make_tuple(range.min, range.max, range.step);
    82    } else {
    83      // Taking the dim of a slice must take a range (since other dims have been
    84      // rank-reduced).
    85      auto *rangeValue = slice.getRanges()[dim];
    86      // Cannot traverse block arguments, fail.
    87      if (isa<BlockArgument>(rangeValue))
    88        return matchFailure();
    89      auto range = cast<RangeOp>(rangeValue->getDefiningOp());
    90      // Record min, max, step for further processing.
    91      std::tie(min, max, step) =
    92          std::make_tuple(range.min(), range.max(), range.step());
    93    }
    94  
    95    // Only support constant steps of 1 atm.
    96    auto constant = dyn_cast_or_null<ConstantIndexOp>(step->getDefiningOp());
    97    if (!constant || constant.getValue() != 1)
    98      return matchFailure();
    99  
   100    // Circumvent affine constraints:
   101    //   emit an affine_apply when possible, otherwise emit a `subi`.
   102    bool validAffineMin = isValidDim(min) || isValidSymbol(min) ||
   103                          isa_and_nonnull<ConstantIndexOp>(min->getDefiningOp());
   104    bool validAffineMax = isValidDim(max) || isValidSymbol(max) ||
   105                          isa_and_nonnull<ConstantIndexOp>(max->getDefiningOp());
   106  
   107    OpBuilder b(dimOp);
   108    ScopedContext scope(b, dimOp.getLoc());
   109    // Emit `subi`.
   110    if (!validAffineMin || !validAffineMax) {
   111      rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()});
   112      return matchSuccess();
   113    }
   114  
   115    // Emit affine_apply.
   116    using edsc::op::operator-;
   117    rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)},
   118                       {dimOp.view()});
   119    return matchSuccess();
   120  }
   121  
   122  ///////////////////// Operations defined with Tablegen /////////////////////////
   123  // For such operations that do not correspond to library calls (i.e. defined in
   124  // LinalgOps.td), we define an overloaded `print` function and a
   125  // parse`className` function.
   126  
   127  //===----------------------------------------------------------------------===//
   128  // BufferAllocOp
   129  //===----------------------------------------------------------------------===//
   130  
   131  static void print(OpAsmPrinter *p, BufferAllocOp op) {
   132    *p << op.getOperationName() << " ";
   133    if (!llvm::empty(op.size()))
   134      *p << *op.getOperand(0);
   135    if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
   136      p->printOptionalAttrDict(op.getAttrs());
   137    else
   138      p->printOptionalAttrDict(op.getAttrs(),
   139                               BufferAllocOp::getAlignmentAttrName());
   140    *p << " : " << op.getBufferType();
   141  }
   142  
   143  static ParseResult parseBufferAllocOp(OpAsmParser *parser,
   144                                        OperationState *result) {
   145    SmallVector<OpAsmParser::OperandType, 1> sizeInfo;
   146    BufferType bufferType;
   147    auto indexTy = parser->getBuilder().getIndexType();
   148    if (parser->parseOperandList(sizeInfo) ||
   149        parser->parseOptionalAttributeDict(result->attributes) ||
   150        parser->parseColonType(bufferType))
   151      return failure();
   152    if (sizeInfo.empty())
   153      return parser->addTypeToList(bufferType, result->types);
   154    return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
   155                   parser->addTypeToList(bufferType, result->types));
   156  }
   157  
   158  static LogicalResult verify(BufferAllocOp op) {
   159    if (!op.getBufferType().hasConstantSize()) {
   160      if (llvm::size(op.size()) != 1)
   161        return op.emitOpError("expected one index operand");
   162    } else { // op.getBufferType().hasConstantSize()
   163      if (!llvm::empty(op.size()))
   164        return op.emitOpError("expected zero operand");
   165      if (op.getBufferType().getBufferSize().getValue() <= 0)
   166        return op.emitOpError("expected nonnegative static buffer size");
   167    }
   168    if (op.alignment().hasValue()) {
   169      auto align = op.alignment().getValue();
   170      if (align.getSExtValue() < 0)
   171        return op.emitOpError("expected positive alignment");
   172      if (!llvm::isPowerOf2_64(align.getZExtValue()))
   173        return op.emitOpError("expected power of 2 alignment");
   174    }
   175    if (!TensorType::isValidElementType(op.getElementType()))
   176      return op.emitOpError("expected valid buffer element type");
   177    return success();
   178  }
   179  
   180  //===----------------------------------------------------------------------===//
   181  // BufferDeallocOp
   182  //===----------------------------------------------------------------------===//
   183  
   184  static void print(OpAsmPrinter *p, BufferDeallocOp op) {
   185    *p << op.getOperationName() << " " << *op.buffer();
   186    p->printOptionalAttrDict(op.getAttrs());
   187    *p << " : " << op.getBufferType();
   188  }
   189  
   190  static ParseResult parseBufferDeallocOp(OpAsmParser *parser,
   191                                          OperationState *result) {
   192    OpAsmParser::OperandType bufferInfo;
   193    BufferType bufferType;
   194    if (parser->parseOperand(bufferInfo) ||
   195        parser->parseOptionalAttributeDict(result->attributes) ||
   196        parser->parseColonType(bufferType))
   197      return failure();
   198    return parser->resolveOperands(bufferInfo, bufferType, result->operands);
   199  }
   200  
   201  //===----------------------------------------------------------------------===//
   202  // BufferSizeOp
   203  //===----------------------------------------------------------------------===//
   204  
   205  static void print(OpAsmPrinter *p, BufferSizeOp op) {
   206    *p << op.getOperationName() << " " << *op.buffer();
   207    p->printOptionalAttrDict(op.getAttrs());
   208    *p << " : " << op.buffer()->getType();
   209  }
   210  
   211  static ParseResult parseBufferSizeOp(OpAsmParser *parser,
   212                                       OperationState *result) {
   213    OpAsmParser::OperandType op;
   214    Type type;
   215    return failure(parser->parseOperand(op) ||
   216                   parser->parseOptionalAttributeDict(result->attributes) ||
   217                   parser->parseColonType(type) ||
   218                   parser->resolveOperand(op, type, result->operands) ||
   219                   parser->addTypeToList(parser->getBuilder().getIndexType(),
   220                                         result->types));
   221  }
   222  
   223  //===----------------------------------------------------------------------===//
   224  // DimOp
   225  //===----------------------------------------------------------------------===//
   226  void mlir::linalg::DimOp::getCanonicalizationPatterns(
   227      OwningRewritePatternList &results, MLIRContext *context) {
   228    results.insert<SimplifyDimOp>(context);
   229  }
   230  
   231  static void print(OpAsmPrinter *p, linalg::DimOp op) {
   232    *p << op.getOperationName() << " " << *op.getOperand() << ", "
   233       << op.getIndex();
   234    p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
   235    *p << " : " << op.getOperand()->getType();
   236  }
   237  
   238  static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) {
   239    OpAsmParser::OperandType operandInfo;
   240    IntegerAttr indexAttr;
   241    Type type;
   242    Type indexType = parser->getBuilder().getIndexType();
   243    return failure(parser->parseOperand(operandInfo) || parser->parseComma() ||
   244                   parser->parseAttribute(indexAttr, indexType, "index",
   245                                          result->attributes) ||
   246                   parser->parseOptionalAttributeDict(result->attributes) ||
   247                   parser->parseColonType(type) ||
   248                   parser->resolveOperand(operandInfo, type, result->operands) ||
   249                   parser->addTypeToList(indexType, result->types));
   250  }
   251  
   252  //===----------------------------------------------------------------------===//
   253  // GenericOp
   254  //===----------------------------------------------------------------------===//
   255  
   256  static void print(OpAsmPrinter *p, GenericOp op) {
   257    auto attrNames = op.linalgTraitAttrNames();
   258    llvm::StringSet<> linalgTraitAttrsSet;
   259    linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
   260    SmallVector<NamedAttribute, 8> attrs;
   261    for (auto attr : op.getAttrs()) {
   262      if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
   263        attrs.push_back(attr);
   264    }
   265    auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
   266    *p << op.getOperationName() << " " << dictAttr << " ";
   267    p->printOperands(op.getOperands());
   268    if (!op.region().empty())
   269      p->printRegion(op.region());
   270    p->printOptionalAttrDict(op.getAttrs(), attrNames);
   271    *p << ": ";
   272    interleaveComma(op.getOperandTypes(), *p);
   273  }
   274  
   275  static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) {
   276    SmallVector<OpAsmParser::OperandType, 8> operandsInfo, regionOperandsInfo;
   277    DictionaryAttr dictAttr;
   278    // Parse the core linalg traits that must check into a dictAttr.
   279    // The name is unimportant as we will overwrite result->attributes.
   280    // The core linalg traits must contain the information necessary to pass the
   281    // verifier.
   282    if (parser->parseAttribute(dictAttr, "_", result->attributes) ||
   283        parser->parseOperandList(operandsInfo))
   284      return failure();
   285    result->attributes.assign(dictAttr.getValue().begin(),
   286                              dictAttr.getValue().end());
   287  
   288    Region &region = *result->addRegion();
   289    SmallVector<Type, 8> operandTypes, regionTypes;
   290    // Optional attributes may be added.
   291    // Either Optional "fun" attribute or region must be specified.
   292    if (!dictAttr.get("fun") &&
   293        parser->parseOptionalRegion(region, regionOperandsInfo, regionTypes))
   294      return failure();
   295    if (parser->parseOptionalAttributeDict(result->attributes) ||
   296        parser->parseColonTypeList(operandTypes))
   297      return failure();
   298    return parser->resolveOperands(operandsInfo, operandTypes,
   299                                   parser->getCurrentLocation(),
   300                                   result->operands);
   301  }
   302  
   303  static LogicalResult verify(GenericOp op) {
   304    auto nInputViews = op.getNumInputs();
   305    auto nViews = op.getNumInputsAndOutputs();
   306    if (nViews != llvm::size(op.views()))
   307      return op.emitError("op expected exactly ") << nViews << " view operands";
   308  
   309    auto &region = op.region();
   310    auto funOp = op.getFunction();
   311    auto funType = funOp ? funOp.getType() : FunctionType();
   312    if (!region.empty()) {
   313      if (region.getBlocks().size() != 1)
   314        return op.emitError("op expected region with 1 block");
   315  
   316      auto &block = region.getBlocks().front();
   317      if (block.getNumArguments() != nViews)
   318        return op.emitError(
   319            "op expected number of block arguments to match number of views");
   320  
   321      for (unsigned i = 0; i < nViews; ++i) {
   322        auto viewType = op.getViewType(i);
   323        if (viewType.getElementType() != block.getArgument(i)->getType())
   324          return op.emitError("op expected block argument ")
   325                 << i << " of the same type as elemental type of "
   326                 << ((i < nInputViews) ? "input " : "output ")
   327                 << "view: " << viewType;
   328      }
   329    } else {
   330      if (!funOp || !funOp.getType())
   331        return op.emitError(
   332            "op expected fun attribute to refer to a defined symbol");
   333      if (funType.getNumInputs() != nViews)
   334        return op.emitError("op expected fun arguments to match number of views");
   335      if (funType.getNumResults() != op.getNumOutputs())
   336        return op.emitError(
   337            "op expected fun results to match number of output views");
   338    }
   339  
   340    auto nLoops = op.getNumLoops();
   341    SmallVector<AffineMap, 4> indexingMaps;
   342    indexingMaps.reserve(op.indexing_maps().size());
   343    for (auto en : llvm::enumerate(op.indexing_maps())) {
   344      auto idx = en.index();
   345      auto m = en.value().cast<AffineMapAttr>().getValue();
   346      indexingMaps.push_back(m); // Save reference to map for further checks.
   347      auto view = (idx < nInputViews) ? op.getInputViewType(idx)
   348                                      : op.getOutputViewType(idx - nInputViews);
   349  
   350      if (m.getNumSymbols() != 0)
   351        return op.emitError("op expected indexing_map #")
   352               << idx << " to have no symbols";
   353  
   354      if (m.getNumDims() != nLoops)
   355        return op.emitError("op expected indexing_map #")
   356               << idx << " to have " << nLoops
   357               << " dim(s) to match the number of loops";
   358  
   359      if (m.getNumResults() == 1 && view.getRank() == 0) {
   360        auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>();
   361        if (!cst || cst.getValue() != 0)
   362          return op.emitError("op expected indexing_map #")
   363                 << idx << " to be 0 to match 0-D view: " << view;
   364      }
   365  
   366      if (m.getNumResults() != view.getRank())
   367        return op.emitError("op expected indexing_map #")
   368               << idx << " results to match view rank: " << view;
   369  
   370      if (funType) {
   371        if (funType.getInput(idx) != view.getElementType())
   372          return op.emitError("op expected fun argument ")
   373                 << idx
   374                 << " to match view element type: " << view.getElementType();
   375  
   376        if (idx >= nInputViews)
   377          if (funType.getResult(idx - nInputViews) != view.getElementType())
   378            return op.emitError("op expected fun result ")
   379                   << idx << " to match output view element type: "
   380                   << view.getElementType();
   381      }
   382    }
   383  
   384    auto concatMap = concatAffineMaps(indexingMaps);
   385    auto aggregateMap = inversePermutation(concatMap);
   386    if (!aggregateMap)
   387      return op.emitError("op expected the concatenation of maps in indexing_map "
   388                          "to be invertible");
   389  
   390    return success();
   391  }
   392  
   393  //===----------------------------------------------------------------------===//
   394  // LoadOp
   395  //===----------------------------------------------------------------------===//
   396  
   397  static void print(OpAsmPrinter *p, linalg::LoadOp op) {
   398    *p << op.getOperationName() << " " << *op.view() << '[';
   399    p->printOperands(op.indices());
   400    *p << ']';
   401    p->printOptionalAttrDict(op.getAttrs());
   402    *p << " : " << op.getViewType();
   403  }
   404  
   405  static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
   406    OpAsmParser::OperandType viewInfo;
   407    SmallVector<OpAsmParser::OperandType, 4> indexInfo;
   408    ViewType type;
   409  
   410    auto affineIntTy = parser->getBuilder().getIndexType();
   411    return failure(
   412        parser->parseOperand(viewInfo) ||
   413        parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
   414        parser->parseOptionalAttributeDict(result->attributes) ||
   415        parser->parseColonType(type) ||
   416        parser->resolveOperand(viewInfo, type, result->operands) ||
   417        parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
   418        parser->addTypeToList(type.getElementType(), result->types));
   419  }
   420  
   421  static LogicalResult verify(linalg::LoadOp op) {
   422    if (op.getRank() != llvm::size(op.indices()))
   423      return op.emitOpError("expected ")
   424             << op.getRank() << " indices, got " << llvm::size(op.indices());
   425    return success();
   426  }
   427  
   428  //===----------------------------------------------------------------------===//
   429  // RangeOp
   430  //===----------------------------------------------------------------------===//
   431  
   432  static void print(OpAsmPrinter *p, RangeOp op) {
   433    *p << op.getOperationName() << " " << *op.min() << ":" << *op.max() << ":"
   434       << *op.step();
   435    p->printOptionalAttrDict(op.getAttrs());
   436    *p << " : " << op.getResult()->getType();
   437  }
   438  
   439  static ParseResult parseRangeOp(OpAsmParser *parser, OperationState *result) {
   440    SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
   441    RangeType type;
   442    auto affineIntTy = parser->getBuilder().getIndexType();
   443    return failure(
   444        parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
   445        parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
   446        parser->parseOperand(rangeInfo[2]) ||
   447        parser->parseOptionalAttributeDict(result->attributes) ||
   448        parser->parseColonType(type) ||
   449        parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
   450        parser->addTypeToList(type, result->types));
   451  }
   452  
   453  //===----------------------------------------------------------------------===//
   454  // SliceOp
   455  //===----------------------------------------------------------------------===//
   456  
   457  void mlir::linalg::SliceOp::build(Builder *b, OperationState *result,
   458                                    Value *base, ArrayRef<Value *> indexings) {
   459    result->addOperands(base);
   460    result->addOperands(indexings);
   461  
   462    ViewType viewType = base->getType().cast<ViewType>();
   463    unsigned rank = viewType.getRank();
   464    for (auto *i : indexings)
   465      if (!i->getType().isa<RangeType>())
   466        rank--;
   467    Type elementType = viewType.getElementType();
   468    result->addTypes({ViewType::get(b->getContext(), elementType, rank)});
   469  }
   470  
   471  static void print(OpAsmPrinter *p, SliceOp op) {
   472    *p << SliceOp::getOperationName() << " " << *op.view() << "[";
   473    p->printOperands(op.indexings());
   474    *p << "] ";
   475    p->printOptionalAttrDict(op.getAttrs());
   476    *p << " : " << op.getBaseViewType();
   477    for (auto indexing : op.indexings()) {
   478      *p << ", " << indexing->getType();
   479    }
   480    *p << ", " << op.getType();
   481  }
   482  
   483  static ParseResult parseSliceOp(OpAsmParser *parser, OperationState *result) {
   484    OpAsmParser::OperandType baseInfo;
   485    SmallVector<OpAsmParser::OperandType, 8> operands;
   486    SmallVector<Type, 8> types;
   487    if (parser->parseOperand(baseInfo) ||
   488        parser->parseOperandList(operands, OpAsmParser::Delimiter::Square) ||
   489        parser->parseOptionalAttributeDict(result->attributes) ||
   490        parser->parseColonTypeList(types))
   491      return failure();
   492  
   493    if (types.size() < 2)
   494      return parser->emitError(parser->getCurrentLocation(),
   495                               "expected at least input and result view types");
   496  
   497    ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back();
   498    return failure(
   499        parser->resolveOperand(baseInfo, types.front(), result->operands) ||
   500        (!operands.empty() &&
   501         parser->resolveOperands(operands, indexingTypes,
   502                                 operands.front().location, result->operands)) ||
   503        parser->addTypeToList(types.back(), result->types));
   504  }
   505  
   506  static LogicalResult verify(SliceOp op) {
   507    unsigned rank = op.getBaseViewRank();
   508    if (rank != llvm::size(op.indexings()))
   509      return op.emitOpError("expected ")
   510             << op.getRank() << " indexings, got " << llvm::size(op.indexings());
   511    unsigned index = 0;
   512    for (auto indexing : op.indexings()) {
   513      if (indexing->getType().isa<IndexType>())
   514        --rank;
   515      ++index;
   516    }
   517    if (op.getRank() != rank)
   518      return op.emitOpError() << "expected rank of the view(" << op.getRank()
   519                              << ") to be the number of ranges(" << rank << ")";
   520    return success();
   521  }
   522  
   523  //===----------------------------------------------------------------------===//
   524  // StoreOp
   525  //===----------------------------------------------------------------------===//
   526  
   527  static void print(OpAsmPrinter *p, linalg::StoreOp op) {
   528    *p << op.getOperationName() << " " << *op.value();
   529    *p << ", " << *op.view() << '[';
   530    p->printOperands(op.indices());
   531    *p << ']';
   532    p->printOptionalAttrDict(op.getAttrs());
   533    *p << " : " << op.getViewType();
   534  }
   535  
   536  static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
   537    OpAsmParser::OperandType storeValueInfo;
   538    OpAsmParser::OperandType viewInfo;
   539    SmallVector<OpAsmParser::OperandType, 4> indexInfo;
   540    ViewType viewType;
   541  
   542    auto affineIntTy = parser->getBuilder().getIndexType();
   543    return failure(
   544        parser->parseOperand(storeValueInfo) || parser->parseComma() ||
   545        parser->parseOperand(viewInfo) ||
   546        parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
   547        parser->parseOptionalAttributeDict(result->attributes) ||
   548        parser->parseColonType(viewType) ||
   549        parser->resolveOperand(storeValueInfo, viewType.getElementType(),
   550                               result->operands) ||
   551        parser->resolveOperand(viewInfo, viewType, result->operands) ||
   552        parser->resolveOperands(indexInfo, affineIntTy, result->operands));
   553  }
   554  
   555  static LogicalResult verify(linalg::StoreOp op) {
   556    if (op.value()->getType() != op.getViewType().getElementType())
   557      return op.emitOpError("expected value type to match view element type");
   558    if (op.getRank() != llvm::size(op.indices()))
   559      return op.emitOpError("expected ")
   560             << op.getRank() << " indices, got " << llvm::size(op.indices());
   561    return success();
   562  }
   563  
   564  //===----------------------------------------------------------------------===//
   565  // SubViewOp
   566  //===----------------------------------------------------------------------===//
   567  
   568  static void print(OpAsmPrinter *p, SubViewOp op) {
   569    *p << op.getOperationName() << " " << *op.getOperand(0) << "[";
   570    auto ranges = op.getRanges();
   571    interleaveComma(ranges, *p, [&p](const SubViewOp::Range &i) {
   572      *p << *i.min << ", " << *i.max << ", " << *i.step;
   573    });
   574    *p << "]";
   575    p->printOptionalAttrDict(op.getAttrs());
   576    *p << " : " << op.getViewType();
   577  }
   578  
   579  static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
   580    OpAsmParser::OperandType inputView, resultView;
   581    Type viewType;
   582    if (parser->parseOperand(inputView))
   583      return failure();
   584  
   585    SmallVector<OpAsmParser::OperandType, 12> ops;
   586    // TODO(ntv) evolve parsing from
   587    //    linalg.subview %0[%1, %2, %3, %4, %5, %6]
   588    // to something resembling
   589    //    linalg.subview %0[%1:%2:%3][%4:%5:%6]
   590    if (parser->parseOperandList(ops, OpAsmParser::Delimiter::Square) ||
   591        parser->parseOptionalAttributeDict(result->attributes) ||
   592        parser->parseColonType(viewType))
   593      return failure();
   594  
   595    auto indexTy = parser->getBuilder().getIndexType();
   596    return failure(
   597        parser->resolveOperand(inputView, viewType, result->operands) ||
   598        parser->resolveOperands(ops, indexTy, result->operands) ||
   599        parser->addTypeToList(viewType, result->types));
   600  }
   601  
   602  //===----------------------------------------------------------------------===//
   603  // TransposeOp
   604  //===----------------------------------------------------------------------===//
   605  void mlir::linalg::TransposeOp::build(Builder *b, OperationState *result,
   606                                        Value *view, AffineMapAttr permutation,
   607                                        ArrayRef<NamedAttribute> attrs) {
   608    // TODO(ntv): once views have static dimensions, compute the permuted type.
   609    build(b, result, view->getType(), view, attrs);
   610    result->addAttribute(TransposeOp::getPermutationAttrName(), permutation);
   611  }
   612  
   613  static void print(OpAsmPrinter *p, TransposeOp op) {
   614    *p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
   615    p->printOptionalAttrDict(op.getAttrs(),
   616                             {TransposeOp::getPermutationAttrName()});
   617    *p << " : " << op.view()->getType();
   618  }
   619  
   620  static ParseResult parseTransposeOp(OpAsmParser *parser,
   621                                      OperationState *result) {
   622    OpAsmParser::OperandType view;
   623    AffineMapAttr permutation;
   624    Type type;
   625    return failure(parser->parseOperand(view) ||
   626                   parser->parseAttribute(permutation,
   627                                          TransposeOp::getPermutationAttrName(),
   628                                          result->attributes) ||
   629                   parser->parseOptionalAttributeDict(result->attributes) ||
   630                   parser->parseColonType(type) ||
   631                   parser->resolveOperand(view, type, result->operands) ||
   632                   parser->addTypeToList(type, result->types));
   633  }
   634  
   635  //===----------------------------------------------------------------------===//
   636  // ViewOp
   637  //===----------------------------------------------------------------------===//
   638  void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
   639                                   Value *buffer, ArrayRef<Value *> ranges,
   640                                   Type resultType,
   641                                   ArrayRef<NamedAttribute> attrs) {
   642    if (!resultType) {
   643      Type elementType = buffer->getType().cast<BufferType>().getElementType();
   644      resultType = ViewType::get(b->getContext(), elementType, ranges.size());
   645    }
   646    build(b, result, resultType, buffer, ranges);
   647    result->addAttributes(attrs);
   648  }
   649  
   650  static void print(OpAsmPrinter *p, ViewOp op) {
   651    *p << op.getOperationName() << " " << *op.buffer() << "[";
   652    interleaveComma(op.ranges(), *p, [&](Value *v) { *p << *v; });
   653    *p << "] ";
   654    p->printOptionalAttrDict(op.getAttrs());
   655    *p << " : " << op.buffer()->getType() << " -> " << op.getType();
   656  }
   657  
   658  static ParseResult parseViewOp(OpAsmParser *parser, OperationState *result) {
   659    OpAsmParser::OperandType bufferInfo;
   660    SmallVector<OpAsmParser::OperandType, 8> rangesInfo;
   661    Type bType, vType;
   662    if (parser->parseOperand(bufferInfo) ||
   663        parser->parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) ||
   664        parser->parseOptionalAttributeDict(result->attributes) ||
   665        parser->parseColon() || parser->parseType(bType) ||
   666        parser->parseArrow() || parser->parseType(vType)) {
   667      return failure();
   668    }
   669  
   670    ViewType viewType = vType.dyn_cast<ViewType>();
   671    if (!viewType)
   672      return parser->emitError(parser->getNameLoc(), "expected view type");
   673    if (viewType.getRank() != rangesInfo.size())
   674      return parser->emitError(parser->getNameLoc(), "expected ")
   675             << viewType.getRank() << " ranges";
   676    return failure(
   677        parser->resolveOperand(bufferInfo, bType, result->operands) ||
   678        (!rangesInfo.empty() &&
   679         parser->resolveOperands(rangesInfo, RangeType::get(vType.getContext()),
   680                                 result->operands)) ||
   681        parser->addTypeToList(viewType, result->types));
   682  }
   683  
   684  //===----------------------------------------------------------------------===//
   685  // YieldOp
   686  //===----------------------------------------------------------------------===//
   687  
   688  static void print(OpAsmPrinter *p, YieldOp op) {
   689    *p << op.getOperationName();
   690    if (op.getNumOperands() > 0) {
   691      *p << ' ';
   692      p->printOperands(op.operand_begin(), op.operand_end());
   693    }
   694    p->printOptionalAttrDict(op.getAttrs());
   695    if (op.getNumOperands() > 0) {
   696      *p << " : ";
   697      interleaveComma(op.getOperands(), *p,
   698                      [&](Value *e) { p->printType(e->getType()); });
   699    }
   700  }
   701  
   702  static ParseResult parseYieldOp(OpAsmParser *parser, OperationState *result) {
   703    SmallVector<OpAsmParser::OperandType, 2> opInfo;
   704    SmallVector<Type, 2> types;
   705    llvm::SMLoc loc = parser->getCurrentLocation();
   706    return failure(parser->parseOperandList(opInfo) ||
   707                   parser->parseOptionalAttributeDict(result->attributes) ||
   708                   (!opInfo.empty() && parser->parseColonTypeList(types)) ||
   709                   parser->resolveOperands(opInfo, types, loc, result->operands));
   710  }
   711  
   712  static LogicalResult verify(YieldOp op) {
   713    auto *parentOp = op.getParentOp();
   714    if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
   715      return op.emitOpError("op expected single non-empty parent region");
   716  
   717    auto genericOp = dyn_cast<GenericOp>(parentOp);
   718    if (!genericOp)
   719      return op.emitOpError("op expected '")
   720             << GenericOp::getOperationName() << "' parent op";
   721  
   722    // The operand number and types must match the view element types.
   723    auto nOutputViews = genericOp.getNumOutputs();
   724    if (op.getNumOperands() != nOutputViews)
   725      return op.emitOpError("op expected ")
   726             << nOutputViews << " operand to match enclosing linalg.generic op";
   727  
   728    for (unsigned i = 0; i != nOutputViews; ++i) {
   729      auto elementType = genericOp.getOutputViewType(i).getElementType();
   730      if (op.getOperand(i)->getType() != elementType)
   731        return op.emitError("type of return operand ")
   732               << i << " (" << op.getOperand(i)->getType()
   733               << ") doesn't match view element type (" << elementType << ")";
   734    }
   735    return success();
   736  }
   737  
   738  /////// Operations corresponding to library calls defined with Tablegen ////////
   739  // For such operations correspond to library calls (i.e. defined in
   740  // LinalgLibraryOps.td), we define an overloaded `print` function and a
   741  // parse`className` function.
   742  
   743  // A LinalgLibraryOp prints as:
   744  //
   745  // ```{.mlir}
   746  //   concrete_op_name (ssa-inputs, ssa-outputs) : view-types
   747  // ```
   748  //
   749  // for example:
   750  //
   751  // ```
   752  //   linalg.matmul(%0, %1, %2) :
   753  //     !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
   754  // ```
   755  //
   756  // Where %0, %1 and %2 are ssa-values of type ViewType.
   757  static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) {
   758    assert(op->getAbstractOperation() && "unregistered operation");
   759    *p << op->getName().getStringRef() << "(";
   760    interleave(
   761        op->getOperands().begin(), op->getOperands().end(),
   762        [&](Value *v) { *p << *v; }, [&]() { *p << ", "; });
   763    *p << ")";
   764    p->printOptionalAttrDict(op->getAttrs());
   765    *p << " : ";
   766    interleave(
   767        op->getOperands().begin(), op->getOperands().end(),
   768        [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; });
   769  }
   770  
   771  static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
   772                                          OperationState *result) {
   773    SmallVector<OpAsmParser::OperandType, 3> ops;
   774    SmallVector<Type, 3> types;
   775    return failure(parser->parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
   776                   parser->parseOptionalAttributeDict(result->attributes) ||
   777                   parser->parseColonTypeList(types) ||
   778                   parser->resolveOperands(ops, types, parser->getNameLoc(),
   779                                           result->operands));
   780  }
   781  
   782  static LogicalResult verify(FillOp op) {
   783    auto viewType = op.getOutputViewType(0);
   784    auto fillType = op.getValue()->getType();
   785    if (viewType.getElementType() != fillType)
   786      return op.emitOpError("expects fill type to match view elemental type");
   787    return success();
   788  }
   789  
   790  static LogicalResult verify(CopyOp op) {
   791    auto outputViewType = op.getOutputViewType(0);
   792    auto inputViewType = op.getInputViewType(0);
   793    if (inputViewType.getElementType() != outputViewType.getElementType())
   794      return op.emitOpError("expects views of the same type");
   795    if (inputViewType.getRank() != outputViewType.getRank())
   796      return op.emitOpError("expects views of the same rank");
   797    auto rank = op.getNumParallelLoops();
   798    auto inputPermutationMap = op.inputPermutation();
   799    if (inputPermutationMap) {
   800      if (inputPermutationMap->getNumInputs() != rank)
   801        return op.emitOpError("expects optional input_permutation map of rank ")
   802               << rank;
   803      if (!inputPermutationMap->isPermutation())
   804        return op.emitOpError(
   805            "expects optional input_permutation map to be a permutation");
   806    }
   807    auto outputPermutationMap = op.outputPermutation();
   808    if (outputPermutationMap) {
   809      if (outputPermutationMap->getNumInputs() != rank)
   810        return op.emitOpError("expects optional output_permutation map of rank ")
   811               << rank;
   812      if (!outputPermutationMap->isPermutation())
   813        return op.emitOpError(
   814            "expects optional output_permutation map to be a permutation");
   815    }
   816    if (rank == 0 && inputPermutationMap)
   817      return op.emitOpError("expected no input permutation when rank == 0");
   818    if (rank == 0 && outputPermutationMap)
   819      return op.emitOpError("expected no output permutation when rank == 0");
   820    return success();
   821  }
   822  
   823  static LogicalResult
   824  verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) {
   825    auto strideOrDilation = isStride ? "stride" : "dilation";
   826    if (attrs.size() != op.getNumWindowLoops())
   827      return op.emitOpError("expects num ")
   828             << strideOrDilation
   829             << "s equal to number of window dimensions: " << attrs.size()
   830             << " vs " << op.getNumWindowLoops();
   831    return success();
   832  }
   833  
   834  static LogicalResult verify(ConvOp op) {
   835    auto oType = op.output()->getType().cast<ViewType>();
   836    auto fType = op.filter()->getType().cast<ViewType>();
   837    auto iType = op.input()->getType().cast<ViewType>();
   838    if (oType.getElementType() != iType.getElementType() ||
   839        oType.getElementType() != fType.getElementType())
   840      return op.emitOpError("expects view elemental types to match");
   841    if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
   842      return op.emitOpError("expects view ranks to match");
   843    if (auto strides = op.strides()) {
   844      if (failed(
   845              verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
   846        return failure();
   847    }
   848    if (auto dilations = op.dilations()) {
   849      if (failed(verifyStrideOrDilation(op, dilations->getValue(),
   850                                        /*isStride=*/false)))
   851        return failure();
   852    }
   853    return success();
   854  }
   855  
   856  llvm::raw_ostream &mlir::linalg::operator<<(llvm::raw_ostream &os,
   857                                              SubViewOp::Range &range) {
   858    return os << "range " << *range.min << ":" << *range.max << ":"
   859              << *range.step;
   860  }
   861  
   862  namespace mlir {
   863  namespace linalg {
   864  
   865  #include "mlir/Dialect/Linalg/IR/LinalgLibraryOpInterfaces.cpp.inc"
   866  
   867  #define GET_OP_CLASSES
   868  #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
   869  
   870  #define GET_OP_CLASSES
   871  #include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
   872  
   873  } // namespace linalg
   874  } // namespace mlir
   875  
   876  static AffineMap extractOrIdentityMap(llvm::Optional<AffineMap> maybeMap,
   877                                        unsigned rank, MLIRContext *context) {
   878    if (maybeMap)
   879      return maybeMap.getValue();
   880    if (rank == 0)
   881      return AffineMap();
   882    return AffineMap::getMultiDimIdentityMap(rank, context);
   883  }
   884  
   885  // Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num)
   886  // and increments `curIdx` to `curIdx + num`.
   887  static SmallVector<AffineExpr, 4>
   888  makeAffineDimExprs(unsigned num, unsigned &curIdx, MLIRContext *context) {
   889    SmallVector<AffineExpr, 4> res;
   890    res.reserve(num);
   891    for (unsigned i = 0; i < num; ++i)
   892      res.push_back(getAffineDimExpr(curIdx++, context));
   893    return res;
   894  }
   895  
   896  static SmallVector<AffineExpr, 4>
   897  weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> a,
   898                         ArrayRef<AffineExpr> b) {
   899    assert(a.size() == b.size());
   900    SmallVector<AffineExpr, 4> res;
   901    res.reserve(a.size());
   902    for (unsigned i = 0, e = a.size(); i < e; ++i) {
   903      res.push_back(op.getStride(i) * a[i] + op.getDilation(i) * b[i]);
   904    }
   905    return res;
   906  }
   907  
   908  static SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
   909                                           ArrayRef<AffineExpr> b) {
   910    SmallVector<AffineExpr, 4> res;
   911    res.reserve(a.size() + b.size());
   912    res.assign(a.begin(), a.end());
   913    res.append(b.begin(), b.end());
   914    return res;
   915  }
   916  
   917  // Note: both functions below would completely disappear with a simple tensor
   918  // kernel language.
   919  //
   920  // Ideally this should all be Tablegen'd but there is no good story for
   921  // AffineMap for now.
   922  SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
   923    MLIRContext *context = op->getContext();
   924    if (auto copyOp = dyn_cast<CopyOp>(op)) {
   925      // I(input_perm(ivs)) -> O(output_perm(ivs))
   926      auto maybeInputMap = copyOp.inputPermutation();
   927      auto maybeOutputMap = copyOp.outputPermutation();
   928      unsigned inputRank = copyOp.getInputViewType(0).getRank();
   929      unsigned outputRank = copyOp.getOutputViewType(0).getRank();
   930      return SmallVector<AffineMap, 4>{
   931          extractOrIdentityMap(maybeInputMap, inputRank, context),
   932          extractOrIdentityMap(maybeOutputMap, outputRank, context)};
   933    }
   934    if (auto fillOp = dyn_cast<FillOp>(op)) {
   935      // filling_value -> O(ivs)
   936      unsigned rank = fillOp.getNumParallelLoops();
   937      return SmallVector<AffineMap, 4>{
   938          extractOrIdentityMap(llvm::None, rank, context)};
   939    }
   940    auto i = getAffineDimExpr(0, context);
   941    auto j = getAffineDimExpr(1, context);
   942    auto k = getAffineDimExpr(2, context);
   943    if (isa<DotOp>(op))
   944      // A(r_i) * B(r_i) -> C()
   945      return SmallVector<AffineMap, 4>{AffineMap::get(1, 0, {i}),
   946                                       AffineMap::get(1, 0, {i}), AffineMap()};
   947    if (isa<MatvecOp>(op))
   948      //   A(i, r_j) * B(r_j) -> C(i)
   949      return SmallVector<AffineMap, 4>{AffineMap::get(2, 0, {i, j}),
   950                                       AffineMap::get(2, 0, {j}),
   951                                       AffineMap::get(2, 0, {i})};
   952    if (isa<MatmulOp>(op))
   953      //   A(i, r_k) * B(r_k, j) -> C(i, j)
   954      return SmallVector<AffineMap, 4>{AffineMap::get(3, 0, {i, k}),
   955                                       AffineMap::get(3, 0, {k, j}),
   956                                       AffineMap::get(3, 0, {i, j})};
   957    if (auto convOp = dyn_cast<ConvOp>(op)) {
   958      //   F(z0, ..., zN-1, q, k) * I(b, x0 + z0, ..., xN-1 + zN-1, q) ->
   959      //     O(b, x0, ..., xN-1, k)
   960      // for N equal to `nWindow`.
   961      auto nWin = convOp.getNumWindowLoops();
   962      assert(nWin > 0 && "expected at least one window dimension");
   963      unsigned idx = 0;
   964      // In the following, AffineDimExprs are indexed in loop order:
   965      //   [ b, xs, k,           q,                     zs]
   966      //    parallels     non-window reductions     windows
   967      //
   968      // Parallel dims are exactly the dimensions indexing `output`:
   969      //     output[b, x[0], ..., x[N-1], k]; i.e.
   970      //  * batch dimensions (bs with #bs = 1 for now)
   971      //  * "image" dimensions (xs with #xs = #zs = output_rank - #bs - #ks)
   972      //  * output filter dimensions (ks with #ks = 1 for now)
   973      auto bs = makeAffineDimExprs(convOp.getNumBatchDimensions(), idx, context);
   974      auto xs = makeAffineDimExprs(nWin, idx, context);
   975      auto ks = makeAffineDimExprs(convOp.getNumOutputFeatureDimensions(), idx,
   976                                   context);
   977      // Non-window reduction dim: sum_{z[0], ..., z[N-1], q}
   978      auto qs =
   979          makeAffineDimExprs(convOp.getNumInputFeatureDimensions(), idx, context);
   980      // Window reduction dims: sum_{z[0], ..., z[N-1], q}
   981      auto zs = makeAffineDimExprs(nWin, idx, context);
   982      // Construct the weighedSum expression.
   983      auto ws = weightedConvInputIndex(convOp, xs, zs);
   984      return SmallVector<AffineMap, 4>{
   985          // filter[z[0], ..., z[N-1], q, k]
   986          AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
   987          // input[b,
   988          //       x[0]*s[0] + d[0]*z[0], ..., x[N-1]*s[N-1] + d[N-1]*z[N-1],
   989          //       q]
   990          AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
   991          // output[b, x[0], ..., x[N-1], k]
   992          AffineMap::get(idx, 0, concat(concat(bs, xs), ks))};
   993    } else if (auto genericOp = dyn_cast<GenericOp>(op)) {
   994      SmallVector<AffineMap, 4> res;
   995      unsigned nViews = genericOp.getNumInputsAndOutputs();
   996      res.reserve(nViews);
   997      for (unsigned i = 0, e = nViews; i < e; ++i) {
   998        res.push_back(genericOp.getIndexingMap(i));
   999      }
  1000      return res;
  1001    }
  1002    llvm_unreachable("Missing loopToOperandRangesMaps for op");
  1003  }
  1004  
  1005  static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
  1006    if (auto view = t.dyn_cast<ViewType>()) {
  1007      ss << "view";
  1008      for (unsigned i = 0, e = view.getRank(); i < e; ++i)
  1009        ss << "x";
  1010      appendMangledType(ss, view.getElementType());
  1011    } else if (auto vec = t.dyn_cast<VectorType>()) {
  1012      ss << "vector";
  1013      interleave(
  1014          vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
  1015      appendMangledType(ss, vec.getElementType());
  1016    } else if (t.isIntOrIndexOrFloat()) {
  1017      ss << t;
  1018    } else {
  1019      llvm_unreachable("Invalid type for linalg library name mangling");
  1020    }
  1021  }
  1022  
  1023  std::string mlir::linalg::generateLibraryCallName(Operation *op) {
  1024    assert(isa<LinalgOp>(op));
  1025    std::string name(op->getName().getStringRef().str());
  1026    name.reserve(128);
  1027    std::replace(name.begin(), name.end(), '.', '_');
  1028    llvm::raw_string_ostream ss(name);
  1029    ss << "_";
  1030    auto types = op->getOperandTypes();
  1031    interleave(
  1032        types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
  1033        [&]() { ss << "_"; });
  1034    return ss.str();
  1035  }