github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Analysis/AffineStructures.cpp (about) 1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // Structures for affine/polyhedral analysis of MLIR functions. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Analysis/AffineStructures.h" 23 #include "mlir/Dialect/AffineOps/AffineOps.h" 24 #include "mlir/Dialect/StandardOps/Ops.h" 25 #include "mlir/IR/AffineExprVisitor.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/IR/IntegerSet.h" 28 #include "mlir/IR/Operation.h" 29 #include "mlir/Support/MathExtras.h" 30 #include "llvm/ADT/DenseSet.h" 31 #include "llvm/ADT/SmallPtrSet.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 35 #define DEBUG_TYPE "affine-structures" 36 37 using namespace mlir; 38 using llvm::SmallDenseMap; 39 using llvm::SmallDenseSet; 40 using llvm::SmallPtrSet; 41 42 namespace { 43 44 // See comments for SimpleAffineExprFlattener. 45 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording 46 // constraint information associated with mod's, floordiv's, and ceildiv's 47 // in FlatAffineConstraints 'localVarCst'. 48 struct AffineExprFlattener : public SimpleAffineExprFlattener { 49 public: 50 // Constraints connecting newly introduced local variables (for mod's and 51 // div's) to existing (dimensional and symbolic) ones. These are always 52 // inequalities. 53 FlatAffineConstraints localVarCst; 54 55 AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx) 56 : SimpleAffineExprFlattener(nDims, nSymbols) { 57 localVarCst.reset(nDims, nSymbols, /*numLocals=*/0); 58 } 59 60 private: 61 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 62 // The local identifier added is always a floordiv of a pure add/mul affine 63 // function of other identifiers, coefficients of which are specified in 64 // `dividend' and with respect to the positive constant `divisor'. localExpr 65 // is the simplified tree expression (AffineExpr) corresponding to the 66 // quantifier. 67 void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, 68 AffineExpr localExpr) override { 69 SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); 70 // Update localVarCst. 71 localVarCst.addLocalFloorDiv(dividend, divisor); 72 } 73 }; 74 75 } // end anonymous namespace 76 77 // Flattens the expressions in map. Returns failure if 'expr' was unable to be 78 // flattened (i.e., semi-affine expressions not handled yet). 79 static LogicalResult getFlattenedAffineExprs( 80 ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols, 81 std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, 82 FlatAffineConstraints *localVarCst) { 83 if (exprs.empty()) { 84 localVarCst->reset(numDims, numSymbols); 85 return success(); 86 } 87 88 AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); 89 // Use the same flattener to simplify each expression successively. This way 90 // local identifiers / expressions are shared. 91 for (auto expr : exprs) { 92 if (!expr.isPureAffine()) 93 return failure(); 94 95 flattener.walkPostOrder(expr); 96 } 97 98 assert(flattener.operandExprStack.size() == exprs.size()); 99 flattenedExprs->clear(); 100 flattenedExprs->assign(flattener.operandExprStack.begin(), 101 flattener.operandExprStack.end()); 102 103 if (localVarCst) { 104 localVarCst->clearAndCopyFrom(flattener.localVarCst); 105 } 106 107 return success(); 108 } 109 110 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to 111 // be flattened (semi-affine expressions not handled yet). 112 LogicalResult 113 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, 114 unsigned numSymbols, 115 llvm::SmallVectorImpl<int64_t> *flattenedExpr, 116 FlatAffineConstraints *localVarCst) { 117 std::vector<SmallVector<int64_t, 8>> flattenedExprs; 118 LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, 119 &flattenedExprs, localVarCst); 120 *flattenedExpr = flattenedExprs[0]; 121 return ret; 122 } 123 124 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be 125 /// flattened (i.e., semi-affine expressions not handled yet). 126 LogicalResult mlir::getFlattenedAffineExprs( 127 AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, 128 FlatAffineConstraints *localVarCst) { 129 if (map.getNumResults() == 0) { 130 localVarCst->reset(map.getNumDims(), map.getNumSymbols()); 131 return success(); 132 } 133 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), 134 map.getNumSymbols(), flattenedExprs, 135 localVarCst); 136 } 137 138 LogicalResult mlir::getFlattenedAffineExprs( 139 IntegerSet set, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs, 140 FlatAffineConstraints *localVarCst) { 141 if (set.getNumConstraints() == 0) { 142 localVarCst->reset(set.getNumDims(), set.getNumSymbols()); 143 return success(); 144 } 145 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), 146 set.getNumSymbols(), flattenedExprs, 147 localVarCst); 148 } 149 150 //===----------------------------------------------------------------------===// 151 // MutableAffineMap. 152 //===----------------------------------------------------------------------===// 153 154 MutableAffineMap::MutableAffineMap(AffineMap map) 155 : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), 156 // A map always has at least 1 result by construction 157 context(map.getResult(0).getContext()) { 158 for (auto result : map.getResults()) 159 results.push_back(result); 160 } 161 162 void MutableAffineMap::reset(AffineMap map) { 163 results.clear(); 164 numDims = map.getNumDims(); 165 numSymbols = map.getNumSymbols(); 166 // A map always has at least 1 result by construction 167 context = map.getResult(0).getContext(); 168 for (auto result : map.getResults()) 169 results.push_back(result); 170 } 171 172 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { 173 if (results[idx].isMultipleOf(factor)) 174 return true; 175 176 // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to 177 // complete this (for a more powerful analysis). 178 return false; 179 } 180 181 // Simplifies the result affine expressions of this map. The expressions have to 182 // be pure for the simplification implemented. 183 void MutableAffineMap::simplify() { 184 // Simplify each of the results if possible. 185 // TODO(ntv): functional-style map 186 for (unsigned i = 0, e = getNumResults(); i < e; i++) { 187 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); 188 } 189 } 190 191 AffineMap MutableAffineMap::getAffineMap() const { 192 return AffineMap::get(numDims, numSymbols, results); 193 } 194 195 MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) 196 : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()) { 197 // TODO(bondhugula) 198 } 199 200 // Universal set. 201 MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, 202 MLIRContext *context) 203 : numDims(numDims), numSymbols(numSymbols) {} 204 205 //===----------------------------------------------------------------------===// 206 // AffineValueMap. 207 //===----------------------------------------------------------------------===// 208 209 AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value *> operands, 210 ArrayRef<Value *> results) 211 : map(map), operands(operands.begin(), operands.end()), 212 results(results.begin(), results.end()) {} 213 214 AffineValueMap::AffineValueMap(AffineApplyOp applyOp) 215 : map(applyOp.getAffineMap()), 216 operands(applyOp.operand_begin(), applyOp.operand_end()) { 217 results.push_back(applyOp.getResult()); 218 } 219 220 AffineValueMap::AffineValueMap(AffineBound bound) 221 : map(bound.getMap()), 222 operands(bound.operand_begin(), bound.operand_end()) {} 223 224 void AffineValueMap::reset(AffineMap map, ArrayRef<Value *> operands, 225 ArrayRef<Value *> results) { 226 this->map.reset(map); 227 this->operands.assign(operands.begin(), operands.end()); 228 this->results.assign(results.begin(), results.end()); 229 } 230 231 // Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in 232 // 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. 233 static bool findIndex(Value *valueToMatch, ArrayRef<Value *> valuesToSearch, 234 unsigned indexStart, unsigned *indexOfMatch) { 235 unsigned size = valuesToSearch.size(); 236 for (unsigned i = indexStart; i < size; ++i) { 237 if (valueToMatch == valuesToSearch[i]) { 238 *indexOfMatch = i; 239 return true; 240 } 241 } 242 return false; 243 } 244 245 inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { 246 return map.isMultipleOf(idx, factor); 247 } 248 249 /// This method uses the invariant that operands are always positionally aligned 250 /// with the AffineDimExpr in the underlying AffineMap. 251 bool AffineValueMap::isFunctionOf(unsigned idx, Value *value) const { 252 unsigned index; 253 if (!findIndex(value, operands, /*indexStart=*/0, &index)) { 254 return false; 255 } 256 auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx); 257 // TODO(ntv): this is better implemented on a flattened representation. 258 // At least for now it is conservative. 259 return expr.isFunctionOfDim(index); 260 } 261 262 Value *AffineValueMap::getOperand(unsigned i) const { 263 return static_cast<Value *>(operands[i]); 264 } 265 266 ArrayRef<Value *> AffineValueMap::getOperands() const { 267 return ArrayRef<Value *>(operands); 268 } 269 270 AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } 271 272 AffineValueMap::~AffineValueMap() {} 273 274 //===----------------------------------------------------------------------===// 275 // FlatAffineConstraints. 276 //===----------------------------------------------------------------------===// 277 278 // Copy constructor. 279 FlatAffineConstraints::FlatAffineConstraints( 280 const FlatAffineConstraints &other) { 281 numReservedCols = other.numReservedCols; 282 numDims = other.getNumDimIds(); 283 numSymbols = other.getNumSymbolIds(); 284 numIds = other.getNumIds(); 285 286 auto otherIds = other.getIds(); 287 ids.reserve(numReservedCols); 288 ids.append(otherIds.begin(), otherIds.end()); 289 290 unsigned numReservedEqualities = other.getNumReservedEqualities(); 291 unsigned numReservedInequalities = other.getNumReservedInequalities(); 292 293 equalities.reserve(numReservedEqualities * numReservedCols); 294 inequalities.reserve(numReservedInequalities * numReservedCols); 295 296 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { 297 addInequality(other.getInequality(r)); 298 } 299 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { 300 addEquality(other.getEquality(r)); 301 } 302 } 303 304 // Clones this object. 305 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const { 306 return std::make_unique<FlatAffineConstraints>(*this); 307 } 308 309 // Construct from an IntegerSet. 310 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) 311 : numReservedCols(set.getNumOperands() + 1), 312 numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), 313 numSymbols(set.getNumSymbols()) { 314 equalities.reserve(set.getNumEqualities() * numReservedCols); 315 inequalities.reserve(set.getNumInequalities() * numReservedCols); 316 ids.resize(numIds, None); 317 318 // Flatten expressions and add them to the constraint system. 319 std::vector<SmallVector<int64_t, 8>> flatExprs; 320 FlatAffineConstraints localVarCst; 321 if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { 322 assert(false && "flattening unimplemented for semi-affine integer sets"); 323 return; 324 } 325 assert(flatExprs.size() == set.getNumConstraints()); 326 for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) { 327 addLocalId(getNumLocalIds()); 328 } 329 330 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { 331 const auto &flatExpr = flatExprs[i]; 332 assert(flatExpr.size() == getNumCols()); 333 if (set.getEqFlags()[i]) { 334 addEquality(flatExpr); 335 } else { 336 addInequality(flatExpr); 337 } 338 } 339 // Add the other constraints involving local id's from flattening. 340 append(localVarCst); 341 } 342 343 void FlatAffineConstraints::reset(unsigned numReservedInequalities, 344 unsigned numReservedEqualities, 345 unsigned newNumReservedCols, 346 unsigned newNumDims, unsigned newNumSymbols, 347 unsigned newNumLocals, 348 ArrayRef<Value *> idArgs) { 349 assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && 350 "minimum 1 column"); 351 numReservedCols = newNumReservedCols; 352 numDims = newNumDims; 353 numSymbols = newNumSymbols; 354 numIds = numDims + numSymbols + newNumLocals; 355 assert(idArgs.empty() || idArgs.size() == numIds); 356 357 clearConstraints(); 358 if (numReservedEqualities >= 1) 359 equalities.reserve(newNumReservedCols * numReservedEqualities); 360 if (numReservedInequalities >= 1) 361 inequalities.reserve(newNumReservedCols * numReservedInequalities); 362 if (idArgs.empty()) { 363 ids.resize(numIds, None); 364 } else { 365 ids.assign(idArgs.begin(), idArgs.end()); 366 } 367 } 368 369 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, 370 unsigned newNumLocals, 371 ArrayRef<Value *> idArgs) { 372 reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, 373 newNumSymbols, newNumLocals, idArgs); 374 } 375 376 void FlatAffineConstraints::append(const FlatAffineConstraints &other) { 377 assert(other.getNumCols() == getNumCols()); 378 assert(other.getNumDimIds() == getNumDimIds()); 379 assert(other.getNumSymbolIds() == getNumSymbolIds()); 380 381 inequalities.reserve(inequalities.size() + 382 other.getNumInequalities() * numReservedCols); 383 equalities.reserve(equalities.size() + 384 other.getNumEqualities() * numReservedCols); 385 386 for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { 387 addInequality(other.getInequality(r)); 388 } 389 for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { 390 addEquality(other.getEquality(r)); 391 } 392 } 393 394 void FlatAffineConstraints::addLocalId(unsigned pos) { 395 addId(IdKind::Local, pos); 396 } 397 398 void FlatAffineConstraints::addDimId(unsigned pos, Value *id) { 399 addId(IdKind::Dimension, pos, id); 400 } 401 402 void FlatAffineConstraints::addSymbolId(unsigned pos, Value *id) { 403 addId(IdKind::Symbol, pos, id); 404 } 405 406 /// Adds a dimensional identifier. The added column is initialized to 407 /// zero. 408 void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { 409 if (kind == IdKind::Dimension) { 410 assert(pos <= getNumDimIds()); 411 } else if (kind == IdKind::Symbol) { 412 assert(pos <= getNumSymbolIds()); 413 } else { 414 assert(pos <= getNumLocalIds()); 415 } 416 417 unsigned oldNumReservedCols = numReservedCols; 418 419 // Check if a resize is necessary. 420 if (getNumCols() + 1 > numReservedCols) { 421 equalities.resize(getNumEqualities() * (getNumCols() + 1)); 422 inequalities.resize(getNumInequalities() * (getNumCols() + 1)); 423 numReservedCols++; 424 } 425 426 int absolutePos; 427 428 if (kind == IdKind::Dimension) { 429 absolutePos = pos; 430 numDims++; 431 } else if (kind == IdKind::Symbol) { 432 absolutePos = pos + getNumDimIds(); 433 numSymbols++; 434 } else { 435 absolutePos = pos + getNumDimIds() + getNumSymbolIds(); 436 } 437 numIds++; 438 439 // Note that getNumCols() now will already return the new size, which will be 440 // at least one. 441 int numInequalities = static_cast<int>(getNumInequalities()); 442 int numEqualities = static_cast<int>(getNumEqualities()); 443 int numCols = static_cast<int>(getNumCols()); 444 for (int r = numInequalities - 1; r >= 0; r--) { 445 for (int c = numCols - 2; c >= 0; c--) { 446 if (c < absolutePos) 447 atIneq(r, c) = inequalities[r * oldNumReservedCols + c]; 448 else 449 atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c]; 450 } 451 atIneq(r, absolutePos) = 0; 452 } 453 454 for (int r = numEqualities - 1; r >= 0; r--) { 455 for (int c = numCols - 2; c >= 0; c--) { 456 // All values in column absolutePositions < absolutePos have the same 457 // coordinates in the 2-d view of the coefficient buffer. 458 if (c < absolutePos) 459 atEq(r, c) = equalities[r * oldNumReservedCols + c]; 460 else 461 // Those at absolutePosition >= absolutePos, get a shifted 462 // absolutePosition. 463 atEq(r, c + 1) = equalities[r * oldNumReservedCols + c]; 464 } 465 // Initialize added dimension to zero. 466 atEq(r, absolutePos) = 0; 467 } 468 469 // If an 'id' is provided, insert it; otherwise use None. 470 if (id) { 471 ids.insert(ids.begin() + absolutePos, id); 472 } else { 473 ids.insert(ids.begin() + absolutePos, None); 474 } 475 assert(ids.size() == getNumIds()); 476 } 477 478 /// Checks if two constraint systems are in the same space, i.e., if they are 479 /// associated with the same set of identifiers, appearing in the same order. 480 static bool areIdsAligned(const FlatAffineConstraints &A, 481 const FlatAffineConstraints &B) { 482 return A.getNumDimIds() == B.getNumDimIds() && 483 A.getNumSymbolIds() == B.getNumSymbolIds() && 484 A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds()); 485 } 486 487 /// Calls areIdsAligned to check if two constraint systems have the same set 488 /// of identifiers in the same order. 489 bool FlatAffineConstraints::areIdsAlignedWithOther( 490 const FlatAffineConstraints &other) { 491 return areIdsAligned(*this, other); 492 } 493 494 /// Checks if the SSA values associated with `cst''s identifiers are unique. 495 static bool LLVM_ATTRIBUTE_UNUSED 496 areIdsUnique(const FlatAffineConstraints &cst) { 497 SmallPtrSet<Value *, 8> uniqueIds; 498 for (auto id : cst.getIds()) { 499 if (id.hasValue() && !uniqueIds.insert(id.getValue()).second) 500 return false; 501 } 502 return true; 503 } 504 505 // Swap the posA^th identifier with the posB^th identifier. 506 static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) { 507 assert(posA < A->getNumIds() && "invalid position A"); 508 assert(posB < A->getNumIds() && "invalid position B"); 509 510 if (posA == posB) 511 return; 512 513 for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) { 514 std::swap(A->atIneq(r, posA), A->atIneq(r, posB)); 515 } 516 for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) { 517 std::swap(A->atEq(r, posA), A->atEq(r, posB)); 518 } 519 std::swap(A->getId(posA), A->getId(posB)); 520 } 521 522 /// Merge and align the identifiers of A and B starting at 'offset', so that 523 /// both constraint systems get the union of the contained identifiers that is 524 /// dimension-wise and symbol-wise unique; both constraint systems are updated 525 /// so that they have the union of all identifiers, with A's original 526 /// identifiers appearing first followed by any of B's identifiers that didn't 527 /// appear in A. Local identifiers of each system are by design separate/local 528 /// and are placed one after other (A's followed by B's). 529 // Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M]) 530 // Output: both A, B have (%i, %j, %k) [%M, %N, %P] 531 // 532 static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, 533 FlatAffineConstraints *B) { 534 assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds()); 535 // A merge/align isn't meaningful if a cst's ids aren't distinct. 536 assert(areIdsUnique(*A) && "A's id values aren't unique"); 537 assert(areIdsUnique(*B) && "B's id values aren't unique"); 538 539 assert(std::all_of(A->getIds().begin() + offset, 540 A->getIds().begin() + A->getNumDimAndSymbolIds(), 541 [](Optional<Value *> id) { return id.hasValue(); })); 542 543 assert(std::all_of(B->getIds().begin() + offset, 544 B->getIds().begin() + B->getNumDimAndSymbolIds(), 545 [](Optional<Value *> id) { return id.hasValue(); })); 546 547 // Place local id's of A after local id's of B. 548 for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) { 549 B->addLocalId(0); 550 } 551 for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e; 552 t++) { 553 A->addLocalId(A->getNumLocalIds()); 554 } 555 556 SmallVector<Value *, 4> aDimValues, aSymValues; 557 A->getIdValues(offset, A->getNumDimIds(), &aDimValues); 558 A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues); 559 { 560 // Merge dims from A into B. 561 unsigned d = offset; 562 for (auto *aDimValue : aDimValues) { 563 unsigned loc; 564 if (B->findId(*aDimValue, &loc)) { 565 assert(loc >= offset && "A's dim appears in B's aligned range"); 566 assert(loc < B->getNumDimIds() && 567 "A's dim appears in B's non-dim position"); 568 swapId(B, d, loc); 569 } else { 570 B->addDimId(d); 571 B->setIdValue(d, aDimValue); 572 } 573 d++; 574 } 575 576 // Dimensions that are in B, but not in A, are added at the end. 577 for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) { 578 A->addDimId(A->getNumDimIds()); 579 A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t)); 580 } 581 } 582 { 583 // Merge symbols: merge A's symbols into B first. 584 unsigned s = B->getNumDimIds(); 585 for (auto *aSymValue : aSymValues) { 586 unsigned loc; 587 if (B->findId(*aSymValue, &loc)) { 588 assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() && 589 "A's symbol appears in B's non-symbol position"); 590 swapId(B, s, loc); 591 } else { 592 B->addSymbolId(s - B->getNumDimIds()); 593 B->setIdValue(s, aSymValue); 594 } 595 s++; 596 } 597 // Symbols that are in B, but not in A, are added at the end. 598 for (unsigned t = A->getNumDimAndSymbolIds(), 599 e = B->getNumDimAndSymbolIds(); 600 t < e; t++) { 601 A->addSymbolId(A->getNumSymbolIds()); 602 A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t)); 603 } 604 } 605 assert(areIdsAligned(*A, *B) && "IDs expected to be aligned"); 606 } 607 608 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'. 609 void FlatAffineConstraints::mergeAndAlignIdsWithOther( 610 unsigned offset, FlatAffineConstraints *other) { 611 mergeAndAlignIds(offset, this, other); 612 } 613 614 // This routine may add additional local variables if the flattened expression 615 // corresponding to the map has such variables due to mod's, ceildiv's, and 616 // floordiv's in it. 617 LogicalResult FlatAffineConstraints::composeMap(AffineValueMap *vMap) { 618 std::vector<SmallVector<int64_t, 8>> flatExprs; 619 FlatAffineConstraints localCst; 620 if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, 621 &localCst))) { 622 LLVM_DEBUG(llvm::dbgs() 623 << "composition unimplemented for semi-affine maps\n"); 624 return failure(); 625 } 626 assert(flatExprs.size() == vMap->getNumResults()); 627 628 // Add localCst information. 629 if (localCst.getNumLocalIds() > 0) { 630 SmallVector<Value *, 8> values(vMap->getOperands().begin(), 631 vMap->getOperands().end()); 632 localCst.setIdValues(0, localCst.getNumDimAndSymbolIds(), values); 633 // Align localCst and this. 634 mergeAndAlignIds(/*offset=*/0, &localCst, this); 635 // Finally, append localCst to this constraint set. 636 append(localCst); 637 } 638 639 // Add dimensions corresponding to the map's results. 640 for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { 641 // TODO: Consider using a batched version to add a range of IDs. 642 addDimId(0); 643 } 644 645 // We add one equality for each result connecting the result dim of the map to 646 // the other identifiers. 647 // For eg: if the expression is 16*i0 + i1, and this is the r^th 648 // iteration/result of the value map, we are adding the equality: 649 // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we 650 // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. 651 for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { 652 const auto &flatExpr = flatExprs[r]; 653 assert(flatExpr.size() >= vMap->getNumOperands() + 1); 654 655 // eqToAdd is the equality corresponding to the flattened affine expression. 656 SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); 657 // Set the coefficient for this result to one. 658 eqToAdd[r] = 1; 659 660 // Dims and symbols. 661 for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { 662 unsigned loc; 663 bool ret = findId(*vMap->getOperand(i), &loc); 664 assert(ret && "value map's id can't be found"); 665 (void)ret; 666 // Negate 'eq[r]' since the newly added dimension will be set to this one. 667 eqToAdd[loc] = -flatExpr[i]; 668 } 669 // Local vars common to eq and localCst are at the beginning. 670 unsigned j = getNumDimIds() + getNumSymbolIds(); 671 unsigned end = flatExpr.size() - 1; 672 for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) { 673 eqToAdd[j] = -flatExpr[i]; 674 } 675 676 // Constant term. 677 eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; 678 679 // Add the equality connecting the result of the map to this constraint set. 680 addEquality(eqToAdd); 681 } 682 683 return success(); 684 } 685 686 // Turn a dimension into a symbol. 687 static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value &id) { 688 unsigned pos; 689 if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) { 690 swapId(cst, pos, cst->getNumDimIds() - 1); 691 cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1); 692 } 693 } 694 695 // Turn a symbol into a dimension. 696 static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value &id) { 697 unsigned pos; 698 if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && 699 pos < cst->getNumDimAndSymbolIds()) { 700 swapId(cst, pos, cst->getNumDimIds()); 701 cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1); 702 } 703 } 704 705 // Changes all symbol identifiers which are loop IVs to dim identifiers. 706 void FlatAffineConstraints::convertLoopIVSymbolsToDims() { 707 // Gather all symbols which are loop IVs. 708 SmallVector<Value *, 4> loopIVs; 709 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) { 710 if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue())) 711 loopIVs.push_back(ids[i].getValue()); 712 } 713 // Turn each symbol in 'loopIVs' into a dim identifier. 714 for (auto *iv : loopIVs) { 715 turnSymbolIntoDim(this, *iv); 716 } 717 } 718 719 void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { 720 if (containsId(*id)) 721 return; 722 723 // Caller is expected to fully compose map/operands if necessary. 724 assert((isTopLevelSymbol(id) || isForInductionVar(id)) && 725 "non-terminal symbol / loop IV expected"); 726 // Outer loop IVs could be used in forOp's bounds. 727 if (auto loop = getForInductionVarOwner(id)) { 728 addDimId(getNumDimIds(), id); 729 if (failed(this->addAffineForOpDomain(loop))) 730 LLVM_DEBUG( 731 loop.emitWarning("failed to add domain info to constraint system")); 732 return; 733 } 734 // Add top level symbol. 735 addSymbolId(getNumSymbolIds(), id); 736 // Check if the symbol is a constant. 737 if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id->getDefiningOp())) 738 setIdToConstant(*id, constOp.getValue()); 739 } 740 741 LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { 742 unsigned pos; 743 // Pre-condition for this method. 744 if (!findId(*forOp.getInductionVar(), &pos)) { 745 assert(false && "Value not found"); 746 return failure(); 747 } 748 749 int64_t step = forOp.getStep(); 750 if (step != 1) { 751 if (!forOp.hasConstantLowerBound()) 752 forOp.emitWarning("domain conservatively approximated"); 753 else { 754 // Add constraints for the stride. 755 // (iv - lb) % step = 0 can be written as: 756 // (iv - lb) - step * q = 0 where q = (iv - lb) / step. 757 // Add local variable 'q' and add the above equality. 758 // The first constraint is q = (iv - lb) floordiv step 759 SmallVector<int64_t, 8> dividend(getNumCols(), 0); 760 int64_t lb = forOp.getConstantLowerBound(); 761 dividend[pos] = 1; 762 dividend.back() -= lb; 763 addLocalFloorDiv(dividend, step); 764 // Second constraint: (iv - lb) - step * q = 0. 765 SmallVector<int64_t, 8> eq(getNumCols(), 0); 766 eq[pos] = 1; 767 eq.back() -= lb; 768 // For the local var just added above. 769 eq[getNumCols() - 2] = -step; 770 addEquality(eq); 771 } 772 } 773 774 if (forOp.hasConstantLowerBound()) { 775 addConstantLowerBound(pos, forOp.getConstantLowerBound()); 776 } else { 777 // Non-constant lower bound case. 778 SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands().begin(), 779 forOp.getLowerBoundOperands().end()); 780 if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands, 781 /*eq=*/false, /*lower=*/true))) 782 return failure(); 783 } 784 785 if (forOp.hasConstantUpperBound()) { 786 addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1); 787 return success(); 788 } 789 // Non-constant upper bound case. 790 SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands().begin(), 791 forOp.getUpperBoundOperands().end()); 792 return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands, 793 /*eq=*/false, /*lower=*/false); 794 } 795 796 // Searches for a constraint with a non-zero coefficient at 'colIdx' in 797 // equality (isEq=true) or inequality (isEq=false) constraints. 798 // Returns true and sets row found in search in 'rowIdx'. 799 // Returns false otherwise. 800 static bool 801 findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, 802 unsigned colIdx, bool isEq, unsigned *rowIdx) { 803 auto at = [&](unsigned rowIdx) -> int64_t { 804 return isEq ? constraints.atEq(rowIdx, colIdx) 805 : constraints.atIneq(rowIdx, colIdx); 806 }; 807 unsigned e = 808 isEq ? constraints.getNumEqualities() : constraints.getNumInequalities(); 809 for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { 810 if (at(*rowIdx) != 0) { 811 return true; 812 } 813 } 814 return false; 815 } 816 817 // Normalizes the coefficient values across all columns in 'rowIDx' by their 818 // GCD in equality or inequality contraints as specified by 'isEq'. 819 template <bool isEq> 820 static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, 821 unsigned rowIdx) { 822 auto at = [&](unsigned colIdx) -> int64_t { 823 return isEq ? constraints->atEq(rowIdx, colIdx) 824 : constraints->atIneq(rowIdx, colIdx); 825 }; 826 uint64_t gcd = std::abs(at(0)); 827 for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) { 828 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j))); 829 } 830 if (gcd > 0 && gcd != 1) { 831 for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) { 832 int64_t v = at(j) / static_cast<int64_t>(gcd); 833 isEq ? constraints->atEq(rowIdx, j) = v 834 : constraints->atIneq(rowIdx, j) = v; 835 } 836 } 837 } 838 839 void FlatAffineConstraints::normalizeConstraintsByGCD() { 840 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 841 normalizeConstraintByGCD</*isEq=*/true>(this, i); 842 } 843 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 844 normalizeConstraintByGCD</*isEq=*/false>(this, i); 845 } 846 } 847 848 bool FlatAffineConstraints::hasConsistentState() const { 849 if (inequalities.size() != getNumInequalities() * numReservedCols) 850 return false; 851 if (equalities.size() != getNumEqualities() * numReservedCols) 852 return false; 853 if (ids.size() != getNumIds()) 854 return false; 855 856 // Catches errors where numDims, numSymbols, numIds aren't consistent. 857 if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) 858 return false; 859 860 return true; 861 } 862 863 /// Checks all rows of equality/inequality constraints for trivial 864 /// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced 865 /// after elimination. Returns 'true' if an invalid constraint is found; 866 /// 'false' otherwise. 867 bool FlatAffineConstraints::hasInvalidConstraint() const { 868 assert(hasConsistentState()); 869 auto check = [&](bool isEq) -> bool { 870 unsigned numCols = getNumCols(); 871 unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); 872 for (unsigned i = 0, e = numRows; i < e; ++i) { 873 unsigned j; 874 for (j = 0; j < numCols - 1; ++j) { 875 int64_t v = isEq ? atEq(i, j) : atIneq(i, j); 876 // Skip rows with non-zero variable coefficients. 877 if (v != 0) 878 break; 879 } 880 if (j < numCols - 1) { 881 continue; 882 } 883 // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. 884 // Example invalid constraints include: '1 == 0' or '-1 >= 0' 885 int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); 886 if ((isEq && v != 0) || (!isEq && v < 0)) { 887 return true; 888 } 889 } 890 return false; 891 }; 892 if (check(/*isEq=*/true)) 893 return true; 894 return check(/*isEq=*/false); 895 } 896 897 // Eliminate identifier from constraint at 'rowIdx' based on coefficient at 898 // pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be 899 // updated as they have already been eliminated. 900 static void eliminateFromConstraint(FlatAffineConstraints *constraints, 901 unsigned rowIdx, unsigned pivotRow, 902 unsigned pivotCol, unsigned elimColStart, 903 bool isEq) { 904 // Skip if equality 'rowIdx' if same as 'pivotRow'. 905 if (isEq && rowIdx == pivotRow) 906 return; 907 auto at = [&](unsigned i, unsigned j) -> int64_t { 908 return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); 909 }; 910 int64_t leadCoeff = at(rowIdx, pivotCol); 911 // Skip if leading coefficient at 'rowIdx' is already zero. 912 if (leadCoeff == 0) 913 return; 914 int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); 915 int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; 916 int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); 917 int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); 918 int64_t rowMultiplier = lcm / std::abs(leadCoeff); 919 920 unsigned numCols = constraints->getNumCols(); 921 for (unsigned j = 0; j < numCols; ++j) { 922 // Skip updating column 'j' if it was just eliminated. 923 if (j >= elimColStart && j < pivotCol) 924 continue; 925 int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + 926 rowMultiplier * at(rowIdx, j); 927 isEq ? constraints->atEq(rowIdx, j) = v 928 : constraints->atIneq(rowIdx, j) = v; 929 } 930 } 931 932 // Remove coefficients in column range [colStart, colLimit) in place. 933 // This removes in data in the specified column range, and copies any 934 // remaining valid data into place. 935 static void shiftColumnsToLeft(FlatAffineConstraints *constraints, 936 unsigned colStart, unsigned colLimit, 937 bool isEq) { 938 assert(colLimit <= constraints->getNumIds()); 939 if (colLimit <= colStart) 940 return; 941 942 unsigned numCols = constraints->getNumCols(); 943 unsigned numRows = isEq ? constraints->getNumEqualities() 944 : constraints->getNumInequalities(); 945 unsigned numToEliminate = colLimit - colStart; 946 for (unsigned r = 0, e = numRows; r < e; ++r) { 947 for (unsigned c = colLimit; c < numCols; ++c) { 948 if (isEq) { 949 constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c); 950 } else { 951 constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c); 952 } 953 } 954 } 955 } 956 957 // Removes identifiers in column range [idStart, idLimit), and copies any 958 // remaining valid data into place, and updates member variables. 959 void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) { 960 assert(idLimit < getNumCols() && "invalid id limit"); 961 962 if (idStart >= idLimit) 963 return; 964 965 // We are going to be removing one or more identifiers from the range. 966 assert(idStart < numIds && "invalid idStart position"); 967 968 // TODO(andydavis) Make 'removeIdRange' a lambda called from here. 969 // Remove eliminated identifiers from equalities. 970 shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true); 971 972 // Remove eliminated identifiers from inequalities. 973 shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false); 974 975 // Update members numDims, numSymbols and numIds. 976 unsigned numDimsEliminated = 0; 977 unsigned numLocalsEliminated = 0; 978 unsigned numColsEliminated = idLimit - idStart; 979 if (idStart < numDims) { 980 numDimsEliminated = std::min(numDims, idLimit) - idStart; 981 } 982 // Check how many local id's were removed. Note that our identifier order is 983 // [dims, symbols, locals]. Local id start at position numDims + numSymbols. 984 if (idLimit > numDims + numSymbols) { 985 numLocalsEliminated = std::min( 986 idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); 987 } 988 unsigned numSymbolsEliminated = 989 numColsEliminated - numDimsEliminated - numLocalsEliminated; 990 991 numDims -= numDimsEliminated; 992 numSymbols -= numSymbolsEliminated; 993 numIds = numIds - numColsEliminated; 994 995 ids.erase(ids.begin() + idStart, ids.begin() + idLimit); 996 997 // No resize necessary. numReservedCols remains the same. 998 } 999 1000 /// Returns the position of the identifier that has the minimum <number of lower 1001 /// bounds> times <number of upper bounds> from the specified range of 1002 /// identifiers [start, end). It is often best to eliminate in the increasing 1003 /// order of these counts when doing Fourier-Motzkin elimination since FM adds 1004 /// that many new constraints. 1005 static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst, 1006 unsigned start, unsigned end) { 1007 assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); 1008 1009 auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { 1010 unsigned numLb = 0; 1011 unsigned numUb = 0; 1012 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 1013 if (cst.atIneq(r, pos) > 0) { 1014 ++numLb; 1015 } else if (cst.atIneq(r, pos) < 0) { 1016 ++numUb; 1017 } 1018 } 1019 return numLb * numUb; 1020 }; 1021 1022 unsigned minLoc = start; 1023 unsigned min = getProductOfNumLowerUpperBounds(start); 1024 for (unsigned c = start + 1; c < end; c++) { 1025 unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); 1026 if (numLbUbProduct < min) { 1027 min = numLbUbProduct; 1028 minLoc = c; 1029 } 1030 } 1031 return minLoc; 1032 } 1033 1034 // Checks for emptiness of the set by eliminating identifiers successively and 1035 // using the GCD test (on all equality constraints) and checking for trivially 1036 // invalid constraints. Returns 'true' if the constraint system is found to be 1037 // empty; false otherwise. 1038 bool FlatAffineConstraints::isEmpty() const { 1039 if (isEmptyByGCDTest() || hasInvalidConstraint()) 1040 return true; 1041 1042 // First, eliminate as many identifiers as possible using Gaussian 1043 // elimination. 1044 FlatAffineConstraints tmpCst(*this); 1045 unsigned currentPos = 0; 1046 while (currentPos < tmpCst.getNumIds()) { 1047 tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); 1048 ++currentPos; 1049 // We check emptiness through trivial checks after eliminating each ID to 1050 // detect emptiness early. Since the checks isEmptyByGCDTest() and 1051 // hasInvalidConstraint() are linear time and single sweep on the constraint 1052 // buffer, this appears reasonable - but can optimize in the future. 1053 if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) 1054 return true; 1055 } 1056 1057 // Eliminate the remaining using FM. 1058 for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { 1059 tmpCst.FourierMotzkinEliminate( 1060 getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); 1061 // Check for a constraint explosion. This rarely happens in practice, but 1062 // this check exists as a safeguard against improperly constructed 1063 // constraint systems or artifically created arbitrarily complex systems 1064 // that aren't the intended use case for FlatAffineConstraints. This is 1065 // needed since FM has a worst case exponential complexity in theory. 1066 if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { 1067 LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n"); 1068 return false; 1069 } 1070 1071 // FM wouldn't have modified the equalities in any way. So no need to again 1072 // run GCD test. Check for trivial invalid constraints. 1073 if (tmpCst.hasInvalidConstraint()) 1074 return true; 1075 } 1076 return false; 1077 } 1078 1079 // Runs the GCD test on all equality constraints. Returns 'true' if this test 1080 // fails on any equality. Returns 'false' otherwise. 1081 // This test can be used to disprove the existence of a solution. If it returns 1082 // true, no integer solution to the equality constraints can exist. 1083 // 1084 // GCD test definition: 1085 // 1086 // The equality constraint: 1087 // 1088 // c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 1089 // 1090 // has an integer solution iff: 1091 // 1092 // GCD of c_1, c_2, ..., c_n divides c_0. 1093 // 1094 bool FlatAffineConstraints::isEmptyByGCDTest() const { 1095 assert(hasConsistentState()); 1096 unsigned numCols = getNumCols(); 1097 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 1098 uint64_t gcd = std::abs(atEq(i, 0)); 1099 for (unsigned j = 1; j < numCols - 1; ++j) { 1100 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); 1101 } 1102 int64_t v = std::abs(atEq(i, numCols - 1)); 1103 if (gcd > 0 && (v % gcd != 0)) { 1104 return true; 1105 } 1106 } 1107 return false; 1108 } 1109 1110 /// Tightens inequalities given that we are dealing with integer spaces. This is 1111 /// analogous to the GCD test but applied to inequalities. The constant term can 1112 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e., 1113 /// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a 1114 /// fast method - linear in the number of coefficients. 1115 // Example on how this affects practical cases: consider the scenario: 1116 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield 1117 // j >= 100 instead of the tighter (exact) j >= 128. 1118 void FlatAffineConstraints::GCDTightenInequalities() { 1119 unsigned numCols = getNumCols(); 1120 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 1121 uint64_t gcd = std::abs(atIneq(i, 0)); 1122 for (unsigned j = 1; j < numCols - 1; ++j) { 1123 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j))); 1124 } 1125 if (gcd > 0 && gcd != 1) { 1126 int64_t gcdI = static_cast<int64_t>(gcd); 1127 // Tighten the constant term and normalize the constraint by the GCD. 1128 atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI); 1129 for (unsigned j = 0, e = numCols - 1; j < e; ++j) 1130 atIneq(i, j) /= gcdI; 1131 } 1132 } 1133 } 1134 1135 // Eliminates all identifer variables in column range [posStart, posLimit). 1136 // Returns the number of variables eliminated. 1137 unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, 1138 unsigned posLimit) { 1139 // Return if identifier positions to eliminate are out of range. 1140 assert(posLimit <= numIds); 1141 assert(hasConsistentState()); 1142 1143 if (posStart >= posLimit) 1144 return 0; 1145 1146 GCDTightenInequalities(); 1147 1148 unsigned pivotCol = 0; 1149 for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { 1150 // Find a row which has a non-zero coefficient in column 'j'. 1151 unsigned pivotRow; 1152 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true, 1153 &pivotRow)) { 1154 // No pivot row in equalities with non-zero at 'pivotCol'. 1155 if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false, 1156 &pivotRow)) { 1157 // If inequalities are also non-zero in 'pivotCol', it can be 1158 // eliminated. 1159 continue; 1160 } 1161 break; 1162 } 1163 1164 // Eliminate identifier at 'pivotCol' from each equality row. 1165 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 1166 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 1167 /*isEq=*/true); 1168 normalizeConstraintByGCD</*isEq=*/true>(this, i); 1169 } 1170 1171 // Eliminate identifier at 'pivotCol' from each inequality row. 1172 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 1173 eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, 1174 /*isEq=*/false); 1175 normalizeConstraintByGCD</*isEq=*/false>(this, i); 1176 } 1177 removeEquality(pivotRow); 1178 GCDTightenInequalities(); 1179 } 1180 // Update position limit based on number eliminated. 1181 posLimit = pivotCol; 1182 // Remove eliminated columns from all constraints. 1183 removeIdRange(posStart, posLimit); 1184 return posLimit - posStart; 1185 } 1186 1187 // Detect the identifier at 'pos' (say id_r) as modulo of another identifier 1188 // (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) 1189 // could be detected as the floordiv of n. For eg: 1190 // id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> 1191 // id_r = id_n mod 4, id_q = id_n floordiv 4. 1192 // lbConst and ubConst are the constant lower and upper bounds for 'pos' - 1193 // pre-detected at the caller. 1194 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, 1195 int64_t lbConst, int64_t ubConst, 1196 SmallVectorImpl<AffineExpr> *memo) { 1197 assert(pos < cst.getNumIds() && "invalid position"); 1198 1199 // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to 1200 // id_n - divisor * id_q. If these are true, then id_n becomes the dividend 1201 // and id_q the quotient when dividing id_n by the divisor. 1202 1203 if (lbConst != 0 || ubConst < 1) 1204 return false; 1205 1206 int64_t divisor = ubConst + 1; 1207 1208 // Now check for: id_r = id_n - divisor * id_q. As an example, we 1209 // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. 1210 unsigned seenQuotient = 0, seenDividend = 0; 1211 int quotientPos = -1, dividendPos = -1; 1212 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 1213 // id_n should have coeff 1 or -1. 1214 if (std::abs(cst.atEq(r, pos)) != 1) 1215 continue; 1216 // constant term should be 0. 1217 if (cst.atEq(r, cst.getNumCols() - 1) != 0) 1218 continue; 1219 unsigned c, f; 1220 int quotientSign = 1, dividendSign = 1; 1221 for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { 1222 if (c == pos) 1223 continue; 1224 // The coefficient of the quotient should be +/-divisor. 1225 // TODO(bondhugula): could be extended to detect an affine function for 1226 // the quotient (i.e., the coeff could be a non-zero multiple of divisor). 1227 int64_t v = cst.atEq(r, c) * cst.atEq(r, pos); 1228 if (v == divisor || v == -divisor) { 1229 seenQuotient++; 1230 quotientPos = c; 1231 quotientSign = v > 0 ? 1 : -1; 1232 } 1233 // The coefficient of the dividend should be +/-1. 1234 // TODO(bondhugula): could be extended to detect an affine function of 1235 // the other identifiers as the dividend. 1236 else if (v == -1 || v == 1) { 1237 seenDividend++; 1238 dividendPos = c; 1239 dividendSign = v < 0 ? 1 : -1; 1240 } else if (cst.atEq(r, c) != 0) { 1241 // Cannot be inferred as a mod since the constraint has a coefficient 1242 // for an identifier that's neither a unit nor the divisor (see TODOs 1243 // above). 1244 break; 1245 } 1246 } 1247 if (c < f) 1248 // Cannot be inferred as a mod since the constraint has a coefficient for 1249 // an identifier that's neither a unit nor the divisor (see TODOs above). 1250 continue; 1251 1252 // We are looking for exactly one identifier as the dividend. 1253 if (seenDividend == 1 && seenQuotient >= 1) { 1254 if (!(*memo)[dividendPos]) 1255 return false; 1256 // Successfully detected a mod. 1257 (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign; 1258 auto ub = cst.getConstantUpperBound(dividendPos); 1259 if (ub.hasValue() && ub.getValue() < divisor) 1260 // The mod can be optimized away. 1261 (*memo)[pos] = (*memo)[dividendPos] * dividendSign; 1262 else 1263 (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign; 1264 1265 if (seenQuotient == 1 && !(*memo)[quotientPos]) 1266 // Successfully detected a floordiv as well. 1267 (*memo)[quotientPos] = 1268 (*memo)[dividendPos].floorDiv(divisor) * quotientSign; 1269 return true; 1270 } 1271 } 1272 return false; 1273 } 1274 1275 // Gather lower and upper bounds for the pos^th identifier. 1276 static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst, 1277 unsigned pos, 1278 SmallVectorImpl<unsigned> *lbIndices, 1279 SmallVectorImpl<unsigned> *ubIndices) { 1280 assert(pos < cst.getNumIds() && "invalid position"); 1281 1282 // Gather all lower bounds and upper bounds of the variable. Since the 1283 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 1284 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 1285 for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { 1286 if (cst.atIneq(r, pos) >= 1) { 1287 // Lower bound. 1288 lbIndices->push_back(r); 1289 } else if (cst.atIneq(r, pos) <= -1) { 1290 // Upper bound. 1291 ubIndices->push_back(r); 1292 } 1293 } 1294 } 1295 1296 // Check if the pos^th identifier can be expressed as a floordiv of an affine 1297 // function of other identifiers (where the divisor is a positive constant). 1298 // For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. 1299 bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, 1300 SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) { 1301 assert(pos < cst.getNumIds() && "invalid position"); 1302 1303 SmallVector<unsigned, 4> lbIndices, ubIndices; 1304 getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices); 1305 1306 // Check if any lower bound, upper bound pair is of the form: 1307 // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' 1308 // divisor * id <= expr <-- Upper bound for 'id' 1309 // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). 1310 // 1311 // For example, if -32*k + 16*i + j >= 0 1312 // 32*k - 16*i - j + 31 >= 0 <=> 1313 // k = ( 16*i + j ) floordiv 32 1314 unsigned seenDividends = 0; 1315 for (auto ubPos : ubIndices) { 1316 for (auto lbPos : lbIndices) { 1317 // Check if lower bound's constant term is 'divisor - 1'. The 'divisor' 1318 // here is cst.atIneq(lbPos, pos) and we already know that it's positive 1319 // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'. 1320 if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1) 1321 continue; 1322 // Check if upper bound's constant term is 0. 1323 if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) 1324 continue; 1325 // For the remaining part, check if the lower bound expr's coeff's are 1326 // negations of corresponding upper bound ones'. 1327 unsigned c, f; 1328 for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { 1329 if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) 1330 break; 1331 if (c != pos && cst.atIneq(lbPos, c) != 0) 1332 seenDividends++; 1333 } 1334 // Lb coeff's aren't negative of ub coeff's (for the non constant term 1335 // part). 1336 if (c < f) 1337 continue; 1338 if (seenDividends >= 1) { 1339 // The divisor is the constant term of the lower bound expression. 1340 // We already know that cst.atIneq(lbPos, pos) > 0. 1341 int64_t divisor = cst.atIneq(lbPos, pos); 1342 // Construct the dividend expression. 1343 auto dividendExpr = getAffineConstantExpr(0, context); 1344 unsigned c, f; 1345 for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { 1346 if (c == pos) 1347 continue; 1348 int64_t ubVal = cst.atIneq(ubPos, c); 1349 if (ubVal == 0) 1350 continue; 1351 if (!(*memo)[c]) 1352 break; 1353 dividendExpr = dividendExpr + ubVal * (*memo)[c]; 1354 } 1355 // Expression can't be constructed as it depends on a yet unknown 1356 // identifier. 1357 // TODO(mlir-team): Visit/compute the identifiers in an order so that 1358 // this doesn't happen. More complex but much more efficient. 1359 if (c < f) 1360 continue; 1361 // Successfully detected the floordiv. 1362 (*memo)[pos] = dividendExpr.floorDiv(divisor); 1363 return true; 1364 } 1365 } 1366 } 1367 return false; 1368 } 1369 1370 // Fills an inequality row with the value 'val'. 1371 static inline void fillInequality(FlatAffineConstraints *cst, unsigned r, 1372 int64_t val) { 1373 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { 1374 cst->atIneq(r, c) = val; 1375 } 1376 } 1377 1378 // Negates an inequality. 1379 static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) { 1380 for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { 1381 cst->atIneq(r, c) = -cst->atIneq(r, c); 1382 } 1383 } 1384 1385 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin 1386 // to check if a constraint is redundant. 1387 void FlatAffineConstraints::removeRedundantInequalities() { 1388 SmallVector<bool, 32> redun(getNumInequalities(), false); 1389 // To check if an inequality is redundant, we replace the inequality by its 1390 // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting 1391 // system is empty. If it is, the inequality is redundant. 1392 FlatAffineConstraints tmpCst(*this); 1393 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1394 // Change the inequality to its complement. 1395 negateInequality(&tmpCst, r); 1396 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--; 1397 if (tmpCst.isEmpty()) { 1398 redun[r] = true; 1399 // Zero fill the redundant inequality. 1400 fillInequality(this, r, /*val=*/0); 1401 fillInequality(&tmpCst, r, /*val=*/0); 1402 } else { 1403 // Reverse the change (to avoid recreating tmpCst each time). 1404 tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++; 1405 negateInequality(&tmpCst, r); 1406 } 1407 } 1408 1409 // Scan to get rid of all rows marked redundant, in-place. 1410 auto copyRow = [&](unsigned src, unsigned dest) { 1411 if (src == dest) 1412 return; 1413 for (unsigned c = 0, e = getNumCols(); c < e; c++) { 1414 atIneq(dest, c) = atIneq(src, c); 1415 } 1416 }; 1417 unsigned pos = 0; 1418 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1419 if (!redun[r]) 1420 copyRow(r, pos++); 1421 } 1422 inequalities.resize(numReservedCols * pos); 1423 } 1424 1425 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound( 1426 unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, 1427 ArrayRef<AffineExpr> localExprs, MLIRContext *context) { 1428 assert(pos + offset < getNumDimIds() && "invalid dim start pos"); 1429 assert(symStartPos >= (pos + offset) && "invalid sym start pos"); 1430 assert(getNumLocalIds() == localExprs.size() && 1431 "incorrect local exprs count"); 1432 1433 SmallVector<unsigned, 4> lbIndices, ubIndices; 1434 getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices); 1435 1436 /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). 1437 auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { 1438 b.clear(); 1439 for (unsigned i = 0, e = a.size(); i < e; ++i) { 1440 if (i < offset || i >= offset + num) 1441 b.push_back(a[i]); 1442 } 1443 }; 1444 1445 SmallVector<int64_t, 8> lb, ub; 1446 SmallVector<AffineExpr, 4> exprs; 1447 unsigned dimCount = symStartPos - num; 1448 unsigned symCount = getNumDimAndSymbolIds() - symStartPos; 1449 exprs.reserve(lbIndices.size()); 1450 // Lower bound expressions. 1451 for (auto idx : lbIndices) { 1452 auto ineq = getInequality(idx); 1453 // Extract the lower bound (in terms of other coeff's + const), i.e., if 1454 // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j 1455 // - 1. 1456 addCoeffs(ineq, lb); 1457 std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>()); 1458 auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); 1459 exprs.push_back(expr); 1460 } 1461 auto lbMap = 1462 exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs); 1463 1464 exprs.clear(); 1465 exprs.reserve(ubIndices.size()); 1466 // Upper bound expressions. 1467 for (auto idx : ubIndices) { 1468 auto ineq = getInequality(idx); 1469 // Extract the upper bound (in terms of other coeff's + const). 1470 addCoeffs(ineq, ub); 1471 auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context); 1472 // Upper bound is exclusive. 1473 exprs.push_back(expr + 1); 1474 } 1475 auto ubMap = 1476 exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs); 1477 1478 return {lbMap, ubMap}; 1479 } 1480 1481 /// Computes the lower and upper bounds of the first 'num' dimensional 1482 /// identifiers (starting at 'offset') as affine maps of the remaining 1483 /// identifiers (dimensional and symbolic identifiers). Local identifiers are 1484 /// themselves explicitly computed as affine functions of other identifiers in 1485 /// this process if needed. 1486 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, 1487 MLIRContext *context, 1488 SmallVectorImpl<AffineMap> *lbMaps, 1489 SmallVectorImpl<AffineMap> *ubMaps) { 1490 assert(num < getNumDimIds() && "invalid range"); 1491 1492 // Basic simplification. 1493 normalizeConstraintsByGCD(); 1494 1495 LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num 1496 << " identifiers\n"); 1497 LLVM_DEBUG(dump()); 1498 1499 // Record computed/detected identifiers. 1500 SmallVector<AffineExpr, 8> memo(getNumIds()); 1501 // Initialize dimensional and symbolic identifiers. 1502 for (unsigned i = 0, e = getNumDimIds(); i < e; i++) { 1503 if (i < offset) 1504 memo[i] = getAffineDimExpr(i, context); 1505 else if (i >= offset + num) 1506 memo[i] = getAffineDimExpr(i - num, context); 1507 } 1508 for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) 1509 memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); 1510 1511 bool changed; 1512 do { 1513 changed = false; 1514 // Identify yet unknown identifiers as constants or mod's / floordiv's of 1515 // other identifiers if possible. 1516 for (unsigned pos = 0; pos < getNumIds(); pos++) { 1517 if (memo[pos]) 1518 continue; 1519 1520 auto lbConst = getConstantLowerBound(pos); 1521 auto ubConst = getConstantUpperBound(pos); 1522 if (lbConst.hasValue() && ubConst.hasValue()) { 1523 // Detect equality to a constant. 1524 if (lbConst.getValue() == ubConst.getValue()) { 1525 memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); 1526 changed = true; 1527 continue; 1528 } 1529 1530 // Detect an identifier as modulo of another identifier w.r.t a 1531 // constant. 1532 if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), 1533 &memo)) { 1534 changed = true; 1535 continue; 1536 } 1537 } 1538 1539 // Detect an identifier as floordiv of another identifier w.r.t a 1540 // constant. 1541 if (detectAsFloorDiv(*this, pos, &memo, context)) { 1542 changed = true; 1543 continue; 1544 } 1545 1546 // Detect an identifier as an expression of other identifiers. 1547 unsigned idx; 1548 if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) { 1549 continue; 1550 } 1551 1552 // Build AffineExpr solving for identifier 'pos' in terms of all others. 1553 auto expr = getAffineConstantExpr(0, context); 1554 unsigned j, e; 1555 for (j = 0, e = getNumIds(); j < e; ++j) { 1556 if (j == pos) 1557 continue; 1558 int64_t c = atEq(idx, j); 1559 if (c == 0) 1560 continue; 1561 // If any of the involved IDs hasn't been found yet, we can't proceed. 1562 if (!memo[j]) 1563 break; 1564 expr = expr + memo[j] * c; 1565 } 1566 if (j < e) 1567 // Can't construct expression as it depends on a yet uncomputed 1568 // identifier. 1569 continue; 1570 1571 // Add constant term to AffineExpr. 1572 expr = expr + atEq(idx, getNumIds()); 1573 int64_t vPos = atEq(idx, pos); 1574 assert(vPos != 0 && "expected non-zero here"); 1575 if (vPos > 0) 1576 expr = (-expr).floorDiv(vPos); 1577 else 1578 // vPos < 0. 1579 expr = expr.floorDiv(-vPos); 1580 // Successfully constructed expression. 1581 memo[pos] = expr; 1582 changed = true; 1583 } 1584 // This loop is guaranteed to reach a fixed point - since once an 1585 // identifier's explicit form is computed (in memo[pos]), it's not updated 1586 // again. 1587 } while (changed); 1588 1589 // Set the lower and upper bound maps for all the identifiers that were 1590 // computed as affine expressions of the rest as the "detected expr" and 1591 // "detected expr + 1" respectively; set the undetected ones to null. 1592 Optional<FlatAffineConstraints> tmpClone; 1593 for (unsigned pos = 0; pos < num; pos++) { 1594 unsigned numMapDims = getNumDimIds() - num; 1595 unsigned numMapSymbols = getNumSymbolIds(); 1596 AffineExpr expr = memo[pos + offset]; 1597 if (expr) 1598 expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); 1599 1600 AffineMap &lbMap = (*lbMaps)[pos]; 1601 AffineMap &ubMap = (*ubMaps)[pos]; 1602 1603 if (expr) { 1604 lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); 1605 ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1); 1606 } else { 1607 // TODO(bondhugula): Whenever there are local identifiers in the 1608 // dependence constraints, we'll conservatively over-approximate, since we 1609 // don't always explicitly compute them above (in the while loop). 1610 if (getNumLocalIds() == 0) { 1611 // Work on a copy so that we don't update this constraint system. 1612 if (!tmpClone) { 1613 tmpClone.emplace(FlatAffineConstraints(*this)); 1614 // Removing redudnant inequalities is necessary so that we don't get 1615 // redundant loop bounds. 1616 tmpClone->removeRedundantInequalities(); 1617 } 1618 std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( 1619 pos, offset, num, getNumDimIds(), {}, context); 1620 } 1621 1622 // If the above fails, we'll just use the constant lower bound and the 1623 // constant upper bound (if they exist) as the slice bounds. 1624 // TODO(b/126426796): being conservative for the moment in cases that 1625 // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is 1626 // fixed (b/126426796). 1627 if (!lbMap || lbMap.getNumResults() > 1) { 1628 LLVM_DEBUG(llvm::dbgs() 1629 << "WARNING: Potentially over-approximating slice lb\n"); 1630 auto lbConst = getConstantLowerBound(pos + offset); 1631 if (lbConst.hasValue()) { 1632 lbMap = AffineMap::get( 1633 numMapDims, numMapSymbols, 1634 getAffineConstantExpr(lbConst.getValue(), context)); 1635 } 1636 } 1637 if (!ubMap || ubMap.getNumResults() > 1) { 1638 LLVM_DEBUG(llvm::dbgs() 1639 << "WARNING: Potentially over-approximating slice ub\n"); 1640 auto ubConst = getConstantUpperBound(pos + offset); 1641 if (ubConst.hasValue()) { 1642 (ubMap) = AffineMap::get( 1643 numMapDims, numMapSymbols, 1644 getAffineConstantExpr(ubConst.getValue() + 1, context)); 1645 } 1646 } 1647 } 1648 LLVM_DEBUG(llvm::dbgs() 1649 << "lb map for pos = " << Twine(pos + offset) << ", expr: "); 1650 LLVM_DEBUG(lbMap.dump();); 1651 LLVM_DEBUG(llvm::dbgs() 1652 << "ub map for pos = " << Twine(pos + offset) << ", expr: "); 1653 LLVM_DEBUG(ubMap.dump();); 1654 } 1655 } 1656 1657 LogicalResult 1658 FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, 1659 ArrayRef<Value *> boundOperands, 1660 bool eq, bool lower) { 1661 assert(pos < getNumDimAndSymbolIds() && "invalid position"); 1662 // Equality follows the logic of lower bound except that we add an equality 1663 // instead of an inequality. 1664 assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); 1665 if (eq) 1666 lower = true; 1667 1668 // Fully commpose map and operands; canonicalize and simplify so that we 1669 // transitively get to terminal symbols or loop IVs. 1670 auto map = boundMap; 1671 SmallVector<Value *, 4> operands(boundOperands.begin(), boundOperands.end()); 1672 fullyComposeAffineMapAndOperands(&map, &operands); 1673 map = simplifyAffineMap(map); 1674 canonicalizeMapAndOperands(&map, &operands); 1675 for (auto *operand : operands) 1676 addInductionVarOrTerminalSymbol(operand); 1677 1678 FlatAffineConstraints localVarCst; 1679 std::vector<SmallVector<int64_t, 8>> flatExprs; 1680 if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) { 1681 LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); 1682 return failure(); 1683 } 1684 1685 // Merge and align with localVarCst. 1686 if (localVarCst.getNumLocalIds() > 0) { 1687 // Set values for localVarCst. 1688 localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands); 1689 for (auto *operand : operands) { 1690 unsigned pos; 1691 if (findId(*operand, &pos)) { 1692 if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) { 1693 // If the local var cst has this as a dim, turn it into its symbol. 1694 turnDimIntoSymbol(&localVarCst, *operand); 1695 } else if (pos < getNumDimIds()) { 1696 // Or vice versa. 1697 turnSymbolIntoDim(&localVarCst, *operand); 1698 } 1699 } 1700 } 1701 mergeAndAlignIds(/*offset=*/0, this, &localVarCst); 1702 append(localVarCst); 1703 } 1704 1705 // Record positions of the operands in the constraint system. Need to do 1706 // this here since the constraint system changes after a bound is added. 1707 SmallVector<unsigned, 8> positions; 1708 unsigned numOperands = operands.size(); 1709 for (auto *operand : operands) { 1710 unsigned pos; 1711 if (!findId(*operand, &pos)) 1712 assert(0 && "expected to be found"); 1713 positions.push_back(pos); 1714 } 1715 1716 for (const auto &flatExpr : flatExprs) { 1717 SmallVector<int64_t, 4> ineq(getNumCols(), 0); 1718 ineq[pos] = lower ? 1 : -1; 1719 // Dims and symbols. 1720 for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) { 1721 ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; 1722 } 1723 // Copy over the local id coefficients. 1724 unsigned numLocalIds = flatExpr.size() - 1 - numOperands; 1725 for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds; 1726 jj++, j++) { 1727 ineq[j] = 1728 lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj]; 1729 } 1730 // Constant term. 1731 ineq[getNumCols() - 1] = 1732 lower ? -flatExpr[flatExpr.size() - 1] 1733 // Upper bound in flattenedExpr is an exclusive one. 1734 : flatExpr[flatExpr.size() - 1] - 1; 1735 eq ? addEquality(ineq) : addInequality(ineq); 1736 } 1737 return success(); 1738 } 1739 1740 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper 1741 // bounds in 'ubMaps' to each value in `values' that appears in the constraint 1742 // system. Note that both lower/upper bounds share the same operand list 1743 // 'operands'. 1744 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and 1745 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'. 1746 // Note that both lower/upper bounds use operands from 'operands'. 1747 // Returns failure for unimplemented cases such as semi-affine expressions or 1748 // expressions with mod/floordiv. 1749 LogicalResult FlatAffineConstraints::addSliceBounds( 1750 ArrayRef<Value *> values, ArrayRef<AffineMap> lbMaps, 1751 ArrayRef<AffineMap> ubMaps, ArrayRef<Value *> operands) { 1752 assert(values.size() == lbMaps.size()); 1753 assert(lbMaps.size() == ubMaps.size()); 1754 1755 for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { 1756 unsigned pos; 1757 if (!findId(*values[i], &pos)) 1758 continue; 1759 1760 AffineMap lbMap = lbMaps[i]; 1761 AffineMap ubMap = ubMaps[i]; 1762 assert(!lbMap || lbMap.getNumInputs() == operands.size()); 1763 assert(!ubMap || ubMap.getNumInputs() == operands.size()); 1764 1765 // Check if this slice is just an equality along this dimension. 1766 if (lbMap && ubMap && lbMap.getNumResults() == 1 && 1767 ubMap.getNumResults() == 1 && 1768 lbMap.getResult(0) + 1 == ubMap.getResult(0)) { 1769 if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true, 1770 /*lower=*/true))) 1771 return failure(); 1772 continue; 1773 } 1774 1775 if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, 1776 /*lower=*/true))) 1777 return failure(); 1778 1779 if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, 1780 /*lower=*/false))) 1781 return failure(); 1782 } 1783 return success(); 1784 } 1785 1786 void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) { 1787 assert(eq.size() == getNumCols()); 1788 unsigned offset = equalities.size(); 1789 equalities.resize(equalities.size() + numReservedCols); 1790 std::copy(eq.begin(), eq.end(), equalities.begin() + offset); 1791 } 1792 1793 void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) { 1794 assert(inEq.size() == getNumCols()); 1795 unsigned offset = inequalities.size(); 1796 inequalities.resize(inequalities.size() + numReservedCols); 1797 std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset); 1798 } 1799 1800 void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) { 1801 assert(pos < getNumCols()); 1802 unsigned offset = inequalities.size(); 1803 inequalities.resize(inequalities.size() + numReservedCols); 1804 std::fill(inequalities.begin() + offset, 1805 inequalities.begin() + offset + getNumCols(), 0); 1806 inequalities[offset + pos] = 1; 1807 inequalities[offset + getNumCols() - 1] = -lb; 1808 } 1809 1810 void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) { 1811 assert(pos < getNumCols()); 1812 unsigned offset = inequalities.size(); 1813 inequalities.resize(inequalities.size() + numReservedCols); 1814 std::fill(inequalities.begin() + offset, 1815 inequalities.begin() + offset + getNumCols(), 0); 1816 inequalities[offset + pos] = -1; 1817 inequalities[offset + getNumCols() - 1] = ub; 1818 } 1819 1820 void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr, 1821 int64_t lb) { 1822 assert(expr.size() == getNumCols()); 1823 unsigned offset = inequalities.size(); 1824 inequalities.resize(inequalities.size() + numReservedCols); 1825 std::fill(inequalities.begin() + offset, 1826 inequalities.begin() + offset + getNumCols(), 0); 1827 std::copy(expr.begin(), expr.end(), inequalities.begin() + offset); 1828 inequalities[offset + getNumCols() - 1] += -lb; 1829 } 1830 1831 void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr, 1832 int64_t ub) { 1833 assert(expr.size() == getNumCols()); 1834 unsigned offset = inequalities.size(); 1835 inequalities.resize(inequalities.size() + numReservedCols); 1836 std::fill(inequalities.begin() + offset, 1837 inequalities.begin() + offset + getNumCols(), 0); 1838 for (unsigned i = 0, e = getNumCols(); i < e; i++) { 1839 inequalities[offset + i] = -expr[i]; 1840 } 1841 inequalities[offset + getNumCols() - 1] += ub; 1842 } 1843 1844 /// Adds a new local identifier as the floordiv of an affine function of other 1845 /// identifiers, the coefficients of which are provided in 'dividend' and with 1846 /// respect to a positive constant 'divisor'. Two constraints are added to the 1847 /// system to capture equivalence with the floordiv. 1848 /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. 1849 void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend, 1850 int64_t divisor) { 1851 assert(dividend.size() == getNumCols() && "incorrect dividend size"); 1852 assert(divisor > 0 && "positive divisor expected"); 1853 1854 addLocalId(getNumLocalIds()); 1855 1856 // Add two constraints for this new identifier 'q'. 1857 SmallVector<int64_t, 8> bound(dividend.size() + 1); 1858 1859 // dividend - q * divisor >= 0 1860 std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, 1861 bound.begin()); 1862 bound.back() = dividend.back(); 1863 bound[getNumIds() - 1] = -divisor; 1864 addInequality(bound); 1865 1866 // -dividend +qdivisor * q + divisor - 1 >= 0 1867 std::transform(bound.begin(), bound.end(), bound.begin(), 1868 std::negate<int64_t>()); 1869 bound[bound.size() - 1] += divisor - 1; 1870 addInequality(bound); 1871 } 1872 1873 bool FlatAffineConstraints::findId(Value &id, unsigned *pos) const { 1874 unsigned i = 0; 1875 for (const auto &mayBeId : ids) { 1876 if (mayBeId.hasValue() && mayBeId.getValue() == &id) { 1877 *pos = i; 1878 return true; 1879 } 1880 i++; 1881 } 1882 return false; 1883 } 1884 1885 bool FlatAffineConstraints::containsId(Value &id) const { 1886 return llvm::any_of(ids, [&](const Optional<Value *> &mayBeId) { 1887 return mayBeId.hasValue() && mayBeId.getValue() == &id; 1888 }); 1889 } 1890 1891 void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { 1892 assert(newSymbolCount <= numDims + numSymbols && 1893 "invalid separation position"); 1894 numDims = numDims + numSymbols - newSymbolCount; 1895 numSymbols = newSymbolCount; 1896 } 1897 1898 /// Sets the specified identifer to a constant value. 1899 void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { 1900 unsigned offset = equalities.size(); 1901 equalities.resize(equalities.size() + numReservedCols); 1902 std::fill(equalities.begin() + offset, 1903 equalities.begin() + offset + getNumCols(), 0); 1904 equalities[offset + pos] = 1; 1905 equalities[offset + getNumCols() - 1] = -val; 1906 } 1907 1908 /// Sets the specified identifer to a constant value; asserts if the id is not 1909 /// found. 1910 void FlatAffineConstraints::setIdToConstant(Value &id, int64_t val) { 1911 unsigned pos; 1912 if (!findId(id, &pos)) 1913 // This is a pre-condition for this method. 1914 assert(0 && "id not found"); 1915 setIdToConstant(pos, val); 1916 } 1917 1918 void FlatAffineConstraints::removeEquality(unsigned pos) { 1919 unsigned numEqualities = getNumEqualities(); 1920 assert(pos < numEqualities); 1921 unsigned outputIndex = pos * numReservedCols; 1922 unsigned inputIndex = (pos + 1) * numReservedCols; 1923 unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols; 1924 std::copy(equalities.begin() + inputIndex, 1925 equalities.begin() + inputIndex + numElemsToCopy, 1926 equalities.begin() + outputIndex); 1927 equalities.resize(equalities.size() - numReservedCols); 1928 } 1929 1930 /// Finds an equality that equates the specified identifier to a constant. 1931 /// Returns the position of the equality row. If 'symbolic' is set to true, 1932 /// symbols are also treated like a constant, i.e., an affine function of the 1933 /// symbols is also treated like a constant. 1934 static int findEqualityToConstant(const FlatAffineConstraints &cst, 1935 unsigned pos, bool symbolic = false) { 1936 assert(pos < cst.getNumIds() && "invalid position"); 1937 for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { 1938 int64_t v = cst.atEq(r, pos); 1939 if (v * v != 1) 1940 continue; 1941 unsigned c; 1942 unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); 1943 // This checks for zeros in all positions other than 'pos' in [0, f) 1944 for (c = 0; c < f; c++) { 1945 if (c == pos) 1946 continue; 1947 if (cst.atEq(r, c) != 0) { 1948 // Dependent on another identifier. 1949 break; 1950 } 1951 } 1952 if (c == f) 1953 // Equality is free of other identifiers. 1954 return r; 1955 } 1956 return -1; 1957 } 1958 1959 void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) { 1960 assert(pos < getNumIds() && "invalid position"); 1961 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 1962 atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal; 1963 } 1964 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 1965 atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal; 1966 } 1967 removeId(pos); 1968 } 1969 1970 LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) { 1971 assert(pos < getNumIds() && "invalid position"); 1972 int rowIdx; 1973 if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) 1974 return failure(); 1975 1976 // atEq(rowIdx, pos) is either -1 or 1. 1977 assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); 1978 int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); 1979 setAndEliminate(pos, constVal); 1980 return success(); 1981 } 1982 1983 void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { 1984 for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { 1985 if (failed(constantFoldId(t))) 1986 t++; 1987 } 1988 } 1989 1990 /// Returns the extent (upper bound - lower bound) of the specified 1991 /// identifier if it is found to be a constant; returns None if it's not a 1992 /// constant. This methods treats symbolic identifiers specially, i.e., 1993 /// it looks for constant differences between affine expressions involving 1994 /// only the symbolic identifiers. See comments at function definition for 1995 /// example. 'lb', if provided, is set to the lower bound associated with the 1996 /// constant difference. Note that 'lb' is purely symbolic and thus will contain 1997 /// the coefficients of the symbolic identifiers and the constant coefficient. 1998 // Egs: 0 <= i <= 15, return 16. 1999 // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) 2000 // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. 2001 // s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = 2002 // ceil(s0 - 7 / 8) = floor(s0 / 8)). 2003 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize( 2004 unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor, 2005 SmallVectorImpl<int64_t> *ub) const { 2006 assert(pos < getNumDimIds() && "Invalid identifier position"); 2007 assert(getNumLocalIds() == 0); 2008 2009 // TODO(bondhugula): eliminate all remaining dimensional identifiers (other 2010 // than the one at 'pos' to make this more powerful. Not needed for 2011 // hyper-rectangular spaces. 2012 2013 // Find an equality for 'pos'^th identifier that equates it to some function 2014 // of the symbolic identifiers (+ constant). 2015 int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true); 2016 if (eqRow != -1) { 2017 // This identifier can only take a single value. 2018 if (lb) { 2019 // Set lb to the symbolic value. 2020 lb->resize(getNumSymbolIds() + 1); 2021 if (ub) 2022 ub->resize(getNumSymbolIds() + 1); 2023 for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { 2024 int64_t v = atEq(eqRow, pos); 2025 // atEq(eqRow, pos) is either -1 or 1. 2026 assert(v * v == 1); 2027 (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v 2028 : -atEq(eqRow, getNumDimIds() + c) / v; 2029 // Since this is an equality, ub = lb. 2030 if (ub) 2031 (*ub)[c] = (*lb)[c]; 2032 } 2033 assert(lbFloorDivisor && 2034 "both lb and divisor or none should be provided"); 2035 *lbFloorDivisor = 1; 2036 } 2037 return 1; 2038 } 2039 2040 // Check if the identifier appears at all in any of the inequalities. 2041 unsigned r, e; 2042 for (r = 0, e = getNumInequalities(); r < e; r++) { 2043 if (atIneq(r, pos) != 0) 2044 break; 2045 } 2046 if (r == e) 2047 // If it doesn't, there isn't a bound on it. 2048 return None; 2049 2050 // Positions of constraints that are lower/upper bounds on the variable. 2051 SmallVector<unsigned, 4> lbIndices, ubIndices; 2052 2053 // Gather all symbolic lower bounds and upper bounds of the variable. Since 2054 // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a 2055 // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 2056 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 2057 unsigned c, f; 2058 for (c = 0, f = getNumDimIds(); c < f; c++) { 2059 if (c != pos && atIneq(r, c) != 0) 2060 break; 2061 } 2062 if (c < getNumDimIds()) 2063 // Not a pure symbolic bound. 2064 continue; 2065 if (atIneq(r, pos) >= 1) 2066 // Lower bound. 2067 lbIndices.push_back(r); 2068 else if (atIneq(r, pos) <= -1) 2069 // Upper bound. 2070 ubIndices.push_back(r); 2071 } 2072 2073 // TODO(bondhugula): eliminate other dimensional identifiers to make this more 2074 // powerful. Not needed for hyper-rectangular iteration spaces. 2075 2076 Optional<int64_t> minDiff = None; 2077 unsigned minLbPosition, minUbPosition; 2078 for (auto ubPos : ubIndices) { 2079 for (auto lbPos : lbIndices) { 2080 // Look for a lower bound and an upper bound that only differ by a 2081 // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. 2082 // For example, if ii is the pos^th variable, we are looking for 2083 // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The 2084 // minimum among all such constant differences is kept since that's the 2085 // constant bounding the extent of the pos^th variable. 2086 unsigned j, e; 2087 for (j = 0, e = getNumCols() - 1; j < e; j++) 2088 if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { 2089 break; 2090 } 2091 if (j < getNumCols() - 1) 2092 continue; 2093 int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + 2094 atIneq(lbPos, getNumCols() - 1) + 1, 2095 atIneq(lbPos, pos)); 2096 if (minDiff == None || diff < minDiff) { 2097 minDiff = diff; 2098 minLbPosition = lbPos; 2099 minUbPosition = ubPos; 2100 } 2101 } 2102 } 2103 if (lb && minDiff.hasValue()) { 2104 // Set lb to the symbolic lower bound. 2105 lb->resize(getNumSymbolIds() + 1); 2106 if (ub) 2107 ub->resize(getNumSymbolIds() + 1); 2108 // The lower bound is the ceildiv of the lb constraint over the coefficient 2109 // of the variable at 'pos'. We express the ceildiv equivalently as a floor 2110 // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + 2111 // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). 2112 *lbFloorDivisor = atIneq(minLbPosition, pos); 2113 assert(*lbFloorDivisor == -atIneq(minUbPosition, pos)); 2114 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { 2115 (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c); 2116 } 2117 if (ub) { 2118 for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) 2119 (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c); 2120 } 2121 // The lower bound leads to a ceildiv while the upper bound is a floordiv 2122 // whenever the cofficient at pos != 1. ceildiv (val / d) = floordiv (val + 2123 // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to 2124 // the constant term for the lower bound. 2125 (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1; 2126 } 2127 return minDiff; 2128 } 2129 2130 template <bool isLower> 2131 Optional<int64_t> 2132 FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) { 2133 assert(pos < getNumIds() && "invalid position"); 2134 // Project to 'pos'. 2135 projectOut(0, pos); 2136 projectOut(1, getNumIds() - 1); 2137 // Check if there's an equality equating the '0'^th identifier to a constant. 2138 int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); 2139 if (eqRowIdx != -1) 2140 // atEq(rowIdx, 0) is either -1 or 1. 2141 return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); 2142 2143 // Check if the identifier appears at all in any of the inequalities. 2144 unsigned r, e; 2145 for (r = 0, e = getNumInequalities(); r < e; r++) { 2146 if (atIneq(r, 0) != 0) 2147 break; 2148 } 2149 if (r == e) 2150 // If it doesn't, there isn't a bound on it. 2151 return None; 2152 2153 Optional<int64_t> minOrMaxConst = None; 2154 2155 // Take the max across all const lower bounds (or min across all constant 2156 // upper bounds). 2157 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 2158 if (isLower) { 2159 if (atIneq(r, 0) <= 0) 2160 // Not a lower bound. 2161 continue; 2162 } else if (atIneq(r, 0) >= 0) { 2163 // Not an upper bound. 2164 continue; 2165 } 2166 unsigned c, f; 2167 for (c = 0, f = getNumCols() - 1; c < f; c++) 2168 if (c != 0 && atIneq(r, c) != 0) 2169 break; 2170 if (c < getNumCols() - 1) 2171 // Not a constant bound. 2172 continue; 2173 2174 int64_t boundConst = 2175 isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) 2176 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); 2177 if (isLower) { 2178 if (minOrMaxConst == None || boundConst > minOrMaxConst) 2179 minOrMaxConst = boundConst; 2180 } else { 2181 if (minOrMaxConst == None || boundConst < minOrMaxConst) 2182 minOrMaxConst = boundConst; 2183 } 2184 } 2185 return minOrMaxConst; 2186 } 2187 2188 Optional<int64_t> 2189 FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { 2190 FlatAffineConstraints tmpCst(*this); 2191 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos); 2192 } 2193 2194 Optional<int64_t> 2195 FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { 2196 FlatAffineConstraints tmpCst(*this); 2197 return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos); 2198 } 2199 2200 // A simple (naive and conservative) check for hyper-rectangularlity. 2201 bool FlatAffineConstraints::isHyperRectangular(unsigned pos, 2202 unsigned num) const { 2203 assert(pos < getNumCols() - 1); 2204 // Check for two non-zero coefficients in the range [pos, pos + sum). 2205 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 2206 unsigned sum = 0; 2207 for (unsigned c = pos; c < pos + num; c++) { 2208 if (atIneq(r, c) != 0) 2209 sum++; 2210 } 2211 if (sum > 1) 2212 return false; 2213 } 2214 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 2215 unsigned sum = 0; 2216 for (unsigned c = pos; c < pos + num; c++) { 2217 if (atEq(r, c) != 0) 2218 sum++; 2219 } 2220 if (sum > 1) 2221 return false; 2222 } 2223 return true; 2224 } 2225 2226 void FlatAffineConstraints::print(raw_ostream &os) const { 2227 assert(hasConsistentState()); 2228 os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds() 2229 << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints() 2230 << " constraints)\n"; 2231 os << "("; 2232 for (unsigned i = 0, e = getNumIds(); i < e; i++) { 2233 if (ids[i] == None) 2234 os << "None "; 2235 else 2236 os << "Value "; 2237 } 2238 os << " const)\n"; 2239 for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { 2240 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2241 os << atEq(i, j) << " "; 2242 } 2243 os << "= 0\n"; 2244 } 2245 for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { 2246 for (unsigned j = 0, f = getNumCols(); j < f; ++j) { 2247 os << atIneq(i, j) << " "; 2248 } 2249 os << ">= 0\n"; 2250 } 2251 os << '\n'; 2252 } 2253 2254 void FlatAffineConstraints::dump() const { print(llvm::errs()); } 2255 2256 /// Removes duplicate constraints, trivially true constraints, and constraints 2257 /// that can be detected as redundant as a result of differing only in their 2258 /// constant term part. A constraint of the form <non-negative constant> >= 0 is 2259 /// considered trivially true. 2260 // Uses a DenseSet to hash and detect duplicates followed by a linear scan to 2261 // remove duplicates in place. 2262 void FlatAffineConstraints::removeTrivialRedundancy() { 2263 SmallDenseSet<ArrayRef<int64_t>, 8> rowSet; 2264 2265 // A map used to detect redundancy stemming from constraints that only differ 2266 // in their constant term. The value stored is <row position, const term> 2267 // for a given row. 2268 SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>> 2269 rowsWithoutConstTerm; 2270 2271 // Check if constraint is of the form <non-negative-constant> >= 0. 2272 auto isTriviallyValid = [&](unsigned r) -> bool { 2273 for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { 2274 if (atIneq(r, c) != 0) 2275 return false; 2276 } 2277 return atIneq(r, getNumCols() - 1) >= 0; 2278 }; 2279 2280 // Detect and mark redundant constraints. 2281 SmallVector<bool, 256> redunIneq(getNumInequalities(), false); 2282 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 2283 int64_t *rowStart = inequalities.data() + numReservedCols * r; 2284 auto row = ArrayRef<int64_t>(rowStart, getNumCols()); 2285 if (isTriviallyValid(r) || !rowSet.insert(row).second) { 2286 redunIneq[r] = true; 2287 continue; 2288 } 2289 2290 // Among constraints that only differ in the constant term part, mark 2291 // everything other than the one with the smallest constant term redundant. 2292 // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the 2293 // former two are redundant). 2294 int64_t constTerm = atIneq(r, getNumCols() - 1); 2295 auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1); 2296 const auto &ret = 2297 rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}}); 2298 if (!ret.second) { 2299 // Check if the other constraint has a higher constant term. 2300 auto &val = ret.first->second; 2301 if (val.second > constTerm) { 2302 // The stored row is redundant. Mark it so, and update with this one. 2303 redunIneq[val.first] = true; 2304 val = {r, constTerm}; 2305 } else { 2306 // The one stored makes this one redundant. 2307 redunIneq[r] = true; 2308 } 2309 } 2310 } 2311 2312 auto copyRow = [&](unsigned src, unsigned dest) { 2313 if (src == dest) 2314 return; 2315 for (unsigned c = 0, e = getNumCols(); c < e; c++) { 2316 atIneq(dest, c) = atIneq(src, c); 2317 } 2318 }; 2319 2320 // Scan to get rid of all rows marked redundant, in-place. 2321 unsigned pos = 0; 2322 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 2323 if (!redunIneq[r]) 2324 copyRow(r, pos++); 2325 } 2326 inequalities.resize(numReservedCols * pos); 2327 2328 // TODO(bondhugula): consider doing this for equalities as well, but probably 2329 // not worth the savings. 2330 } 2331 2332 void FlatAffineConstraints::clearAndCopyFrom( 2333 const FlatAffineConstraints &other) { 2334 FlatAffineConstraints copy(other); 2335 std::swap(*this, copy); 2336 assert(copy.getNumIds() == copy.getIds().size()); 2337 } 2338 2339 void FlatAffineConstraints::removeId(unsigned pos) { 2340 removeIdRange(pos, pos + 1); 2341 } 2342 2343 static std::pair<unsigned, unsigned> 2344 getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { 2345 unsigned numDims = cst.getNumDimIds(); 2346 unsigned numSymbols = cst.getNumSymbolIds(); 2347 unsigned newNumDims, newNumSymbols; 2348 if (pos < numDims) { 2349 newNumDims = numDims - 1; 2350 newNumSymbols = numSymbols; 2351 } else if (pos < numDims + numSymbols) { 2352 assert(numSymbols >= 1); 2353 newNumDims = numDims; 2354 newNumSymbols = numSymbols - 1; 2355 } else { 2356 newNumDims = numDims; 2357 newNumSymbols = numSymbols; 2358 } 2359 return {newNumDims, newNumSymbols}; 2360 } 2361 2362 #undef DEBUG_TYPE 2363 #define DEBUG_TYPE "fm" 2364 2365 /// Eliminates identifier at the specified position using Fourier-Motzkin 2366 /// variable elimination. This technique is exact for rational spaces but 2367 /// conservative (in "rare" cases) for integer spaces. The operation corresponds 2368 /// to a projection operation yielding the (convex) set of integer points 2369 /// contained in the rational shadow of the set. An emptiness test that relies 2370 /// on this method will guarantee emptiness, i.e., it disproves the existence of 2371 /// a solution if it says it's empty. 2372 /// If a non-null isResultIntegerExact is passed, it is set to true if the 2373 /// result is also integer exact. If it's set to false, the obtained solution 2374 /// *may* not be exact, i.e., it may contain integer points that do not have an 2375 /// integer pre-image in the original set. 2376 /// 2377 /// Eg: 2378 /// j >= 0, j <= i + 1 2379 /// i >= 0, i <= N + 1 2380 /// Eliminating i yields, 2381 /// j >= 0, 0 <= N + 1, j - 1 <= N + 1 2382 /// 2383 /// If darkShadow = true, this method computes the dark shadow on elimination; 2384 /// the dark shadow is a convex integer subset of the exact integer shadow. A 2385 /// non-empty dark shadow proves the existence of an integer solution. The 2386 /// elimination in such a case could however be an under-approximation, and thus 2387 /// should not be used for scanning sets or used by itself for dependence 2388 /// checking. 2389 /// 2390 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. 2391 /// ^ 2392 /// | 2393 /// | * * * * o o 2394 /// i | * * o o o o 2395 /// | o * * * * * 2396 /// ---------------> 2397 /// j -> 2398 /// 2399 /// Eliminating i from this system (projecting on the j dimension): 2400 /// rational shadow / integer light shadow: 1 <= j <= 6 2401 /// dark shadow: 3 <= j <= 6 2402 /// exact integer shadow: j = 1 \union 3 <= j <= 6 2403 /// holes/splinters: j = 2 2404 /// 2405 /// darkShadow = false, isResultIntegerExact = nullptr are default values. 2406 // TODO(bondhugula): a slight modification to yield dark shadow version of FM 2407 // (tightened), which can prove the existence of a solution if there is one. 2408 void FlatAffineConstraints::FourierMotzkinEliminate( 2409 unsigned pos, bool darkShadow, bool *isResultIntegerExact) { 2410 LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); 2411 LLVM_DEBUG(dump()); 2412 assert(pos < getNumIds() && "invalid position"); 2413 assert(hasConsistentState()); 2414 2415 // Check if this identifier can be eliminated through a substitution. 2416 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 2417 if (atEq(r, pos) != 0) { 2418 // Use Gaussian elimination here (since we have an equality). 2419 LogicalResult ret = gaussianEliminateId(pos); 2420 (void)ret; 2421 assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed"); 2422 LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n"); 2423 LLVM_DEBUG(dump()); 2424 return; 2425 } 2426 } 2427 2428 // A fast linear time tightening. 2429 GCDTightenInequalities(); 2430 2431 // Check if the identifier appears at all in any of the inequalities. 2432 unsigned r, e; 2433 for (r = 0, e = getNumInequalities(); r < e; r++) { 2434 if (atIneq(r, pos) != 0) 2435 break; 2436 } 2437 if (r == getNumInequalities()) { 2438 // If it doesn't appear, just remove the column and return. 2439 // TODO(andydavis,bondhugula): refactor removeColumns to use it from here. 2440 removeId(pos); 2441 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 2442 LLVM_DEBUG(dump()); 2443 return; 2444 } 2445 2446 // Positions of constraints that are lower bounds on the variable. 2447 SmallVector<unsigned, 4> lbIndices; 2448 // Positions of constraints that are lower bounds on the variable. 2449 SmallVector<unsigned, 4> ubIndices; 2450 // Positions of constraints that do not involve the variable. 2451 std::vector<unsigned> nbIndices; 2452 nbIndices.reserve(getNumInequalities()); 2453 2454 // Gather all lower bounds and upper bounds of the variable. Since the 2455 // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower 2456 // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. 2457 for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { 2458 if (atIneq(r, pos) == 0) { 2459 // Id does not appear in bound. 2460 nbIndices.push_back(r); 2461 } else if (atIneq(r, pos) >= 1) { 2462 // Lower bound. 2463 lbIndices.push_back(r); 2464 } else { 2465 // Upper bound. 2466 ubIndices.push_back(r); 2467 } 2468 } 2469 2470 // Set the number of dimensions, symbols in the resulting system. 2471 const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this); 2472 unsigned newNumDims = dimsSymbols.first; 2473 unsigned newNumSymbols = dimsSymbols.second; 2474 2475 SmallVector<Optional<Value *>, 8> newIds; 2476 newIds.reserve(numIds - 1); 2477 newIds.append(ids.begin(), ids.begin() + pos); 2478 newIds.append(ids.begin() + pos + 1, ids.end()); 2479 2480 /// Create the new system which has one identifier less. 2481 FlatAffineConstraints newFac( 2482 lbIndices.size() * ubIndices.size() + nbIndices.size(), 2483 getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols, 2484 /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds); 2485 2486 assert(newFac.getIds().size() == newFac.getNumIds()); 2487 2488 // This will be used to check if the elimination was integer exact. 2489 unsigned lcmProducts = 1; 2490 2491 // Let x be the variable we are eliminating. 2492 // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note 2493 // that c_l, c_u >= 1) we have: 2494 // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u 2495 // We thus generate a constraint: 2496 // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. 2497 // Note if c_l = c_u = 1, all integer points captured by the resulting 2498 // constraint correspond to integer points in the original system (i.e., they 2499 // have integer pre-images). Hence, if the lcm's are all 1, the elimination is 2500 // integer exact. 2501 for (auto ubPos : ubIndices) { 2502 for (auto lbPos : lbIndices) { 2503 SmallVector<int64_t, 4> ineq; 2504 ineq.reserve(newFac.getNumCols()); 2505 int64_t lbCoeff = atIneq(lbPos, pos); 2506 // Note that in the comments above, ubCoeff is the negation of the 2507 // coefficient in the canonical form as the view taken here is that of the 2508 // term being moved to the other size of '>='. 2509 int64_t ubCoeff = -atIneq(ubPos, pos); 2510 // TODO(bondhugula): refactor this loop to avoid all branches inside. 2511 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 2512 if (l == pos) 2513 continue; 2514 assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); 2515 int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); 2516 ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + 2517 atIneq(lbPos, l) * (lcm / lbCoeff)); 2518 lcmProducts *= lcm; 2519 } 2520 if (darkShadow) { 2521 // The dark shadow is a convex subset of the exact integer shadow. If 2522 // there is a point here, it proves the existence of a solution. 2523 ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; 2524 } 2525 // TODO: we need to have a way to add inequalities in-place in 2526 // FlatAffineConstraints instead of creating and copying over. 2527 newFac.addInequality(ineq); 2528 } 2529 } 2530 2531 LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1) 2532 << "\n"); 2533 if (lcmProducts == 1 && isResultIntegerExact) 2534 *isResultIntegerExact = 1; 2535 2536 // Copy over the constraints not involving this variable. 2537 for (auto nbPos : nbIndices) { 2538 SmallVector<int64_t, 4> ineq; 2539 ineq.reserve(getNumCols() - 1); 2540 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 2541 if (l == pos) 2542 continue; 2543 ineq.push_back(atIneq(nbPos, l)); 2544 } 2545 newFac.addInequality(ineq); 2546 } 2547 2548 assert(newFac.getNumConstraints() == 2549 lbIndices.size() * ubIndices.size() + nbIndices.size()); 2550 2551 // Copy over the equalities. 2552 for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { 2553 SmallVector<int64_t, 4> eq; 2554 eq.reserve(newFac.getNumCols()); 2555 for (unsigned l = 0, e = getNumCols(); l < e; l++) { 2556 if (l == pos) 2557 continue; 2558 eq.push_back(atEq(r, l)); 2559 } 2560 newFac.addEquality(eq); 2561 } 2562 2563 // GCD tightening and normalization allows detection of more trivially 2564 // redundant constraints. 2565 newFac.GCDTightenInequalities(); 2566 newFac.normalizeConstraintsByGCD(); 2567 newFac.removeTrivialRedundancy(); 2568 clearAndCopyFrom(newFac); 2569 LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); 2570 LLVM_DEBUG(dump()); 2571 } 2572 2573 #undef DEBUG_TYPE 2574 #define DEBUG_TYPE "affine-structures" 2575 2576 void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { 2577 if (num == 0) 2578 return; 2579 2580 // 'pos' can be at most getNumCols() - 2 if num > 0. 2581 assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position"); 2582 assert(pos + num < getNumCols() && "invalid range"); 2583 2584 // Eliminate as many identifiers as possible using Gaussian elimination. 2585 unsigned currentPos = pos; 2586 unsigned numToEliminate = num; 2587 unsigned numGaussianEliminated = 0; 2588 2589 while (currentPos < getNumIds()) { 2590 unsigned curNumEliminated = 2591 gaussianEliminateIds(currentPos, currentPos + numToEliminate); 2592 ++currentPos; 2593 numToEliminate -= curNumEliminated + 1; 2594 numGaussianEliminated += curNumEliminated; 2595 } 2596 2597 // Eliminate the remaining using Fourier-Motzkin. 2598 for (unsigned i = 0; i < num - numGaussianEliminated; i++) { 2599 unsigned numToEliminate = num - numGaussianEliminated - i; 2600 FourierMotzkinEliminate( 2601 getBestIdToEliminate(*this, pos, pos + numToEliminate)); 2602 } 2603 2604 // Fast/trivial simplifications. 2605 GCDTightenInequalities(); 2606 // Normalize constraints after tightening since the latter impacts this, but 2607 // not the other way round. 2608 normalizeConstraintsByGCD(); 2609 } 2610 2611 void FlatAffineConstraints::projectOut(Value *id) { 2612 unsigned pos; 2613 bool ret = findId(*id, &pos); 2614 assert(ret); 2615 (void)ret; 2616 FourierMotzkinEliminate(pos); 2617 } 2618 2619 bool FlatAffineConstraints::isRangeOneToOne(unsigned start, 2620 unsigned limit) const { 2621 assert(start <= getNumIds() - 1 && "invalid start position"); 2622 assert(limit > start && limit <= getNumIds() && "invalid limit"); 2623 2624 FlatAffineConstraints tmpCst(*this); 2625 2626 if (start != 0) { 2627 // Move [start, limit) to the left. 2628 for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { 2629 for (unsigned c = 0, f = getNumCols(); c < f; ++c) { 2630 if (c >= start && c < limit) 2631 tmpCst.atIneq(r, c - start) = atIneq(r, c); 2632 else if (c < start) 2633 tmpCst.atIneq(r, c + limit - start) = atIneq(r, c); 2634 else 2635 tmpCst.atIneq(r, c) = atIneq(r, c); 2636 } 2637 } 2638 for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) { 2639 for (unsigned c = 0, f = getNumCols(); c < f; ++c) { 2640 if (c >= start && c < limit) 2641 tmpCst.atEq(r, c - start) = atEq(r, c); 2642 else if (c < start) 2643 tmpCst.atEq(r, c + limit - start) = atEq(r, c); 2644 else 2645 tmpCst.atEq(r, c) = atEq(r, c); 2646 } 2647 } 2648 } 2649 2650 // Mark everything to the right as symbols so that we can check the extents in 2651 // a symbolic way below. 2652 tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start)); 2653 2654 // Check if the extents of all the specified dimensions are just one (when 2655 // treating the rest as symbols). 2656 for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) { 2657 auto extent = tmpCst.getConstantBoundOnDimSize(pos); 2658 if (!extent.hasValue() || extent.getValue() != 1) 2659 return false; 2660 } 2661 return true; 2662 } 2663 2664 void FlatAffineConstraints::clearConstraints() { 2665 equalities.clear(); 2666 inequalities.clear(); 2667 } 2668 2669 namespace { 2670 2671 enum BoundCmpResult { Greater, Less, Equal, Unknown }; 2672 2673 /// Compares two affine bounds whose coefficients are provided in 'first' and 2674 /// 'second'. The last coefficient is the constant term. 2675 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { 2676 assert(a.size() == b.size()); 2677 2678 // For the bounds to be comparable, their corresponding identifier 2679 // coefficients should be equal; the constant terms are then compared to 2680 // determine less/greater/equal. 2681 2682 if (!std::equal(a.begin(), a.end() - 1, b.begin())) 2683 return Unknown; 2684 2685 if (a.back() == b.back()) 2686 return Equal; 2687 2688 return a.back() < b.back() ? Less : Greater; 2689 } 2690 } // namespace 2691 2692 // Computes the bounding box with respect to 'other' by finding the min of the 2693 // lower bounds and the max of the upper bounds along each of the dimensions. 2694 LogicalResult 2695 FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) { 2696 assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); 2697 assert(otherCst.getIds() 2698 .slice(0, getNumDimIds()) 2699 .equals(getIds().slice(0, getNumDimIds())) && 2700 "dim values mismatch"); 2701 assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); 2702 assert(getNumLocalIds() == 0 && "local ids not supported yet here"); 2703 2704 Optional<FlatAffineConstraints> otherCopy; 2705 if (!areIdsAligned(*this, otherCst)) { 2706 otherCopy.emplace(FlatAffineConstraints(otherCst)); 2707 mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue()); 2708 } 2709 2710 const auto &other = otherCopy ? *otherCopy : otherCst; 2711 2712 std::vector<SmallVector<int64_t, 8>> boundingLbs; 2713 std::vector<SmallVector<int64_t, 8>> boundingUbs; 2714 boundingLbs.reserve(2 * getNumDimIds()); 2715 boundingUbs.reserve(2 * getNumDimIds()); 2716 2717 // To hold lower and upper bounds for each dimension. 2718 SmallVector<int64_t, 4> lb, otherLb, ub, otherUb; 2719 // To compute min of lower bounds and max of upper bounds for each dimension. 2720 SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1); 2721 SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1); 2722 // To compute final new lower and upper bounds for the union. 2723 SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols()); 2724 2725 int64_t lbFloorDivisor, otherLbFloorDivisor; 2726 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { 2727 auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub); 2728 if (!extent.hasValue()) 2729 // TODO(bondhugula): symbolic extents when necessary. 2730 // TODO(bondhugula): handle union if a dimension is unbounded. 2731 return failure(); 2732 2733 auto otherExtent = other.getConstantBoundOnDimSize( 2734 d, &otherLb, &otherLbFloorDivisor, &otherUb); 2735 if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor) 2736 // TODO(bondhugula): symbolic extents when necessary. 2737 return failure(); 2738 2739 assert(lbFloorDivisor > 0 && "divisor always expected to be positive"); 2740 2741 auto res = compareBounds(lb, otherLb); 2742 // Identify min. 2743 if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { 2744 minLb = lb; 2745 // Since the divisor is for a floordiv, we need to convert to ceildiv, 2746 // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=> 2747 // div * i >= expr - div + 1. 2748 minLb.back() -= lbFloorDivisor - 1; 2749 } else if (res == BoundCmpResult::Greater) { 2750 minLb = otherLb; 2751 minLb.back() -= otherLbFloorDivisor - 1; 2752 } else { 2753 // Uncomparable - check for constant lower/upper bounds. 2754 auto constLb = getConstantLowerBound(d); 2755 auto constOtherLb = other.getConstantLowerBound(d); 2756 if (!constLb.hasValue() || !constOtherLb.hasValue()) 2757 return failure(); 2758 std::fill(minLb.begin(), minLb.end(), 0); 2759 minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue()); 2760 } 2761 2762 // Do the same for ub's but max of upper bounds. Identify max. 2763 auto uRes = compareBounds(ub, otherUb); 2764 if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { 2765 maxUb = ub; 2766 } else if (uRes == BoundCmpResult::Less) { 2767 maxUb = otherUb; 2768 } else { 2769 // Uncomparable - check for constant lower/upper bounds. 2770 auto constUb = getConstantUpperBound(d); 2771 auto constOtherUb = other.getConstantUpperBound(d); 2772 if (!constUb.hasValue() || !constOtherUb.hasValue()) 2773 return failure(); 2774 std::fill(maxUb.begin(), maxUb.end(), 0); 2775 maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); 2776 } 2777 2778 std::fill(newLb.begin(), newLb.end(), 0); 2779 std::fill(newUb.begin(), newUb.end(), 0); 2780 2781 // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, 2782 // and so it's the divisor for newLb and newUb as well. 2783 newLb[d] = lbFloorDivisor; 2784 newUb[d] = -lbFloorDivisor; 2785 // Copy over the symbolic part + constant term. 2786 std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); 2787 std::transform(newLb.begin() + getNumDimIds(), newLb.end(), 2788 newLb.begin() + getNumDimIds(), std::negate<int64_t>()); 2789 std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); 2790 2791 boundingLbs.push_back(newLb); 2792 boundingUbs.push_back(newUb); 2793 } 2794 2795 // Clear all constraints and add the lower/upper bounds for the bounding box. 2796 clearConstraints(); 2797 for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { 2798 addInequality(boundingLbs[d]); 2799 addInequality(boundingUbs[d]); 2800 } 2801 // TODO(mlir-team): copy over pure symbolic constraints from this and 'other' 2802 // over to the union (since the above are just the union along dimensions); we 2803 // shouldn't be discarding any other constraints on the symbols. 2804 2805 return success(); 2806 }