github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Analysis/Utils.cpp (about)

     1  //===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements miscellaneous analysis routines for non-loop IR
    19  // structures.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Analysis/Utils.h"
    24  
    25  #include "mlir/Analysis/AffineAnalysis.h"
    26  #include "mlir/Analysis/AffineStructures.h"
    27  #include "mlir/Dialect/AffineOps/AffineOps.h"
    28  #include "mlir/Dialect/StandardOps/Ops.h"
    29  #include "mlir/IR/Builders.h"
    30  #include "llvm/ADT/DenseMap.h"
    31  #include "llvm/ADT/SmallPtrSet.h"
    32  #include "llvm/Support/Debug.h"
    33  #include "llvm/Support/raw_ostream.h"
    34  
    35  #define DEBUG_TYPE "analysis-utils"
    36  
    37  using namespace mlir;
    38  
    39  using llvm::SmallDenseMap;
    40  
    41  /// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
    42  /// the outermost 'affine.for' operation to the innermost one.
    43  void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
    44    auto *currOp = op.getParentOp();
    45    AffineForOp currAffineForOp;
    46    // Traverse up the hierarchy collecing all 'affine.for' operation while
    47    // skipping over 'affine.if' operations.
    48    while (currOp && ((currAffineForOp = dyn_cast<AffineForOp>(currOp)) ||
    49                      isa<AffineIfOp>(currOp))) {
    50      if (currAffineForOp)
    51        loops->push_back(currAffineForOp);
    52      currOp = currOp->getParentOp();
    53    }
    54    std::reverse(loops->begin(), loops->end());
    55  }
    56  
    57  // Populates 'cst' with FlatAffineConstraints which represent slice bounds.
    58  LogicalResult
    59  ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
    60    assert(!lbOperands.empty());
    61    // Adds src 'ivs' as dimension identifiers in 'cst'.
    62    unsigned numDims = ivs.size();
    63    // Adds operands (dst ivs and symbols) as symbols in 'cst'.
    64    unsigned numSymbols = lbOperands[0].size();
    65  
    66    SmallVector<Value *, 4> values(ivs);
    67    // Append 'ivs' then 'operands' to 'values'.
    68    values.append(lbOperands[0].begin(), lbOperands[0].end());
    69    cst->reset(numDims, numSymbols, 0, values);
    70  
    71    // Add loop bound constraints for values which are loop IVs and equality
    72    // constraints for symbols which are constants.
    73    for (const auto &value : values) {
    74      assert(cst->containsId(*value) && "value expected to be present");
    75      if (isValidSymbol(value)) {
    76        // Check if the symbol is a constant.
    77  
    78        if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
    79          cst->setIdToConstant(*value, cOp.getValue());
    80      } else if (auto loop = getForInductionVarOwner(value)) {
    81        if (failed(cst->addAffineForOpDomain(loop)))
    82          return failure();
    83      }
    84    }
    85  
    86    // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
    87    LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
    88    assert(succeeded(ret) &&
    89           "should not fail as we never have semi-affine slice maps");
    90    (void)ret;
    91    return success();
    92  }
    93  
    94  // Clears state bounds and operand state.
    95  void ComputationSliceState::clearBounds() {
    96    lbs.clear();
    97    ubs.clear();
    98    lbOperands.clear();
    99    ubOperands.clear();
   100  }
   101  
   102  unsigned MemRefRegion::getRank() const {
   103    return memref->getType().cast<MemRefType>().getRank();
   104  }
   105  
   106  Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
   107      SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
   108      SmallVectorImpl<int64_t> *lbDivisors) const {
   109    auto memRefType = memref->getType().cast<MemRefType>();
   110    unsigned rank = memRefType.getRank();
   111    if (shape)
   112      shape->reserve(rank);
   113  
   114    assert(rank == cst.getNumDimIds() && "inconsistent memref region");
   115  
   116    // Find a constant upper bound on the extent of this memref region along each
   117    // dimension.
   118    int64_t numElements = 1;
   119    int64_t diffConstant;
   120    int64_t lbDivisor;
   121    for (unsigned d = 0; d < rank; d++) {
   122      SmallVector<int64_t, 4> lb;
   123      Optional<int64_t> diff = cst.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
   124      if (diff.hasValue()) {
   125        diffConstant = diff.getValue();
   126        assert(lbDivisor > 0);
   127      } else {
   128        // If no constant bound is found, then it can always be bound by the
   129        // memref's dim size if the latter has a constant size along this dim.
   130        auto dimSize = memRefType.getDimSize(d);
   131        if (dimSize == -1)
   132          return None;
   133        diffConstant = dimSize;
   134        // Lower bound becomes 0.
   135        lb.resize(cst.getNumSymbolIds() + 1, 0);
   136        lbDivisor = 1;
   137      }
   138      numElements *= diffConstant;
   139      if (lbs) {
   140        lbs->push_back(lb);
   141        assert(lbDivisors && "both lbs and lbDivisor or none");
   142        lbDivisors->push_back(lbDivisor);
   143      }
   144      if (shape) {
   145        shape->push_back(diffConstant);
   146      }
   147    }
   148    return numElements;
   149  }
   150  
   151  LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
   152    assert(memref == other.memref);
   153    return cst.unionBoundingBox(*other.getConstraints());
   154  }
   155  
   156  /// Computes the memory region accessed by this memref with the region
   157  /// represented as constraints symbolic/parameteric in 'loopDepth' loops
   158  /// surrounding opInst and any additional Function symbols.
   159  //  For example, the memref region for this load operation at loopDepth = 1 will
   160  //  be as below:
   161  //
   162  //    affine.for %i = 0 to 32 {
   163  //      affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
   164  //        load %A[%ii]
   165  //      }
   166  //    }
   167  //
   168  // region:  {memref = %A, write = false, {%i <= m0 <= %i + 7} }
   169  // The last field is a 2-d FlatAffineConstraints symbolic in %i.
   170  //
   171  // TODO(bondhugula): extend this to any other memref dereferencing ops
   172  // (dma_start, dma_wait).
   173  LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
   174                                      ComputationSliceState *sliceState,
   175                                      bool addMemRefDimBounds) {
   176    assert((isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) &&
   177           "affine load/store op expected");
   178  
   179    MemRefAccess access(op);
   180    memref = access.memref;
   181    write = access.isStore();
   182  
   183    unsigned rank = access.getRank();
   184  
   185    LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
   186                            << "depth: " << loopDepth << "\n";);
   187  
   188    if (rank == 0) {
   189      SmallVector<AffineForOp, 4> ivs;
   190      getLoopIVs(*op, &ivs);
   191      SmallVector<Value *, 8> regionSymbols;
   192      extractForInductionVars(ivs, &regionSymbols);
   193      // A rank 0 memref has a 0-d region.
   194      cst.reset(rank, loopDepth, 0, regionSymbols);
   195      return success();
   196    }
   197  
   198    // Build the constraints for this region.
   199    AffineValueMap accessValueMap;
   200    access.getAccessMap(&accessValueMap);
   201    AffineMap accessMap = accessValueMap.getAffineMap();
   202  
   203    unsigned numDims = accessMap.getNumDims();
   204    unsigned numSymbols = accessMap.getNumSymbols();
   205    unsigned numOperands = accessValueMap.getNumOperands();
   206    // Merge operands with slice operands.
   207    SmallVector<Value *, 4> operands;
   208    operands.resize(numOperands);
   209    for (unsigned i = 0; i < numOperands; ++i)
   210      operands[i] = accessValueMap.getOperand(i);
   211  
   212    if (sliceState != nullptr) {
   213      operands.reserve(operands.size() + sliceState->lbOperands[0].size());
   214      // Append slice operands to 'operands' as symbols.
   215      for (auto extraOperand : sliceState->lbOperands[0]) {
   216        if (!llvm::is_contained(operands, extraOperand)) {
   217          operands.push_back(extraOperand);
   218          numSymbols++;
   219        }
   220      }
   221    }
   222    // We'll first associate the dims and symbols of the access map to the dims
   223    // and symbols resp. of cst. This will change below once cst is
   224    // fully constructed out.
   225    cst.reset(numDims, numSymbols, 0, operands);
   226  
   227    // Add equality constraints.
   228    // Add inequalties for loop lower/upper bounds.
   229    for (unsigned i = 0; i < numDims + numSymbols; ++i) {
   230      auto *operand = operands[i];
   231      if (auto loop = getForInductionVarOwner(operand)) {
   232        // Note that cst can now have more dimensions than accessMap if the
   233        // bounds expressions involve outer loops or other symbols.
   234        // TODO(bondhugula): rewrite this to use getInstIndexSet; this way
   235        // conditionals will be handled when the latter supports it.
   236        if (failed(cst.addAffineForOpDomain(loop)))
   237          return failure();
   238      } else {
   239        // Has to be a valid symbol.
   240        auto *symbol = operand;
   241        assert(isValidSymbol(symbol));
   242        // Check if the symbol is a constant.
   243        if (auto *op = symbol->getDefiningOp()) {
   244          if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
   245            cst.setIdToConstant(*symbol, constOp.getValue());
   246          }
   247        }
   248      }
   249    }
   250  
   251    // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
   252    if (sliceState != nullptr) {
   253      // Add dim and symbol slice operands.
   254      for (auto operand : sliceState->lbOperands[0]) {
   255        cst.addInductionVarOrTerminalSymbol(operand);
   256      }
   257      // Add upper/lower bounds from 'sliceState' to 'cst'.
   258      LogicalResult ret =
   259          cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
   260                             sliceState->lbOperands[0]);
   261      assert(succeeded(ret) &&
   262             "should not fail as we never have semi-affine slice maps");
   263      (void)ret;
   264    }
   265  
   266    // Add access function equalities to connect loop IVs to data dimensions.
   267    if (failed(cst.composeMap(&accessValueMap))) {
   268      op->emitError("getMemRefRegion: compose affine map failed");
   269      LLVM_DEBUG(accessValueMap.getAffineMap().dump());
   270      return failure();
   271    }
   272  
   273    // Set all identifiers appearing after the first 'rank' identifiers as
   274    // symbolic identifiers - so that the ones corresponding to the memref
   275    // dimensions are the dimensional identifiers for the memref region.
   276    cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
   277  
   278    // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
   279    // this memref region is symbolic.
   280    SmallVector<AffineForOp, 4> enclosingIVs;
   281    getLoopIVs(*op, &enclosingIVs);
   282    assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
   283    enclosingIVs.resize(loopDepth);
   284    SmallVector<Value *, 4> ids;
   285    cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
   286    for (auto *id : ids) {
   287      AffineForOp iv;
   288      if ((iv = getForInductionVarOwner(id)) &&
   289          llvm::is_contained(enclosingIVs, iv) == false) {
   290        cst.projectOut(id);
   291      }
   292    }
   293  
   294    // Project out any local variables (these would have been added for any
   295    // mod/divs).
   296    cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
   297  
   298    // Constant fold any symbolic identifiers.
   299    cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
   300                            /*num=*/cst.getNumSymbolIds());
   301  
   302    assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format");
   303  
   304    // Add upper/lower bounds for each memref dimension with static size
   305    // to guard against potential over-approximation from projection.
   306    // TODO(andydavis) Support dynamic memref dimensions.
   307    if (addMemRefDimBounds) {
   308      auto memRefType = memref->getType().cast<MemRefType>();
   309      for (unsigned r = 0; r < rank; r++) {
   310        cst.addConstantLowerBound(r, 0);
   311        int64_t dimSize = memRefType.getDimSize(r);
   312        if (ShapedType::isDynamic(dimSize))
   313          continue;
   314        cst.addConstantUpperBound(r, dimSize - 1);
   315      }
   316    }
   317  
   318    LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
   319    LLVM_DEBUG(cst.dump());
   320    return success();
   321  }
   322  
   323  //  TODO(mlir-team): improve/complete this when we have target data.
   324  static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
   325    auto elementType = memRefType.getElementType();
   326  
   327    unsigned sizeInBits;
   328    if (elementType.isIntOrFloat()) {
   329      sizeInBits = elementType.getIntOrFloatBitWidth();
   330    } else {
   331      auto vectorType = elementType.cast<VectorType>();
   332      sizeInBits =
   333          vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
   334    }
   335    return llvm::divideCeil(sizeInBits, 8);
   336  }
   337  
   338  // Returns the size of the region.
   339  Optional<int64_t> MemRefRegion::getRegionSize() {
   340    auto memRefType = memref->getType().cast<MemRefType>();
   341  
   342    auto layoutMaps = memRefType.getAffineMaps();
   343    if (layoutMaps.size() > 1 ||
   344        (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
   345      LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
   346      return false;
   347    }
   348  
   349    // Indices to use for the DmaStart op.
   350    // Indices for the original memref being DMAed from/to.
   351    SmallVector<Value *, 4> memIndices;
   352    // Indices for the faster buffer being DMAed into/from.
   353    SmallVector<Value *, 4> bufIndices;
   354  
   355    // Compute the extents of the buffer.
   356    Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
   357    if (!numElements.hasValue()) {
   358      LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
   359      return None;
   360    }
   361    return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
   362  }
   363  
   364  /// Returns the size of memref data in bytes if it's statically shaped, None
   365  /// otherwise.  If the element of the memref has vector type, takes into account
   366  /// size of the vector as well.
   367  //  TODO(mlir-team): improve/complete this when we have target data.
   368  Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
   369    if (!memRefType.hasStaticShape())
   370      return None;
   371    auto elementType = memRefType.getElementType();
   372    if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
   373      return None;
   374  
   375    uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
   376    for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
   377      sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
   378    }
   379    return sizeInBytes;
   380  }
   381  
   382  template <typename LoadOrStoreOpPointer>
   383  LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
   384                                              bool emitError) {
   385    static_assert(std::is_same<LoadOrStoreOpPointer, AffineLoadOp>::value ||
   386                      std::is_same<LoadOrStoreOpPointer, AffineStoreOp>::value,
   387                  "argument should be either a AffineLoadOp or a AffineStoreOp");
   388  
   389    Operation *opInst = loadOrStoreOp.getOperation();
   390    MemRefRegion region(opInst->getLoc());
   391    if (failed(region.compute(opInst, /*loopDepth=*/0, /*sliceState=*/nullptr,
   392                              /*addMemRefDimBounds=*/false)))
   393      return success();
   394  
   395    LLVM_DEBUG(llvm::dbgs() << "Memory region");
   396    LLVM_DEBUG(region.getConstraints()->dump());
   397  
   398    bool outOfBounds = false;
   399    unsigned rank = loadOrStoreOp.getMemRefType().getRank();
   400  
   401    // For each dimension, check for out of bounds.
   402    for (unsigned r = 0; r < rank; r++) {
   403      FlatAffineConstraints ucst(*region.getConstraints());
   404  
   405      // Intersect memory region with constraint capturing out of bounds (both out
   406      // of upper and out of lower), and check if the constraint system is
   407      // feasible. If it is, there is at least one point out of bounds.
   408      SmallVector<int64_t, 4> ineq(rank + 1, 0);
   409      int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
   410      // TODO(bondhugula): handle dynamic dim sizes.
   411      if (dimSize == -1)
   412        continue;
   413  
   414      // Check for overflow: d_i >= memref dim size.
   415      ucst.addConstantLowerBound(r, dimSize);
   416      outOfBounds = !ucst.isEmpty();
   417      if (outOfBounds && emitError) {
   418        loadOrStoreOp.emitOpError()
   419            << "memref out of upper bound access along dimension #" << (r + 1);
   420      }
   421  
   422      // Check for a negative index.
   423      FlatAffineConstraints lcst(*region.getConstraints());
   424      std::fill(ineq.begin(), ineq.end(), 0);
   425      // d_i <= -1;
   426      lcst.addConstantUpperBound(r, -1);
   427      outOfBounds = !lcst.isEmpty();
   428      if (outOfBounds && emitError) {
   429        loadOrStoreOp.emitOpError()
   430            << "memref out of lower bound access along dimension #" << (r + 1);
   431      }
   432    }
   433    return failure(outOfBounds);
   434  }
   435  
   436  // Explicitly instantiate the template so that the compiler knows we need them!
   437  template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp,
   438                                                       bool emitError);
   439  template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp,
   440                                                       bool emitError);
   441  
   442  // Returns in 'positions' the Block positions of 'op' in each ancestor
   443  // Block from the Block containing operation, stopping at 'limitBlock'.
   444  static void findInstPosition(Operation *op, Block *limitBlock,
   445                               SmallVectorImpl<unsigned> *positions) {
   446    Block *block = op->getBlock();
   447    while (block != limitBlock) {
   448      // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
   449      // rely on linear scans.
   450      int instPosInBlock = std::distance(block->begin(), op->getIterator());
   451      positions->push_back(instPosInBlock);
   452      op = block->getParentOp();
   453      block = op->getBlock();
   454    }
   455    std::reverse(positions->begin(), positions->end());
   456  }
   457  
   458  // Returns the Operation in a possibly nested set of Blocks, where the
   459  // position of the operation is represented by 'positions', which has a
   460  // Block position for each level of nesting.
   461  static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
   462                                      unsigned level, Block *block) {
   463    unsigned i = 0;
   464    for (auto &op : *block) {
   465      if (i != positions[level]) {
   466        ++i;
   467        continue;
   468      }
   469      if (level == positions.size() - 1)
   470        return &op;
   471      if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
   472        return getInstAtPosition(positions, level + 1,
   473                                 childAffineForOp.getBody());
   474  
   475      for (auto &region : op.getRegions()) {
   476        for (auto &b : region)
   477          if (auto *ret = getInstAtPosition(positions, level + 1, &b))
   478            return ret;
   479      }
   480      return nullptr;
   481    }
   482    return nullptr;
   483  }
   484  
   485  // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
   486  LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value *, 8> &ivs,
   487                                       FlatAffineConstraints *cst) {
   488    for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
   489      auto *value = cst->getIdValue(i);
   490      if (ivs.count(value) == 0) {
   491        assert(isForInductionVar(value));
   492        auto loop = getForInductionVarOwner(value);
   493        if (failed(cst->addAffineForOpDomain(loop)))
   494          return failure();
   495      }
   496    }
   497    return success();
   498  }
   499  
   500  // Returns the innermost common loop depth for the set of operations in 'ops'.
   501  // TODO(andydavis) Move this to LoopUtils.
   502  static unsigned
   503  getInnermostCommonLoopDepth(ArrayRef<Operation *> ops,
   504                              SmallVectorImpl<AffineForOp> &surroundingLoops) {
   505    unsigned numOps = ops.size();
   506    assert(numOps > 0);
   507  
   508    std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
   509    unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
   510    for (unsigned i = 0; i < numOps; ++i) {
   511      getLoopIVs(*ops[i], &loops[i]);
   512      loopDepthLimit =
   513          std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
   514    }
   515  
   516    unsigned loopDepth = 0;
   517    for (unsigned d = 0; d < loopDepthLimit; ++d) {
   518      unsigned i;
   519      for (i = 1; i < numOps; ++i) {
   520        if (loops[i - 1][d] != loops[i][d])
   521          return loopDepth;
   522      }
   523      surroundingLoops.push_back(loops[i - 1][d]);
   524      ++loopDepth;
   525    }
   526    return loopDepth;
   527  }
   528  
   529  /// Computes in 'sliceUnion' the union of all slice bounds computed at
   530  /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
   531  /// Returns 'Success' if union was computed, 'failure' otherwise.
   532  LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
   533                                        ArrayRef<Operation *> opsB,
   534                                        unsigned loopDepth,
   535                                        unsigned numCommonLoops,
   536                                        bool isBackwardSlice,
   537                                        ComputationSliceState *sliceUnion) {
   538    // Compute the union of slice bounds between all pairs in 'opsA' and
   539    // 'opsB' in 'sliceUnionCst'.
   540    FlatAffineConstraints sliceUnionCst;
   541    assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
   542    std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
   543    for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) {
   544      MemRefAccess srcAccess(opsA[i]);
   545      for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) {
   546        MemRefAccess dstAccess(opsB[j]);
   547        if (srcAccess.memref != dstAccess.memref)
   548          continue;
   549        // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
   550        if ((!isBackwardSlice && loopDepth > getNestingDepth(*opsA[i])) ||
   551            (isBackwardSlice && loopDepth > getNestingDepth(*opsB[j]))) {
   552          LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n.");
   553          return failure();
   554        }
   555  
   556        bool readReadAccesses = isa<AffineLoadOp>(srcAccess.opInst) &&
   557                                isa<AffineLoadOp>(dstAccess.opInst);
   558        FlatAffineConstraints dependenceConstraints;
   559        // Check dependence between 'srcAccess' and 'dstAccess'.
   560        DependenceResult result = checkMemrefAccessDependence(
   561            srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
   562            &dependenceConstraints, /*dependenceComponents=*/nullptr,
   563            /*allowRAR=*/readReadAccesses);
   564        if (result.value == DependenceResult::Failure) {
   565          LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n.");
   566          return failure();
   567        }
   568        if (result.value == DependenceResult::NoDependence)
   569          continue;
   570        dependentOpPairs.push_back({opsA[i], opsB[j]});
   571  
   572        // Compute slice bounds for 'srcAccess' and 'dstAccess'.
   573        ComputationSliceState tmpSliceState;
   574        mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints,
   575                                       loopDepth, isBackwardSlice,
   576                                       &tmpSliceState);
   577  
   578        if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
   579          // Initialize 'sliceUnionCst' with the bounds computed in previous step.
   580          if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
   581            LLVM_DEBUG(llvm::dbgs()
   582                       << "Unable to compute slice bound constraints\n.");
   583            return failure();
   584          }
   585          assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
   586          continue;
   587        }
   588  
   589        // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
   590        FlatAffineConstraints tmpSliceCst;
   591        if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
   592          LLVM_DEBUG(llvm::dbgs()
   593                     << "Unable to compute slice bound constraints\n.");
   594          return failure();
   595        }
   596  
   597        // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
   598        if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
   599  
   600          // Pre-constraint id alignment: record loop IVs used in each constraint
   601          // system.
   602          SmallPtrSet<Value *, 8> sliceUnionIVs;
   603          for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
   604            sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
   605          SmallPtrSet<Value *, 8> tmpSliceIVs;
   606          for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
   607            tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
   608  
   609          sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
   610  
   611          // Post-constraint id alignment: add loop IV bounds missing after
   612          // id alignment to constraint systems. This can occur if one constraint
   613          // system uses an loop IV that is not used by the other. The call
   614          // to unionBoundingBox below expects constraints for each Loop IV, even
   615          // if they are the unsliced full loop bounds added here.
   616          if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
   617            return failure();
   618          if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
   619            return failure();
   620        }
   621        // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
   622        if (failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
   623          LLVM_DEBUG(llvm::dbgs()
   624                     << "Unable to compute union bounding box of slice bounds."
   625                        "\n.");
   626          return failure();
   627        }
   628      }
   629    }
   630  
   631    // Empty union.
   632    if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
   633      return failure();
   634  
   635    // Gather loops surrounding ops from loop nest where slice will be inserted.
   636    SmallVector<Operation *, 4> ops;
   637    for (auto &dep : dependentOpPairs) {
   638      ops.push_back(isBackwardSlice ? dep.second : dep.first);
   639    }
   640    SmallVector<AffineForOp, 4> surroundingLoops;
   641    unsigned innermostCommonLoopDepth =
   642        getInnermostCommonLoopDepth(ops, surroundingLoops);
   643    if (loopDepth > innermostCommonLoopDepth) {
   644      LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n.");
   645      return failure();
   646    }
   647  
   648    // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
   649    unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds();
   650  
   651    // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
   652    sliceUnionCst.convertLoopIVSymbolsToDims();
   653    sliceUnion->clearBounds();
   654    sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
   655    sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
   656  
   657    // Get slice bounds from slice union constraints 'sliceUnionCst'.
   658    sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
   659                                 opsA[0]->getContext(), &sliceUnion->lbs,
   660                                 &sliceUnion->ubs);
   661  
   662    // Add slice bound operands of union.
   663    SmallVector<Value *, 4> sliceBoundOperands;
   664    sliceUnionCst.getIdValues(numSliceLoopIVs,
   665                              sliceUnionCst.getNumDimAndSymbolIds(),
   666                              &sliceBoundOperands);
   667  
   668    // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
   669    sliceUnion->ivs.clear();
   670    sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs);
   671  
   672    // Set loop nest insertion point to block start at 'loopDepth'.
   673    sliceUnion->insertPoint =
   674        isBackwardSlice
   675            ? surroundingLoops[loopDepth - 1].getBody()->begin()
   676            : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
   677  
   678    // Give each bound its own copy of 'sliceBoundOperands' for subsequent
   679    // canonicalization.
   680    sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
   681    sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
   682    return success();
   683  }
   684  
   685  const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
   686  // Computes slice bounds by projecting out any loop IVs from
   687  // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
   688  // bounds in 'sliceState' which represent the one loop nest's IVs in terms of
   689  // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
   690  void mlir::getComputationSliceState(
   691      Operation *depSourceOp, Operation *depSinkOp,
   692      FlatAffineConstraints *dependenceConstraints, unsigned loopDepth,
   693      bool isBackwardSlice, ComputationSliceState *sliceState) {
   694    // Get loop nest surrounding src operation.
   695    SmallVector<AffineForOp, 4> srcLoopIVs;
   696    getLoopIVs(*depSourceOp, &srcLoopIVs);
   697    unsigned numSrcLoopIVs = srcLoopIVs.size();
   698  
   699    // Get loop nest surrounding dst operation.
   700    SmallVector<AffineForOp, 4> dstLoopIVs;
   701    getLoopIVs(*depSinkOp, &dstLoopIVs);
   702    unsigned numDstLoopIVs = dstLoopIVs.size();
   703  
   704    assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
   705           (isBackwardSlice && loopDepth <= numDstLoopIVs));
   706  
   707    // Project out dimensions other than those up to 'loopDepth'.
   708    unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
   709    unsigned num =
   710        isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
   711    dependenceConstraints->projectOut(pos, num);
   712  
   713    // Add slice loop IV values to 'sliceState'.
   714    unsigned offset = isBackwardSlice ? 0 : loopDepth;
   715    unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
   716    dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs,
   717                                       &sliceState->ivs);
   718  
   719    // Set up lower/upper bound affine maps for the slice.
   720    sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
   721    sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
   722  
   723    // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
   724    dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
   725                                          depSourceOp->getContext(),
   726                                          &sliceState->lbs, &sliceState->ubs);
   727  
   728    // Set up bound operands for the slice's lower and upper bounds.
   729    SmallVector<Value *, 4> sliceBoundOperands;
   730    unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
   731    for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
   732      if (i < offset || i >= offset + numSliceLoopIVs) {
   733        sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i));
   734      }
   735    }
   736  
   737    // Give each bound its own copy of 'sliceBoundOperands' for subsequent
   738    // canonicalization.
   739    sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
   740    sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
   741  
   742    // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
   743    sliceState->insertPoint =
   744        isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
   745                        : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
   746  
   747    llvm::SmallDenseSet<Value *, 8> sequentialLoops;
   748    if (isa<AffineLoadOp>(depSourceOp) && isa<AffineLoadOp>(depSinkOp)) {
   749      // For read-read access pairs, clear any slice bounds on sequential loops.
   750      // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
   751      getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
   752                         &sequentialLoops);
   753    }
   754    // Clear all sliced loop bounds beginning at the first sequential loop, or
   755    // first loop with a slice fusion barrier attribute..
   756    // TODO(andydavis, bondhugula) Use MemRef read/write regions instead of
   757    // using 'kSliceFusionBarrierAttrName'.
   758    auto getSliceLoop = [&](unsigned i) {
   759      return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
   760    };
   761    for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
   762      Value *iv = getSliceLoop(i).getInductionVar();
   763      if (sequentialLoops.count(iv) == 0 &&
   764          getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr)
   765        continue;
   766      for (unsigned j = i; j < numSliceLoopIVs; ++j) {
   767        sliceState->lbs[j] = AffineMap();
   768        sliceState->ubs[j] = AffineMap();
   769      }
   770      break;
   771    }
   772  }
   773  
   774  /// Creates a computation slice of the loop nest surrounding 'srcOpInst',
   775  /// updates the slice loop bounds with any non-null bound maps specified in
   776  /// 'sliceState', and inserts this slice into the loop nest surrounding
   777  /// 'dstOpInst' at loop depth 'dstLoopDepth'.
   778  // TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
   779  // aren't necessarily a one-to-one relation b/w the source and destination. The
   780  // relation between the source and destination could be many-to-many in general.
   781  // TODO(andydavis,bondhugula): the slice computation is incorrect in the cases
   782  // where the dependence from the source to the destination does not cover the
   783  // entire destination index set. Subtract out the dependent destination
   784  // iterations from destination index set and check for emptiness --- this is one
   785  // solution.
   786  AffineForOp
   787  mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
   788                                       unsigned dstLoopDepth,
   789                                       ComputationSliceState *sliceState) {
   790    // Get loop nest surrounding src operation.
   791    SmallVector<AffineForOp, 4> srcLoopIVs;
   792    getLoopIVs(*srcOpInst, &srcLoopIVs);
   793    unsigned numSrcLoopIVs = srcLoopIVs.size();
   794  
   795    // Get loop nest surrounding dst operation.
   796    SmallVector<AffineForOp, 4> dstLoopIVs;
   797    getLoopIVs(*dstOpInst, &dstLoopIVs);
   798    unsigned dstLoopIVsSize = dstLoopIVs.size();
   799    if (dstLoopDepth > dstLoopIVsSize) {
   800      dstOpInst->emitError("invalid destination loop depth");
   801      return AffineForOp();
   802    }
   803  
   804    // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
   805    SmallVector<unsigned, 4> positions;
   806    // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
   807    findInstPosition(srcOpInst, srcLoopIVs[0].getOperation()->getBlock(),
   808                     &positions);
   809  
   810    // Clone src loop nest and insert it a the beginning of the operation block
   811    // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
   812    auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
   813    OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
   814    auto sliceLoopNest =
   815        cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
   816  
   817    Operation *sliceInst =
   818        getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
   819    // Get loop nest surrounding 'sliceInst'.
   820    SmallVector<AffineForOp, 4> sliceSurroundingLoops;
   821    getLoopIVs(*sliceInst, &sliceSurroundingLoops);
   822  
   823    // Sanity check.
   824    unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
   825    (void)sliceSurroundingLoopsSize;
   826    assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
   827    unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
   828    (void)sliceLoopLimit;
   829    assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
   830  
   831    // Update loop bounds for loops in 'sliceLoopNest'.
   832    for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
   833      auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
   834      if (AffineMap lbMap = sliceState->lbs[i])
   835        forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
   836      if (AffineMap ubMap = sliceState->ubs[i])
   837        forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
   838    }
   839    return sliceLoopNest;
   840  }
   841  
   842  // Constructs  MemRefAccess populating it with the memref, its indices and
   843  // opinst from 'loadOrStoreOpInst'.
   844  MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
   845    if (auto loadOp = dyn_cast<AffineLoadOp>(loadOrStoreOpInst)) {
   846      memref = loadOp.getMemRef();
   847      opInst = loadOrStoreOpInst;
   848      auto loadMemrefType = loadOp.getMemRefType();
   849      indices.reserve(loadMemrefType.getRank());
   850      for (auto *index : loadOp.getIndices()) {
   851        indices.push_back(index);
   852      }
   853    } else {
   854      assert(isa<AffineStoreOp>(loadOrStoreOpInst) && "load/store op expected");
   855      auto storeOp = dyn_cast<AffineStoreOp>(loadOrStoreOpInst);
   856      opInst = loadOrStoreOpInst;
   857      memref = storeOp.getMemRef();
   858      auto storeMemrefType = storeOp.getMemRefType();
   859      indices.reserve(storeMemrefType.getRank());
   860      for (auto *index : storeOp.getIndices()) {
   861        indices.push_back(index);
   862      }
   863    }
   864  }
   865  
   866  unsigned MemRefAccess::getRank() const {
   867    return memref->getType().cast<MemRefType>().getRank();
   868  }
   869  
   870  bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); }
   871  
   872  /// Returns the nesting depth of this statement, i.e., the number of loops
   873  /// surrounding this statement.
   874  unsigned mlir::getNestingDepth(Operation &op) {
   875    Operation *currOp = &op;
   876    unsigned depth = 0;
   877    while ((currOp = currOp->getParentOp())) {
   878      if (isa<AffineForOp>(currOp))
   879        depth++;
   880    }
   881    return depth;
   882  }
   883  
   884  /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
   885  /// where each lists loops from outer-most to inner-most in loop nest.
   886  unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
   887    SmallVector<AffineForOp, 4> loopsA, loopsB;
   888    getLoopIVs(A, &loopsA);
   889    getLoopIVs(B, &loopsB);
   890  
   891    unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
   892    unsigned numCommonLoops = 0;
   893    for (unsigned i = 0; i < minNumLoops; ++i) {
   894      if (loopsA[i].getOperation() != loopsB[i].getOperation())
   895        break;
   896      ++numCommonLoops;
   897    }
   898    return numCommonLoops;
   899  }
   900  
   901  static Optional<int64_t> getMemoryFootprintBytes(Block &block,
   902                                                   Block::iterator start,
   903                                                   Block::iterator end,
   904                                                   int memorySpace) {
   905    SmallDenseMap<Value *, std::unique_ptr<MemRefRegion>, 4> regions;
   906  
   907    // Walk this 'affine.for' operation to gather all memory regions.
   908    auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
   909      if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) {
   910        // Neither load nor a store op.
   911        return WalkResult::advance();
   912      }
   913  
   914      // Compute the memref region symbolic in any IVs enclosing this block.
   915      auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
   916      if (failed(
   917              region->compute(opInst,
   918                              /*loopDepth=*/getNestingDepth(*block.begin())))) {
   919        return opInst->emitError("Error obtaining memory region\n");
   920      }
   921  
   922      auto it = regions.find(region->memref);
   923      if (it == regions.end()) {
   924        regions[region->memref] = std::move(region);
   925      } else if (failed(it->second->unionBoundingBox(*region))) {
   926        return opInst->emitWarning(
   927            "getMemoryFootprintBytes: unable to perform a union on a memory "
   928            "region");
   929      }
   930      return WalkResult::advance();
   931    });
   932    if (result.wasInterrupted())
   933      return None;
   934  
   935    int64_t totalSizeInBytes = 0;
   936    for (const auto &region : regions) {
   937      Optional<int64_t> size = region.second->getRegionSize();
   938      if (!size.hasValue())
   939        return None;
   940      totalSizeInBytes += size.getValue();
   941    }
   942    return totalSizeInBytes;
   943  }
   944  
   945  Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
   946                                                  int memorySpace) {
   947    auto *forInst = forOp.getOperation();
   948    return ::getMemoryFootprintBytes(
   949        *forInst->getBlock(), Block::iterator(forInst),
   950        std::next(Block::iterator(forInst)), memorySpace);
   951  }
   952  
   953  /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
   954  /// at 'forOp'.
   955  void mlir::getSequentialLoops(
   956      AffineForOp forOp, llvm::SmallDenseSet<Value *, 8> *sequentialLoops) {
   957    forOp.getOperation()->walk([&](Operation *op) {
   958      if (auto innerFor = dyn_cast<AffineForOp>(op))
   959        if (!isLoopParallel(innerFor))
   960          sequentialLoops->insert(innerFor.getInductionVar());
   961    });
   962  }
   963  
   964  /// Returns true if 'forOp' is parallel.
   965  bool mlir::isLoopParallel(AffineForOp forOp) {
   966    // Collect all load and store ops in loop nest rooted at 'forOp'.
   967    SmallVector<Operation *, 8> loadAndStoreOpInsts;
   968    auto walkResult = forOp.walk([&](Operation *opInst) {
   969      if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
   970        loadAndStoreOpInsts.push_back(opInst);
   971      else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
   972               !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect())
   973        return WalkResult::interrupt();
   974  
   975      return WalkResult::advance();
   976    });
   977  
   978    // Stop early if the loop has unknown ops with side effects.
   979    if (walkResult.wasInterrupted())
   980      return false;
   981  
   982    // Dep check depth would be number of enclosing loops + 1.
   983    unsigned depth = getNestingDepth(*forOp.getOperation()) + 1;
   984  
   985    // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'.
   986    for (auto *srcOpInst : loadAndStoreOpInsts) {
   987      MemRefAccess srcAccess(srcOpInst);
   988      for (auto *dstOpInst : loadAndStoreOpInsts) {
   989        MemRefAccess dstAccess(dstOpInst);
   990        FlatAffineConstraints dependenceConstraints;
   991        DependenceResult result = checkMemrefAccessDependence(
   992            srcAccess, dstAccess, depth, &dependenceConstraints,
   993            /*dependenceComponents=*/nullptr);
   994        if (result.value != DependenceResult::NoDependence)
   995          return false;
   996      }
   997    }
   998    return true;
   999  }