github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Analysis/AffineAnalysis.cpp (about)

     1  //===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements miscellaneous analysis routines for affine structures
    19  // (expressions, maps, sets), and other utilities relying on such analysis.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Analysis/AffineAnalysis.h"
    24  #include "mlir/Analysis/AffineStructures.h"
    25  #include "mlir/Analysis/Utils.h"
    26  #include "mlir/Dialect/AffineOps/AffineOps.h"
    27  #include "mlir/Dialect/StandardOps/Ops.h"
    28  #include "mlir/IR/AffineExprVisitor.h"
    29  #include "mlir/IR/Builders.h"
    30  #include "mlir/IR/Function.h"
    31  #include "mlir/IR/IntegerSet.h"
    32  #include "mlir/IR/Operation.h"
    33  #include "mlir/Support/MathExtras.h"
    34  #include "mlir/Support/STLExtras.h"
    35  #include "llvm/ADT/DenseMap.h"
    36  #include "llvm/Support/Debug.h"
    37  #include "llvm/Support/raw_ostream.h"
    38  
    39  #define DEBUG_TYPE "affine-analysis"
    40  
    41  using namespace mlir;
    42  
    43  using llvm::dbgs;
    44  
    45  /// Returns the sequence of AffineApplyOp Operations operation in
    46  /// 'affineApplyOps', which are reachable via a search starting from 'operands',
    47  /// and ending at operands which are not defined by AffineApplyOps.
    48  // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
    49  // the AffineApplyOp into any user AffineApplyOps.
    50  void mlir::getReachableAffineApplyOps(
    51      ArrayRef<Value *> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
    52    struct State {
    53      // The ssa value for this node in the DFS traversal.
    54      Value *value;
    55      // The operand index of 'value' to explore next during DFS traversal.
    56      unsigned operandIndex;
    57    };
    58    SmallVector<State, 4> worklist;
    59    for (auto *operand : operands) {
    60      worklist.push_back({operand, 0});
    61    }
    62  
    63    while (!worklist.empty()) {
    64      State &state = worklist.back();
    65      auto *opInst = state.value->getDefiningOp();
    66      // Note: getDefiningOp will return nullptr if the operand is not an
    67      // Operation (i.e. block argument), which is a terminator for the search.
    68      if (!isa_and_nonnull<AffineApplyOp>(opInst)) {
    69        worklist.pop_back();
    70        continue;
    71      }
    72  
    73      if (state.operandIndex == 0) {
    74        // Pre-Visit: Add 'opInst' to reachable sequence.
    75        affineApplyOps.push_back(opInst);
    76      }
    77      if (state.operandIndex < opInst->getNumOperands()) {
    78        // Visit: Add next 'affineApplyOp' operand to worklist.
    79        // Get next operand to visit at 'operandIndex'.
    80        auto *nextOperand = opInst->getOperand(state.operandIndex);
    81        // Increment 'operandIndex' in 'state'.
    82        ++state.operandIndex;
    83        // Add 'nextOperand' to worklist.
    84        worklist.push_back({nextOperand, 0});
    85      } else {
    86        // Post-visit: done visiting operands AffineApplyOp, pop off stack.
    87        worklist.pop_back();
    88      }
    89    }
    90  }
    91  
    92  // Builds a system of constraints with dimensional identifiers corresponding to
    93  // the loop IVs of the forOps appearing in that order. Any symbols founds in
    94  // the bound operands are added as symbols in the system. Returns failure for
    95  // the yet unimplemented cases.
    96  // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
    97  // stride information in FlatAffineConstraints. (For eg., by using iv - lb %
    98  // step = 0 and/or by introducing a method in FlatAffineConstraints
    99  // setExprStride(ArrayRef<int64_t> expr, int64_t stride)
   100  LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps,
   101                                  FlatAffineConstraints *domain) {
   102    SmallVector<Value *, 4> indices;
   103    extractForInductionVars(forOps, &indices);
   104    // Reset while associated Values in 'indices' to the domain.
   105    domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
   106    for (auto forOp : forOps) {
   107      // Add constraints from forOp's bounds.
   108      if (failed(domain->addAffineForOpDomain(forOp)))
   109        return failure();
   110    }
   111    return success();
   112  }
   113  
   114  // Computes the iteration domain for 'opInst' and populates 'indexSet', which
   115  // encapsulates the constraints involving loops surrounding 'opInst' and
   116  // potentially involving any Function symbols. The dimensional identifiers in
   117  // 'indexSet' correspond to the loops surounding 'op' from outermost to
   118  // innermost.
   119  // TODO(andydavis) Add support to handle IfInsts surrounding 'op'.
   120  static LogicalResult getInstIndexSet(Operation *op,
   121                                       FlatAffineConstraints *indexSet) {
   122    // TODO(andydavis) Extend this to gather enclosing IfInsts and consider
   123    // factoring it out into a utility function.
   124    SmallVector<AffineForOp, 4> loops;
   125    getLoopIVs(*op, &loops);
   126    return getIndexSet(loops, indexSet);
   127  }
   128  
   129  // ValuePositionMap manages the mapping from Values which represent dimension
   130  // and symbol identifiers from 'src' and 'dst' access functions to positions
   131  // in new space where some Values are kept separate (using addSrc/DstValue)
   132  // and some Values are merged (addSymbolValue).
   133  // Position lookups return the absolute position in the new space which
   134  // has the following format:
   135  //
   136  //   [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifers]
   137  //
   138  // Note: access function non-IV dimension identifiers (that have 'dimension'
   139  // positions in the access function position space) are assigned as symbols
   140  // in the output position space. Convienience access functions which lookup
   141  // an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
   142  // the common case of resolving positions for all access function operands.
   143  //
   144  // TODO(andydavis) Generalize this: could take a template parameter for
   145  // the number of maps (3 in the current case), and lookups could take indices
   146  // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})".
   147  class ValuePositionMap {
   148  public:
   149    void addSrcValue(Value *value) {
   150      if (addValueAt(value, &srcDimPosMap, numSrcDims))
   151        ++numSrcDims;
   152    }
   153    void addDstValue(Value *value) {
   154      if (addValueAt(value, &dstDimPosMap, numDstDims))
   155        ++numDstDims;
   156    }
   157    void addSymbolValue(Value *value) {
   158      if (addValueAt(value, &symbolPosMap, numSymbols))
   159        ++numSymbols;
   160    }
   161    unsigned getSrcDimOrSymPos(Value *value) const {
   162      return getDimOrSymPos(value, srcDimPosMap, 0);
   163    }
   164    unsigned getDstDimOrSymPos(Value *value) const {
   165      return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
   166    }
   167    unsigned getSymPos(Value *value) const {
   168      auto it = symbolPosMap.find(value);
   169      assert(it != symbolPosMap.end());
   170      return numSrcDims + numDstDims + it->second;
   171    }
   172  
   173    unsigned getNumSrcDims() const { return numSrcDims; }
   174    unsigned getNumDstDims() const { return numDstDims; }
   175    unsigned getNumDims() const { return numSrcDims + numDstDims; }
   176    unsigned getNumSymbols() const { return numSymbols; }
   177  
   178  private:
   179    bool addValueAt(Value *value, DenseMap<Value *, unsigned> *posMap,
   180                    unsigned position) {
   181      auto it = posMap->find(value);
   182      if (it == posMap->end()) {
   183        (*posMap)[value] = position;
   184        return true;
   185      }
   186      return false;
   187    }
   188    unsigned getDimOrSymPos(Value *value,
   189                            const DenseMap<Value *, unsigned> &dimPosMap,
   190                            unsigned dimPosOffset) const {
   191      auto it = dimPosMap.find(value);
   192      if (it != dimPosMap.end()) {
   193        return dimPosOffset + it->second;
   194      }
   195      it = symbolPosMap.find(value);
   196      assert(it != symbolPosMap.end());
   197      return numSrcDims + numDstDims + it->second;
   198    }
   199  
   200    unsigned numSrcDims = 0;
   201    unsigned numDstDims = 0;
   202    unsigned numSymbols = 0;
   203    DenseMap<Value *, unsigned> srcDimPosMap;
   204    DenseMap<Value *, unsigned> dstDimPosMap;
   205    DenseMap<Value *, unsigned> symbolPosMap;
   206  };
   207  
   208  // Builds a map from Value to identifier position in a new merged identifier
   209  // list, which is the result of merging dim/symbol lists from src/dst
   210  // iteration domains, the format of which is as follows:
   211  //
   212  //   [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
   213  //
   214  // This method populates 'valuePosMap' with mappings from operand Values in
   215  // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
   216  // to the position of these values in the merged list.
   217  static void buildDimAndSymbolPositionMaps(
   218      const FlatAffineConstraints &srcDomain,
   219      const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
   220      const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
   221      FlatAffineConstraints *dependenceConstraints) {
   222    auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
   223      for (unsigned i = 0, e = values.size(); i < e; ++i) {
   224        auto *value = values[i];
   225        if (!isForInductionVar(values[i])) {
   226          assert(isValidSymbol(values[i]) &&
   227                 "access operand has to be either a loop IV or a symbol");
   228          valuePosMap->addSymbolValue(value);
   229        } else if (isSrc) {
   230          valuePosMap->addSrcValue(value);
   231        } else {
   232          valuePosMap->addDstValue(value);
   233        }
   234      }
   235    };
   236  
   237    SmallVector<Value *, 4> srcValues, destValues;
   238    srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues);
   239    dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues);
   240    // Update value position map with identifiers from src iteration domain.
   241    updateValuePosMap(srcValues, /*isSrc=*/true);
   242    // Update value position map with identifiers from dst iteration domain.
   243    updateValuePosMap(destValues, /*isSrc=*/false);
   244    // Update value position map with identifiers from src access function.
   245    updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true);
   246    // Update value position map with identifiers from dst access function.
   247    updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
   248  }
   249  
   250  // Sets up dependence constraints columns appropriately, in the format:
   251  // [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
   252  void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
   253                                 const FlatAffineConstraints &dstDomain,
   254                                 const AffineValueMap &srcAccessMap,
   255                                 const AffineValueMap &dstAccessMap,
   256                                 const ValuePositionMap &valuePosMap,
   257                                 FlatAffineConstraints *dependenceConstraints) {
   258    // Calculate number of equalities/inequalities and columns required to
   259    // initialize FlatAffineConstraints for 'dependenceDomain'.
   260    unsigned numIneq =
   261        srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
   262    AffineMap srcMap = srcAccessMap.getAffineMap();
   263    assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
   264    unsigned numEq = srcMap.getNumResults();
   265    unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
   266    unsigned numSymbols = valuePosMap.getNumSymbols();
   267    unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds();
   268    unsigned numIds = numDims + numSymbols + numLocals;
   269    unsigned numCols = numIds + 1;
   270  
   271    // Set flat affine constraints sizes and reserving space for constraints.
   272    dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
   273                                 numLocals);
   274  
   275    // Set values corresponding to dependence constraint identifiers.
   276    SmallVector<Value *, 4> srcLoopIVs, dstLoopIVs;
   277    srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
   278    dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
   279  
   280    dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs);
   281    dependenceConstraints->setIdValues(
   282        srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
   283  
   284    // Set values for the symbolic identifier dimensions.
   285    auto setSymbolIds = [&](ArrayRef<Value *> values) {
   286      for (auto *value : values) {
   287        if (!isForInductionVar(value)) {
   288          assert(isValidSymbol(value) && "expected symbol");
   289          dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
   290        }
   291      }
   292    };
   293  
   294    setSymbolIds(srcAccessMap.getOperands());
   295    setSymbolIds(dstAccessMap.getOperands());
   296  
   297    SmallVector<Value *, 8> srcSymbolValues, dstSymbolValues;
   298    srcDomain.getIdValues(srcDomain.getNumDimIds(),
   299                          srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
   300    dstDomain.getIdValues(dstDomain.getNumDimIds(),
   301                          dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
   302    setSymbolIds(srcSymbolValues);
   303    setSymbolIds(dstSymbolValues);
   304  
   305    for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
   306         i < e; i++)
   307      assert(dependenceConstraints->getIds()[i].hasValue());
   308  }
   309  
   310  // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
   311  // 'dependenceDomain'.
   312  // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
   313  // srcDomain/dstDomain Value maps.
   314  static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
   315                                   const FlatAffineConstraints &dstDomain,
   316                                   const ValuePositionMap &valuePosMap,
   317                                   FlatAffineConstraints *dependenceDomain) {
   318    unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds();
   319  
   320    SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols());
   321  
   322    auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) {
   323      const FlatAffineConstraints &domain = isSrc ? srcDomain : dstDomain;
   324      unsigned numCsts =
   325          isEq ? domain.getNumEqualities() : domain.getNumInequalities();
   326      unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds();
   327      auto at = [&](unsigned i, unsigned j) -> int64_t {
   328        return isEq ? domain.atEq(i, j) : domain.atIneq(i, j);
   329      };
   330      auto map = [&](unsigned i) -> int64_t {
   331        return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getIdValue(i))
   332                     : valuePosMap.getDstDimOrSymPos(domain.getIdValue(i));
   333      };
   334  
   335      for (unsigned i = 0; i < numCsts; ++i) {
   336        // Zero fill.
   337        std::fill(cst.begin(), cst.end(), 0);
   338        // Set coefficients for identifiers corresponding to domain.
   339        for (unsigned j = 0; j < numDimAndSymbolIds; ++j)
   340          cst[map(j)] = at(i, j);
   341        // Local terms.
   342        for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++)
   343          cst[depNumDimsAndSymbolIds + localOffset + j] =
   344              at(i, numDimAndSymbolIds + j);
   345        // Set constant term.
   346        cst[cst.size() - 1] = at(i, domain.getNumCols() - 1);
   347        // Add constraint.
   348        if (isEq)
   349          dependenceDomain->addEquality(cst);
   350        else
   351          dependenceDomain->addInequality(cst);
   352      }
   353    };
   354  
   355    // Add equalities from src domain.
   356    addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0);
   357    // Add inequalities from src domain.
   358    addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0);
   359    // Add equalities from dst domain.
   360    addDomain(/*isSrc=*/false, /*isEq=*/true,
   361              /*localOffset=*/srcDomain.getNumLocalIds());
   362    // Add inequalities from dst domain.
   363    addDomain(/*isSrc=*/false, /*isEq=*/false,
   364              /*localOffset=*/srcDomain.getNumLocalIds());
   365  }
   366  
   367  // Adds equality constraints that equate src and dst access functions
   368  // represented by 'srcAccessMap' and 'dstAccessMap' for each result.
   369  // Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
   370  // For example, given the following two accesses functions to a 2D memref:
   371  //
   372  //   Source access function:
   373  //     (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
   374  //
   375  //   Destination acceses function:
   376  //     (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
   377  //
   378  // This method constructs the following equality constraints in
   379  // 'dependenceDomain', by equating the access functions for each result
   380  // (i.e. each memref dim). Notice that 'd0' for the destination access function
   381  // is mapped into 'd0' in the equality constraint:
   382  //
   383  //   d0      d1      s0         c
   384  //   --      --      --         --
   385  //   a0     -c0      (a1 - c1)  (a1 - c2) = 0
   386  //   b0     -f0      (b1 - f1)  (b1 - f2) = 0
   387  //
   388  // Returns failure if any AffineExpr cannot be flattened (due to it being
   389  // semi-affine). Returns success otherwise.
   390  static LogicalResult
   391  addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
   392                             const AffineValueMap &dstAccessMap,
   393                             const ValuePositionMap &valuePosMap,
   394                             FlatAffineConstraints *dependenceDomain) {
   395    AffineMap srcMap = srcAccessMap.getAffineMap();
   396    AffineMap dstMap = dstAccessMap.getAffineMap();
   397    assert(srcMap.getNumResults() == dstMap.getNumResults());
   398    unsigned numResults = srcMap.getNumResults();
   399  
   400    unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
   401    ArrayRef<Value *> srcOperands = srcAccessMap.getOperands();
   402  
   403    unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
   404    ArrayRef<Value *> dstOperands = dstAccessMap.getOperands();
   405  
   406    std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
   407    std::vector<SmallVector<int64_t, 8>> destFlatExprs;
   408    FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
   409    // Get flattened expressions for the source destination maps.
   410    if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) ||
   411        failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)))
   412      return failure();
   413  
   414    unsigned domNumLocalIds = dependenceDomain->getNumLocalIds();
   415    unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds();
   416    unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds();
   417    unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds;
   418    for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
   419      dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
   420    }
   421  
   422    unsigned numDims = dependenceDomain->getNumDimIds();
   423    unsigned numSymbols = dependenceDomain->getNumSymbolIds();
   424    unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
   425    unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds;
   426  
   427    // Equality to add.
   428    SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
   429    for (unsigned i = 0; i < numResults; ++i) {
   430      // Zero fill.
   431      std::fill(eq.begin(), eq.end(), 0);
   432  
   433      // Flattened AffineExpr for src result 'i'.
   434      const auto &srcFlatExpr = srcFlatExprs[i];
   435      // Set identifier coefficients from src access function.
   436      for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
   437        eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
   438      // Local terms.
   439      for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
   440        eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j];
   441      // Set constant term.
   442      eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
   443  
   444      // Flattened AffineExpr for dest result 'i'.
   445      const auto &destFlatExpr = destFlatExprs[i];
   446      // Set identifier coefficients from dst access function.
   447      for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
   448        eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
   449      // Local terms.
   450      for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
   451        eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j];
   452      // Set constant term.
   453      eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
   454  
   455      // Add equality constraint.
   456      dependenceDomain->addEquality(eq);
   457    }
   458  
   459    // Add equality constraints for any operands that are defined by constant ops.
   460    auto addEqForConstOperands = [&](ArrayRef<Value *> operands) {
   461      for (unsigned i = 0, e = operands.size(); i < e; ++i) {
   462        if (isForInductionVar(operands[i]))
   463          continue;
   464        auto *symbol = operands[i];
   465        assert(isValidSymbol(symbol));
   466        // Check if the symbol is a constant.
   467        if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol->getDefiningOp()))
   468          dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
   469                                            cOp.getValue());
   470      }
   471    };
   472  
   473    // Add equality constraints for any src symbols defined by constant ops.
   474    addEqForConstOperands(srcOperands);
   475    // Add equality constraints for any dst symbols defined by constant ops.
   476    addEqForConstOperands(dstOperands);
   477  
   478    // By construction (see flattener), local var constraints will not have any
   479    // equalities.
   480    assert(srcLocalVarCst.getNumEqualities() == 0 &&
   481           destLocalVarCst.getNumEqualities() == 0);
   482    // Add inequalities from srcLocalVarCst and destLocalVarCst into the
   483    // dependence domain.
   484    SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
   485    for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
   486      std::fill(ineq.begin(), ineq.end(), 0);
   487  
   488      // Set identifier coefficients from src local var constraints.
   489      for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
   490        ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
   491            srcLocalVarCst.atIneq(r, j);
   492      // Local terms.
   493      for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
   494        ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
   495      // Set constant term.
   496      ineq[ineq.size() - 1] =
   497          srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
   498      dependenceDomain->addInequality(ineq);
   499    }
   500  
   501    for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
   502      std::fill(ineq.begin(), ineq.end(), 0);
   503      // Set identifier coefficients from dest local var constraints.
   504      for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
   505        ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
   506            destLocalVarCst.atIneq(r, j);
   507      // Local terms.
   508      for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
   509        ineq[newLocalIdOffset + numSrcLocalIds + j] =
   510            destLocalVarCst.atIneq(r, dstNumIds + j);
   511      // Set constant term.
   512      ineq[ineq.size() - 1] =
   513          destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
   514  
   515      dependenceDomain->addInequality(ineq);
   516    }
   517    return success();
   518  }
   519  
   520  // Returns the number of outer loop common to 'src/dstDomain'.
   521  // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
   522  static unsigned
   523  getNumCommonLoops(const FlatAffineConstraints &srcDomain,
   524                    const FlatAffineConstraints &dstDomain,
   525                    SmallVectorImpl<AffineForOp> *commonLoops = nullptr) {
   526    // Find the number of common loops shared by src and dst accesses.
   527    unsigned minNumLoops =
   528        std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
   529    unsigned numCommonLoops = 0;
   530    for (unsigned i = 0; i < minNumLoops; ++i) {
   531      if (!isForInductionVar(srcDomain.getIdValue(i)) ||
   532          !isForInductionVar(dstDomain.getIdValue(i)) ||
   533          srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
   534        break;
   535      if (commonLoops != nullptr)
   536        commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i)));
   537      ++numCommonLoops;
   538    }
   539    if (commonLoops != nullptr)
   540      assert(commonLoops->size() == numCommonLoops);
   541    return numCommonLoops;
   542  }
   543  
   544  // Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
   545  static Block *getCommonBlock(const MemRefAccess &srcAccess,
   546                               const MemRefAccess &dstAccess,
   547                               const FlatAffineConstraints &srcDomain,
   548                               unsigned numCommonLoops) {
   549    if (numCommonLoops == 0) {
   550      auto *block = srcAccess.opInst->getBlock();
   551      while (!llvm::isa<FuncOp>(block->getParentOp())) {
   552        block = block->getParentOp()->getBlock();
   553      }
   554      return block;
   555    }
   556    auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
   557    auto forOp = getForInductionVarOwner(commonForValue);
   558    assert(forOp && "commonForValue was not an induction variable");
   559    return forOp.getBody();
   560  }
   561  
   562  // Returns true if the ancestor operation of 'srcAccess' appears before the
   563  // ancestor operation of 'dstAccess' in the common ancestral block. Returns
   564  // false otherwise.
   565  // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
   566  // the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that
   567  // 'numCommonLoops' is the number of contiguous surrounding outer loops.
   568  static bool srcAppearsBeforeDstInAncestralBlock(
   569      const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
   570      const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) {
   571    // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
   572    auto *commonBlock =
   573        getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
   574    // Check the dominance relationship between the respective ancestors of the
   575    // src and dst in the Block of the innermost among the common loops.
   576    auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst);
   577    assert(srcInst != nullptr);
   578    auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst);
   579    assert(dstInst != nullptr);
   580  
   581    // Determine whether dstInst comes after srcInst.
   582    return srcInst->isBeforeInBlock(dstInst);
   583  }
   584  
   585  // Adds ordering constraints to 'dependenceDomain' based on number of loops
   586  // common to 'src/dstDomain' and requested 'loopDepth'.
   587  // Note that 'loopDepth' cannot exceed the number of common loops plus one.
   588  // EX: Given a loop nest of depth 2 with IVs 'i' and 'j':
   589  // *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
   590  // *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
   591  // *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
   592  static void addOrderingConstraints(const FlatAffineConstraints &srcDomain,
   593                                     const FlatAffineConstraints &dstDomain,
   594                                     unsigned loopDepth,
   595                                     FlatAffineConstraints *dependenceDomain) {
   596    unsigned numCols = dependenceDomain->getNumCols();
   597    SmallVector<int64_t, 4> eq(numCols);
   598    unsigned numSrcDims = srcDomain.getNumDimIds();
   599    unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
   600    unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth);
   601    for (unsigned i = 0; i < numCommonLoopConstraints; ++i) {
   602      std::fill(eq.begin(), eq.end(), 0);
   603      eq[i] = -1;
   604      eq[i + numSrcDims] = 1;
   605      if (i == loopDepth - 1) {
   606        eq[numCols - 1] = -1;
   607        dependenceDomain->addInequality(eq);
   608      } else {
   609        dependenceDomain->addEquality(eq);
   610      }
   611    }
   612  }
   613  
   614  // Computes distance and direction vectors in 'dependences', by adding
   615  // variables to 'dependenceDomain' which represent the difference of the IVs,
   616  // eliminating all other variables, and reading off distance vectors from
   617  // equality constraints (if possible), and direction vectors from inequalities.
   618  static void computeDirectionVector(
   619      const FlatAffineConstraints &srcDomain,
   620      const FlatAffineConstraints &dstDomain, unsigned loopDepth,
   621      FlatAffineConstraints *dependenceDomain,
   622      llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
   623    // Find the number of common loops shared by src and dst accesses.
   624    SmallVector<AffineForOp, 4> commonLoops;
   625    unsigned numCommonLoops =
   626        getNumCommonLoops(srcDomain, dstDomain, &commonLoops);
   627    if (numCommonLoops == 0)
   628      return;
   629    // Compute direction vectors for requested loop depth.
   630    unsigned numIdsToEliminate = dependenceDomain->getNumIds();
   631    // Add new variables to 'dependenceDomain' to represent the direction
   632    // constraints for each shared loop.
   633    for (unsigned j = 0; j < numCommonLoops; ++j) {
   634      dependenceDomain->addDimId(j);
   635    }
   636  
   637    // Add equality contraints for each common loop, setting newly introduced
   638    // variable at column 'j' to the 'dst' IV minus the 'src IV.
   639    SmallVector<int64_t, 4> eq;
   640    eq.resize(dependenceDomain->getNumCols());
   641    unsigned numSrcDims = srcDomain.getNumDimIds();
   642    // Constraint variables format:
   643    // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant]
   644    for (unsigned j = 0; j < numCommonLoops; ++j) {
   645      std::fill(eq.begin(), eq.end(), 0);
   646      eq[j] = 1;
   647      eq[j + numCommonLoops] = 1;
   648      eq[j + numCommonLoops + numSrcDims] = -1;
   649      dependenceDomain->addEquality(eq);
   650    }
   651  
   652    // Eliminate all variables other than the direction variables just added.
   653    dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate);
   654  
   655    // Scan each common loop variable column and set direction vectors based
   656    // on eliminated constraint system.
   657    dependenceComponents->resize(numCommonLoops);
   658    for (unsigned j = 0; j < numCommonLoops; ++j) {
   659      (*dependenceComponents)[j].op = commonLoops[j].getOperation();
   660      auto lbConst = dependenceDomain->getConstantLowerBound(j);
   661      (*dependenceComponents)[j].lb =
   662          lbConst.getValueOr(std::numeric_limits<int64_t>::min());
   663      auto ubConst = dependenceDomain->getConstantUpperBound(j);
   664      (*dependenceComponents)[j].ub =
   665          ubConst.getValueOr(std::numeric_limits<int64_t>::max());
   666    }
   667  }
   668  
   669  // Populates 'accessMap' with composition of AffineApplyOps reachable from
   670  // indices of MemRefAccess.
   671  void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
   672    // Get affine map from AffineLoad/Store.
   673    AffineMap map;
   674    if (auto loadOp = dyn_cast<AffineLoadOp>(opInst))
   675      map = loadOp.getAffineMap();
   676    else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst))
   677      map = storeOp.getAffineMap();
   678    SmallVector<Value *, 8> operands(indices.begin(), indices.end());
   679    fullyComposeAffineMapAndOperands(&map, &operands);
   680    map = simplifyAffineMap(map);
   681    canonicalizeMapAndOperands(&map, &operands);
   682    accessMap->reset(map, operands);
   683  }
   684  
   685  // Builds a flat affine constraint system to check if there exists a dependence
   686  // between memref accesses 'srcAccess' and 'dstAccess'.
   687  // Returns 'NoDependence' if the accesses can be definitively shown not to
   688  // access the same element.
   689  // Returns 'HasDependence' if the accesses do access the same element.
   690  // Returns 'Failure' if an error or unsupported case was encountered.
   691  // If a dependence exists, returns in 'dependenceComponents' a direction
   692  // vector for the dependence, with a component for each loop IV in loops
   693  // common to both accesses (see Dependence in AffineAnalysis.h for details).
   694  //
   695  // The memref access dependence check is comprised of the following steps:
   696  // *) Compute access functions for each access. Access functions are computed
   697  //    using AffineValueMaps initialized with the indices from an access, then
   698  //    composed with AffineApplyOps reachable from operands of that access,
   699  //    until operands of the AffineValueMap are loop IVs or symbols.
   700  // *) Build iteration domain constraints for each access. Iteration domain
   701  //    constraints are pairs of inequality contraints representing the
   702  //    upper/lower loop bounds for each AffineForOp in the loop nest associated
   703  //    with each access.
   704  // *) Build dimension and symbol position maps for each access, which map
   705  //    Values from access functions and iteration domains to their position
   706  //    in the merged constraint system built by this method.
   707  //
   708  // This method builds a constraint system with the following column format:
   709  //
   710  //  [src-dim-identifiers, dst-dim-identifiers, symbols, constant]
   711  //
   712  // For example, given the following MLIR code with with "source" and
   713  // "destination" accesses to the same memref labled, and symbols %M, %N, %K:
   714  //
   715  //   affine.for %i0 = 0 to 100 {
   716  //     affine.for %i1 = 0 to 50 {
   717  //       %a0 = affine.apply
   718  //         (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
   719  //       // Source memref access.
   720  //       store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
   721  //     }
   722  //   }
   723  //
   724  //   affine.for %i2 = 0 to 100 {
   725  //     affine.for %i3 = 0 to 50 {
   726  //       %a1 = affine.apply
   727  //         (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
   728  //       // Destination memref access.
   729  //       %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
   730  //     }
   731  //   }
   732  //
   733  // The access functions would be the following:
   734  //
   735  //   src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
   736  //   dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
   737  //
   738  // The iteration domains for the src/dst accesses would be the following:
   739  //
   740  //   src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
   741  //   dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
   742  //
   743  // The symbols by both accesses would be assigned to a canonical position order
   744  // which will be used in the dependence constraint system:
   745  //
   746  //   symbol name: %M  %N  %K
   747  //   symbol  pos:  0   1   2
   748  //
   749  // Equality constraints are built by equating each result of src/destination
   750  // access functions. For this example, the following two equality constraints
   751  // will be added to the dependence constraint system:
   752  //
   753  //   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
   754  //      2         -4        -7        -9       1      1     0     0    = 0
   755  //      0          3         0        -11     -1      0     1     0    = 0
   756  //
   757  // Inequality constraints from the iteration domain will be meged into
   758  // the dependence constraint system
   759  //
   760  //   [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
   761  //       1         0         0         0        0     0     0     0    >= 0
   762  //      -1         0         0         0        0     0     0     100  >= 0
   763  //       0         1         0         0        0     0     0     0    >= 0
   764  //       0        -1         0         0        0     0     0     50   >= 0
   765  //       0         0         1         0        0     0     0     0    >= 0
   766  //       0         0        -1         0        0     0     0     100  >= 0
   767  //       0         0         0         1        0     0     0     0    >= 0
   768  //       0         0         0        -1        0     0     0     50   >= 0
   769  //
   770  //
   771  // TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv.
   772  DependenceResult mlir::checkMemrefAccessDependence(
   773      const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
   774      unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
   775      llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
   776      bool allowRAR) {
   777    LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
   778                            << Twine(loopDepth) << " between:\n";);
   779    LLVM_DEBUG(srcAccess.opInst->dump(););
   780    LLVM_DEBUG(dstAccess.opInst->dump(););
   781  
   782    // Return 'NoDependence' if these accesses do not access the same memref.
   783    if (srcAccess.memref != dstAccess.memref)
   784      return DependenceResult::NoDependence;
   785  
   786    // Return 'NoDependence' if one of these accesses is not an AffineStoreOp.
   787    if (!allowRAR && !isa<AffineStoreOp>(srcAccess.opInst) &&
   788        !isa<AffineStoreOp>(dstAccess.opInst))
   789      return DependenceResult::NoDependence;
   790  
   791    // Get composed access function for 'srcAccess'.
   792    AffineValueMap srcAccessMap;
   793    srcAccess.getAccessMap(&srcAccessMap);
   794  
   795    // Get composed access function for 'dstAccess'.
   796    AffineValueMap dstAccessMap;
   797    dstAccess.getAccessMap(&dstAccessMap);
   798  
   799    // Get iteration domain for the 'srcAccess' operation.
   800    FlatAffineConstraints srcDomain;
   801    if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain)))
   802      return DependenceResult::Failure;
   803  
   804    // Get iteration domain for 'dstAccess' operation.
   805    FlatAffineConstraints dstDomain;
   806    if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain)))
   807      return DependenceResult::Failure;
   808  
   809    // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
   810    // operation of 'srcAccess' does not properly dominate the ancestor
   811    // operation of 'dstAccess' in the same common operation block.
   812    // Note: this check is skipped if 'allowRAR' is true, because because RAR
   813    // deps can exist irrespective of lexicographic ordering b/w src and dst.
   814    unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
   815    assert(loopDepth <= numCommonLoops + 1);
   816    if (!allowRAR && loopDepth > numCommonLoops &&
   817        !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
   818                                             numCommonLoops)) {
   819      return DependenceResult::NoDependence;
   820    }
   821    // Build dim and symbol position maps for each access from access operand
   822    // Value to position in merged contstraint system.
   823    ValuePositionMap valuePosMap;
   824    buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
   825                                  dstAccessMap, &valuePosMap,
   826                                  dependenceConstraints);
   827  
   828    initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
   829                              valuePosMap, dependenceConstraints);
   830  
   831    assert(valuePosMap.getNumDims() ==
   832           srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
   833  
   834    // Create memref access constraint by equating src/dst access functions.
   835    // Note that this check is conservative, and will fail in the future when
   836    // local variables for mod/div exprs are supported.
   837    if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
   838                                          dependenceConstraints)))
   839      return DependenceResult::Failure;
   840  
   841    // Add 'src' happens before 'dst' ordering constraints.
   842    addOrderingConstraints(srcDomain, dstDomain, loopDepth,
   843                           dependenceConstraints);
   844    // Add src and dst domain constraints.
   845    addDomainConstraints(srcDomain, dstDomain, valuePosMap,
   846                         dependenceConstraints);
   847  
   848    // Return 'NoDependence' if the solution space is empty: no dependence.
   849    if (dependenceConstraints->isEmpty()) {
   850      return DependenceResult::NoDependence;
   851    }
   852  
   853    // Compute dependence direction vector and return true.
   854    if (dependenceComponents != nullptr) {
   855      computeDirectionVector(srcDomain, dstDomain, loopDepth,
   856                             dependenceConstraints, dependenceComponents);
   857    }
   858  
   859    LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
   860    LLVM_DEBUG(dependenceConstraints->dump());
   861    return DependenceResult::HasDependence;
   862  }
   863  
   864  /// Gathers dependence components for dependences between all ops in loop nest
   865  /// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
   866  void mlir::getDependenceComponents(
   867      AffineForOp forOp, unsigned maxLoopDepth,
   868      std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec) {
   869    // Collect all load and store ops in loop nest rooted at 'forOp'.
   870    SmallVector<Operation *, 8> loadAndStoreOpInsts;
   871    forOp.getOperation()->walk([&](Operation *opInst) {
   872      if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
   873        loadAndStoreOpInsts.push_back(opInst);
   874    });
   875  
   876    unsigned numOps = loadAndStoreOpInsts.size();
   877    for (unsigned d = 1; d <= maxLoopDepth; ++d) {
   878      for (unsigned i = 0; i < numOps; ++i) {
   879        auto *srcOpInst = loadAndStoreOpInsts[i];
   880        MemRefAccess srcAccess(srcOpInst);
   881        for (unsigned j = 0; j < numOps; ++j) {
   882          auto *dstOpInst = loadAndStoreOpInsts[j];
   883          MemRefAccess dstAccess(dstOpInst);
   884  
   885          FlatAffineConstraints dependenceConstraints;
   886          llvm::SmallVector<DependenceComponent, 2> depComps;
   887          // TODO(andydavis,bondhugula) Explore whether it would be profitable
   888          // to pre-compute and store deps instead of repeatedly checking.
   889          DependenceResult result = checkMemrefAccessDependence(
   890              srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
   891          if (hasDependence(result))
   892            depCompsVec->push_back(depComps);
   893        }
   894      }
   895    }
   896  }