github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/AffineOps/AffineOps.cpp (about)

     1  //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/AffineOps/AffineOps.h"
    19  #include "mlir/Dialect/StandardOps/Ops.h"
    20  #include "mlir/IR/Block.h"
    21  #include "mlir/IR/Builders.h"
    22  #include "mlir/IR/Function.h"
    23  #include "mlir/IR/IntegerSet.h"
    24  #include "mlir/IR/Matchers.h"
    25  #include "mlir/IR/OpImplementation.h"
    26  #include "mlir/IR/PatternMatch.h"
    27  #include "llvm/ADT/SetVector.h"
    28  #include "llvm/ADT/SmallBitVector.h"
    29  #include "llvm/Support/Debug.h"
    30  using namespace mlir;
    31  using llvm::dbgs;
    32  
    33  #define DEBUG_TYPE "affine-analysis"
    34  
    35  //===----------------------------------------------------------------------===//
    36  // AffineOpsDialect
    37  //===----------------------------------------------------------------------===//
    38  
    39  AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
    40      : Dialect(getDialectNamespace(), context) {
    41    addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp,
    42                  AffineStoreOp,
    43  #define GET_OP_LIST
    44  #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
    45                  >();
    46  }
    47  
    48  /// A utility function to check if a given region is attached to a function.
    49  static bool isFunctionRegion(Region *region) {
    50    return llvm::isa<FuncOp>(region->getParentOp());
    51  }
    52  
    53  /// A utility function to check if a value is defined at the top level of a
    54  /// function. A value defined at the top level is always a valid symbol.
    55  bool mlir::isTopLevelSymbol(Value *value) {
    56    if (auto *arg = dyn_cast<BlockArgument>(value))
    57      return isFunctionRegion(arg->getOwner()->getParent());
    58    return isFunctionRegion(value->getDefiningOp()->getParentRegion());
    59  }
    60  
    61  // Value can be used as a dimension id if it is valid as a symbol, or
    62  // it is an induction variable, or it is a result of affine apply operation
    63  // with dimension id arguments.
    64  bool mlir::isValidDim(Value *value) {
    65    // The value must be an index type.
    66    if (!value->getType().isIndex())
    67      return false;
    68  
    69    if (auto *op = value->getDefiningOp()) {
    70      // Top level operation or constant operation is ok.
    71      if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
    72        return true;
    73      // Affine apply operation is ok if all of its operands are ok.
    74      if (auto applyOp = dyn_cast<AffineApplyOp>(op))
    75        return applyOp.isValidDim();
    76      // The dim op is okay if its operand memref/tensor is defined at the top
    77      // level.
    78      if (auto dimOp = dyn_cast<DimOp>(op))
    79        return isTopLevelSymbol(dimOp.getOperand());
    80      return false;
    81    }
    82    // This value is a block argument (which also includes 'affine.for' loop IVs).
    83    return true;
    84  }
    85  
    86  // Value can be used as a symbol if it is a constant, or it is defined at
    87  // the top level, or it is a result of affine apply operation with symbol
    88  // arguments.
    89  bool mlir::isValidSymbol(Value *value) {
    90    // The value must be an index type.
    91    if (!value->getType().isIndex())
    92      return false;
    93  
    94    if (auto *op = value->getDefiningOp()) {
    95      // Top level operation or constant operation is ok.
    96      if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
    97        return true;
    98      // Affine apply operation is ok if all of its operands are ok.
    99      if (auto applyOp = dyn_cast<AffineApplyOp>(op))
   100        return applyOp.isValidSymbol();
   101      // The dim op is okay if its operand memref/tensor is defined at the top
   102      // level.
   103      if (auto dimOp = dyn_cast<DimOp>(op))
   104        return isTopLevelSymbol(dimOp.getOperand());
   105      return false;
   106    }
   107    // Otherwise, check that the value is a top level symbol.
   108    return isTopLevelSymbol(value);
   109  }
   110  
   111  // Returns true if 'value' is a valid index to an affine operation (e.g.
   112  // affine.load, affine.store, affine.dma_start, affine.dma_wait).
   113  // Returns false otherwise.
   114  static bool isValidAffineIndexOperand(Value *value) {
   115    return isValidDim(value) || isValidSymbol(value);
   116  }
   117  
   118  /// Utility function to verify that a set of operands are valid dimension and
   119  /// symbol identifiers. The operands should be layed out such that the dimension
   120  /// operands are before the symbol operands. This function returns failure if
   121  /// there was an invalid operand. An operation is provided to emit any necessary
   122  /// errors.
   123  template <typename OpTy>
   124  static LogicalResult
   125  verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
   126                                unsigned numDims) {
   127    unsigned opIt = 0;
   128    for (auto *operand : operands) {
   129      if (opIt++ < numDims) {
   130        if (!isValidDim(operand))
   131          return op.emitOpError("operand cannot be used as a dimension id");
   132      } else if (!isValidSymbol(operand)) {
   133        return op.emitOpError("operand cannot be used as a symbol");
   134      }
   135    }
   136    return success();
   137  }
   138  
   139  //===----------------------------------------------------------------------===//
   140  // AffineApplyOp
   141  //===----------------------------------------------------------------------===//
   142  
   143  void AffineApplyOp::build(Builder *builder, OperationState *result,
   144                            AffineMap map, ArrayRef<Value *> operands) {
   145    result->addOperands(operands);
   146    result->types.append(map.getNumResults(), builder->getIndexType());
   147    result->addAttribute("map", builder->getAffineMapAttr(map));
   148  }
   149  
   150  ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
   151    auto &builder = parser->getBuilder();
   152    auto affineIntTy = builder.getIndexType();
   153  
   154    AffineMapAttr mapAttr;
   155    unsigned numDims;
   156    if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
   157        parseDimAndSymbolList(parser, result->operands, numDims) ||
   158        parser->parseOptionalAttributeDict(result->attributes))
   159      return failure();
   160    auto map = mapAttr.getValue();
   161  
   162    if (map.getNumDims() != numDims ||
   163        numDims + map.getNumSymbols() != result->operands.size()) {
   164      return parser->emitError(parser->getNameLoc(),
   165                               "dimension or symbol index mismatch");
   166    }
   167  
   168    result->types.append(map.getNumResults(), affineIntTy);
   169    return success();
   170  }
   171  
   172  void AffineApplyOp::print(OpAsmPrinter *p) {
   173    *p << "affine.apply " << getAttr("map");
   174    printDimAndSymbolList(operand_begin(), operand_end(),
   175                          getAffineMap().getNumDims(), p);
   176    p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"});
   177  }
   178  
   179  LogicalResult AffineApplyOp::verify() {
   180    // Check that affine map attribute was specified.
   181    auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
   182    if (!affineMapAttr)
   183      return emitOpError("requires an affine map");
   184  
   185    // Check input and output dimensions match.
   186    auto map = affineMapAttr.getValue();
   187  
   188    // Verify that operand count matches affine map dimension and symbol count.
   189    if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
   190      return emitOpError(
   191          "operand count and affine map dimension and symbol count must match");
   192  
   193    // Verify that all operands are of `index` type.
   194    for (Type t : getOperandTypes()) {
   195      if (!t.isIndex())
   196        return emitOpError("operands must be of type 'index'");
   197    }
   198  
   199    if (!getResult()->getType().isIndex())
   200      return emitOpError("result must be of type 'index'");
   201  
   202    // Verify that the operands are valid dimension and symbol identifiers.
   203    if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
   204                                             map.getNumDims())))
   205      return failure();
   206  
   207    // Verify that the map only produces one result.
   208    if (map.getNumResults() != 1)
   209      return emitOpError("mapping must produce one value");
   210  
   211    return success();
   212  }
   213  
   214  // The result of the affine apply operation can be used as a dimension id if it
   215  // is a CFG value or if it is an Value, and all the operands are valid
   216  // dimension ids.
   217  bool AffineApplyOp::isValidDim() {
   218    return llvm::all_of(getOperands(),
   219                        [](Value *op) { return mlir::isValidDim(op); });
   220  }
   221  
   222  // The result of the affine apply operation can be used as a symbol if it is
   223  // a CFG value or if it is an Value, and all the operands are symbols.
   224  bool AffineApplyOp::isValidSymbol() {
   225    return llvm::all_of(getOperands(),
   226                        [](Value *op) { return mlir::isValidSymbol(op); });
   227  }
   228  
   229  OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
   230    auto map = getAffineMap();
   231  
   232    // Fold dims and symbols to existing values.
   233    auto expr = map.getResult(0);
   234    if (auto dim = expr.dyn_cast<AffineDimExpr>())
   235      return getOperand(dim.getPosition());
   236    if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
   237      return getOperand(map.getNumDims() + sym.getPosition());
   238  
   239    // Otherwise, default to folding the map.
   240    SmallVector<Attribute, 1> result;
   241    if (failed(map.constantFold(operands, result)))
   242      return {};
   243    return result[0];
   244  }
   245  
   246  namespace {
   247  /// An `AffineApplyNormalizer` is a helper class that is not visible to the user
   248  /// and supports renumbering operands of AffineApplyOp. This acts as a
   249  /// reindexing map of Value* to positional dims or symbols and allows
   250  /// simplifications such as:
   251  ///
   252  /// ```mlir
   253  ///    %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0)
   254  /// ```
   255  ///
   256  /// into:
   257  ///
   258  /// ```mlir
   259  ///    %1 = affine.apply () -> (0)
   260  /// ```
   261  struct AffineApplyNormalizer {
   262    AffineApplyNormalizer(AffineMap map, ArrayRef<Value *> operands);
   263  
   264    /// Returns the AffineMap resulting from normalization.
   265    AffineMap getAffineMap() { return affineMap; }
   266  
   267    SmallVector<Value *, 8> getOperands() {
   268      SmallVector<Value *, 8> res(reorderedDims);
   269      res.append(concatenatedSymbols.begin(), concatenatedSymbols.end());
   270      return res;
   271    }
   272  
   273  private:
   274    /// Helper function to insert `v` into the coordinate system of the current
   275    /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding
   276    /// renumbered position.
   277    AffineDimExpr renumberOneDim(Value *v);
   278  
   279    /// Given an `other` normalizer, this rewrites `other.affineMap` in the
   280    /// coordinate system of the current AffineApplyNormalizer.
   281    /// Returns the rewritten AffineMap and updates the dims and symbols of
   282    /// `this`.
   283    AffineMap renumber(const AffineApplyNormalizer &other);
   284  
   285    /// Maps of Value* to position in `affineMap`.
   286    DenseMap<Value *, unsigned> dimValueToPosition;
   287  
   288    /// Ordered dims and symbols matching positional dims and symbols in
   289    /// `affineMap`.
   290    SmallVector<Value *, 8> reorderedDims;
   291    SmallVector<Value *, 8> concatenatedSymbols;
   292  
   293    AffineMap affineMap;
   294  
   295    /// Used with RAII to control the depth at which AffineApply are composed
   296    /// recursively. Only accepts depth 1 for now to allow a behavior where a
   297    /// newly composed AffineApplyOp does not increase the length of the chain of
   298    /// AffineApplyOps. Full composition is implemented iteratively on top of
   299    /// this behavior.
   300    static unsigned &affineApplyDepth() {
   301      static thread_local unsigned depth = 0;
   302      return depth;
   303    }
   304    static constexpr unsigned kMaxAffineApplyDepth = 1;
   305  
   306    AffineApplyNormalizer() { affineApplyDepth()++; }
   307  
   308  public:
   309    ~AffineApplyNormalizer() { affineApplyDepth()--; }
   310  };
   311  } // end anonymous namespace.
   312  
   313  AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
   314    DenseMap<Value *, unsigned>::iterator iterPos;
   315    bool inserted = false;
   316    std::tie(iterPos, inserted) =
   317        dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
   318    if (inserted) {
   319      reorderedDims.push_back(v);
   320    }
   321    return getAffineDimExpr(iterPos->second, v->getContext())
   322        .cast<AffineDimExpr>();
   323  }
   324  
   325  AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
   326    SmallVector<AffineExpr, 8> dimRemapping;
   327    for (auto *v : other.reorderedDims) {
   328      auto kvp = other.dimValueToPosition.find(v);
   329      if (dimRemapping.size() <= kvp->second)
   330        dimRemapping.resize(kvp->second + 1);
   331      dimRemapping[kvp->second] = renumberOneDim(kvp->first);
   332    }
   333    unsigned numSymbols = concatenatedSymbols.size();
   334    unsigned numOtherSymbols = other.concatenatedSymbols.size();
   335    SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
   336    for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
   337      symRemapping[idx] =
   338          getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
   339    }
   340    concatenatedSymbols.insert(concatenatedSymbols.end(),
   341                               other.concatenatedSymbols.begin(),
   342                               other.concatenatedSymbols.end());
   343    auto map = other.affineMap;
   344    return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
   345                                     dimRemapping.size(), symRemapping.size());
   346  }
   347  
   348  // Gather the positions of the operands that are produced by an AffineApplyOp.
   349  static llvm::SetVector<unsigned>
   350  indicesFromAffineApplyOp(ArrayRef<Value *> operands) {
   351    llvm::SetVector<unsigned> res;
   352    for (auto en : llvm::enumerate(operands))
   353      if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp()))
   354        res.insert(en.index());
   355    return res;
   356  }
   357  
   358  // Support the special case of a symbol coming from an AffineApplyOp that needs
   359  // to be composed into the current AffineApplyOp.
   360  // This case is handled by rewriting all such symbols into dims for the purpose
   361  // of allowing mathematical AffineMap composition.
   362  // Returns an AffineMap where symbols that come from an AffineApplyOp have been
   363  // rewritten as dims and are ordered after the original dims.
   364  // TODO(andydavis,ntv): This promotion makes AffineMap lose track of which
   365  // symbols are represented as dims. This loss is static but can still be
   366  // recovered dynamically (with `isValidSymbol`). Still this is annoying for the
   367  // semi-affine map case. A dynamic canonicalization of all dims that are valid
   368  // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
   369  // results in better simplifications and foldings. But we should evaluate
   370  // whether this behavior is what we really want after using more.
   371  static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
   372                                                ArrayRef<Value *> symbols) {
   373    if (symbols.empty()) {
   374      return map;
   375    }
   376  
   377    // Sanity check on symbols.
   378    for (auto *sym : symbols) {
   379      assert(isValidSymbol(sym) && "Expected only valid symbols");
   380      (void)sym;
   381    }
   382  
   383    // Extract the symbol positions that come from an AffineApplyOp and
   384    // needs to be rewritten as dims.
   385    auto symPositions = indicesFromAffineApplyOp(symbols);
   386    if (symPositions.empty()) {
   387      return map;
   388    }
   389  
   390    // Create the new map by replacing each symbol at pos by the next new dim.
   391    unsigned numDims = map.getNumDims();
   392    unsigned numSymbols = map.getNumSymbols();
   393    unsigned numNewDims = 0;
   394    unsigned numNewSymbols = 0;
   395    SmallVector<AffineExpr, 8> symReplacements(numSymbols);
   396    for (unsigned i = 0; i < numSymbols; ++i) {
   397      symReplacements[i] =
   398          symPositions.count(i) > 0
   399              ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
   400              : getAffineSymbolExpr(numNewSymbols++, map.getContext());
   401    }
   402    assert(numSymbols >= numNewDims);
   403    AffineMap newMap = map.replaceDimsAndSymbols(
   404        {}, symReplacements, numDims + numNewDims, numNewSymbols);
   405  
   406    return newMap;
   407  }
   408  
   409  /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
   410  /// keep a correspondence between the mathematical `map` and the `operands` of
   411  /// a given AffineApplyOp. This correspondence is maintained by iterating over
   412  /// the operands and forming an `auxiliaryMap` that can be composed
   413  /// mathematically with `map`. To keep this correspondence in cases where
   414  /// symbols are produced by affine.apply operations, we perform a local rewrite
   415  /// of symbols as dims.
   416  ///
   417  /// Rationale for locally rewriting symbols as dims:
   418  /// ================================================
   419  /// The mathematical composition of AffineMap must always concatenate symbols
   420  /// because it does not have enough information to do otherwise. For example,
   421  /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
   422  /// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
   423  ///
   424  /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
   425  /// applied to the same mlir::Value* for both s0 and s1.
   426  /// As a consequence mathematical composition of AffineMap always concatenates
   427  /// symbols.
   428  ///
   429  /// When AffineMaps are used in AffineApplyOp however, they may specify
   430  /// composition via symbols, which is ambiguous mathematically. This corner case
   431  /// is handled by locally rewriting such symbols that come from AffineApplyOp
   432  /// into dims and composing through dims.
   433  /// TODO(andydavis, ntv): Composition via symbols comes at a significant code
   434  /// complexity. Alternatively we should investigate whether we want to
   435  /// explicitly disallow symbols coming from affine.apply and instead force the
   436  /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
   437  /// extra API calls for such uses, which haven't popped up until now) and the
   438  /// benefit potentially big: simpler and more maintainable code for a
   439  /// non-trivial, recursive, procedure.
   440  AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
   441                                               ArrayRef<Value *> operands)
   442      : AffineApplyNormalizer() {
   443    static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
   444    assert(map.getNumInputs() == operands.size() &&
   445           "number of operands does not match the number of map inputs");
   446  
   447    LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
   448  
   449    // Promote symbols that come from an AffineApplyOp to dims by rewriting the
   450    // map to always refer to:
   451    //   (dims, symbols coming from AffineApplyOp, other symbols).
   452    // The order of operands can remain unchanged.
   453    // This is a simplification that relies on 2 ordering properties:
   454    //   1. rewritten symbols always appear after the original dims in the map;
   455    //   2. operands are traversed in order and either dispatched to:
   456    //      a. auxiliaryExprs (dims and symbols rewritten as dims);
   457    //      b. concatenatedSymbols (all other symbols)
   458    // This allows operand order to remain unchanged.
   459    unsigned numDimsBeforeRewrite = map.getNumDims();
   460    map = promoteComposedSymbolsAsDims(map,
   461                                       operands.take_back(map.getNumSymbols()));
   462  
   463    LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
   464  
   465    SmallVector<AffineExpr, 8> auxiliaryExprs;
   466    bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
   467    // We fully spell out the 2 cases below. In this particular instance a little
   468    // code duplication greatly improves readability.
   469    // Note that the first branch would disappear if we only supported full
   470    // composition (i.e. infinite kMaxAffineApplyDepth).
   471    if (!furtherCompose) {
   472      // 1. Only dispatch dims or symbols.
   473      for (auto en : llvm::enumerate(operands)) {
   474        auto *t = en.value();
   475        assert(t->getType().isIndex());
   476        bool isDim = (en.index() < map.getNumDims());
   477        if (isDim) {
   478          // a. The mathematical composition of AffineMap composes dims.
   479          auxiliaryExprs.push_back(renumberOneDim(t));
   480        } else {
   481          // b. The mathematical composition of AffineMap concatenates symbols.
   482          //    We do the same for symbol operands.
   483          concatenatedSymbols.push_back(t);
   484        }
   485      }
   486    } else {
   487      assert(numDimsBeforeRewrite <= operands.size());
   488      // 2. Compose AffineApplyOps and dispatch dims or symbols.
   489      for (unsigned i = 0, e = operands.size(); i < e; ++i) {
   490        auto *t = operands[i];
   491        auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp());
   492        if (affineApply) {
   493          // a. Compose affine.apply operations.
   494          LLVM_DEBUG(affineApply.getOperation()->print(
   495              dbgs() << "\nCompose AffineApplyOp recursively: "));
   496          AffineMap affineApplyMap = affineApply.getAffineMap();
   497          SmallVector<Value *, 8> affineApplyOperands(
   498              affineApply.getOperands().begin(), affineApply.getOperands().end());
   499          AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
   500  
   501          LLVM_DEBUG(normalizer.affineMap.print(
   502              dbgs() << "\nRenumber into current normalizer: "));
   503  
   504          auto renumberedMap = renumber(normalizer);
   505  
   506          LLVM_DEBUG(
   507              renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
   508  
   509          auxiliaryExprs.push_back(renumberedMap.getResult(0));
   510        } else {
   511          if (i < numDimsBeforeRewrite) {
   512            // b. The mathematical composition of AffineMap composes dims.
   513            auxiliaryExprs.push_back(renumberOneDim(t));
   514          } else {
   515            // c. The mathematical composition of AffineMap concatenates symbols.
   516            //    We do the same for symbol operands.
   517            concatenatedSymbols.push_back(t);
   518          }
   519        }
   520      }
   521    }
   522  
   523    // Early exit if `map` is already composed.
   524    if (auxiliaryExprs.empty()) {
   525      affineMap = map;
   526      return;
   527    }
   528  
   529    assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
   530           "Unexpected number of concatenated symbols");
   531    auto numDims = dimValueToPosition.size();
   532    auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
   533    auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs);
   534  
   535    LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
   536    LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
   537    LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
   538  
   539    // TODO(andydavis,ntv): Disabling simplification results in major speed gains.
   540    // Another option is to cache the results as it is expected a lot of redundant
   541    // work is performed in practice.
   542    affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
   543  
   544    LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
   545    LLVM_DEBUG(dbgs() << "\n");
   546  }
   547  
   548  /// Implements `map` and `operands` composition and simplification to support
   549  /// `makeComposedAffineApply`. This can be called to achieve the same effects
   550  /// on `map` and `operands` without creating an AffineApplyOp that needs to be
   551  /// immediately deleted.
   552  static void composeAffineMapAndOperands(AffineMap *map,
   553                                          SmallVectorImpl<Value *> *operands) {
   554    AffineApplyNormalizer normalizer(*map, *operands);
   555    auto normalizedMap = normalizer.getAffineMap();
   556    auto normalizedOperands = normalizer.getOperands();
   557    canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
   558    *map = normalizedMap;
   559    *operands = normalizedOperands;
   560    assert(*map);
   561  }
   562  
   563  void mlir::fullyComposeAffineMapAndOperands(
   564      AffineMap *map, SmallVectorImpl<Value *> *operands) {
   565    while (llvm::any_of(*operands, [](Value *v) {
   566      return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp());
   567    })) {
   568      composeAffineMapAndOperands(map, operands);
   569    }
   570  }
   571  
   572  AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
   573                                              AffineMap map,
   574                                              ArrayRef<Value *> operands) {
   575    AffineMap normalizedMap = map;
   576    SmallVector<Value *, 8> normalizedOperands(operands.begin(), operands.end());
   577    composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
   578    assert(normalizedMap);
   579    return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
   580  }
   581  
   582  // A symbol may appear as a dim in affine.apply operations. This function
   583  // canonicalizes dims that are valid symbols into actual symbols.
   584  static void
   585  canonicalizePromotedSymbols(AffineMap *map,
   586                              llvm::SmallVectorImpl<Value *> *operands) {
   587    if (!map || operands->empty())
   588      return;
   589  
   590    assert(map->getNumInputs() == operands->size() &&
   591           "map inputs must match number of operands");
   592  
   593    auto *context = map->getContext();
   594    SmallVector<Value *, 8> resultOperands;
   595    resultOperands.reserve(operands->size());
   596    SmallVector<Value *, 8> remappedSymbols;
   597    remappedSymbols.reserve(operands->size());
   598    unsigned nextDim = 0;
   599    unsigned nextSym = 0;
   600    unsigned oldNumSyms = map->getNumSymbols();
   601    SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
   602    for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) {
   603      if (i < map->getNumDims()) {
   604        if (isValidSymbol((*operands)[i])) {
   605          // This is a valid symbol that appears as a dim, canonicalize it.
   606          dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
   607          remappedSymbols.push_back((*operands)[i]);
   608        } else {
   609          dimRemapping[i] = getAffineDimExpr(nextDim++, context);
   610          resultOperands.push_back((*operands)[i]);
   611        }
   612      } else {
   613        resultOperands.push_back((*operands)[i]);
   614      }
   615    }
   616  
   617    resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
   618    *operands = resultOperands;
   619    *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
   620                                      oldNumSyms + nextSym);
   621  
   622    assert(map->getNumInputs() == operands->size() &&
   623           "map inputs must match number of operands");
   624  }
   625  
   626  void mlir::canonicalizeMapAndOperands(
   627      AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
   628    if (!map || operands->empty())
   629      return;
   630  
   631    assert(map->getNumInputs() == operands->size() &&
   632           "map inputs must match number of operands");
   633  
   634    canonicalizePromotedSymbols(map, operands);
   635  
   636    // Check to see what dims are used.
   637    llvm::SmallBitVector usedDims(map->getNumDims());
   638    llvm::SmallBitVector usedSyms(map->getNumSymbols());
   639    map->walkExprs([&](AffineExpr expr) {
   640      if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
   641        usedDims[dimExpr.getPosition()] = true;
   642      else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
   643        usedSyms[symExpr.getPosition()] = true;
   644    });
   645  
   646    auto *context = map->getContext();
   647  
   648    SmallVector<Value *, 8> resultOperands;
   649    resultOperands.reserve(operands->size());
   650  
   651    llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
   652    SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
   653    unsigned nextDim = 0;
   654    for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
   655      if (usedDims[i]) {
   656        // Remap dim positions for duplicate operands.
   657        auto it = seenDims.find((*operands)[i]);
   658        if (it == seenDims.end()) {
   659          dimRemapping[i] = getAffineDimExpr(nextDim++, context);
   660          resultOperands.push_back((*operands)[i]);
   661          seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
   662        } else {
   663          dimRemapping[i] = it->second;
   664        }
   665      }
   666    }
   667    llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
   668    SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
   669    unsigned nextSym = 0;
   670    for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
   671      if (!usedSyms[i])
   672        continue;
   673      // Handle constant operands (only needed for symbolic operands since
   674      // constant operands in dimensional positions would have already been
   675      // promoted to symbolic positions above).
   676      IntegerAttr operandCst;
   677      if (matchPattern((*operands)[i + map->getNumDims()],
   678                       m_Constant(&operandCst))) {
   679        symRemapping[i] =
   680            getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
   681        continue;
   682      }
   683      // Remap symbol positions for duplicate operands.
   684      auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
   685      if (it == seenSymbols.end()) {
   686        symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
   687        resultOperands.push_back((*operands)[i + map->getNumDims()]);
   688        seenSymbols.insert(
   689            std::make_pair((*operands)[i + map->getNumDims()], symRemapping[i]));
   690      } else {
   691        symRemapping[i] = it->second;
   692      }
   693    }
   694    *map =
   695        map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
   696    *operands = resultOperands;
   697  }
   698  
   699  namespace {
   700  /// Simplify AffineApply operations.
   701  ///
   702  struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
   703    using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
   704  
   705    PatternMatchResult matchAndRewrite(AffineApplyOp apply,
   706                                       PatternRewriter &rewriter) const override {
   707      auto map = apply.getAffineMap();
   708  
   709      AffineMap oldMap = map;
   710      SmallVector<Value *, 8> resultOperands(apply.getOperands());
   711      composeAffineMapAndOperands(&map, &resultOperands);
   712      if (map == oldMap)
   713        return matchFailure();
   714  
   715      rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
   716      return matchSuccess();
   717    }
   718  };
   719  } // end anonymous namespace.
   720  
   721  void AffineApplyOp::getCanonicalizationPatterns(
   722      OwningRewritePatternList &results, MLIRContext *context) {
   723    results.insert<SimplifyAffineApply>(context);
   724  }
   725  
   726  //===----------------------------------------------------------------------===//
   727  // Common canonicalization pattern support logic
   728  //===----------------------------------------------------------------------===//
   729  
   730  namespace {
   731  /// This is a common class used for patterns of the form
   732  /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
   733  /// into the root operation directly.
   734  struct MemRefCastFolder : public RewritePattern {
   735    /// The rootOpName is the name of the root operation to match against.
   736    MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
   737        : RewritePattern(rootOpName, 1, context) {}
   738  
   739    PatternMatchResult match(Operation *op) const override {
   740      for (auto *operand : op->getOperands())
   741        if (matchPattern(operand, m_Op<MemRefCastOp>()))
   742          return matchSuccess();
   743  
   744      return matchFailure();
   745    }
   746  
   747    void rewrite(Operation *op, PatternRewriter &rewriter) const override {
   748      for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
   749        if (auto *memref = op->getOperand(i)->getDefiningOp())
   750          if (auto cast = dyn_cast<MemRefCastOp>(memref))
   751            op->setOperand(i, cast.getOperand());
   752      rewriter.updatedRootInPlace(op);
   753    }
   754  };
   755  
   756  } // end anonymous namespace.
   757  
   758  //===----------------------------------------------------------------------===//
   759  // AffineDmaStartOp
   760  //===----------------------------------------------------------------------===//
   761  
   762  // TODO(b/133776335) Check that map operands are loop IVs or symbols.
   763  void AffineDmaStartOp::build(Builder *builder, OperationState *result,
   764                               Value *srcMemRef, AffineMap srcMap,
   765                               ArrayRef<Value *> srcIndices, Value *destMemRef,
   766                               AffineMap dstMap, ArrayRef<Value *> destIndices,
   767                               Value *tagMemRef, AffineMap tagMap,
   768                               ArrayRef<Value *> tagIndices, Value *numElements,
   769                               Value *stride, Value *elementsPerStride) {
   770    result->addOperands(srcMemRef);
   771    result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap));
   772    result->addOperands(srcIndices);
   773    result->addOperands(destMemRef);
   774    result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap));
   775    result->addOperands(destIndices);
   776    result->addOperands(tagMemRef);
   777    result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
   778    result->addOperands(tagIndices);
   779    result->addOperands(numElements);
   780    if (stride) {
   781      result->addOperands({stride, elementsPerStride});
   782    }
   783  }
   784  
   785  void AffineDmaStartOp::print(OpAsmPrinter *p) {
   786    *p << "affine.dma_start " << *getSrcMemRef() << '[';
   787    SmallVector<Value *, 8> operands(getSrcIndices());
   788    p->printAffineMapOfSSAIds(getSrcMapAttr(), operands);
   789    *p << "], " << *getDstMemRef() << '[';
   790    operands.assign(getDstIndices().begin(), getDstIndices().end());
   791    p->printAffineMapOfSSAIds(getDstMapAttr(), operands);
   792    *p << "], " << *getTagMemRef() << '[';
   793    operands.assign(getTagIndices().begin(), getTagIndices().end());
   794    p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
   795    *p << "], " << *getNumElements();
   796    if (isStrided()) {
   797      *p << ", " << *getStride();
   798      *p << ", " << *getNumElementsPerStride();
   799    }
   800    *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
   801       << getTagMemRefType();
   802  }
   803  
   804  // Parse AffineDmaStartOp.
   805  // Ex:
   806  //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
   807  //     %stride, %num_elt_per_stride
   808  //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
   809  //
   810  ParseResult AffineDmaStartOp::parse(OpAsmParser *parser,
   811                                      OperationState *result) {
   812    OpAsmParser::OperandType srcMemRefInfo;
   813    AffineMapAttr srcMapAttr;
   814    SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
   815    OpAsmParser::OperandType dstMemRefInfo;
   816    AffineMapAttr dstMapAttr;
   817    SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
   818    OpAsmParser::OperandType tagMemRefInfo;
   819    AffineMapAttr tagMapAttr;
   820    SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
   821    OpAsmParser::OperandType numElementsInfo;
   822    SmallVector<OpAsmParser::OperandType, 2> strideInfo;
   823  
   824    SmallVector<Type, 3> types;
   825    auto indexType = parser->getBuilder().getIndexType();
   826  
   827    // Parse and resolve the following list of operands:
   828    // *) dst memref followed by its affine maps operands (in square brackets).
   829    // *) src memref followed by its affine map operands (in square brackets).
   830    // *) tag memref followed by its affine map operands (in square brackets).
   831    // *) number of elements transferred by DMA operation.
   832    if (parser->parseOperand(srcMemRefInfo) ||
   833        parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
   834                                       getSrcMapAttrName(), result->attributes) ||
   835        parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
   836        parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
   837                                       getDstMapAttrName(), result->attributes) ||
   838        parser->parseComma() || parser->parseOperand(tagMemRefInfo) ||
   839        parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
   840                                       getTagMapAttrName(), result->attributes) ||
   841        parser->parseComma() || parser->parseOperand(numElementsInfo))
   842      return failure();
   843  
   844    // Parse optional stride and elements per stride.
   845    if (parser->parseTrailingOperandList(strideInfo)) {
   846      return failure();
   847    }
   848    if (!strideInfo.empty() && strideInfo.size() != 2) {
   849      return parser->emitError(parser->getNameLoc(),
   850                               "expected two stride related operands");
   851    }
   852    bool isStrided = strideInfo.size() == 2;
   853  
   854    if (parser->parseColonTypeList(types))
   855      return failure();
   856  
   857    if (types.size() != 3)
   858      return parser->emitError(parser->getNameLoc(), "expected three types");
   859  
   860    if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
   861        parser->resolveOperands(srcMapOperands, indexType, result->operands) ||
   862        parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
   863        parser->resolveOperands(dstMapOperands, indexType, result->operands) ||
   864        parser->resolveOperand(tagMemRefInfo, types[2], result->operands) ||
   865        parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
   866        parser->resolveOperand(numElementsInfo, indexType, result->operands))
   867      return failure();
   868  
   869    if (isStrided) {
   870      if (parser->resolveOperands(strideInfo, indexType, result->operands))
   871        return failure();
   872    }
   873  
   874    // Check that src/dst/tag operand counts match their map.numInputs.
   875    if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
   876        dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
   877        tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
   878      return parser->emitError(parser->getNameLoc(),
   879                               "memref operand count not equal to map.numInputs");
   880    return success();
   881  }
   882  
   883  LogicalResult AffineDmaStartOp::verify() {
   884    if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
   885      return emitOpError("expected DMA source to be of memref type");
   886    if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
   887      return emitOpError("expected DMA destination to be of memref type");
   888    if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
   889      return emitOpError("expected DMA tag to be of memref type");
   890  
   891    // DMAs from different memory spaces supported.
   892    if (getSrcMemorySpace() == getDstMemorySpace()) {
   893      return emitOpError("DMA should be between different memory spaces");
   894    }
   895    unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
   896                                getDstMap().getNumInputs() +
   897                                getTagMap().getNumInputs();
   898    if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
   899        getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
   900      return emitOpError("incorrect number of operands");
   901    }
   902  
   903    for (auto *idx : getSrcIndices()) {
   904      if (!idx->getType().isIndex())
   905        return emitOpError("src index to dma_start must have 'index' type");
   906      if (!isValidAffineIndexOperand(idx))
   907        return emitOpError("src index must be a dimension or symbol identifier");
   908    }
   909    for (auto *idx : getDstIndices()) {
   910      if (!idx->getType().isIndex())
   911        return emitOpError("dst index to dma_start must have 'index' type");
   912      if (!isValidAffineIndexOperand(idx))
   913        return emitOpError("dst index must be a dimension or symbol identifier");
   914    }
   915    for (auto *idx : getTagIndices()) {
   916      if (!idx->getType().isIndex())
   917        return emitOpError("tag index to dma_start must have 'index' type");
   918      if (!isValidAffineIndexOperand(idx))
   919        return emitOpError("tag index must be a dimension or symbol identifier");
   920    }
   921    return success();
   922  }
   923  
   924  void AffineDmaStartOp::getCanonicalizationPatterns(
   925      OwningRewritePatternList &results, MLIRContext *context) {
   926    /// dma_start(memrefcast) -> dma_start
   927    results.insert<MemRefCastFolder>(getOperationName(), context);
   928  }
   929  
   930  //===----------------------------------------------------------------------===//
   931  // AffineDmaWaitOp
   932  //===----------------------------------------------------------------------===//
   933  
   934  // TODO(b/133776335) Check that map operands are loop IVs or symbols.
   935  void AffineDmaWaitOp::build(Builder *builder, OperationState *result,
   936                              Value *tagMemRef, AffineMap tagMap,
   937                              ArrayRef<Value *> tagIndices, Value *numElements) {
   938    result->addOperands(tagMemRef);
   939    result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
   940    result->addOperands(tagIndices);
   941    result->addOperands(numElements);
   942  }
   943  
   944  void AffineDmaWaitOp::print(OpAsmPrinter *p) {
   945    *p << "affine.dma_wait " << *getTagMemRef() << '[';
   946    SmallVector<Value *, 2> operands(getTagIndices());
   947    p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
   948    *p << "], ";
   949    p->printOperand(getNumElements());
   950    *p << " : " << getTagMemRef()->getType();
   951  }
   952  
   953  // Parse AffineDmaWaitOp.
   954  // Eg:
   955  //   affine.dma_wait %tag[%index], %num_elements
   956  //     : memref<1 x i32, (d0) -> (d0), 4>
   957  //
   958  ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser,
   959                                     OperationState *result) {
   960    OpAsmParser::OperandType tagMemRefInfo;
   961    AffineMapAttr tagMapAttr;
   962    SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
   963    Type type;
   964    auto indexType = parser->getBuilder().getIndexType();
   965    OpAsmParser::OperandType numElementsInfo;
   966  
   967    // Parse tag memref, its map operands, and dma size.
   968    if (parser->parseOperand(tagMemRefInfo) ||
   969        parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
   970                                       getTagMapAttrName(), result->attributes) ||
   971        parser->parseComma() || parser->parseOperand(numElementsInfo) ||
   972        parser->parseColonType(type) ||
   973        parser->resolveOperand(tagMemRefInfo, type, result->operands) ||
   974        parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
   975        parser->resolveOperand(numElementsInfo, indexType, result->operands))
   976      return failure();
   977  
   978    if (!type.isa<MemRefType>())
   979      return parser->emitError(parser->getNameLoc(),
   980                               "expected tag to be of memref type");
   981  
   982    if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
   983      return parser->emitError(parser->getNameLoc(),
   984                               "tag memref operand count != to map.numInputs");
   985    return success();
   986  }
   987  
   988  LogicalResult AffineDmaWaitOp::verify() {
   989    if (!getOperand(0)->getType().isa<MemRefType>())
   990      return emitOpError("expected DMA tag to be of memref type");
   991    for (auto *idx : getTagIndices()) {
   992      if (!idx->getType().isIndex())
   993        return emitOpError("index to dma_wait must have 'index' type");
   994      if (!isValidAffineIndexOperand(idx))
   995        return emitOpError("index must be a dimension or symbol identifier");
   996    }
   997    return success();
   998  }
   999  
  1000  void AffineDmaWaitOp::getCanonicalizationPatterns(
  1001      OwningRewritePatternList &results, MLIRContext *context) {
  1002    /// dma_wait(memrefcast) -> dma_wait
  1003    results.insert<MemRefCastFolder>(getOperationName(), context);
  1004  }
  1005  
  1006  //===----------------------------------------------------------------------===//
  1007  // AffineForOp
  1008  //===----------------------------------------------------------------------===//
  1009  
  1010  void AffineForOp::build(Builder *builder, OperationState *result,
  1011                          ArrayRef<Value *> lbOperands, AffineMap lbMap,
  1012                          ArrayRef<Value *> ubOperands, AffineMap ubMap,
  1013                          int64_t step) {
  1014    assert(((!lbMap && lbOperands.empty()) ||
  1015            lbOperands.size() == lbMap.getNumInputs()) &&
  1016           "lower bound operand count does not match the affine map");
  1017    assert(((!ubMap && ubOperands.empty()) ||
  1018            ubOperands.size() == ubMap.getNumInputs()) &&
  1019           "upper bound operand count does not match the affine map");
  1020    assert(step > 0 && "step has to be a positive integer constant");
  1021  
  1022    // Add an attribute for the step.
  1023    result->addAttribute(getStepAttrName(),
  1024                         builder->getIntegerAttr(builder->getIndexType(), step));
  1025  
  1026    // Add the lower bound.
  1027    result->addAttribute(getLowerBoundAttrName(),
  1028                         builder->getAffineMapAttr(lbMap));
  1029    result->addOperands(lbOperands);
  1030  
  1031    // Add the upper bound.
  1032    result->addAttribute(getUpperBoundAttrName(),
  1033                         builder->getAffineMapAttr(ubMap));
  1034    result->addOperands(ubOperands);
  1035  
  1036    // Create a region and a block for the body.  The argument of the region is
  1037    // the loop induction variable.
  1038    Region *bodyRegion = result->addRegion();
  1039    Block *body = new Block();
  1040    body->addArgument(IndexType::get(builder->getContext()));
  1041    bodyRegion->push_back(body);
  1042    ensureTerminator(*bodyRegion, *builder, result->location);
  1043  
  1044    // Set the operands list as resizable so that we can freely modify the bounds.
  1045    result->setOperandListToResizable();
  1046  }
  1047  
  1048  void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb,
  1049                          int64_t ub, int64_t step) {
  1050    auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
  1051    auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
  1052    return build(builder, result, {}, lbMap, {}, ubMap, step);
  1053  }
  1054  
  1055  static LogicalResult verify(AffineForOp op) {
  1056    // Check that the body defines as single block argument for the induction
  1057    // variable.
  1058    auto *body = op.getBody();
  1059    if (body->getNumArguments() != 1 ||
  1060        !body->getArgument(0)->getType().isIndex())
  1061      return op.emitOpError(
  1062          "expected body to have a single index argument for the "
  1063          "induction variable");
  1064  
  1065    // Verify that there are enough operands for the bounds.
  1066    AffineMap lowerBoundMap = op.getLowerBoundMap(),
  1067              upperBoundMap = op.getUpperBoundMap();
  1068    if (op.getNumOperands() !=
  1069        (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
  1070      return op.emitOpError(
  1071          "operand count must match with affine map dimension and symbol count");
  1072  
  1073    // Verify that the bound operands are valid dimension/symbols.
  1074    /// Lower bound.
  1075    if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
  1076                                             op.getLowerBoundMap().getNumDims())))
  1077      return failure();
  1078    /// Upper bound.
  1079    if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
  1080                                             op.getUpperBoundMap().getNumDims())))
  1081      return failure();
  1082    return success();
  1083  }
  1084  
  1085  /// Parse a for operation loop bounds.
  1086  static ParseResult parseBound(bool isLower, OperationState *result,
  1087                                OpAsmParser *p) {
  1088    // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
  1089    // the map has multiple results.
  1090    bool failedToParsedMinMax =
  1091        failed(p->parseOptionalKeyword(isLower ? "max" : "min"));
  1092  
  1093    auto &builder = p->getBuilder();
  1094    auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
  1095                                 : AffineForOp::getUpperBoundAttrName();
  1096  
  1097    // Parse ssa-id as identity map.
  1098    SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
  1099    if (p->parseOperandList(boundOpInfos))
  1100      return failure();
  1101  
  1102    if (!boundOpInfos.empty()) {
  1103      // Check that only one operand was parsed.
  1104      if (boundOpInfos.size() > 1)
  1105        return p->emitError(p->getNameLoc(),
  1106                            "expected only one loop bound operand");
  1107  
  1108      // TODO: improve error message when SSA value is not an affine integer.
  1109      // Currently it is 'use of value ... expects different type than prior uses'
  1110      if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
  1111                            result->operands))
  1112        return failure();
  1113  
  1114      // Create an identity map using symbol id. This representation is optimized
  1115      // for storage. Analysis passes may expand it into a multi-dimensional map
  1116      // if desired.
  1117      AffineMap map = builder.getSymbolIdentityMap();
  1118      result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
  1119      return success();
  1120    }
  1121  
  1122    // Get the attribute location.
  1123    llvm::SMLoc attrLoc = p->getCurrentLocation();
  1124  
  1125    Attribute boundAttr;
  1126    if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
  1127                          result->attributes))
  1128      return failure();
  1129  
  1130    // Parse full form - affine map followed by dim and symbol list.
  1131    if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
  1132      unsigned currentNumOperands = result->operands.size();
  1133      unsigned numDims;
  1134      if (parseDimAndSymbolList(p, result->operands, numDims))
  1135        return failure();
  1136  
  1137      auto map = affineMapAttr.getValue();
  1138      if (map.getNumDims() != numDims)
  1139        return p->emitError(
  1140            p->getNameLoc(),
  1141            "dim operand count and integer set dim count must match");
  1142  
  1143      unsigned numDimAndSymbolOperands =
  1144          result->operands.size() - currentNumOperands;
  1145      if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
  1146        return p->emitError(
  1147            p->getNameLoc(),
  1148            "symbol operand count and integer set symbol count must match");
  1149  
  1150      // If the map has multiple results, make sure that we parsed the min/max
  1151      // prefix.
  1152      if (map.getNumResults() > 1 && failedToParsedMinMax) {
  1153        if (isLower) {
  1154          return p->emitError(attrLoc, "lower loop bound affine map with "
  1155                                       "multiple results requires 'max' prefix");
  1156        }
  1157        return p->emitError(attrLoc, "upper loop bound affine map with multiple "
  1158                                     "results requires 'min' prefix");
  1159      }
  1160      return success();
  1161    }
  1162  
  1163    // Parse custom assembly form.
  1164    if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
  1165      result->attributes.pop_back();
  1166      result->addAttribute(
  1167          boundAttrName, builder.getAffineMapAttr(
  1168                             builder.getConstantAffineMap(integerAttr.getInt())));
  1169      return success();
  1170    }
  1171  
  1172    return p->emitError(
  1173        p->getNameLoc(),
  1174        "expected valid affine map representation for loop bounds");
  1175  }
  1176  
  1177  ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) {
  1178    auto &builder = parser->getBuilder();
  1179    OpAsmParser::OperandType inductionVariable;
  1180    // Parse the induction variable followed by '='.
  1181    if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
  1182      return failure();
  1183  
  1184    // Parse loop bounds.
  1185    if (parseBound(/*isLower=*/true, result, parser) ||
  1186        parser->parseKeyword("to", " between bounds") ||
  1187        parseBound(/*isLower=*/false, result, parser))
  1188      return failure();
  1189  
  1190    // Parse the optional loop step, we default to 1 if one is not present.
  1191    if (parser->parseOptionalKeyword("step")) {
  1192      result->addAttribute(
  1193          AffineForOp::getStepAttrName(),
  1194          builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
  1195    } else {
  1196      llvm::SMLoc stepLoc = parser->getCurrentLocation();
  1197      IntegerAttr stepAttr;
  1198      if (parser->parseAttribute(stepAttr, builder.getIndexType(),
  1199                                 AffineForOp::getStepAttrName().data(),
  1200                                 result->attributes))
  1201        return failure();
  1202  
  1203      if (stepAttr.getValue().getSExtValue() < 0)
  1204        return parser->emitError(
  1205            stepLoc,
  1206            "expected step to be representable as a positive signed integer");
  1207    }
  1208  
  1209    // Parse the body region.
  1210    Region *body = result->addRegion();
  1211    if (parser->parseRegion(*body, inductionVariable, builder.getIndexType()))
  1212      return failure();
  1213  
  1214    AffineForOp::ensureTerminator(*body, builder, result->location);
  1215  
  1216    // Parse the optional attribute list.
  1217    if (parser->parseOptionalAttributeDict(result->attributes))
  1218      return failure();
  1219  
  1220    // Set the operands list as resizable so that we can freely modify the bounds.
  1221    result->setOperandListToResizable();
  1222    return success();
  1223  }
  1224  
  1225  static void printBound(AffineMapAttr boundMap,
  1226                         Operation::operand_range boundOperands,
  1227                         const char *prefix, OpAsmPrinter *p) {
  1228    AffineMap map = boundMap.getValue();
  1229  
  1230    // Check if this bound should be printed using custom assembly form.
  1231    // The decision to restrict printing custom assembly form to trivial cases
  1232    // comes from the will to roundtrip MLIR binary -> text -> binary in a
  1233    // lossless way.
  1234    // Therefore, custom assembly form parsing and printing is only supported for
  1235    // zero-operand constant maps and single symbol operand identity maps.
  1236    if (map.getNumResults() == 1) {
  1237      AffineExpr expr = map.getResult(0);
  1238  
  1239      // Print constant bound.
  1240      if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
  1241        if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
  1242          *p << constExpr.getValue();
  1243          return;
  1244        }
  1245      }
  1246  
  1247      // Print bound that consists of a single SSA symbol if the map is over a
  1248      // single symbol.
  1249      if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
  1250        if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
  1251          p->printOperand(*boundOperands.begin());
  1252          return;
  1253        }
  1254      }
  1255    } else {
  1256      // Map has multiple results. Print 'min' or 'max' prefix.
  1257      *p << prefix << ' ';
  1258    }
  1259  
  1260    // Print the map and its operands.
  1261    *p << boundMap;
  1262    printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
  1263                          map.getNumDims(), p);
  1264  }
  1265  
  1266  void print(OpAsmPrinter *p, AffineForOp op) {
  1267    *p << "affine.for ";
  1268    p->printOperand(op.getBody()->getArgument(0));
  1269    *p << " = ";
  1270    printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
  1271    *p << " to ";
  1272    printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
  1273  
  1274    if (op.getStep() != 1)
  1275      *p << " step " << op.getStep();
  1276    p->printRegion(op.region(),
  1277                   /*printEntryBlockArgs=*/false,
  1278                   /*printBlockTerminators=*/false);
  1279    p->printOptionalAttrDict(op.getAttrs(),
  1280                             /*elidedAttrs=*/{op.getLowerBoundAttrName(),
  1281                                              op.getUpperBoundAttrName(),
  1282                                              op.getStepAttrName()});
  1283  }
  1284  
  1285  namespace {
  1286  /// This is a pattern to fold trivially empty loops.
  1287  struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
  1288    using OpRewritePattern<AffineForOp>::OpRewritePattern;
  1289  
  1290    PatternMatchResult matchAndRewrite(AffineForOp forOp,
  1291                                       PatternRewriter &rewriter) const override {
  1292      // Check that the body only contains a terminator.
  1293      auto *body = forOp.getBody();
  1294      if (std::next(body->begin()) != body->end())
  1295        return matchFailure();
  1296      rewriter.replaceOp(forOp, llvm::None);
  1297      return matchSuccess();
  1298    }
  1299  };
  1300  
  1301  /// This is a pattern to fold constant loop bounds.
  1302  struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
  1303    using OpRewritePattern<AffineForOp>::OpRewritePattern;
  1304  
  1305    PatternMatchResult matchAndRewrite(AffineForOp forOp,
  1306                                       PatternRewriter &rewriter) const override {
  1307      auto foldLowerOrUpperBound = [&forOp](bool lower) {
  1308        // Check to see if each of the operands is the result of a constant.  If
  1309        // so, get the value.  If not, ignore it.
  1310        SmallVector<Attribute, 8> operandConstants;
  1311        auto boundOperands =
  1312            lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
  1313        for (auto *operand : boundOperands) {
  1314          Attribute operandCst;
  1315          matchPattern(operand, m_Constant(&operandCst));
  1316          operandConstants.push_back(operandCst);
  1317        }
  1318  
  1319        AffineMap boundMap =
  1320            lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
  1321        assert(boundMap.getNumResults() >= 1 &&
  1322               "bound maps should have at least one result");
  1323        SmallVector<Attribute, 4> foldedResults;
  1324        if (failed(boundMap.constantFold(operandConstants, foldedResults)))
  1325          return failure();
  1326  
  1327        // Compute the max or min as applicable over the results.
  1328        assert(!foldedResults.empty() &&
  1329               "bounds should have at least one result");
  1330        auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
  1331        for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
  1332          auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
  1333          maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
  1334                           : llvm::APIntOps::smin(maxOrMin, foldedResult);
  1335        }
  1336        lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
  1337              : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
  1338        return success();
  1339      };
  1340  
  1341      // Try to fold the lower bound.
  1342      bool folded = false;
  1343      if (!forOp.hasConstantLowerBound())
  1344        folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
  1345  
  1346      // Try to fold the upper bound.
  1347      if (!forOp.hasConstantUpperBound())
  1348        folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
  1349  
  1350      // If any of the bounds were folded we return success.
  1351      if (!folded)
  1352        return matchFailure();
  1353      rewriter.updatedRootInPlace(forOp);
  1354      return matchSuccess();
  1355    }
  1356  };
  1357  } // end anonymous namespace
  1358  
  1359  void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
  1360                                                MLIRContext *context) {
  1361    results.insert<AffineForEmptyLoopFolder, AffineForLoopBoundFolder>(context);
  1362  }
  1363  
  1364  AffineBound AffineForOp::getLowerBound() {
  1365    auto lbMap = getLowerBoundMap();
  1366    return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
  1367  }
  1368  
  1369  AffineBound AffineForOp::getUpperBound() {
  1370    auto lbMap = getLowerBoundMap();
  1371    auto ubMap = getUpperBoundMap();
  1372    return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
  1373                       ubMap);
  1374  }
  1375  
  1376  void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
  1377    assert(lbOperands.size() == map.getNumInputs());
  1378    assert(map.getNumResults() >= 1 && "bound map has at least one result");
  1379  
  1380    SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
  1381  
  1382    auto ubOperands = getUpperBoundOperands();
  1383    newOperands.append(ubOperands.begin(), ubOperands.end());
  1384    getOperation()->setOperands(newOperands);
  1385  
  1386    setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
  1387  }
  1388  
  1389  void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
  1390    assert(ubOperands.size() == map.getNumInputs());
  1391    assert(map.getNumResults() >= 1 && "bound map has at least one result");
  1392  
  1393    SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
  1394    newOperands.append(ubOperands.begin(), ubOperands.end());
  1395    getOperation()->setOperands(newOperands);
  1396  
  1397    setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
  1398  }
  1399  
  1400  void AffineForOp::setLowerBoundMap(AffineMap map) {
  1401    auto lbMap = getLowerBoundMap();
  1402    assert(lbMap.getNumDims() == map.getNumDims() &&
  1403           lbMap.getNumSymbols() == map.getNumSymbols());
  1404    assert(map.getNumResults() >= 1 && "bound map has at least one result");
  1405    (void)lbMap;
  1406    setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
  1407  }
  1408  
  1409  void AffineForOp::setUpperBoundMap(AffineMap map) {
  1410    auto ubMap = getUpperBoundMap();
  1411    assert(ubMap.getNumDims() == map.getNumDims() &&
  1412           ubMap.getNumSymbols() == map.getNumSymbols());
  1413    assert(map.getNumResults() >= 1 && "bound map has at least one result");
  1414    (void)ubMap;
  1415    setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
  1416  }
  1417  
  1418  bool AffineForOp::hasConstantLowerBound() {
  1419    return getLowerBoundMap().isSingleConstant();
  1420  }
  1421  
  1422  bool AffineForOp::hasConstantUpperBound() {
  1423    return getUpperBoundMap().isSingleConstant();
  1424  }
  1425  
  1426  int64_t AffineForOp::getConstantLowerBound() {
  1427    return getLowerBoundMap().getSingleConstantResult();
  1428  }
  1429  
  1430  int64_t AffineForOp::getConstantUpperBound() {
  1431    return getUpperBoundMap().getSingleConstantResult();
  1432  }
  1433  
  1434  void AffineForOp::setConstantLowerBound(int64_t value) {
  1435    setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
  1436  }
  1437  
  1438  void AffineForOp::setConstantUpperBound(int64_t value) {
  1439    setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
  1440  }
  1441  
  1442  AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
  1443    return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
  1444  }
  1445  
  1446  AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
  1447    return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
  1448  }
  1449  
  1450  bool AffineForOp::matchingBoundOperandList() {
  1451    auto lbMap = getLowerBoundMap();
  1452    auto ubMap = getUpperBoundMap();
  1453    if (lbMap.getNumDims() != ubMap.getNumDims() ||
  1454        lbMap.getNumSymbols() != ubMap.getNumSymbols())
  1455      return false;
  1456  
  1457    unsigned numOperands = lbMap.getNumInputs();
  1458    for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
  1459      // Compare Value *'s.
  1460      if (getOperand(i) != getOperand(numOperands + i))
  1461        return false;
  1462    }
  1463    return true;
  1464  }
  1465  
  1466  /// Returns if the provided value is the induction variable of a AffineForOp.
  1467  bool mlir::isForInductionVar(Value *val) {
  1468    return getForInductionVarOwner(val) != AffineForOp();
  1469  }
  1470  
  1471  /// Returns the loop parent of an induction variable. If the provided value is
  1472  /// not an induction variable, then return nullptr.
  1473  AffineForOp mlir::getForInductionVarOwner(Value *val) {
  1474    auto *ivArg = dyn_cast<BlockArgument>(val);
  1475    if (!ivArg || !ivArg->getOwner())
  1476      return AffineForOp();
  1477    auto *containingInst = ivArg->getOwner()->getParent()->getParentOp();
  1478    return dyn_cast<AffineForOp>(containingInst);
  1479  }
  1480  
  1481  /// Extracts the induction variables from a list of AffineForOps and returns
  1482  /// them.
  1483  void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
  1484                                     SmallVectorImpl<Value *> *ivs) {
  1485    ivs->reserve(forInsts.size());
  1486    for (auto forInst : forInsts)
  1487      ivs->push_back(forInst.getInductionVar());
  1488  }
  1489  
  1490  //===----------------------------------------------------------------------===//
  1491  // AffineIfOp
  1492  //===----------------------------------------------------------------------===//
  1493  
  1494  static LogicalResult verify(AffineIfOp op) {
  1495    // Verify that we have a condition attribute.
  1496    auto conditionAttr =
  1497        op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
  1498    if (!conditionAttr)
  1499      return op.emitOpError(
  1500          "requires an integer set attribute named 'condition'");
  1501  
  1502    // Verify that there are enough operands for the condition.
  1503    IntegerSet condition = conditionAttr.getValue();
  1504    if (op.getNumOperands() != condition.getNumOperands())
  1505      return op.emitOpError(
  1506          "operand count and condition integer set dimension and "
  1507          "symbol count must match");
  1508  
  1509    // Verify that the operands are valid dimension/symbols.
  1510    if (failed(verifyDimAndSymbolIdentifiers(
  1511            op, op.getOperation()->getNonSuccessorOperands(),
  1512            condition.getNumDims())))
  1513      return failure();
  1514  
  1515    // Verify that the entry of each child region does not have arguments.
  1516    for (auto &region : op.getOperation()->getRegions()) {
  1517      for (auto &b : region)
  1518        if (b.getNumArguments() != 0)
  1519          return op.emitOpError(
  1520              "requires that child entry blocks have no arguments");
  1521    }
  1522    return success();
  1523  }
  1524  
  1525  ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) {
  1526    // Parse the condition attribute set.
  1527    IntegerSetAttr conditionAttr;
  1528    unsigned numDims;
  1529    if (parser->parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
  1530                               result->attributes) ||
  1531        parseDimAndSymbolList(parser, result->operands, numDims))
  1532      return failure();
  1533  
  1534    // Verify the condition operands.
  1535    auto set = conditionAttr.getValue();
  1536    if (set.getNumDims() != numDims)
  1537      return parser->emitError(
  1538          parser->getNameLoc(),
  1539          "dim operand count and integer set dim count must match");
  1540    if (numDims + set.getNumSymbols() != result->operands.size())
  1541      return parser->emitError(
  1542          parser->getNameLoc(),
  1543          "symbol operand count and integer set symbol count must match");
  1544  
  1545    // Create the regions for 'then' and 'else'.  The latter must be created even
  1546    // if it remains empty for the validity of the operation.
  1547    result->regions.reserve(2);
  1548    Region *thenRegion = result->addRegion();
  1549    Region *elseRegion = result->addRegion();
  1550  
  1551    // Parse the 'then' region.
  1552    if (parser->parseRegion(*thenRegion, {}, {}))
  1553      return failure();
  1554    AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(),
  1555                                 result->location);
  1556  
  1557    // If we find an 'else' keyword then parse the 'else' region.
  1558    if (!parser->parseOptionalKeyword("else")) {
  1559      if (parser->parseRegion(*elseRegion, {}, {}))
  1560        return failure();
  1561      AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(),
  1562                                   result->location);
  1563    }
  1564  
  1565    // Parse the optional attribute list.
  1566    if (parser->parseOptionalAttributeDict(result->attributes))
  1567      return failure();
  1568  
  1569    return success();
  1570  }
  1571  
  1572  void print(OpAsmPrinter *p, AffineIfOp op) {
  1573    auto conditionAttr =
  1574        op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
  1575    *p << "affine.if " << conditionAttr;
  1576    printDimAndSymbolList(op.operand_begin(), op.operand_end(),
  1577                          conditionAttr.getValue().getNumDims(), p);
  1578    p->printRegion(op.thenRegion(),
  1579                   /*printEntryBlockArgs=*/false,
  1580                   /*printBlockTerminators=*/false);
  1581  
  1582    // Print the 'else' regions if it has any blocks.
  1583    auto &elseRegion = op.elseRegion();
  1584    if (!elseRegion.empty()) {
  1585      *p << " else";
  1586      p->printRegion(elseRegion,
  1587                     /*printEntryBlockArgs=*/false,
  1588                     /*printBlockTerminators=*/false);
  1589    }
  1590  
  1591    // Print the attribute list.
  1592    p->printOptionalAttrDict(op.getAttrs(),
  1593                             /*elidedAttrs=*/op.getConditionAttrName());
  1594  }
  1595  
  1596  IntegerSet AffineIfOp::getIntegerSet() {
  1597    return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
  1598  }
  1599  void AffineIfOp::setIntegerSet(IntegerSet newSet) {
  1600    setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
  1601  }
  1602  
  1603  //===----------------------------------------------------------------------===//
  1604  // AffineLoadOp
  1605  //===----------------------------------------------------------------------===//
  1606  
  1607  void AffineLoadOp::build(Builder *builder, OperationState *result,
  1608                           AffineMap map, ArrayRef<Value *> operands) {
  1609    result->addOperands(operands);
  1610    if (map)
  1611      result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
  1612    auto memrefType = operands[0]->getType().cast<MemRefType>();
  1613    result->types.push_back(memrefType.getElementType());
  1614  }
  1615  
  1616  void AffineLoadOp::build(Builder *builder, OperationState *result,
  1617                           Value *memref, ArrayRef<Value *> indices) {
  1618    result->addOperands(memref);
  1619    result->addOperands(indices);
  1620    auto memrefType = memref->getType().cast<MemRefType>();
  1621    auto rank = memrefType.getRank();
  1622    // Create identity map for memrefs with at least one dimension or () -> ()
  1623    // for zero-dimensional memrefs.
  1624    auto map = rank ? builder->getMultiDimIdentityMap(rank)
  1625                    : builder->getEmptyAffineMap();
  1626    result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
  1627    result->types.push_back(memrefType.getElementType());
  1628  }
  1629  
  1630  ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
  1631    auto &builder = parser->getBuilder();
  1632    auto affineIntTy = builder.getIndexType();
  1633  
  1634    MemRefType type;
  1635    OpAsmParser::OperandType memrefInfo;
  1636    AffineMapAttr mapAttr;
  1637    SmallVector<OpAsmParser::OperandType, 1> mapOperands;
  1638    return failure(
  1639        parser->parseOperand(memrefInfo) ||
  1640        parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
  1641                                       result->attributes) ||
  1642        parser->parseOptionalAttributeDict(result->attributes) ||
  1643        parser->parseColonType(type) ||
  1644        parser->resolveOperand(memrefInfo, type, result->operands) ||
  1645        parser->resolveOperands(mapOperands, affineIntTy, result->operands) ||
  1646        parser->addTypeToList(type.getElementType(), result->types));
  1647  }
  1648  
  1649  void AffineLoadOp::print(OpAsmPrinter *p) {
  1650    *p << "affine.load " << *getMemRef() << '[';
  1651    AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
  1652    if (mapAttr) {
  1653      SmallVector<Value *, 2> operands(getIndices());
  1654      p->printAffineMapOfSSAIds(mapAttr, operands);
  1655    }
  1656    *p << ']';
  1657    p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
  1658    *p << " : " << getMemRefType();
  1659  }
  1660  
  1661  LogicalResult AffineLoadOp::verify() {
  1662    if (getType() != getMemRefType().getElementType())
  1663      return emitOpError("result type must match element type of memref");
  1664  
  1665    auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
  1666    if (mapAttr) {
  1667      AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue();
  1668      if (map.getNumResults() != getMemRefType().getRank())
  1669        return emitOpError("affine.load affine map num results must equal"
  1670                           " memref rank");
  1671      if (map.getNumInputs() != getNumOperands() - 1)
  1672        return emitOpError("expects as many subscripts as affine map inputs");
  1673    } else {
  1674      if (getMemRefType().getRank() != getNumOperands() - 1)
  1675        return emitOpError(
  1676            "expects the number of subscripts to be equal to memref rank");
  1677    }
  1678  
  1679    for (auto *idx : getIndices()) {
  1680      if (!idx->getType().isIndex())
  1681        return emitOpError("index to load must have 'index' type");
  1682      if (!isValidAffineIndexOperand(idx))
  1683        return emitOpError("index must be a dimension or symbol identifier");
  1684    }
  1685    return success();
  1686  }
  1687  
  1688  void AffineLoadOp::getCanonicalizationPatterns(
  1689      OwningRewritePatternList &results, MLIRContext *context) {
  1690    /// load(memrefcast) -> load
  1691    results.insert<MemRefCastFolder>(getOperationName(), context);
  1692  }
  1693  
  1694  //===----------------------------------------------------------------------===//
  1695  // AffineStoreOp
  1696  //===----------------------------------------------------------------------===//
  1697  
  1698  void AffineStoreOp::build(Builder *builder, OperationState *result,
  1699                            Value *valueToStore, AffineMap map,
  1700                            ArrayRef<Value *> operands) {
  1701    result->addOperands(valueToStore);
  1702    result->addOperands(operands);
  1703    if (map)
  1704      result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
  1705  }
  1706  
  1707  void AffineStoreOp::build(Builder *builder, OperationState *result,
  1708                            Value *valueToStore, Value *memref,
  1709                            ArrayRef<Value *> operands) {
  1710    result->addOperands(valueToStore);
  1711    result->addOperands(memref);
  1712    result->addOperands(operands);
  1713    auto memrefType = memref->getType().cast<MemRefType>();
  1714    auto rank = memrefType.getRank();
  1715    // Create identity map for memrefs with at least one dimension or () -> ()
  1716    // for zero-dimensional memrefs.
  1717    auto map = rank ? builder->getMultiDimIdentityMap(rank)
  1718                    : builder->getEmptyAffineMap();
  1719    result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
  1720  }
  1721  
  1722  ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
  1723    auto affineIntTy = parser->getBuilder().getIndexType();
  1724  
  1725    MemRefType type;
  1726    OpAsmParser::OperandType storeValueInfo;
  1727    OpAsmParser::OperandType memrefInfo;
  1728    AffineMapAttr mapAttr;
  1729    SmallVector<OpAsmParser::OperandType, 1> mapOperands;
  1730    return failure(
  1731        parser->parseOperand(storeValueInfo) || parser->parseComma() ||
  1732        parser->parseOperand(memrefInfo) ||
  1733        parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
  1734                                       result->attributes) ||
  1735        parser->parseOptionalAttributeDict(result->attributes) ||
  1736        parser->parseColonType(type) ||
  1737        parser->resolveOperand(storeValueInfo, type.getElementType(),
  1738                               result->operands) ||
  1739        parser->resolveOperand(memrefInfo, type, result->operands) ||
  1740        parser->resolveOperands(mapOperands, affineIntTy, result->operands));
  1741  }
  1742  
  1743  void AffineStoreOp::print(OpAsmPrinter *p) {
  1744    *p << "affine.store " << *getValueToStore();
  1745    *p << ", " << *getMemRef() << '[';
  1746    AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
  1747    if (mapAttr) {
  1748      SmallVector<Value *, 2> operands(getIndices());
  1749      p->printAffineMapOfSSAIds(mapAttr, operands);
  1750    }
  1751    *p << ']';
  1752    p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
  1753    *p << " : " << getMemRefType();
  1754  }
  1755  
  1756  LogicalResult AffineStoreOp::verify() {
  1757    // First operand must have same type as memref element type.
  1758    if (getValueToStore()->getType() != getMemRefType().getElementType())
  1759      return emitOpError("first operand must have same type memref element type");
  1760  
  1761    auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
  1762    if (mapAttr) {
  1763      AffineMap map = mapAttr.getValue();
  1764      if (map.getNumResults() != getMemRefType().getRank())
  1765        return emitOpError("affine.store affine map num results must equal"
  1766                           " memref rank");
  1767      if (map.getNumInputs() != getNumOperands() - 2)
  1768        return emitOpError("expects as many subscripts as affine map inputs");
  1769    } else {
  1770      if (getMemRefType().getRank() != getNumOperands() - 2)
  1771        return emitOpError(
  1772            "expects the number of subscripts to be equal to memref rank");
  1773    }
  1774  
  1775    for (auto *idx : getIndices()) {
  1776      if (!idx->getType().isIndex())
  1777        return emitOpError("index to store must have 'index' type");
  1778      if (!isValidAffineIndexOperand(idx))
  1779        return emitOpError("index must be a dimension or symbol identifier");
  1780    }
  1781    return success();
  1782  }
  1783  
  1784  void AffineStoreOp::getCanonicalizationPatterns(
  1785      OwningRewritePatternList &results, MLIRContext *context) {
  1786    /// load(memrefcast) -> load
  1787    results.insert<MemRefCastFolder>(getOperationName(), context);
  1788  }
  1789  
  1790  #define GET_OP_CLASSES
  1791  #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"