github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Analysis/AffineStructures.cpp (about)

     1  //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // Structures for affine/polyhedral analysis of MLIR functions.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Analysis/AffineStructures.h"
    23  #include "mlir/Dialect/AffineOps/AffineOps.h"
    24  #include "mlir/Dialect/StandardOps/Ops.h"
    25  #include "mlir/IR/AffineExprVisitor.h"
    26  #include "mlir/IR/AffineMap.h"
    27  #include "mlir/IR/IntegerSet.h"
    28  #include "mlir/IR/Operation.h"
    29  #include "mlir/Support/MathExtras.h"
    30  #include "llvm/ADT/DenseSet.h"
    31  #include "llvm/ADT/SmallPtrSet.h"
    32  #include "llvm/Support/Debug.h"
    33  #include "llvm/Support/raw_ostream.h"
    34  
    35  #define DEBUG_TYPE "affine-structures"
    36  
    37  using namespace mlir;
    38  using llvm::SmallDenseMap;
    39  using llvm::SmallDenseSet;
    40  using llvm::SmallPtrSet;
    41  
    42  namespace {
    43  
    44  // See comments for SimpleAffineExprFlattener.
    45  // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
    46  // constraint information associated with mod's, floordiv's, and ceildiv's
    47  // in FlatAffineConstraints 'localVarCst'.
    48  struct AffineExprFlattener : public SimpleAffineExprFlattener {
    49  public:
    50    // Constraints connecting newly introduced local variables (for mod's and
    51    // div's) to existing (dimensional and symbolic) ones. These are always
    52    // inequalities.
    53    FlatAffineConstraints localVarCst;
    54  
    55    AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
    56        : SimpleAffineExprFlattener(nDims, nSymbols) {
    57      localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
    58    }
    59  
    60  private:
    61    // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
    62    // The local identifier added is always a floordiv of a pure add/mul affine
    63    // function of other identifiers, coefficients of which are specified in
    64    // `dividend' and with respect to the positive constant `divisor'. localExpr
    65    // is the simplified tree expression (AffineExpr) corresponding to the
    66    // quantifier.
    67    void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
    68                            AffineExpr localExpr) override {
    69      SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
    70      // Update localVarCst.
    71      localVarCst.addLocalFloorDiv(dividend, divisor);
    72    }
    73  };
    74  
    75  } // end anonymous namespace
    76  
    77  // Flattens the expressions in map. Returns failure if 'expr' was unable to be
    78  // flattened (i.e., semi-affine expressions not handled yet).
    79  static LogicalResult getFlattenedAffineExprs(
    80      ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols,
    81      std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
    82      FlatAffineConstraints *localVarCst) {
    83    if (exprs.empty()) {
    84      localVarCst->reset(numDims, numSymbols);
    85      return success();
    86    }
    87  
    88    AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
    89    // Use the same flattener to simplify each expression successively. This way
    90    // local identifiers / expressions are shared.
    91    for (auto expr : exprs) {
    92      if (!expr.isPureAffine())
    93        return failure();
    94  
    95      flattener.walkPostOrder(expr);
    96    }
    97  
    98    assert(flattener.operandExprStack.size() == exprs.size());
    99    flattenedExprs->clear();
   100    flattenedExprs->assign(flattener.operandExprStack.begin(),
   101                           flattener.operandExprStack.end());
   102  
   103    if (localVarCst) {
   104      localVarCst->clearAndCopyFrom(flattener.localVarCst);
   105    }
   106  
   107    return success();
   108  }
   109  
   110  // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
   111  // be flattened (semi-affine expressions not handled yet).
   112  LogicalResult
   113  mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
   114                               unsigned numSymbols,
   115                               llvm::SmallVectorImpl<int64_t> *flattenedExpr,
   116                               FlatAffineConstraints *localVarCst) {
   117    std::vector<SmallVector<int64_t, 8>> flattenedExprs;
   118    LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
   119                                                  &flattenedExprs, localVarCst);
   120    *flattenedExpr = flattenedExprs[0];
   121    return ret;
   122  }
   123  
   124  /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
   125  /// flattened (i.e., semi-affine expressions not handled yet).
   126  LogicalResult mlir::getFlattenedAffineExprs(
   127      AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
   128      FlatAffineConstraints *localVarCst) {
   129    if (map.getNumResults() == 0) {
   130      localVarCst->reset(map.getNumDims(), map.getNumSymbols());
   131      return success();
   132    }
   133    return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
   134                                     map.getNumSymbols(), flattenedExprs,
   135                                     localVarCst);
   136  }
   137  
   138  LogicalResult mlir::getFlattenedAffineExprs(
   139      IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs,
   140      FlatAffineConstraints *localVarCst) {
   141    if (set.getNumConstraints() == 0) {
   142      localVarCst->reset(set.getNumDims(), set.getNumSymbols());
   143      return success();
   144    }
   145    return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
   146                                     set.getNumSymbols(), flattenedExprs,
   147                                     localVarCst);
   148  }
   149  
   150  //===----------------------------------------------------------------------===//
   151  // MutableAffineMap.
   152  //===----------------------------------------------------------------------===//
   153  
   154  MutableAffineMap::MutableAffineMap(AffineMap map)
   155      : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
   156        // A map always has at least 1 result by construction
   157        context(map.getResult(0).getContext()) {
   158    for (auto result : map.getResults())
   159      results.push_back(result);
   160  }
   161  
   162  void MutableAffineMap::reset(AffineMap map) {
   163    results.clear();
   164    numDims = map.getNumDims();
   165    numSymbols = map.getNumSymbols();
   166    // A map always has at least 1 result by construction
   167    context = map.getResult(0).getContext();
   168    for (auto result : map.getResults())
   169      results.push_back(result);
   170  }
   171  
   172  bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
   173    if (results[idx].isMultipleOf(factor))
   174      return true;
   175  
   176    // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
   177    // complete this (for a more powerful analysis).
   178    return false;
   179  }
   180  
   181  // Simplifies the result affine expressions of this map. The expressions have to
   182  // be pure for the simplification implemented.
   183  void MutableAffineMap::simplify() {
   184    // Simplify each of the results if possible.
   185    // TODO(ntv): functional-style map
   186    for (unsigned i = 0, e = getNumResults(); i < e; i++) {
   187      results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
   188    }
   189  }
   190  
   191  AffineMap MutableAffineMap::getAffineMap() const {
   192    return AffineMap::get(numDims, numSymbols, results);
   193  }
   194  
   195  MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context)
   196      : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()) {
   197    // TODO(bondhugula)
   198  }
   199  
   200  // Universal set.
   201  MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
   202                                       MLIRContext *context)
   203      : numDims(numDims), numSymbols(numSymbols) {}
   204  
   205  //===----------------------------------------------------------------------===//
   206  // AffineValueMap.
   207  //===----------------------------------------------------------------------===//
   208  
   209  AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands,
   210                                 ArrayRef<Value *> results)
   211      : map(map), operands(operands.begin(), operands.end()),
   212        results(results.begin(), results.end()) {}
   213  
   214  AffineValueMap::AffineValueMap(AffineApplyOp applyOp)
   215      : map(applyOp.getAffineMap()),
   216        operands(applyOp.operand_begin(), applyOp.operand_end()) {
   217    results.push_back(applyOp.getResult());
   218  }
   219  
   220  AffineValueMap::AffineValueMap(AffineBound bound)
   221      : map(bound.getMap()),
   222        operands(bound.operand_begin(), bound.operand_end()) {}
   223  
   224  void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands,
   225                             ArrayRef<Value *> results) {
   226    this->map.reset(map);
   227    this->operands.assign(operands.begin(), operands.end());
   228    this->results.assign(results.begin(), results.end());
   229  }
   230  
   231  // Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
   232  // 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise.
   233  static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch,
   234                        unsigned indexStart, unsigned *indexOfMatch) {
   235    unsigned size = valuesToSearch.size();
   236    for (unsigned i = indexStart; i < size; ++i) {
   237      if (valueToMatch == valuesToSearch[i]) {
   238        *indexOfMatch = i;
   239        return true;
   240      }
   241    }
   242    return false;
   243  }
   244  
   245  inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
   246    return map.isMultipleOf(idx, factor);
   247  }
   248  
   249  /// This method uses the invariant that operands are always positionally aligned
   250  /// with the AffineDimExpr in the underlying AffineMap.
   251  bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const {
   252    unsigned index;
   253    if (!findIndex(value, operands, /*indexStart=*/0, &index)) {
   254      return false;
   255    }
   256    auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx);
   257    // TODO(ntv): this is better implemented on a flattened representation.
   258    // At least for now it is conservative.
   259    return expr.isFunctionOfDim(index);
   260  }
   261  
   262  Value *AffineValueMap::getOperand(unsigned i) const {
   263    return static_cast<Value *>(operands[i]);
   264  }
   265  
   266  ArrayRef<Value *> AffineValueMap::getOperands() const {
   267    return ArrayRef<Value *>(operands);
   268  }
   269  
   270  AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); }
   271  
   272  AffineValueMap::~AffineValueMap() {}
   273  
   274  //===----------------------------------------------------------------------===//
   275  // FlatAffineConstraints.
   276  //===----------------------------------------------------------------------===//
   277  
   278  // Copy constructor.
   279  FlatAffineConstraints::FlatAffineConstraints(
   280      const FlatAffineConstraints &other) {
   281    numReservedCols = other.numReservedCols;
   282    numDims = other.getNumDimIds();
   283    numSymbols = other.getNumSymbolIds();
   284    numIds = other.getNumIds();
   285  
   286    auto otherIds = other.getIds();
   287    ids.reserve(numReservedCols);
   288    ids.append(otherIds.begin(), otherIds.end());
   289  
   290    unsigned numReservedEqualities = other.getNumReservedEqualities();
   291    unsigned numReservedInequalities = other.getNumReservedInequalities();
   292  
   293    equalities.reserve(numReservedEqualities * numReservedCols);
   294    inequalities.reserve(numReservedInequalities * numReservedCols);
   295  
   296    for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
   297      addInequality(other.getInequality(r));
   298    }
   299    for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
   300      addEquality(other.getEquality(r));
   301    }
   302  }
   303  
   304  // Clones this object.
   305  std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
   306    return std::make_unique<FlatAffineConstraints>(*this);
   307  }
   308  
   309  // Construct from an IntegerSet.
   310  FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
   311      : numReservedCols(set.getNumOperands() + 1),
   312        numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
   313        numSymbols(set.getNumSymbols()) {
   314    equalities.reserve(set.getNumEqualities() * numReservedCols);
   315    inequalities.reserve(set.getNumInequalities() * numReservedCols);
   316    ids.resize(numIds, None);
   317  
   318    // Flatten expressions and add them to the constraint system.
   319    std::vector<SmallVector<int64_t, 8>> flatExprs;
   320    FlatAffineConstraints localVarCst;
   321    if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
   322      assert(false && "flattening unimplemented for semi-affine integer sets");
   323      return;
   324    }
   325    assert(flatExprs.size() == set.getNumConstraints());
   326    for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
   327      addLocalId(getNumLocalIds());
   328    }
   329  
   330    for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
   331      const auto &flatExpr = flatExprs[i];
   332      assert(flatExpr.size() == getNumCols());
   333      if (set.getEqFlags()[i]) {
   334        addEquality(flatExpr);
   335      } else {
   336        addInequality(flatExpr);
   337      }
   338    }
   339    // Add the other constraints involving local id's from flattening.
   340    append(localVarCst);
   341  }
   342  
   343  void FlatAffineConstraints::reset(unsigned numReservedInequalities,
   344                                    unsigned numReservedEqualities,
   345                                    unsigned newNumReservedCols,
   346                                    unsigned newNumDims, unsigned newNumSymbols,
   347                                    unsigned newNumLocals,
   348                                    ArrayRef<Value *> idArgs) {
   349    assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
   350           "minimum 1 column");
   351    numReservedCols = newNumReservedCols;
   352    numDims = newNumDims;
   353    numSymbols = newNumSymbols;
   354    numIds = numDims + numSymbols + newNumLocals;
   355    assert(idArgs.empty() || idArgs.size() == numIds);
   356  
   357    clearConstraints();
   358    if (numReservedEqualities >= 1)
   359      equalities.reserve(newNumReservedCols * numReservedEqualities);
   360    if (numReservedInequalities >= 1)
   361      inequalities.reserve(newNumReservedCols * numReservedInequalities);
   362    if (idArgs.empty()) {
   363      ids.resize(numIds, None);
   364    } else {
   365      ids.assign(idArgs.begin(), idArgs.end());
   366    }
   367  }
   368  
   369  void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
   370                                    unsigned newNumLocals,
   371                                    ArrayRef<Value *> idArgs) {
   372    reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
   373          newNumSymbols, newNumLocals, idArgs);
   374  }
   375  
   376  void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
   377    assert(other.getNumCols() == getNumCols());
   378    assert(other.getNumDimIds() == getNumDimIds());
   379    assert(other.getNumSymbolIds() == getNumSymbolIds());
   380  
   381    inequalities.reserve(inequalities.size() +
   382                         other.getNumInequalities() * numReservedCols);
   383    equalities.reserve(equalities.size() +
   384                       other.getNumEqualities() * numReservedCols);
   385  
   386    for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
   387      addInequality(other.getInequality(r));
   388    }
   389    for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
   390      addEquality(other.getEquality(r));
   391    }
   392  }
   393  
   394  void FlatAffineConstraints::addLocalId(unsigned pos) {
   395    addId(IdKind::Local, pos);
   396  }
   397  
   398  void FlatAffineConstraints::addDimId(unsigned pos, Value *id) {
   399    addId(IdKind::Dimension, pos, id);
   400  }
   401  
   402  void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) {
   403    addId(IdKind::Symbol, pos, id);
   404  }
   405  
   406  /// Adds a dimensional identifier. The added column is initialized to
   407  /// zero.
   408  void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
   409    if (kind == IdKind::Dimension) {
   410      assert(pos <= getNumDimIds());
   411    } else if (kind == IdKind::Symbol) {
   412      assert(pos <= getNumSymbolIds());
   413    } else {
   414      assert(pos <= getNumLocalIds());
   415    }
   416  
   417    unsigned oldNumReservedCols = numReservedCols;
   418  
   419    // Check if a resize is necessary.
   420    if (getNumCols() + 1 > numReservedCols) {
   421      equalities.resize(getNumEqualities() * (getNumCols() + 1));
   422      inequalities.resize(getNumInequalities() * (getNumCols() + 1));
   423      numReservedCols++;
   424    }
   425  
   426    int absolutePos;
   427  
   428    if (kind == IdKind::Dimension) {
   429      absolutePos = pos;
   430      numDims++;
   431    } else if (kind == IdKind::Symbol) {
   432      absolutePos = pos + getNumDimIds();
   433      numSymbols++;
   434    } else {
   435      absolutePos = pos + getNumDimIds() + getNumSymbolIds();
   436    }
   437    numIds++;
   438  
   439    // Note that getNumCols() now will already return the new size, which will be
   440    // at least one.
   441    int numInequalities = static_cast<int>(getNumInequalities());
   442    int numEqualities = static_cast<int>(getNumEqualities());
   443    int numCols = static_cast<int>(getNumCols());
   444    for (int r = numInequalities - 1; r >= 0; r--) {
   445      for (int c = numCols - 2; c >= 0; c--) {
   446        if (c < absolutePos)
   447          atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
   448        else
   449          atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
   450      }
   451      atIneq(r, absolutePos) = 0;
   452    }
   453  
   454    for (int r = numEqualities - 1; r >= 0; r--) {
   455      for (int c = numCols - 2; c >= 0; c--) {
   456        // All values in column absolutePositions < absolutePos have the same
   457        // coordinates in the 2-d view of the coefficient buffer.
   458        if (c < absolutePos)
   459          atEq(r, c) = equalities[r * oldNumReservedCols + c];
   460        else
   461          // Those at absolutePosition >= absolutePos, get a shifted
   462          // absolutePosition.
   463          atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
   464      }
   465      // Initialize added dimension to zero.
   466      atEq(r, absolutePos) = 0;
   467    }
   468  
   469    // If an 'id' is provided, insert it; otherwise use None.
   470    if (id) {
   471      ids.insert(ids.begin() + absolutePos, id);
   472    } else {
   473      ids.insert(ids.begin() + absolutePos, None);
   474    }
   475    assert(ids.size() == getNumIds());
   476  }
   477  
   478  /// Checks if two constraint systems are in the same space, i.e., if they are
   479  /// associated with the same set of identifiers, appearing in the same order.
   480  static bool areIdsAligned(const FlatAffineConstraints &A,
   481                            const FlatAffineConstraints &B) {
   482    return A.getNumDimIds() == B.getNumDimIds() &&
   483           A.getNumSymbolIds() == B.getNumSymbolIds() &&
   484           A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
   485  }
   486  
   487  /// Calls areIdsAligned to check if two constraint systems have the same set
   488  /// of identifiers in the same order.
   489  bool FlatAffineConstraints::areIdsAlignedWithOther(
   490      const FlatAffineConstraints &other) {
   491    return areIdsAligned(*this, other);
   492  }
   493  
   494  /// Checks if the SSA values associated with `cst''s identifiers are unique.
   495  static bool LLVM_ATTRIBUTE_UNUSED
   496  areIdsUnique(const FlatAffineConstraints &cst) {
   497    SmallPtrSet<Value *, 8> uniqueIds;
   498    for (auto id : cst.getIds()) {
   499      if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
   500        return false;
   501    }
   502    return true;
   503  }
   504  
   505  // Swap the posA^th identifier with the posB^th identifier.
   506  static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
   507    assert(posA < A->getNumIds() && "invalid position A");
   508    assert(posB < A->getNumIds() && "invalid position B");
   509  
   510    if (posA == posB)
   511      return;
   512  
   513    for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) {
   514      std::swap(A->atIneq(r, posA), A->atIneq(r, posB));
   515    }
   516    for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) {
   517      std::swap(A->atEq(r, posA), A->atEq(r, posB));
   518    }
   519    std::swap(A->getId(posA), A->getId(posB));
   520  }
   521  
   522  /// Merge and align the identifiers of A and B starting at 'offset', so that
   523  /// both constraint systems get the union of the contained identifiers that is
   524  /// dimension-wise and symbol-wise unique; both constraint systems are updated
   525  /// so that they have the union of all identifiers, with A's original
   526  /// identifiers appearing first followed by any of B's identifiers that didn't
   527  /// appear in A. Local identifiers of each system are by design separate/local
   528  /// and are placed one after other (A's followed by B's).
   529  //  Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
   530  //      Output: both A, B have (%i, %j, %k) [%M, %N, %P]
   531  //
   532  static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
   533                               FlatAffineConstraints *B) {
   534    assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
   535    // A merge/align isn't meaningful if a cst's ids aren't distinct.
   536    assert(areIdsUnique(*A) && "A's id values aren't unique");
   537    assert(areIdsUnique(*B) && "B's id values aren't unique");
   538  
   539    assert(std::all_of(A->getIds().begin() + offset,
   540                       A->getIds().begin() + A->getNumDimAndSymbolIds(),
   541                       [](Optional<Value *> id) { return id.hasValue(); }));
   542  
   543    assert(std::all_of(B->getIds().begin() + offset,
   544                       B->getIds().begin() + B->getNumDimAndSymbolIds(),
   545                       [](Optional<Value *> id) { return id.hasValue(); }));
   546  
   547    // Place local id's of A after local id's of B.
   548    for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
   549      B->addLocalId(0);
   550    }
   551    for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
   552         t++) {
   553      A->addLocalId(A->getNumLocalIds());
   554    }
   555  
   556    SmallVector<Value *, 4> aDimValues, aSymValues;
   557    A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
   558    A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
   559    {
   560      // Merge dims from A into B.
   561      unsigned d = offset;
   562      for (auto *aDimValue : aDimValues) {
   563        unsigned loc;
   564        if (B->findId(*aDimValue, &loc)) {
   565          assert(loc >= offset && "A's dim appears in B's aligned range");
   566          assert(loc < B->getNumDimIds() &&
   567                 "A's dim appears in B's non-dim position");
   568          swapId(B, d, loc);
   569        } else {
   570          B->addDimId(d);
   571          B->setIdValue(d, aDimValue);
   572        }
   573        d++;
   574      }
   575  
   576      // Dimensions that are in B, but not in A, are added at the end.
   577      for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
   578        A->addDimId(A->getNumDimIds());
   579        A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
   580      }
   581    }
   582    {
   583      // Merge symbols: merge A's symbols into B first.
   584      unsigned s = B->getNumDimIds();
   585      for (auto *aSymValue : aSymValues) {
   586        unsigned loc;
   587        if (B->findId(*aSymValue, &loc)) {
   588          assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
   589                 "A's symbol appears in B's non-symbol position");
   590          swapId(B, s, loc);
   591        } else {
   592          B->addSymbolId(s - B->getNumDimIds());
   593          B->setIdValue(s, aSymValue);
   594        }
   595        s++;
   596      }
   597      // Symbols that are in B, but not in A, are added at the end.
   598      for (unsigned t = A->getNumDimAndSymbolIds(),
   599                    e = B->getNumDimAndSymbolIds();
   600           t < e; t++) {
   601        A->addSymbolId(A->getNumSymbolIds());
   602        A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
   603      }
   604    }
   605    assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
   606  }
   607  
   608  // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
   609  void FlatAffineConstraints::mergeAndAlignIdsWithOther(
   610      unsigned offset, FlatAffineConstraints *other) {
   611    mergeAndAlignIds(offset, this, other);
   612  }
   613  
   614  // This routine may add additional local variables if the flattened expression
   615  // corresponding to the map has such variables due to mod's, ceildiv's, and
   616  // floordiv's in it.
   617  LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
   618    std::vector<SmallVector<int64_t, 8>> flatExprs;
   619    FlatAffineConstraints localCst;
   620    if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
   621                                       &localCst))) {
   622      LLVM_DEBUG(llvm::dbgs()
   623                 << "composition unimplemented for semi-affine maps\n");
   624      return failure();
   625    }
   626    assert(flatExprs.size() == vMap->getNumResults());
   627  
   628    // Add localCst information.
   629    if (localCst.getNumLocalIds() > 0) {
   630      SmallVector<Value *, 8> values(vMap->getOperands().begin(),
   631                                     vMap->getOperands().end());
   632      localCst.setIdValues(0, localCst.getNumDimAndSymbolIds(), values);
   633      // Align localCst and this.
   634      mergeAndAlignIds(/*offset=*/0, &localCst, this);
   635      // Finally, append localCst to this constraint set.
   636      append(localCst);
   637    }
   638  
   639    // Add dimensions corresponding to the map's results.
   640    for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
   641      // TODO: Consider using a batched version to add a range of IDs.
   642      addDimId(0);
   643    }
   644  
   645    // We add one equality for each result connecting the result dim of the map to
   646    // the other identifiers.
   647    // For eg: if the expression is 16*i0 + i1, and this is the r^th
   648    // iteration/result of the value map, we are adding the equality:
   649    //  d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
   650    //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
   651    for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
   652      const auto &flatExpr = flatExprs[r];
   653      assert(flatExpr.size() >= vMap->getNumOperands() + 1);
   654  
   655      // eqToAdd is the equality corresponding to the flattened affine expression.
   656      SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
   657      // Set the coefficient for this result to one.
   658      eqToAdd[r] = 1;
   659  
   660      // Dims and symbols.
   661      for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
   662        unsigned loc;
   663        bool ret = findId(*vMap->getOperand(i), &loc);
   664        assert(ret && "value map's id can't be found");
   665        (void)ret;
   666        // Negate 'eq[r]' since the newly added dimension will be set to this one.
   667        eqToAdd[loc] = -flatExpr[i];
   668      }
   669      // Local vars common to eq and localCst are at the beginning.
   670      unsigned j = getNumDimIds() + getNumSymbolIds();
   671      unsigned end = flatExpr.size() - 1;
   672      for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
   673        eqToAdd[j] = -flatExpr[i];
   674      }
   675  
   676      // Constant term.
   677      eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
   678  
   679      // Add the equality connecting the result of the map to this constraint set.
   680      addEquality(eqToAdd);
   681    }
   682  
   683    return success();
   684  }
   685  
   686  // Turn a dimension into a symbol.
   687  static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) {
   688    unsigned pos;
   689    if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
   690      swapId(cst, pos, cst->getNumDimIds() - 1);
   691      cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
   692    }
   693  }
   694  
   695  // Turn a symbol into a dimension.
   696  static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) {
   697    unsigned pos;
   698    if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
   699        pos < cst->getNumDimAndSymbolIds()) {
   700      swapId(cst, pos, cst->getNumDimIds());
   701      cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
   702    }
   703  }
   704  
   705  // Changes all symbol identifiers which are loop IVs to dim identifiers.
   706  void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
   707    // Gather all symbols which are loop IVs.
   708    SmallVector<Value *, 4> loopIVs;
   709    for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
   710      if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
   711        loopIVs.push_back(ids[i].getValue());
   712    }
   713    // Turn each symbol in 'loopIVs' into a dim identifier.
   714    for (auto *iv : loopIVs) {
   715      turnSymbolIntoDim(this, *iv);
   716    }
   717  }
   718  
   719  void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) {
   720    if (containsId(*id))
   721      return;
   722  
   723    // Caller is expected to fully compose map/operands if necessary.
   724    assert((isTopLevelSymbol(id) || isForInductionVar(id)) &&
   725           "non-terminal symbol / loop IV expected");
   726    // Outer loop IVs could be used in forOp's bounds.
   727    if (auto loop = getForInductionVarOwner(id)) {
   728      addDimId(getNumDimIds(), id);
   729      if (failed(this->addAffineForOpDomain(loop)))
   730        LLVM_DEBUG(
   731            loop.emitWarning("failed to add domain info to constraint system"));
   732      return;
   733    }
   734    // Add top level symbol.
   735    addSymbolId(getNumSymbolIds(), id);
   736    // Check if the symbol is a constant.
   737    if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id->getDefiningOp()))
   738      setIdToConstant(*id, constOp.getValue());
   739  }
   740  
   741  LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
   742    unsigned pos;
   743    // Pre-condition for this method.
   744    if (!findId(*forOp.getInductionVar(), &pos)) {
   745      assert(false && "Value not found");
   746      return failure();
   747    }
   748  
   749    int64_t step = forOp.getStep();
   750    if (step != 1) {
   751      if (!forOp.hasConstantLowerBound())
   752        forOp.emitWarning("domain conservatively approximated");
   753      else {
   754        // Add constraints for the stride.
   755        // (iv - lb) % step = 0 can be written as:
   756        // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
   757        // Add local variable 'q' and add the above equality.
   758        // The first constraint is q = (iv - lb) floordiv step
   759        SmallVector<int64_t, 8> dividend(getNumCols(), 0);
   760        int64_t lb = forOp.getConstantLowerBound();
   761        dividend[pos] = 1;
   762        dividend.back() -= lb;
   763        addLocalFloorDiv(dividend, step);
   764        // Second constraint: (iv - lb) - step * q = 0.
   765        SmallVector<int64_t, 8> eq(getNumCols(), 0);
   766        eq[pos] = 1;
   767        eq.back() -= lb;
   768        // For the local var just added above.
   769        eq[getNumCols() - 2] = -step;
   770        addEquality(eq);
   771      }
   772    }
   773  
   774    if (forOp.hasConstantLowerBound()) {
   775      addConstantLowerBound(pos, forOp.getConstantLowerBound());
   776    } else {
   777      // Non-constant lower bound case.
   778      SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands().begin(),
   779                                         forOp.getLowerBoundOperands().end());
   780      if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands,
   781                                      /*eq=*/false, /*lower=*/true)))
   782        return failure();
   783    }
   784  
   785    if (forOp.hasConstantUpperBound()) {
   786      addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
   787      return success();
   788    }
   789    // Non-constant upper bound case.
   790    SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands().begin(),
   791                                       forOp.getUpperBoundOperands().end());
   792    return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands,
   793                                /*eq=*/false, /*lower=*/false);
   794  }
   795  
   796  // Searches for a constraint with a non-zero coefficient at 'colIdx' in
   797  // equality (isEq=true) or inequality (isEq=false) constraints.
   798  // Returns true and sets row found in search in 'rowIdx'.
   799  // Returns false otherwise.
   800  static bool
   801  findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints,
   802                              unsigned colIdx, bool isEq, unsigned *rowIdx) {
   803    auto at = [&](unsigned rowIdx) -> int64_t {
   804      return isEq ? constraints.atEq(rowIdx, colIdx)
   805                  : constraints.atIneq(rowIdx, colIdx);
   806    };
   807    unsigned e =
   808        isEq ? constraints.getNumEqualities() : constraints.getNumInequalities();
   809    for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
   810      if (at(*rowIdx) != 0) {
   811        return true;
   812      }
   813    }
   814    return false;
   815  }
   816  
   817  // Normalizes the coefficient values across all columns in 'rowIDx' by their
   818  // GCD in equality or inequality contraints as specified by 'isEq'.
   819  template <bool isEq>
   820  static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
   821                                       unsigned rowIdx) {
   822    auto at = [&](unsigned colIdx) -> int64_t {
   823      return isEq ? constraints->atEq(rowIdx, colIdx)
   824                  : constraints->atIneq(rowIdx, colIdx);
   825    };
   826    uint64_t gcd = std::abs(at(0));
   827    for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
   828      gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
   829    }
   830    if (gcd > 0 && gcd != 1) {
   831      for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
   832        int64_t v = at(j) / static_cast<int64_t>(gcd);
   833        isEq ? constraints->atEq(rowIdx, j) = v
   834             : constraints->atIneq(rowIdx, j) = v;
   835      }
   836    }
   837  }
   838  
   839  void FlatAffineConstraints::normalizeConstraintsByGCD() {
   840    for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
   841      normalizeConstraintByGCD</*isEq=*/true>(this, i);
   842    }
   843    for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
   844      normalizeConstraintByGCD</*isEq=*/false>(this, i);
   845    }
   846  }
   847  
   848  bool FlatAffineConstraints::hasConsistentState() const {
   849    if (inequalities.size() != getNumInequalities() * numReservedCols)
   850      return false;
   851    if (equalities.size() != getNumEqualities() * numReservedCols)
   852      return false;
   853    if (ids.size() != getNumIds())
   854      return false;
   855  
   856    // Catches errors where numDims, numSymbols, numIds aren't consistent.
   857    if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
   858      return false;
   859  
   860    return true;
   861  }
   862  
   863  /// Checks all rows of equality/inequality constraints for trivial
   864  /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
   865  /// after elimination. Returns 'true' if an invalid constraint is found;
   866  /// 'false' otherwise.
   867  bool FlatAffineConstraints::hasInvalidConstraint() const {
   868    assert(hasConsistentState());
   869    auto check = [&](bool isEq) -> bool {
   870      unsigned numCols = getNumCols();
   871      unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
   872      for (unsigned i = 0, e = numRows; i < e; ++i) {
   873        unsigned j;
   874        for (j = 0; j < numCols - 1; ++j) {
   875          int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
   876          // Skip rows with non-zero variable coefficients.
   877          if (v != 0)
   878            break;
   879        }
   880        if (j < numCols - 1) {
   881          continue;
   882        }
   883        // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
   884        // Example invalid constraints include: '1 == 0' or '-1 >= 0'
   885        int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
   886        if ((isEq && v != 0) || (!isEq && v < 0)) {
   887          return true;
   888        }
   889      }
   890      return false;
   891    };
   892    if (check(/*isEq=*/true))
   893      return true;
   894    return check(/*isEq=*/false);
   895  }
   896  
   897  // Eliminate identifier from constraint at 'rowIdx' based on coefficient at
   898  // pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
   899  // updated as they have already been eliminated.
   900  static void eliminateFromConstraint(FlatAffineConstraints *constraints,
   901                                      unsigned rowIdx, unsigned pivotRow,
   902                                      unsigned pivotCol, unsigned elimColStart,
   903                                      bool isEq) {
   904    // Skip if equality 'rowIdx' if same as 'pivotRow'.
   905    if (isEq && rowIdx == pivotRow)
   906      return;
   907    auto at = [&](unsigned i, unsigned j) -> int64_t {
   908      return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
   909    };
   910    int64_t leadCoeff = at(rowIdx, pivotCol);
   911    // Skip if leading coefficient at 'rowIdx' is already zero.
   912    if (leadCoeff == 0)
   913      return;
   914    int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
   915    int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
   916    int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
   917    int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
   918    int64_t rowMultiplier = lcm / std::abs(leadCoeff);
   919  
   920    unsigned numCols = constraints->getNumCols();
   921    for (unsigned j = 0; j < numCols; ++j) {
   922      // Skip updating column 'j' if it was just eliminated.
   923      if (j >= elimColStart && j < pivotCol)
   924        continue;
   925      int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
   926                  rowMultiplier * at(rowIdx, j);
   927      isEq ? constraints->atEq(rowIdx, j) = v
   928           : constraints->atIneq(rowIdx, j) = v;
   929    }
   930  }
   931  
   932  // Remove coefficients in column range [colStart, colLimit) in place.
   933  // This removes in data in the specified column range, and copies any
   934  // remaining valid data into place.
   935  static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
   936                                 unsigned colStart, unsigned colLimit,
   937                                 bool isEq) {
   938    assert(colLimit <= constraints->getNumIds());
   939    if (colLimit <= colStart)
   940      return;
   941  
   942    unsigned numCols = constraints->getNumCols();
   943    unsigned numRows = isEq ? constraints->getNumEqualities()
   944                            : constraints->getNumInequalities();
   945    unsigned numToEliminate = colLimit - colStart;
   946    for (unsigned r = 0, e = numRows; r < e; ++r) {
   947      for (unsigned c = colLimit; c < numCols; ++c) {
   948        if (isEq) {
   949          constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
   950        } else {
   951          constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
   952        }
   953      }
   954    }
   955  }
   956  
   957  // Removes identifiers in column range [idStart, idLimit), and copies any
   958  // remaining valid data into place, and updates member variables.
   959  void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
   960    assert(idLimit < getNumCols() && "invalid id limit");
   961  
   962    if (idStart >= idLimit)
   963      return;
   964  
   965    // We are going to be removing one or more identifiers from the range.
   966    assert(idStart < numIds && "invalid idStart position");
   967  
   968    // TODO(andydavis) Make 'removeIdRange' a lambda called from here.
   969    // Remove eliminated identifiers from equalities.
   970    shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
   971  
   972    // Remove eliminated identifiers from inequalities.
   973    shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
   974  
   975    // Update members numDims, numSymbols and numIds.
   976    unsigned numDimsEliminated = 0;
   977    unsigned numLocalsEliminated = 0;
   978    unsigned numColsEliminated = idLimit - idStart;
   979    if (idStart < numDims) {
   980      numDimsEliminated = std::min(numDims, idLimit) - idStart;
   981    }
   982    // Check how many local id's were removed. Note that our identifier order is
   983    // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
   984    if (idLimit > numDims + numSymbols) {
   985      numLocalsEliminated = std::min(
   986          idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
   987    }
   988    unsigned numSymbolsEliminated =
   989        numColsEliminated - numDimsEliminated - numLocalsEliminated;
   990  
   991    numDims -= numDimsEliminated;
   992    numSymbols -= numSymbolsEliminated;
   993    numIds = numIds - numColsEliminated;
   994  
   995    ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
   996  
   997    // No resize necessary. numReservedCols remains the same.
   998  }
   999  
  1000  /// Returns the position of the identifier that has the minimum <number of lower
  1001  /// bounds> times <number of upper bounds> from the specified range of
  1002  /// identifiers [start, end). It is often best to eliminate in the increasing
  1003  /// order of these counts when doing Fourier-Motzkin elimination since FM adds
  1004  /// that many new constraints.
  1005  static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
  1006                                       unsigned start, unsigned end) {
  1007    assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
  1008  
  1009    auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
  1010      unsigned numLb = 0;
  1011      unsigned numUb = 0;
  1012      for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
  1013        if (cst.atIneq(r, pos) > 0) {
  1014          ++numLb;
  1015        } else if (cst.atIneq(r, pos) < 0) {
  1016          ++numUb;
  1017        }
  1018      }
  1019      return numLb * numUb;
  1020    };
  1021  
  1022    unsigned minLoc = start;
  1023    unsigned min = getProductOfNumLowerUpperBounds(start);
  1024    for (unsigned c = start + 1; c < end; c++) {
  1025      unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
  1026      if (numLbUbProduct < min) {
  1027        min = numLbUbProduct;
  1028        minLoc = c;
  1029      }
  1030    }
  1031    return minLoc;
  1032  }
  1033  
  1034  // Checks for emptiness of the set by eliminating identifiers successively and
  1035  // using the GCD test (on all equality constraints) and checking for trivially
  1036  // invalid constraints. Returns 'true' if the constraint system is found to be
  1037  // empty; false otherwise.
  1038  bool FlatAffineConstraints::isEmpty() const {
  1039    if (isEmptyByGCDTest() || hasInvalidConstraint())
  1040      return true;
  1041  
  1042    // First, eliminate as many identifiers as possible using Gaussian
  1043    // elimination.
  1044    FlatAffineConstraints tmpCst(*this);
  1045    unsigned currentPos = 0;
  1046    while (currentPos < tmpCst.getNumIds()) {
  1047      tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
  1048      ++currentPos;
  1049      // We check emptiness through trivial checks after eliminating each ID to
  1050      // detect emptiness early. Since the checks isEmptyByGCDTest() and
  1051      // hasInvalidConstraint() are linear time and single sweep on the constraint
  1052      // buffer, this appears reasonable - but can optimize in the future.
  1053      if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
  1054        return true;
  1055    }
  1056  
  1057    // Eliminate the remaining using FM.
  1058    for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
  1059      tmpCst.FourierMotzkinEliminate(
  1060          getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
  1061      // Check for a constraint explosion. This rarely happens in practice, but
  1062      // this check exists as a safeguard against improperly constructed
  1063      // constraint systems or artifically created arbitrarily complex systems
  1064      // that aren't the intended use case for FlatAffineConstraints. This is
  1065      // needed since FM has a worst case exponential complexity in theory.
  1066      if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
  1067        LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
  1068        return false;
  1069      }
  1070  
  1071      // FM wouldn't have modified the equalities in any way. So no need to again
  1072      // run GCD test. Check for trivial invalid constraints.
  1073      if (tmpCst.hasInvalidConstraint())
  1074        return true;
  1075    }
  1076    return false;
  1077  }
  1078  
  1079  // Runs the GCD test on all equality constraints. Returns 'true' if this test
  1080  // fails on any equality. Returns 'false' otherwise.
  1081  // This test can be used to disprove the existence of a solution. If it returns
  1082  // true, no integer solution to the equality constraints can exist.
  1083  //
  1084  // GCD test definition:
  1085  //
  1086  // The equality constraint:
  1087  //
  1088  //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
  1089  //
  1090  // has an integer solution iff:
  1091  //
  1092  //  GCD of c_1, c_2, ..., c_n divides c_0.
  1093  //
  1094  bool FlatAffineConstraints::isEmptyByGCDTest() const {
  1095    assert(hasConsistentState());
  1096    unsigned numCols = getNumCols();
  1097    for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
  1098      uint64_t gcd = std::abs(atEq(i, 0));
  1099      for (unsigned j = 1; j < numCols - 1; ++j) {
  1100        gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
  1101      }
  1102      int64_t v = std::abs(atEq(i, numCols - 1));
  1103      if (gcd > 0 && (v % gcd != 0)) {
  1104        return true;
  1105      }
  1106    }
  1107    return false;
  1108  }
  1109  
  1110  /// Tightens inequalities given that we are dealing with integer spaces. This is
  1111  /// analogous to the GCD test but applied to inequalities. The constant term can
  1112  /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
  1113  ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
  1114  /// fast method - linear in the number of coefficients.
  1115  // Example on how this affects practical cases: consider the scenario:
  1116  // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
  1117  // j >= 100 instead of the tighter (exact) j >= 128.
  1118  void FlatAffineConstraints::GCDTightenInequalities() {
  1119    unsigned numCols = getNumCols();
  1120    for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
  1121      uint64_t gcd = std::abs(atIneq(i, 0));
  1122      for (unsigned j = 1; j < numCols - 1; ++j) {
  1123        gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
  1124      }
  1125      if (gcd > 0 && gcd != 1) {
  1126        int64_t gcdI = static_cast<int64_t>(gcd);
  1127        // Tighten the constant term and normalize the constraint by the GCD.
  1128        atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
  1129        for (unsigned j = 0, e = numCols - 1; j < e; ++j)
  1130          atIneq(i, j) /= gcdI;
  1131      }
  1132    }
  1133  }
  1134  
  1135  // Eliminates all identifer variables in column range [posStart, posLimit).
  1136  // Returns the number of variables eliminated.
  1137  unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
  1138                                                       unsigned posLimit) {
  1139    // Return if identifier positions to eliminate are out of range.
  1140    assert(posLimit <= numIds);
  1141    assert(hasConsistentState());
  1142  
  1143    if (posStart >= posLimit)
  1144      return 0;
  1145  
  1146    GCDTightenInequalities();
  1147  
  1148    unsigned pivotCol = 0;
  1149    for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
  1150      // Find a row which has a non-zero coefficient in column 'j'.
  1151      unsigned pivotRow;
  1152      if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
  1153                                       &pivotRow)) {
  1154        // No pivot row in equalities with non-zero at 'pivotCol'.
  1155        if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
  1156                                         &pivotRow)) {
  1157          // If inequalities are also non-zero in 'pivotCol', it can be
  1158          // eliminated.
  1159          continue;
  1160        }
  1161        break;
  1162      }
  1163  
  1164      // Eliminate identifier at 'pivotCol' from each equality row.
  1165      for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
  1166        eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
  1167                                /*isEq=*/true);
  1168        normalizeConstraintByGCD</*isEq=*/true>(this, i);
  1169      }
  1170  
  1171      // Eliminate identifier at 'pivotCol' from each inequality row.
  1172      for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
  1173        eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
  1174                                /*isEq=*/false);
  1175        normalizeConstraintByGCD</*isEq=*/false>(this, i);
  1176      }
  1177      removeEquality(pivotRow);
  1178      GCDTightenInequalities();
  1179    }
  1180    // Update position limit based on number eliminated.
  1181    posLimit = pivotCol;
  1182    // Remove eliminated columns from all constraints.
  1183    removeIdRange(posStart, posLimit);
  1184    return posLimit - posStart;
  1185  }
  1186  
  1187  // Detect the identifier at 'pos' (say id_r) as modulo of another identifier
  1188  // (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
  1189  // could be detected as the floordiv of n. For eg:
  1190  // id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3    <=>
  1191  //                          id_r = id_n mod 4, id_q = id_n floordiv 4.
  1192  // lbConst and ubConst are the constant lower and upper bounds for 'pos' -
  1193  // pre-detected at the caller.
  1194  static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
  1195                          int64_t lbConst, int64_t ubConst,
  1196                          SmallVectorImpl<AffineExpr> *memo) {
  1197    assert(pos < cst.getNumIds() && "invalid position");
  1198  
  1199    // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
  1200    // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
  1201    // and id_q the quotient when dividing id_n by the divisor.
  1202  
  1203    if (lbConst != 0 || ubConst < 1)
  1204      return false;
  1205  
  1206    int64_t divisor = ubConst + 1;
  1207  
  1208    // Now check for: id_r =  id_n - divisor * id_q. As an example, we
  1209    // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
  1210    unsigned seenQuotient = 0, seenDividend = 0;
  1211    int quotientPos = -1, dividendPos = -1;
  1212    for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
  1213      // id_n should have coeff 1 or -1.
  1214      if (std::abs(cst.atEq(r, pos)) != 1)
  1215        continue;
  1216      // constant term should be 0.
  1217      if (cst.atEq(r, cst.getNumCols() - 1) != 0)
  1218        continue;
  1219      unsigned c, f;
  1220      int quotientSign = 1, dividendSign = 1;
  1221      for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
  1222        if (c == pos)
  1223          continue;
  1224        // The coefficient of the quotient should be +/-divisor.
  1225        // TODO(bondhugula): could be extended to detect an affine function for
  1226        // the quotient (i.e., the coeff could be a non-zero multiple of divisor).
  1227        int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
  1228        if (v == divisor || v == -divisor) {
  1229          seenQuotient++;
  1230          quotientPos = c;
  1231          quotientSign = v > 0 ? 1 : -1;
  1232        }
  1233        // The coefficient of the dividend should be +/-1.
  1234        // TODO(bondhugula): could be extended to detect an affine function of
  1235        // the other identifiers as the dividend.
  1236        else if (v == -1 || v == 1) {
  1237          seenDividend++;
  1238          dividendPos = c;
  1239          dividendSign = v < 0 ? 1 : -1;
  1240        } else if (cst.atEq(r, c) != 0) {
  1241          // Cannot be inferred as a mod since the constraint has a coefficient
  1242          // for an identifier that's neither a unit nor the divisor (see TODOs
  1243          // above).
  1244          break;
  1245        }
  1246      }
  1247      if (c < f)
  1248        // Cannot be inferred as a mod since the constraint has a coefficient for
  1249        // an identifier that's neither a unit nor the divisor (see TODOs above).
  1250        continue;
  1251  
  1252      // We are looking for exactly one identifier as the dividend.
  1253      if (seenDividend == 1 && seenQuotient >= 1) {
  1254        if (!(*memo)[dividendPos])
  1255          return false;
  1256        // Successfully detected a mod.
  1257        (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
  1258        auto ub = cst.getConstantUpperBound(dividendPos);
  1259        if (ub.hasValue() && ub.getValue() < divisor)
  1260          // The mod can be optimized away.
  1261          (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
  1262        else
  1263          (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
  1264  
  1265        if (seenQuotient == 1 && !(*memo)[quotientPos])
  1266          // Successfully detected a floordiv as well.
  1267          (*memo)[quotientPos] =
  1268              (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
  1269        return true;
  1270      }
  1271    }
  1272    return false;
  1273  }
  1274  
  1275  // Gather lower and upper bounds for the pos^th identifier.
  1276  static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
  1277                                           unsigned pos,
  1278                                           SmallVectorImpl<unsigned> *lbIndices,
  1279                                           SmallVectorImpl<unsigned> *ubIndices) {
  1280    assert(pos < cst.getNumIds() && "invalid position");
  1281  
  1282    // Gather all lower bounds and upper bounds of the variable. Since the
  1283    // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
  1284    // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
  1285    for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
  1286      if (cst.atIneq(r, pos) >= 1) {
  1287        // Lower bound.
  1288        lbIndices->push_back(r);
  1289      } else if (cst.atIneq(r, pos) <= -1) {
  1290        // Upper bound.
  1291        ubIndices->push_back(r);
  1292      }
  1293    }
  1294  }
  1295  
  1296  // Check if the pos^th identifier can be expressed as a floordiv of an affine
  1297  // function of other identifiers (where the divisor is a positive constant).
  1298  // For eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4.
  1299  bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
  1300                        SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
  1301    assert(pos < cst.getNumIds() && "invalid position");
  1302  
  1303    SmallVector<unsigned, 4> lbIndices, ubIndices;
  1304    getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices);
  1305  
  1306    // Check if any lower bound, upper bound pair is of the form:
  1307    // divisor * id >=  expr - (divisor - 1)    <-- Lower bound for 'id'
  1308    // divisor * id <=  expr                    <-- Upper bound for 'id'
  1309    // Then, 'id' is equivalent to 'expr floordiv divisor'.  (where divisor > 1).
  1310    //
  1311    // For example, if -32*k + 16*i + j >= 0
  1312    //                  32*k - 16*i - j + 31 >= 0   <=>
  1313    //             k = ( 16*i + j ) floordiv 32
  1314    unsigned seenDividends = 0;
  1315    for (auto ubPos : ubIndices) {
  1316      for (auto lbPos : lbIndices) {
  1317        // Check if lower bound's constant term is 'divisor - 1'. The 'divisor'
  1318        // here is cst.atIneq(lbPos, pos) and we already know that it's positive
  1319        // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'.
  1320        if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1)
  1321          continue;
  1322        // Check if upper bound's constant term is 0.
  1323        if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
  1324          continue;
  1325        // For the remaining part, check if the lower bound expr's coeff's are
  1326        // negations of corresponding upper bound ones'.
  1327        unsigned c, f;
  1328        for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
  1329          if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
  1330            break;
  1331          if (c != pos && cst.atIneq(lbPos, c) != 0)
  1332            seenDividends++;
  1333        }
  1334        // Lb coeff's aren't negative of ub coeff's (for the non constant term
  1335        // part).
  1336        if (c < f)
  1337          continue;
  1338        if (seenDividends >= 1) {
  1339          // The divisor is the constant term of the lower bound expression.
  1340          // We already know that cst.atIneq(lbPos, pos) > 0.
  1341          int64_t divisor = cst.atIneq(lbPos, pos);
  1342          // Construct the dividend expression.
  1343          auto dividendExpr = getAffineConstantExpr(0, context);
  1344          unsigned c, f;
  1345          for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
  1346            if (c == pos)
  1347              continue;
  1348            int64_t ubVal = cst.atIneq(ubPos, c);
  1349            if (ubVal == 0)
  1350              continue;
  1351            if (!(*memo)[c])
  1352              break;
  1353            dividendExpr = dividendExpr + ubVal * (*memo)[c];
  1354          }
  1355          // Expression can't be constructed as it depends on a yet unknown
  1356          // identifier.
  1357          // TODO(mlir-team): Visit/compute the identifiers in an order so that
  1358          // this doesn't happen. More complex but much more efficient.
  1359          if (c < f)
  1360            continue;
  1361          // Successfully detected the floordiv.
  1362          (*memo)[pos] = dividendExpr.floorDiv(divisor);
  1363          return true;
  1364        }
  1365      }
  1366    }
  1367    return false;
  1368  }
  1369  
  1370  // Fills an inequality row with the value 'val'.
  1371  static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
  1372                                    int64_t val) {
  1373    for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
  1374      cst->atIneq(r, c) = val;
  1375    }
  1376  }
  1377  
  1378  // Negates an inequality.
  1379  static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
  1380    for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
  1381      cst->atIneq(r, c) = -cst->atIneq(r, c);
  1382    }
  1383  }
  1384  
  1385  // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
  1386  // to check if a constraint is redundant.
  1387  void FlatAffineConstraints::removeRedundantInequalities() {
  1388    SmallVector<bool, 32> redun(getNumInequalities(), false);
  1389    // To check if an inequality is redundant, we replace the inequality by its
  1390    // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
  1391    // system is empty. If it is, the inequality is redundant.
  1392    FlatAffineConstraints tmpCst(*this);
  1393    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  1394      // Change the inequality to its complement.
  1395      negateInequality(&tmpCst, r);
  1396      tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
  1397      if (tmpCst.isEmpty()) {
  1398        redun[r] = true;
  1399        // Zero fill the redundant inequality.
  1400        fillInequality(this, r, /*val=*/0);
  1401        fillInequality(&tmpCst, r, /*val=*/0);
  1402      } else {
  1403        // Reverse the change (to avoid recreating tmpCst each time).
  1404        tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
  1405        negateInequality(&tmpCst, r);
  1406      }
  1407    }
  1408  
  1409    // Scan to get rid of all rows marked redundant, in-place.
  1410    auto copyRow = [&](unsigned src, unsigned dest) {
  1411      if (src == dest)
  1412        return;
  1413      for (unsigned c = 0, e = getNumCols(); c < e; c++) {
  1414        atIneq(dest, c) = atIneq(src, c);
  1415      }
  1416    };
  1417    unsigned pos = 0;
  1418    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  1419      if (!redun[r])
  1420        copyRow(r, pos++);
  1421    }
  1422    inequalities.resize(numReservedCols * pos);
  1423  }
  1424  
  1425  std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
  1426      unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
  1427      ArrayRef<AffineExpr> localExprs, MLIRContext *context) {
  1428    assert(pos + offset < getNumDimIds() && "invalid dim start pos");
  1429    assert(symStartPos >= (pos + offset) && "invalid sym start pos");
  1430    assert(getNumLocalIds() == localExprs.size() &&
  1431           "incorrect local exprs count");
  1432  
  1433    SmallVector<unsigned, 4> lbIndices, ubIndices;
  1434    getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices);
  1435  
  1436    /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
  1437    auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
  1438      b.clear();
  1439      for (unsigned i = 0, e = a.size(); i < e; ++i) {
  1440        if (i < offset || i >= offset + num)
  1441          b.push_back(a[i]);
  1442      }
  1443    };
  1444  
  1445    SmallVector<int64_t, 8> lb, ub;
  1446    SmallVector<AffineExpr, 4> exprs;
  1447    unsigned dimCount = symStartPos - num;
  1448    unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
  1449    exprs.reserve(lbIndices.size());
  1450    // Lower bound expressions.
  1451    for (auto idx : lbIndices) {
  1452      auto ineq = getInequality(idx);
  1453      // Extract the lower bound (in terms of other coeff's + const), i.e., if
  1454      // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
  1455      // - 1.
  1456      addCoeffs(ineq, lb);
  1457      std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
  1458      auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context);
  1459      exprs.push_back(expr);
  1460    }
  1461    auto lbMap =
  1462        exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs);
  1463  
  1464    exprs.clear();
  1465    exprs.reserve(ubIndices.size());
  1466    // Upper bound expressions.
  1467    for (auto idx : ubIndices) {
  1468      auto ineq = getInequality(idx);
  1469      // Extract the upper bound (in terms of other coeff's + const).
  1470      addCoeffs(ineq, ub);
  1471      auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context);
  1472      // Upper bound is exclusive.
  1473      exprs.push_back(expr + 1);
  1474    }
  1475    auto ubMap =
  1476        exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs);
  1477  
  1478    return {lbMap, ubMap};
  1479  }
  1480  
  1481  /// Computes the lower and upper bounds of the first 'num' dimensional
  1482  /// identifiers (starting at 'offset') as affine maps of the remaining
  1483  /// identifiers (dimensional and symbolic identifiers). Local identifiers are
  1484  /// themselves explicitly computed as affine functions of other identifiers in
  1485  /// this process if needed.
  1486  void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
  1487                                             MLIRContext *context,
  1488                                             SmallVectorImpl<AffineMap> *lbMaps,
  1489                                             SmallVectorImpl<AffineMap> *ubMaps) {
  1490    assert(num < getNumDimIds() && "invalid range");
  1491  
  1492    // Basic simplification.
  1493    normalizeConstraintsByGCD();
  1494  
  1495    LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
  1496                            << " identifiers\n");
  1497    LLVM_DEBUG(dump());
  1498  
  1499    // Record computed/detected identifiers.
  1500    SmallVector<AffineExpr, 8> memo(getNumIds());
  1501    // Initialize dimensional and symbolic identifiers.
  1502    for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
  1503      if (i < offset)
  1504        memo[i] = getAffineDimExpr(i, context);
  1505      else if (i >= offset + num)
  1506        memo[i] = getAffineDimExpr(i - num, context);
  1507    }
  1508    for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
  1509      memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
  1510  
  1511    bool changed;
  1512    do {
  1513      changed = false;
  1514      // Identify yet unknown identifiers as constants or mod's / floordiv's of
  1515      // other identifiers if possible.
  1516      for (unsigned pos = 0; pos < getNumIds(); pos++) {
  1517        if (memo[pos])
  1518          continue;
  1519  
  1520        auto lbConst = getConstantLowerBound(pos);
  1521        auto ubConst = getConstantUpperBound(pos);
  1522        if (lbConst.hasValue() && ubConst.hasValue()) {
  1523          // Detect equality to a constant.
  1524          if (lbConst.getValue() == ubConst.getValue()) {
  1525            memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
  1526            changed = true;
  1527            continue;
  1528          }
  1529  
  1530          // Detect an identifier as modulo of another identifier w.r.t a
  1531          // constant.
  1532          if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
  1533                          &memo)) {
  1534            changed = true;
  1535            continue;
  1536          }
  1537        }
  1538  
  1539        // Detect an identifier as floordiv of another identifier w.r.t a
  1540        // constant.
  1541        if (detectAsFloorDiv(*this, pos, &memo, context)) {
  1542          changed = true;
  1543          continue;
  1544        }
  1545  
  1546        // Detect an identifier as an expression of other identifiers.
  1547        unsigned idx;
  1548        if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
  1549          continue;
  1550        }
  1551  
  1552        // Build AffineExpr solving for identifier 'pos' in terms of all others.
  1553        auto expr = getAffineConstantExpr(0, context);
  1554        unsigned j, e;
  1555        for (j = 0, e = getNumIds(); j < e; ++j) {
  1556          if (j == pos)
  1557            continue;
  1558          int64_t c = atEq(idx, j);
  1559          if (c == 0)
  1560            continue;
  1561          // If any of the involved IDs hasn't been found yet, we can't proceed.
  1562          if (!memo[j])
  1563            break;
  1564          expr = expr + memo[j] * c;
  1565        }
  1566        if (j < e)
  1567          // Can't construct expression as it depends on a yet uncomputed
  1568          // identifier.
  1569          continue;
  1570  
  1571        // Add constant term to AffineExpr.
  1572        expr = expr + atEq(idx, getNumIds());
  1573        int64_t vPos = atEq(idx, pos);
  1574        assert(vPos != 0 && "expected non-zero here");
  1575        if (vPos > 0)
  1576          expr = (-expr).floorDiv(vPos);
  1577        else
  1578          // vPos < 0.
  1579          expr = expr.floorDiv(-vPos);
  1580        // Successfully constructed expression.
  1581        memo[pos] = expr;
  1582        changed = true;
  1583      }
  1584      // This loop is guaranteed to reach a fixed point - since once an
  1585      // identifier's explicit form is computed (in memo[pos]), it's not updated
  1586      // again.
  1587    } while (changed);
  1588  
  1589    // Set the lower and upper bound maps for all the identifiers that were
  1590    // computed as affine expressions of the rest as the "detected expr" and
  1591    // "detected expr + 1" respectively; set the undetected ones to null.
  1592    Optional<FlatAffineConstraints> tmpClone;
  1593    for (unsigned pos = 0; pos < num; pos++) {
  1594      unsigned numMapDims = getNumDimIds() - num;
  1595      unsigned numMapSymbols = getNumSymbolIds();
  1596      AffineExpr expr = memo[pos + offset];
  1597      if (expr)
  1598        expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
  1599  
  1600      AffineMap &lbMap = (*lbMaps)[pos];
  1601      AffineMap &ubMap = (*ubMaps)[pos];
  1602  
  1603      if (expr) {
  1604        lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
  1605        ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
  1606      } else {
  1607        // TODO(bondhugula): Whenever there are local identifiers in the
  1608        // dependence constraints, we'll conservatively over-approximate, since we
  1609        // don't always explicitly compute them above (in the while loop).
  1610        if (getNumLocalIds() == 0) {
  1611          // Work on a copy so that we don't update this constraint system.
  1612          if (!tmpClone) {
  1613            tmpClone.emplace(FlatAffineConstraints(*this));
  1614            // Removing redudnant inequalities is necessary so that we don't get
  1615            // redundant loop bounds.
  1616            tmpClone->removeRedundantInequalities();
  1617          }
  1618          std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
  1619              pos, offset, num, getNumDimIds(), {}, context);
  1620        }
  1621  
  1622        // If the above fails, we'll just use the constant lower bound and the
  1623        // constant upper bound (if they exist) as the slice bounds.
  1624        // TODO(b/126426796): being conservative for the moment in cases that
  1625        // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
  1626        // fixed (b/126426796).
  1627        if (!lbMap || lbMap.getNumResults() > 1) {
  1628          LLVM_DEBUG(llvm::dbgs()
  1629                     << "WARNING: Potentially over-approximating slice lb\n");
  1630          auto lbConst = getConstantLowerBound(pos + offset);
  1631          if (lbConst.hasValue()) {
  1632            lbMap = AffineMap::get(
  1633                numMapDims, numMapSymbols,
  1634                getAffineConstantExpr(lbConst.getValue(), context));
  1635          }
  1636        }
  1637        if (!ubMap || ubMap.getNumResults() > 1) {
  1638          LLVM_DEBUG(llvm::dbgs()
  1639                     << "WARNING: Potentially over-approximating slice ub\n");
  1640          auto ubConst = getConstantUpperBound(pos + offset);
  1641          if (ubConst.hasValue()) {
  1642            (ubMap) = AffineMap::get(
  1643                numMapDims, numMapSymbols,
  1644                getAffineConstantExpr(ubConst.getValue() + 1, context));
  1645          }
  1646        }
  1647      }
  1648      LLVM_DEBUG(llvm::dbgs()
  1649                 << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
  1650      LLVM_DEBUG(lbMap.dump(););
  1651      LLVM_DEBUG(llvm::dbgs()
  1652                 << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
  1653      LLVM_DEBUG(ubMap.dump(););
  1654    }
  1655  }
  1656  
  1657  LogicalResult
  1658  FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
  1659                                              ArrayRef<Value *> boundOperands,
  1660                                              bool eq, bool lower) {
  1661    assert(pos < getNumDimAndSymbolIds() && "invalid position");
  1662    // Equality follows the logic of lower bound except that we add an equality
  1663    // instead of an inequality.
  1664    assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
  1665    if (eq)
  1666      lower = true;
  1667  
  1668    // Fully commpose map and operands; canonicalize and simplify so that we
  1669    // transitively get to terminal symbols or loop IVs.
  1670    auto map = boundMap;
  1671    SmallVector<Value *, 4> operands(boundOperands.begin(), boundOperands.end());
  1672    fullyComposeAffineMapAndOperands(&map, &operands);
  1673    map = simplifyAffineMap(map);
  1674    canonicalizeMapAndOperands(&map, &operands);
  1675    for (auto *operand : operands)
  1676      addInductionVarOrTerminalSymbol(operand);
  1677  
  1678    FlatAffineConstraints localVarCst;
  1679    std::vector<SmallVector<int64_t, 8>> flatExprs;
  1680    if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
  1681      LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
  1682      return failure();
  1683    }
  1684  
  1685    // Merge and align with localVarCst.
  1686    if (localVarCst.getNumLocalIds() > 0) {
  1687      // Set values for localVarCst.
  1688      localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
  1689      for (auto *operand : operands) {
  1690        unsigned pos;
  1691        if (findId(*operand, &pos)) {
  1692          if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
  1693            // If the local var cst has this as a dim, turn it into its symbol.
  1694            turnDimIntoSymbol(&localVarCst, *operand);
  1695          } else if (pos < getNumDimIds()) {
  1696            // Or vice versa.
  1697            turnSymbolIntoDim(&localVarCst, *operand);
  1698          }
  1699        }
  1700      }
  1701      mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
  1702      append(localVarCst);
  1703    }
  1704  
  1705    // Record positions of the operands in the constraint system. Need to do
  1706    // this here since the constraint system changes after a bound is added.
  1707    SmallVector<unsigned, 8> positions;
  1708    unsigned numOperands = operands.size();
  1709    for (auto *operand : operands) {
  1710      unsigned pos;
  1711      if (!findId(*operand, &pos))
  1712        assert(0 && "expected to be found");
  1713      positions.push_back(pos);
  1714    }
  1715  
  1716    for (const auto &flatExpr : flatExprs) {
  1717      SmallVector<int64_t, 4> ineq(getNumCols(), 0);
  1718      ineq[pos] = lower ? 1 : -1;
  1719      // Dims and symbols.
  1720      for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
  1721        ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
  1722      }
  1723      // Copy over the local id coefficients.
  1724      unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
  1725      for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
  1726           jj++, j++) {
  1727        ineq[j] =
  1728            lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
  1729      }
  1730      // Constant term.
  1731      ineq[getNumCols() - 1] =
  1732          lower ? -flatExpr[flatExpr.size() - 1]
  1733                // Upper bound in flattenedExpr is an exclusive one.
  1734                : flatExpr[flatExpr.size() - 1] - 1;
  1735      eq ? addEquality(ineq) : addInequality(ineq);
  1736    }
  1737    return success();
  1738  }
  1739  
  1740  // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
  1741  // bounds in 'ubMaps' to each value in `values' that appears in the constraint
  1742  // system. Note that both lower/upper bounds share the same operand list
  1743  // 'operands'.
  1744  // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
  1745  // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
  1746  // Note that both lower/upper bounds use operands from 'operands'.
  1747  // Returns failure for unimplemented cases such as semi-affine expressions or
  1748  // expressions with mod/floordiv.
  1749  LogicalResult FlatAffineConstraints::addSliceBounds(
  1750      ArrayRef<Value *> values, ArrayRef<AffineMap> lbMaps,
  1751      ArrayRef<AffineMap> ubMaps, ArrayRef<Value *> operands) {
  1752    assert(values.size() == lbMaps.size());
  1753    assert(lbMaps.size() == ubMaps.size());
  1754  
  1755    for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
  1756      unsigned pos;
  1757      if (!findId(*values[i], &pos))
  1758        continue;
  1759  
  1760      AffineMap lbMap = lbMaps[i];
  1761      AffineMap ubMap = ubMaps[i];
  1762      assert(!lbMap || lbMap.getNumInputs() == operands.size());
  1763      assert(!ubMap || ubMap.getNumInputs() == operands.size());
  1764  
  1765      // Check if this slice is just an equality along this dimension.
  1766      if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
  1767          ubMap.getNumResults() == 1 &&
  1768          lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
  1769        if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
  1770                                        /*lower=*/true)))
  1771          return failure();
  1772        continue;
  1773      }
  1774  
  1775      if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
  1776                                               /*lower=*/true)))
  1777        return failure();
  1778  
  1779      if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
  1780                                               /*lower=*/false)))
  1781        return failure();
  1782    }
  1783    return success();
  1784  }
  1785  
  1786  void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
  1787    assert(eq.size() == getNumCols());
  1788    unsigned offset = equalities.size();
  1789    equalities.resize(equalities.size() + numReservedCols);
  1790    std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
  1791  }
  1792  
  1793  void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
  1794    assert(inEq.size() == getNumCols());
  1795    unsigned offset = inequalities.size();
  1796    inequalities.resize(inequalities.size() + numReservedCols);
  1797    std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
  1798  }
  1799  
  1800  void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
  1801    assert(pos < getNumCols());
  1802    unsigned offset = inequalities.size();
  1803    inequalities.resize(inequalities.size() + numReservedCols);
  1804    std::fill(inequalities.begin() + offset,
  1805              inequalities.begin() + offset + getNumCols(), 0);
  1806    inequalities[offset + pos] = 1;
  1807    inequalities[offset + getNumCols() - 1] = -lb;
  1808  }
  1809  
  1810  void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
  1811    assert(pos < getNumCols());
  1812    unsigned offset = inequalities.size();
  1813    inequalities.resize(inequalities.size() + numReservedCols);
  1814    std::fill(inequalities.begin() + offset,
  1815              inequalities.begin() + offset + getNumCols(), 0);
  1816    inequalities[offset + pos] = -1;
  1817    inequalities[offset + getNumCols() - 1] = ub;
  1818  }
  1819  
  1820  void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
  1821                                                    int64_t lb) {
  1822    assert(expr.size() == getNumCols());
  1823    unsigned offset = inequalities.size();
  1824    inequalities.resize(inequalities.size() + numReservedCols);
  1825    std::fill(inequalities.begin() + offset,
  1826              inequalities.begin() + offset + getNumCols(), 0);
  1827    std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
  1828    inequalities[offset + getNumCols() - 1] += -lb;
  1829  }
  1830  
  1831  void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
  1832                                                    int64_t ub) {
  1833    assert(expr.size() == getNumCols());
  1834    unsigned offset = inequalities.size();
  1835    inequalities.resize(inequalities.size() + numReservedCols);
  1836    std::fill(inequalities.begin() + offset,
  1837              inequalities.begin() + offset + getNumCols(), 0);
  1838    for (unsigned i = 0, e = getNumCols(); i < e; i++) {
  1839      inequalities[offset + i] = -expr[i];
  1840    }
  1841    inequalities[offset + getNumCols() - 1] += ub;
  1842  }
  1843  
  1844  /// Adds a new local identifier as the floordiv of an affine function of other
  1845  /// identifiers, the coefficients of which are provided in 'dividend' and with
  1846  /// respect to a positive constant 'divisor'. Two constraints are added to the
  1847  /// system to capture equivalence with the floordiv.
  1848  ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
  1849  void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
  1850                                               int64_t divisor) {
  1851    assert(dividend.size() == getNumCols() && "incorrect dividend size");
  1852    assert(divisor > 0 && "positive divisor expected");
  1853  
  1854    addLocalId(getNumLocalIds());
  1855  
  1856    // Add two constraints for this new identifier 'q'.
  1857    SmallVector<int64_t, 8> bound(dividend.size() + 1);
  1858  
  1859    // dividend - q * divisor >= 0
  1860    std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
  1861              bound.begin());
  1862    bound.back() = dividend.back();
  1863    bound[getNumIds() - 1] = -divisor;
  1864    addInequality(bound);
  1865  
  1866    // -dividend +qdivisor * q + divisor - 1 >= 0
  1867    std::transform(bound.begin(), bound.end(), bound.begin(),
  1868                   std::negate<int64_t>());
  1869    bound[bound.size() - 1] += divisor - 1;
  1870    addInequality(bound);
  1871  }
  1872  
  1873  bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const {
  1874    unsigned i = 0;
  1875    for (const auto &mayBeId : ids) {
  1876      if (mayBeId.hasValue() && mayBeId.getValue() == &id) {
  1877        *pos = i;
  1878        return true;
  1879      }
  1880      i++;
  1881    }
  1882    return false;
  1883  }
  1884  
  1885  bool FlatAffineConstraints::containsId(Value &id) const {
  1886    return llvm::any_of(ids, [&](const Optional<Value *> &mayBeId) {
  1887      return mayBeId.hasValue() && mayBeId.getValue() == &id;
  1888    });
  1889  }
  1890  
  1891  void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
  1892    assert(newSymbolCount <= numDims + numSymbols &&
  1893           "invalid separation position");
  1894    numDims = numDims + numSymbols - newSymbolCount;
  1895    numSymbols = newSymbolCount;
  1896  }
  1897  
  1898  /// Sets the specified identifer to a constant value.
  1899  void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
  1900    unsigned offset = equalities.size();
  1901    equalities.resize(equalities.size() + numReservedCols);
  1902    std::fill(equalities.begin() + offset,
  1903              equalities.begin() + offset + getNumCols(), 0);
  1904    equalities[offset + pos] = 1;
  1905    equalities[offset + getNumCols() - 1] = -val;
  1906  }
  1907  
  1908  /// Sets the specified identifer to a constant value; asserts if the id is not
  1909  /// found.
  1910  void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) {
  1911    unsigned pos;
  1912    if (!findId(id, &pos))
  1913      // This is a pre-condition for this method.
  1914      assert(0 && "id not found");
  1915    setIdToConstant(pos, val);
  1916  }
  1917  
  1918  void FlatAffineConstraints::removeEquality(unsigned pos) {
  1919    unsigned numEqualities = getNumEqualities();
  1920    assert(pos < numEqualities);
  1921    unsigned outputIndex = pos * numReservedCols;
  1922    unsigned inputIndex = (pos + 1) * numReservedCols;
  1923    unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
  1924    std::copy(equalities.begin() + inputIndex,
  1925              equalities.begin() + inputIndex + numElemsToCopy,
  1926              equalities.begin() + outputIndex);
  1927    equalities.resize(equalities.size() - numReservedCols);
  1928  }
  1929  
  1930  /// Finds an equality that equates the specified identifier to a constant.
  1931  /// Returns the position of the equality row. If 'symbolic' is set to true,
  1932  /// symbols are also treated like a constant, i.e., an affine function of the
  1933  /// symbols is also treated like a constant.
  1934  static int findEqualityToConstant(const FlatAffineConstraints &cst,
  1935                                    unsigned pos, bool symbolic = false) {
  1936    assert(pos < cst.getNumIds() && "invalid position");
  1937    for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
  1938      int64_t v = cst.atEq(r, pos);
  1939      if (v * v != 1)
  1940        continue;
  1941      unsigned c;
  1942      unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
  1943      // This checks for zeros in all positions other than 'pos' in [0, f)
  1944      for (c = 0; c < f; c++) {
  1945        if (c == pos)
  1946          continue;
  1947        if (cst.atEq(r, c) != 0) {
  1948          // Dependent on another identifier.
  1949          break;
  1950        }
  1951      }
  1952      if (c == f)
  1953        // Equality is free of other identifiers.
  1954        return r;
  1955    }
  1956    return -1;
  1957  }
  1958  
  1959  void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
  1960    assert(pos < getNumIds() && "invalid position");
  1961    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  1962      atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
  1963    }
  1964    for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
  1965      atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
  1966    }
  1967    removeId(pos);
  1968  }
  1969  
  1970  LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
  1971    assert(pos < getNumIds() && "invalid position");
  1972    int rowIdx;
  1973    if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
  1974      return failure();
  1975  
  1976    // atEq(rowIdx, pos) is either -1 or 1.
  1977    assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
  1978    int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
  1979    setAndEliminate(pos, constVal);
  1980    return success();
  1981  }
  1982  
  1983  void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
  1984    for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
  1985      if (failed(constantFoldId(t)))
  1986        t++;
  1987    }
  1988  }
  1989  
  1990  /// Returns the extent (upper bound - lower bound) of the specified
  1991  /// identifier if it is found to be a constant; returns None if it's not a
  1992  /// constant. This methods treats symbolic identifiers specially, i.e.,
  1993  /// it looks for constant differences between affine expressions involving
  1994  /// only the symbolic identifiers. See comments at function definition for
  1995  /// example. 'lb', if provided, is set to the lower bound associated with the
  1996  /// constant difference. Note that 'lb' is purely symbolic and thus will contain
  1997  /// the coefficients of the symbolic identifiers and the constant coefficient.
  1998  //  Egs: 0 <= i <= 15, return 16.
  1999  //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
  2000  //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
  2001  //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
  2002  //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
  2003  Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
  2004      unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor,
  2005      SmallVectorImpl<int64_t> *ub) const {
  2006    assert(pos < getNumDimIds() && "Invalid identifier position");
  2007    assert(getNumLocalIds() == 0);
  2008  
  2009    // TODO(bondhugula): eliminate all remaining dimensional identifiers (other
  2010    // than the one at 'pos' to make this more powerful. Not needed for
  2011    // hyper-rectangular spaces.
  2012  
  2013    // Find an equality for 'pos'^th identifier that equates it to some function
  2014    // of the symbolic identifiers (+ constant).
  2015    int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true);
  2016    if (eqRow != -1) {
  2017      // This identifier can only take a single value.
  2018      if (lb) {
  2019        // Set lb to the symbolic value.
  2020        lb->resize(getNumSymbolIds() + 1);
  2021        if (ub)
  2022          ub->resize(getNumSymbolIds() + 1);
  2023        for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
  2024          int64_t v = atEq(eqRow, pos);
  2025          // atEq(eqRow, pos) is either -1 or 1.
  2026          assert(v * v == 1);
  2027          (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v
  2028                           : -atEq(eqRow, getNumDimIds() + c) / v;
  2029          // Since this is an equality, ub = lb.
  2030          if (ub)
  2031            (*ub)[c] = (*lb)[c];
  2032        }
  2033        assert(lbFloorDivisor &&
  2034               "both lb and divisor or none should be provided");
  2035        *lbFloorDivisor = 1;
  2036      }
  2037      return 1;
  2038    }
  2039  
  2040    // Check if the identifier appears at all in any of the inequalities.
  2041    unsigned r, e;
  2042    for (r = 0, e = getNumInequalities(); r < e; r++) {
  2043      if (atIneq(r, pos) != 0)
  2044        break;
  2045    }
  2046    if (r == e)
  2047      // If it doesn't, there isn't a bound on it.
  2048      return None;
  2049  
  2050    // Positions of constraints that are lower/upper bounds on the variable.
  2051    SmallVector<unsigned, 4> lbIndices, ubIndices;
  2052  
  2053    // Gather all symbolic lower bounds and upper bounds of the variable. Since
  2054    // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a
  2055    // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
  2056    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  2057      unsigned c, f;
  2058      for (c = 0, f = getNumDimIds(); c < f; c++) {
  2059        if (c != pos && atIneq(r, c) != 0)
  2060          break;
  2061      }
  2062      if (c < getNumDimIds())
  2063        // Not a pure symbolic bound.
  2064        continue;
  2065      if (atIneq(r, pos) >= 1)
  2066        // Lower bound.
  2067        lbIndices.push_back(r);
  2068      else if (atIneq(r, pos) <= -1)
  2069        // Upper bound.
  2070        ubIndices.push_back(r);
  2071    }
  2072  
  2073    // TODO(bondhugula): eliminate other dimensional identifiers to make this more
  2074    // powerful. Not needed for hyper-rectangular iteration spaces.
  2075  
  2076    Optional<int64_t> minDiff = None;
  2077    unsigned minLbPosition, minUbPosition;
  2078    for (auto ubPos : ubIndices) {
  2079      for (auto lbPos : lbIndices) {
  2080        // Look for a lower bound and an upper bound that only differ by a
  2081        // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
  2082        // For example, if ii is the pos^th variable, we are looking for
  2083        // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
  2084        // minimum among all such constant differences is kept since that's the
  2085        // constant bounding the extent of the pos^th variable.
  2086        unsigned j, e;
  2087        for (j = 0, e = getNumCols() - 1; j < e; j++)
  2088          if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
  2089            break;
  2090          }
  2091        if (j < getNumCols() - 1)
  2092          continue;
  2093        int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
  2094                                   atIneq(lbPos, getNumCols() - 1) + 1,
  2095                               atIneq(lbPos, pos));
  2096        if (minDiff == None || diff < minDiff) {
  2097          minDiff = diff;
  2098          minLbPosition = lbPos;
  2099          minUbPosition = ubPos;
  2100        }
  2101      }
  2102    }
  2103    if (lb && minDiff.hasValue()) {
  2104      // Set lb to the symbolic lower bound.
  2105      lb->resize(getNumSymbolIds() + 1);
  2106      if (ub)
  2107        ub->resize(getNumSymbolIds() + 1);
  2108      // The lower bound is the ceildiv of the lb constraint over the coefficient
  2109      // of the variable at 'pos'. We express the ceildiv equivalently as a floor
  2110      // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
  2111      // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
  2112      *lbFloorDivisor = atIneq(minLbPosition, pos);
  2113      assert(*lbFloorDivisor == -atIneq(minUbPosition, pos));
  2114      for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
  2115        (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
  2116      }
  2117      if (ub) {
  2118        for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
  2119          (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
  2120      }
  2121      // The lower bound leads to a ceildiv while the upper bound is a floordiv
  2122      // whenever the cofficient at pos != 1. ceildiv (val / d) = floordiv (val +
  2123      // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
  2124      // the constant term for the lower bound.
  2125      (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
  2126    }
  2127    return minDiff;
  2128  }
  2129  
  2130  template <bool isLower>
  2131  Optional<int64_t>
  2132  FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
  2133    assert(pos < getNumIds() && "invalid position");
  2134    // Project to 'pos'.
  2135    projectOut(0, pos);
  2136    projectOut(1, getNumIds() - 1);
  2137    // Check if there's an equality equating the '0'^th identifier to a constant.
  2138    int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
  2139    if (eqRowIdx != -1)
  2140      // atEq(rowIdx, 0) is either -1 or 1.
  2141      return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
  2142  
  2143    // Check if the identifier appears at all in any of the inequalities.
  2144    unsigned r, e;
  2145    for (r = 0, e = getNumInequalities(); r < e; r++) {
  2146      if (atIneq(r, 0) != 0)
  2147        break;
  2148    }
  2149    if (r == e)
  2150      // If it doesn't, there isn't a bound on it.
  2151      return None;
  2152  
  2153    Optional<int64_t> minOrMaxConst = None;
  2154  
  2155    // Take the max across all const lower bounds (or min across all constant
  2156    // upper bounds).
  2157    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  2158      if (isLower) {
  2159        if (atIneq(r, 0) <= 0)
  2160          // Not a lower bound.
  2161          continue;
  2162      } else if (atIneq(r, 0) >= 0) {
  2163        // Not an upper bound.
  2164        continue;
  2165      }
  2166      unsigned c, f;
  2167      for (c = 0, f = getNumCols() - 1; c < f; c++)
  2168        if (c != 0 && atIneq(r, c) != 0)
  2169          break;
  2170      if (c < getNumCols() - 1)
  2171        // Not a constant bound.
  2172        continue;
  2173  
  2174      int64_t boundConst =
  2175          isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
  2176                  : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
  2177      if (isLower) {
  2178        if (minOrMaxConst == None || boundConst > minOrMaxConst)
  2179          minOrMaxConst = boundConst;
  2180      } else {
  2181        if (minOrMaxConst == None || boundConst < minOrMaxConst)
  2182          minOrMaxConst = boundConst;
  2183      }
  2184    }
  2185    return minOrMaxConst;
  2186  }
  2187  
  2188  Optional<int64_t>
  2189  FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
  2190    FlatAffineConstraints tmpCst(*this);
  2191    return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
  2192  }
  2193  
  2194  Optional<int64_t>
  2195  FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
  2196    FlatAffineConstraints tmpCst(*this);
  2197    return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
  2198  }
  2199  
  2200  // A simple (naive and conservative) check for hyper-rectangularlity.
  2201  bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
  2202                                                 unsigned num) const {
  2203    assert(pos < getNumCols() - 1);
  2204    // Check for two non-zero coefficients in the range [pos, pos + sum).
  2205    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  2206      unsigned sum = 0;
  2207      for (unsigned c = pos; c < pos + num; c++) {
  2208        if (atIneq(r, c) != 0)
  2209          sum++;
  2210      }
  2211      if (sum > 1)
  2212        return false;
  2213    }
  2214    for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
  2215      unsigned sum = 0;
  2216      for (unsigned c = pos; c < pos + num; c++) {
  2217        if (atEq(r, c) != 0)
  2218          sum++;
  2219      }
  2220      if (sum > 1)
  2221        return false;
  2222    }
  2223    return true;
  2224  }
  2225  
  2226  void FlatAffineConstraints::print(raw_ostream &os) const {
  2227    assert(hasConsistentState());
  2228    os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
  2229       << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
  2230       << " constraints)\n";
  2231    os << "(";
  2232    for (unsigned i = 0, e = getNumIds(); i < e; i++) {
  2233      if (ids[i] == None)
  2234        os << "None ";
  2235      else
  2236        os << "Value ";
  2237    }
  2238    os << " const)\n";
  2239    for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
  2240      for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
  2241        os << atEq(i, j) << " ";
  2242      }
  2243      os << "= 0\n";
  2244    }
  2245    for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
  2246      for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
  2247        os << atIneq(i, j) << " ";
  2248      }
  2249      os << ">= 0\n";
  2250    }
  2251    os << '\n';
  2252  }
  2253  
  2254  void FlatAffineConstraints::dump() const { print(llvm::errs()); }
  2255  
  2256  /// Removes duplicate constraints, trivially true constraints, and constraints
  2257  /// that can be detected as redundant as a result of differing only in their
  2258  /// constant term part. A constraint of the form <non-negative constant> >= 0 is
  2259  /// considered trivially true.
  2260  //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
  2261  //  remove duplicates in place.
  2262  void FlatAffineConstraints::removeTrivialRedundancy() {
  2263    SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
  2264  
  2265    // A map used to detect redundancy stemming from constraints that only differ
  2266    // in their constant term. The value stored is <row position, const term>
  2267    // for a given row.
  2268    SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
  2269        rowsWithoutConstTerm;
  2270  
  2271    // Check if constraint is of the form <non-negative-constant> >= 0.
  2272    auto isTriviallyValid = [&](unsigned r) -> bool {
  2273      for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
  2274        if (atIneq(r, c) != 0)
  2275          return false;
  2276      }
  2277      return atIneq(r, getNumCols() - 1) >= 0;
  2278    };
  2279  
  2280    // Detect and mark redundant constraints.
  2281    SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
  2282    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  2283      int64_t *rowStart = inequalities.data() + numReservedCols * r;
  2284      auto row = ArrayRef<int64_t>(rowStart, getNumCols());
  2285      if (isTriviallyValid(r) || !rowSet.insert(row).second) {
  2286        redunIneq[r] = true;
  2287        continue;
  2288      }
  2289  
  2290      // Among constraints that only differ in the constant term part, mark
  2291      // everything other than the one with the smallest constant term redundant.
  2292      // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
  2293      // former two are redundant).
  2294      int64_t constTerm = atIneq(r, getNumCols() - 1);
  2295      auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
  2296      const auto &ret =
  2297          rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
  2298      if (!ret.second) {
  2299        // Check if the other constraint has a higher constant term.
  2300        auto &val = ret.first->second;
  2301        if (val.second > constTerm) {
  2302          // The stored row is redundant. Mark it so, and update with this one.
  2303          redunIneq[val.first] = true;
  2304          val = {r, constTerm};
  2305        } else {
  2306          // The one stored makes this one redundant.
  2307          redunIneq[r] = true;
  2308        }
  2309      }
  2310    }
  2311  
  2312    auto copyRow = [&](unsigned src, unsigned dest) {
  2313      if (src == dest)
  2314        return;
  2315      for (unsigned c = 0, e = getNumCols(); c < e; c++) {
  2316        atIneq(dest, c) = atIneq(src, c);
  2317      }
  2318    };
  2319  
  2320    // Scan to get rid of all rows marked redundant, in-place.
  2321    unsigned pos = 0;
  2322    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  2323      if (!redunIneq[r])
  2324        copyRow(r, pos++);
  2325    }
  2326    inequalities.resize(numReservedCols * pos);
  2327  
  2328    // TODO(bondhugula): consider doing this for equalities as well, but probably
  2329    // not worth the savings.
  2330  }
  2331  
  2332  void FlatAffineConstraints::clearAndCopyFrom(
  2333      const FlatAffineConstraints &other) {
  2334    FlatAffineConstraints copy(other);
  2335    std::swap(*this, copy);
  2336    assert(copy.getNumIds() == copy.getIds().size());
  2337  }
  2338  
  2339  void FlatAffineConstraints::removeId(unsigned pos) {
  2340    removeIdRange(pos, pos + 1);
  2341  }
  2342  
  2343  static std::pair<unsigned, unsigned>
  2344  getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
  2345    unsigned numDims = cst.getNumDimIds();
  2346    unsigned numSymbols = cst.getNumSymbolIds();
  2347    unsigned newNumDims, newNumSymbols;
  2348    if (pos < numDims) {
  2349      newNumDims = numDims - 1;
  2350      newNumSymbols = numSymbols;
  2351    } else if (pos < numDims + numSymbols) {
  2352      assert(numSymbols >= 1);
  2353      newNumDims = numDims;
  2354      newNumSymbols = numSymbols - 1;
  2355    } else {
  2356      newNumDims = numDims;
  2357      newNumSymbols = numSymbols;
  2358    }
  2359    return {newNumDims, newNumSymbols};
  2360  }
  2361  
  2362  #undef DEBUG_TYPE
  2363  #define DEBUG_TYPE "fm"
  2364  
  2365  /// Eliminates identifier at the specified position using Fourier-Motzkin
  2366  /// variable elimination. This technique is exact for rational spaces but
  2367  /// conservative (in "rare" cases) for integer spaces. The operation corresponds
  2368  /// to a projection operation yielding the (convex) set of integer points
  2369  /// contained in the rational shadow of the set. An emptiness test that relies
  2370  /// on this method will guarantee emptiness, i.e., it disproves the existence of
  2371  /// a solution if it says it's empty.
  2372  /// If a non-null isResultIntegerExact is passed, it is set to true if the
  2373  /// result is also integer exact. If it's set to false, the obtained solution
  2374  /// *may* not be exact, i.e., it may contain integer points that do not have an
  2375  /// integer pre-image in the original set.
  2376  ///
  2377  /// Eg:
  2378  /// j >= 0, j <= i + 1
  2379  /// i >= 0, i <= N + 1
  2380  /// Eliminating i yields,
  2381  ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
  2382  ///
  2383  /// If darkShadow = true, this method computes the dark shadow on elimination;
  2384  /// the dark shadow is a convex integer subset of the exact integer shadow. A
  2385  /// non-empty dark shadow proves the existence of an integer solution. The
  2386  /// elimination in such a case could however be an under-approximation, and thus
  2387  /// should not be used for scanning sets or used by itself for dependence
  2388  /// checking.
  2389  ///
  2390  /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
  2391  ///            ^
  2392  ///            |
  2393  ///            | * * * * o o
  2394  ///         i  | * * o o o o
  2395  ///            | o * * * * *
  2396  ///            --------------->
  2397  ///                 j ->
  2398  ///
  2399  /// Eliminating i from this system (projecting on the j dimension):
  2400  /// rational shadow / integer light shadow:  1 <= j <= 6
  2401  /// dark shadow:                             3 <= j <= 6
  2402  /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
  2403  /// holes/splinters:                         j = 2
  2404  ///
  2405  /// darkShadow = false, isResultIntegerExact = nullptr are default values.
  2406  // TODO(bondhugula): a slight modification to yield dark shadow version of FM
  2407  // (tightened), which can prove the existence of a solution if there is one.
  2408  void FlatAffineConstraints::FourierMotzkinEliminate(
  2409      unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
  2410    LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
  2411    LLVM_DEBUG(dump());
  2412    assert(pos < getNumIds() && "invalid position");
  2413    assert(hasConsistentState());
  2414  
  2415    // Check if this identifier can be eliminated through a substitution.
  2416    for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
  2417      if (atEq(r, pos) != 0) {
  2418        // Use Gaussian elimination here (since we have an equality).
  2419        LogicalResult ret = gaussianEliminateId(pos);
  2420        (void)ret;
  2421        assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
  2422        LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
  2423        LLVM_DEBUG(dump());
  2424        return;
  2425      }
  2426    }
  2427  
  2428    // A fast linear time tightening.
  2429    GCDTightenInequalities();
  2430  
  2431    // Check if the identifier appears at all in any of the inequalities.
  2432    unsigned r, e;
  2433    for (r = 0, e = getNumInequalities(); r < e; r++) {
  2434      if (atIneq(r, pos) != 0)
  2435        break;
  2436    }
  2437    if (r == getNumInequalities()) {
  2438      // If it doesn't appear, just remove the column and return.
  2439      // TODO(andydavis,bondhugula): refactor removeColumns to use it from here.
  2440      removeId(pos);
  2441      LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
  2442      LLVM_DEBUG(dump());
  2443      return;
  2444    }
  2445  
  2446    // Positions of constraints that are lower bounds on the variable.
  2447    SmallVector<unsigned, 4> lbIndices;
  2448    // Positions of constraints that are lower bounds on the variable.
  2449    SmallVector<unsigned, 4> ubIndices;
  2450    // Positions of constraints that do not involve the variable.
  2451    std::vector<unsigned> nbIndices;
  2452    nbIndices.reserve(getNumInequalities());
  2453  
  2454    // Gather all lower bounds and upper bounds of the variable. Since the
  2455    // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
  2456    // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
  2457    for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
  2458      if (atIneq(r, pos) == 0) {
  2459        // Id does not appear in bound.
  2460        nbIndices.push_back(r);
  2461      } else if (atIneq(r, pos) >= 1) {
  2462        // Lower bound.
  2463        lbIndices.push_back(r);
  2464      } else {
  2465        // Upper bound.
  2466        ubIndices.push_back(r);
  2467      }
  2468    }
  2469  
  2470    // Set the number of dimensions, symbols in the resulting system.
  2471    const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
  2472    unsigned newNumDims = dimsSymbols.first;
  2473    unsigned newNumSymbols = dimsSymbols.second;
  2474  
  2475    SmallVector<Optional<Value *>, 8> newIds;
  2476    newIds.reserve(numIds - 1);
  2477    newIds.append(ids.begin(), ids.begin() + pos);
  2478    newIds.append(ids.begin() + pos + 1, ids.end());
  2479  
  2480    /// Create the new system which has one identifier less.
  2481    FlatAffineConstraints newFac(
  2482        lbIndices.size() * ubIndices.size() + nbIndices.size(),
  2483        getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
  2484        /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
  2485  
  2486    assert(newFac.getIds().size() == newFac.getNumIds());
  2487  
  2488    // This will be used to check if the elimination was integer exact.
  2489    unsigned lcmProducts = 1;
  2490  
  2491    // Let x be the variable we are eliminating.
  2492    // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
  2493    // that c_l, c_u >= 1) we have:
  2494    // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
  2495    // We thus generate a constraint:
  2496    // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
  2497    // Note if c_l = c_u = 1, all integer points captured by the resulting
  2498    // constraint correspond to integer points in the original system (i.e., they
  2499    // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
  2500    // integer exact.
  2501    for (auto ubPos : ubIndices) {
  2502      for (auto lbPos : lbIndices) {
  2503        SmallVector<int64_t, 4> ineq;
  2504        ineq.reserve(newFac.getNumCols());
  2505        int64_t lbCoeff = atIneq(lbPos, pos);
  2506        // Note that in the comments above, ubCoeff is the negation of the
  2507        // coefficient in the canonical form as the view taken here is that of the
  2508        // term being moved to the other size of '>='.
  2509        int64_t ubCoeff = -atIneq(ubPos, pos);
  2510        // TODO(bondhugula): refactor this loop to avoid all branches inside.
  2511        for (unsigned l = 0, e = getNumCols(); l < e; l++) {
  2512          if (l == pos)
  2513            continue;
  2514          assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
  2515          int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
  2516          ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
  2517                         atIneq(lbPos, l) * (lcm / lbCoeff));
  2518          lcmProducts *= lcm;
  2519        }
  2520        if (darkShadow) {
  2521          // The dark shadow is a convex subset of the exact integer shadow. If
  2522          // there is a point here, it proves the existence of a solution.
  2523          ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
  2524        }
  2525        // TODO: we need to have a way to add inequalities in-place in
  2526        // FlatAffineConstraints instead of creating and copying over.
  2527        newFac.addInequality(ineq);
  2528      }
  2529    }
  2530  
  2531    LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
  2532                            << "\n");
  2533    if (lcmProducts == 1 && isResultIntegerExact)
  2534      *isResultIntegerExact = 1;
  2535  
  2536    // Copy over the constraints not involving this variable.
  2537    for (auto nbPos : nbIndices) {
  2538      SmallVector<int64_t, 4> ineq;
  2539      ineq.reserve(getNumCols() - 1);
  2540      for (unsigned l = 0, e = getNumCols(); l < e; l++) {
  2541        if (l == pos)
  2542          continue;
  2543        ineq.push_back(atIneq(nbPos, l));
  2544      }
  2545      newFac.addInequality(ineq);
  2546    }
  2547  
  2548    assert(newFac.getNumConstraints() ==
  2549           lbIndices.size() * ubIndices.size() + nbIndices.size());
  2550  
  2551    // Copy over the equalities.
  2552    for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
  2553      SmallVector<int64_t, 4> eq;
  2554      eq.reserve(newFac.getNumCols());
  2555      for (unsigned l = 0, e = getNumCols(); l < e; l++) {
  2556        if (l == pos)
  2557          continue;
  2558        eq.push_back(atEq(r, l));
  2559      }
  2560      newFac.addEquality(eq);
  2561    }
  2562  
  2563    // GCD tightening and normalization allows detection of more trivially
  2564    // redundant constraints.
  2565    newFac.GCDTightenInequalities();
  2566    newFac.normalizeConstraintsByGCD();
  2567    newFac.removeTrivialRedundancy();
  2568    clearAndCopyFrom(newFac);
  2569    LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
  2570    LLVM_DEBUG(dump());
  2571  }
  2572  
  2573  #undef DEBUG_TYPE
  2574  #define DEBUG_TYPE "affine-structures"
  2575  
  2576  void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
  2577    if (num == 0)
  2578      return;
  2579  
  2580    // 'pos' can be at most getNumCols() - 2 if num > 0.
  2581    assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
  2582    assert(pos + num < getNumCols() && "invalid range");
  2583  
  2584    // Eliminate as many identifiers as possible using Gaussian elimination.
  2585    unsigned currentPos = pos;
  2586    unsigned numToEliminate = num;
  2587    unsigned numGaussianEliminated = 0;
  2588  
  2589    while (currentPos < getNumIds()) {
  2590      unsigned curNumEliminated =
  2591          gaussianEliminateIds(currentPos, currentPos + numToEliminate);
  2592      ++currentPos;
  2593      numToEliminate -= curNumEliminated + 1;
  2594      numGaussianEliminated += curNumEliminated;
  2595    }
  2596  
  2597    // Eliminate the remaining using Fourier-Motzkin.
  2598    for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
  2599      unsigned numToEliminate = num - numGaussianEliminated - i;
  2600      FourierMotzkinEliminate(
  2601          getBestIdToEliminate(*this, pos, pos + numToEliminate));
  2602    }
  2603  
  2604    // Fast/trivial simplifications.
  2605    GCDTightenInequalities();
  2606    // Normalize constraints after tightening since the latter impacts this, but
  2607    // not the other way round.
  2608    normalizeConstraintsByGCD();
  2609  }
  2610  
  2611  void FlatAffineConstraints::projectOut(Value *id) {
  2612    unsigned pos;
  2613    bool ret = findId(*id, &pos);
  2614    assert(ret);
  2615    (void)ret;
  2616    FourierMotzkinEliminate(pos);
  2617  }
  2618  
  2619  bool FlatAffineConstraints::isRangeOneToOne(unsigned start,
  2620                                              unsigned limit) const {
  2621    assert(start <= getNumIds() - 1 && "invalid start position");
  2622    assert(limit > start && limit <= getNumIds() && "invalid limit");
  2623  
  2624    FlatAffineConstraints tmpCst(*this);
  2625  
  2626    if (start != 0) {
  2627      // Move [start, limit) to the left.
  2628      for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
  2629        for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
  2630          if (c >= start && c < limit)
  2631            tmpCst.atIneq(r, c - start) = atIneq(r, c);
  2632          else if (c < start)
  2633            tmpCst.atIneq(r, c + limit - start) = atIneq(r, c);
  2634          else
  2635            tmpCst.atIneq(r, c) = atIneq(r, c);
  2636        }
  2637      }
  2638      for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) {
  2639        for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
  2640          if (c >= start && c < limit)
  2641            tmpCst.atEq(r, c - start) = atEq(r, c);
  2642          else if (c < start)
  2643            tmpCst.atEq(r, c + limit - start) = atEq(r, c);
  2644          else
  2645            tmpCst.atEq(r, c) = atEq(r, c);
  2646        }
  2647      }
  2648    }
  2649  
  2650    // Mark everything to the right as symbols so that we can check the extents in
  2651    // a symbolic way below.
  2652    tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start));
  2653  
  2654    // Check if the extents of all the specified dimensions are just one (when
  2655    // treating the rest as symbols).
  2656    for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) {
  2657      auto extent = tmpCst.getConstantBoundOnDimSize(pos);
  2658      if (!extent.hasValue() || extent.getValue() != 1)
  2659        return false;
  2660    }
  2661    return true;
  2662  }
  2663  
  2664  void FlatAffineConstraints::clearConstraints() {
  2665    equalities.clear();
  2666    inequalities.clear();
  2667  }
  2668  
  2669  namespace {
  2670  
  2671  enum BoundCmpResult { Greater, Less, Equal, Unknown };
  2672  
  2673  /// Compares two affine bounds whose coefficients are provided in 'first' and
  2674  /// 'second'. The last coefficient is the constant term.
  2675  static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
  2676    assert(a.size() == b.size());
  2677  
  2678    // For the bounds to be comparable, their corresponding identifier
  2679    // coefficients should be equal; the constant terms are then compared to
  2680    // determine less/greater/equal.
  2681  
  2682    if (!std::equal(a.begin(), a.end() - 1, b.begin()))
  2683      return Unknown;
  2684  
  2685    if (a.back() == b.back())
  2686      return Equal;
  2687  
  2688    return a.back() < b.back() ? Less : Greater;
  2689  }
  2690  } // namespace
  2691  
  2692  // Computes the bounding box with respect to 'other' by finding the min of the
  2693  // lower bounds and the max of the upper bounds along each of the dimensions.
  2694  LogicalResult
  2695  FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
  2696    assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
  2697    assert(otherCst.getIds()
  2698               .slice(0, getNumDimIds())
  2699               .equals(getIds().slice(0, getNumDimIds())) &&
  2700           "dim values mismatch");
  2701    assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
  2702    assert(getNumLocalIds() == 0 && "local ids not supported yet here");
  2703  
  2704    Optional<FlatAffineConstraints> otherCopy;
  2705    if (!areIdsAligned(*this, otherCst)) {
  2706      otherCopy.emplace(FlatAffineConstraints(otherCst));
  2707      mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
  2708    }
  2709  
  2710    const auto &other = otherCopy ? *otherCopy : otherCst;
  2711  
  2712    std::vector<SmallVector<int64_t, 8>> boundingLbs;
  2713    std::vector<SmallVector<int64_t, 8>> boundingUbs;
  2714    boundingLbs.reserve(2 * getNumDimIds());
  2715    boundingUbs.reserve(2 * getNumDimIds());
  2716  
  2717    // To hold lower and upper bounds for each dimension.
  2718    SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
  2719    // To compute min of lower bounds and max of upper bounds for each dimension.
  2720    SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
  2721    SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
  2722    // To compute final new lower and upper bounds for the union.
  2723    SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
  2724  
  2725    int64_t lbFloorDivisor, otherLbFloorDivisor;
  2726    for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
  2727      auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
  2728      if (!extent.hasValue())
  2729        // TODO(bondhugula): symbolic extents when necessary.
  2730        // TODO(bondhugula): handle union if a dimension is unbounded.
  2731        return failure();
  2732  
  2733      auto otherExtent = other.getConstantBoundOnDimSize(
  2734          d, &otherLb, &otherLbFloorDivisor, &otherUb);
  2735      if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
  2736        // TODO(bondhugula): symbolic extents when necessary.
  2737        return failure();
  2738  
  2739      assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
  2740  
  2741      auto res = compareBounds(lb, otherLb);
  2742      // Identify min.
  2743      if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
  2744        minLb = lb;
  2745        // Since the divisor is for a floordiv, we need to convert to ceildiv,
  2746        // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
  2747        // div * i >= expr - div + 1.
  2748        minLb.back() -= lbFloorDivisor - 1;
  2749      } else if (res == BoundCmpResult::Greater) {
  2750        minLb = otherLb;
  2751        minLb.back() -= otherLbFloorDivisor - 1;
  2752      } else {
  2753        // Uncomparable - check for constant lower/upper bounds.
  2754        auto constLb = getConstantLowerBound(d);
  2755        auto constOtherLb = other.getConstantLowerBound(d);
  2756        if (!constLb.hasValue() || !constOtherLb.hasValue())
  2757          return failure();
  2758        std::fill(minLb.begin(), minLb.end(), 0);
  2759        minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
  2760      }
  2761  
  2762      // Do the same for ub's but max of upper bounds. Identify max.
  2763      auto uRes = compareBounds(ub, otherUb);
  2764      if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
  2765        maxUb = ub;
  2766      } else if (uRes == BoundCmpResult::Less) {
  2767        maxUb = otherUb;
  2768      } else {
  2769        // Uncomparable - check for constant lower/upper bounds.
  2770        auto constUb = getConstantUpperBound(d);
  2771        auto constOtherUb = other.getConstantUpperBound(d);
  2772        if (!constUb.hasValue() || !constOtherUb.hasValue())
  2773          return failure();
  2774        std::fill(maxUb.begin(), maxUb.end(), 0);
  2775        maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
  2776      }
  2777  
  2778      std::fill(newLb.begin(), newLb.end(), 0);
  2779      std::fill(newUb.begin(), newUb.end(), 0);
  2780  
  2781      // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
  2782      // and so it's the divisor for newLb and newUb as well.
  2783      newLb[d] = lbFloorDivisor;
  2784      newUb[d] = -lbFloorDivisor;
  2785      // Copy over the symbolic part + constant term.
  2786      std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
  2787      std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
  2788                     newLb.begin() + getNumDimIds(), std::negate<int64_t>());
  2789      std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
  2790  
  2791      boundingLbs.push_back(newLb);
  2792      boundingUbs.push_back(newUb);
  2793    }
  2794  
  2795    // Clear all constraints and add the lower/upper bounds for the bounding box.
  2796    clearConstraints();
  2797    for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
  2798      addInequality(boundingLbs[d]);
  2799      addInequality(boundingUbs[d]);
  2800    }
  2801    // TODO(mlir-team): copy over pure symbolic constraints from this and 'other'
  2802    // over to the union (since the above are just the union along dimensions); we
  2803    // shouldn't be discarding any other constraints on the symbols.
  2804  
  2805    return success();
  2806  }