github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/AffineExpr.cpp (about) 1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/AffineExpr.h" 19 #include "AffineExprDetail.h" 20 #include "mlir/IR/AffineExprVisitor.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/Support/MathExtras.h" 24 #include "mlir/Support/STLExtras.h" 25 #include "llvm/ADT/STLExtras.h" 26 27 using namespace mlir; 28 using namespace mlir::detail; 29 30 MLIRContext *AffineExpr::getContext() const { return expr->context; } 31 32 AffineExprKind AffineExpr::getKind() const { 33 return static_cast<AffineExprKind>(expr->getKind()); 34 } 35 36 /// Walk all of the AffineExprs in this subgraph in postorder. 37 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const { 38 struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> { 39 std::function<void(AffineExpr)> callback; 40 41 AffineExprWalker(std::function<void(AffineExpr)> callback) 42 : callback(callback) {} 43 44 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); } 45 void visitConstantExpr(AffineConstantExpr expr) { callback(expr); } 46 void visitDimExpr(AffineDimExpr expr) { callback(expr); } 47 void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); } 48 }; 49 50 AffineExprWalker(callback).walkPostOrder(*this); 51 } 52 53 // Dispatch affine expression construction based on kind. 54 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, 55 AffineExpr rhs) { 56 if (kind == AffineExprKind::Add) 57 return lhs + rhs; 58 if (kind == AffineExprKind::Mul) 59 return lhs * rhs; 60 if (kind == AffineExprKind::FloorDiv) 61 return lhs.floorDiv(rhs); 62 if (kind == AffineExprKind::CeilDiv) 63 return lhs.ceilDiv(rhs); 64 if (kind == AffineExprKind::Mod) 65 return lhs % rhs; 66 67 llvm_unreachable("unknown binary operation on affine expressions"); 68 } 69 70 /// This method substitutes any uses of dimensions and symbols (e.g. 71 /// dim#0 with dimReplacements[0]) and returns the modified expression tree. 72 AffineExpr 73 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements, 74 ArrayRef<AffineExpr> symReplacements) const { 75 switch (getKind()) { 76 case AffineExprKind::Constant: 77 return *this; 78 case AffineExprKind::DimId: { 79 unsigned dimId = cast<AffineDimExpr>().getPosition(); 80 if (dimId >= dimReplacements.size()) 81 return *this; 82 return dimReplacements[dimId]; 83 } 84 case AffineExprKind::SymbolId: { 85 unsigned symId = cast<AffineSymbolExpr>().getPosition(); 86 if (symId >= symReplacements.size()) 87 return *this; 88 return symReplacements[symId]; 89 } 90 case AffineExprKind::Add: 91 case AffineExprKind::Mul: 92 case AffineExprKind::FloorDiv: 93 case AffineExprKind::CeilDiv: 94 case AffineExprKind::Mod: 95 auto binOp = cast<AffineBinaryOpExpr>(); 96 auto lhs = binOp.getLHS(), rhs = binOp.getRHS(); 97 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 98 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements); 99 if (newLHS == lhs && newRHS == rhs) 100 return *this; 101 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS); 102 } 103 llvm_unreachable("Unknown AffineExpr"); 104 } 105 106 /// Returns true if this expression is made out of only symbols and 107 /// constants (no dimensional identifiers). 108 bool AffineExpr::isSymbolicOrConstant() const { 109 switch (getKind()) { 110 case AffineExprKind::Constant: 111 return true; 112 case AffineExprKind::DimId: 113 return false; 114 case AffineExprKind::SymbolId: 115 return true; 116 117 case AffineExprKind::Add: 118 case AffineExprKind::Mul: 119 case AffineExprKind::FloorDiv: 120 case AffineExprKind::CeilDiv: 121 case AffineExprKind::Mod: { 122 auto expr = this->cast<AffineBinaryOpExpr>(); 123 return expr.getLHS().isSymbolicOrConstant() && 124 expr.getRHS().isSymbolicOrConstant(); 125 } 126 } 127 llvm_unreachable("Unknown AffineExpr"); 128 } 129 130 /// Returns true if this is a pure affine expression, i.e., multiplication, 131 /// floordiv, ceildiv, and mod is only allowed w.r.t constants. 132 bool AffineExpr::isPureAffine() const { 133 switch (getKind()) { 134 case AffineExprKind::SymbolId: 135 case AffineExprKind::DimId: 136 case AffineExprKind::Constant: 137 return true; 138 case AffineExprKind::Add: { 139 auto op = cast<AffineBinaryOpExpr>(); 140 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine(); 141 } 142 143 case AffineExprKind::Mul: { 144 // TODO: Canonicalize the constants in binary operators to the RHS when 145 // possible, allowing this to merge into the next case. 146 auto op = cast<AffineBinaryOpExpr>(); 147 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() && 148 (op.getLHS().template isa<AffineConstantExpr>() || 149 op.getRHS().template isa<AffineConstantExpr>()); 150 } 151 case AffineExprKind::FloorDiv: 152 case AffineExprKind::CeilDiv: 153 case AffineExprKind::Mod: { 154 auto op = cast<AffineBinaryOpExpr>(); 155 return op.getLHS().isPureAffine() && 156 op.getRHS().template isa<AffineConstantExpr>(); 157 } 158 } 159 llvm_unreachable("Unknown AffineExpr"); 160 } 161 162 // Returns the greatest known integral divisor of this affine expression. 163 uint64_t AffineExpr::getLargestKnownDivisor() const { 164 AffineBinaryOpExpr binExpr(nullptr); 165 switch (getKind()) { 166 case AffineExprKind::SymbolId: 167 LLVM_FALLTHROUGH; 168 case AffineExprKind::DimId: 169 return 1; 170 case AffineExprKind::Constant: 171 return std::abs(this->cast<AffineConstantExpr>().getValue()); 172 case AffineExprKind::Mul: { 173 binExpr = this->cast<AffineBinaryOpExpr>(); 174 return binExpr.getLHS().getLargestKnownDivisor() * 175 binExpr.getRHS().getLargestKnownDivisor(); 176 } 177 case AffineExprKind::Add: 178 LLVM_FALLTHROUGH; 179 case AffineExprKind::FloorDiv: 180 case AffineExprKind::CeilDiv: 181 case AffineExprKind::Mod: { 182 binExpr = cast<AffineBinaryOpExpr>(); 183 return llvm::GreatestCommonDivisor64( 184 binExpr.getLHS().getLargestKnownDivisor(), 185 binExpr.getRHS().getLargestKnownDivisor()); 186 } 187 } 188 llvm_unreachable("Unknown AffineExpr"); 189 } 190 191 bool AffineExpr::isMultipleOf(int64_t factor) const { 192 AffineBinaryOpExpr binExpr(nullptr); 193 uint64_t l, u; 194 switch (getKind()) { 195 case AffineExprKind::SymbolId: 196 LLVM_FALLTHROUGH; 197 case AffineExprKind::DimId: 198 return factor * factor == 1; 199 case AffineExprKind::Constant: 200 return cast<AffineConstantExpr>().getValue() % factor == 0; 201 case AffineExprKind::Mul: { 202 binExpr = cast<AffineBinaryOpExpr>(); 203 // It's probably not worth optimizing this further (to not traverse the 204 // whole sub-tree under - it that would require a version of isMultipleOf 205 // that on a 'false' return also returns the largest known divisor). 206 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 || 207 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 || 208 (l * u) % factor == 0; 209 } 210 case AffineExprKind::Add: 211 case AffineExprKind::FloorDiv: 212 case AffineExprKind::CeilDiv: 213 case AffineExprKind::Mod: { 214 binExpr = cast<AffineBinaryOpExpr>(); 215 return llvm::GreatestCommonDivisor64( 216 binExpr.getLHS().getLargestKnownDivisor(), 217 binExpr.getRHS().getLargestKnownDivisor()) % 218 factor == 219 0; 220 } 221 } 222 llvm_unreachable("Unknown AffineExpr"); 223 } 224 225 bool AffineExpr::isFunctionOfDim(unsigned position) const { 226 if (getKind() == AffineExprKind::DimId) { 227 return *this == mlir::getAffineDimExpr(position, getContext()); 228 } 229 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) { 230 return expr.getLHS().isFunctionOfDim(position) || 231 expr.getRHS().isFunctionOfDim(position); 232 } 233 return false; 234 } 235 236 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr) 237 : AffineExpr(ptr) {} 238 AffineExpr AffineBinaryOpExpr::getLHS() const { 239 return static_cast<ImplType *>(expr)->lhs; 240 } 241 AffineExpr AffineBinaryOpExpr::getRHS() const { 242 return static_cast<ImplType *>(expr)->rhs; 243 } 244 245 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {} 246 unsigned AffineDimExpr::getPosition() const { 247 return static_cast<ImplType *>(expr)->position; 248 } 249 250 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position, 251 MLIRContext *context) { 252 auto assignCtx = [context](AffineDimExprStorage *storage) { 253 storage->context = context; 254 }; 255 256 StorageUniquer &uniquer = context->getAffineUniquer(); 257 return uniquer.get<AffineDimExprStorage>( 258 assignCtx, static_cast<unsigned>(kind), position); 259 } 260 261 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { 262 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context); 263 } 264 265 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr) 266 : AffineExpr(ptr) {} 267 unsigned AffineSymbolExpr::getPosition() const { 268 return static_cast<ImplType *>(expr)->position; 269 } 270 271 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) { 272 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context); 273 ; 274 } 275 276 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr) 277 : AffineExpr(ptr) {} 278 int64_t AffineConstantExpr::getValue() const { 279 return static_cast<ImplType *>(expr)->constant; 280 } 281 282 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { 283 auto assignCtx = [context](AffineConstantExprStorage *storage) { 284 storage->context = context; 285 }; 286 287 StorageUniquer &uniquer = context->getAffineUniquer(); 288 return uniquer.get<AffineConstantExprStorage>( 289 assignCtx, static_cast<unsigned>(AffineExprKind::Constant), constant); 290 } 291 292 /// Simplify add expression. Return nullptr if it can't be simplified. 293 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { 294 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 295 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 296 // Fold if both LHS, RHS are a constant. 297 if (lhsConst && rhsConst) 298 return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), 299 lhs.getContext()); 300 301 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). 302 // If only one of them is a symbolic expressions, make it the RHS. 303 if (lhs.isa<AffineConstantExpr>() || 304 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) { 305 return rhs + lhs; 306 } 307 308 // At this point, if there was a constant, it would be on the right. 309 310 // Addition with a zero is a noop, return the other input. 311 if (rhsConst) { 312 if (rhsConst.getValue() == 0) 313 return lhs; 314 } 315 // Fold successive additions like (d0 + 2) + 3 into d0 + 5. 316 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 317 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) { 318 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 319 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue()); 320 } 321 322 // When doing successive additions, bring constant to the right: turn (d0 + 2) 323 // + d1 into (d0 + d1) + 2. 324 if (lBin && lBin.getKind() == AffineExprKind::Add) { 325 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 326 return lBin.getLHS() + rhs + lrhs; 327 } 328 } 329 330 // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This 331 // leads to a much more efficient form when 'c' is a power of two, and in 332 // general a more compact and readable form. 333 334 // Process '(expr floordiv c) * (-c)'. 335 AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>(); 336 if (!rBinOpExpr) 337 return nullptr; 338 339 auto lrhs = rBinOpExpr.getLHS(); 340 auto rrhs = rBinOpExpr.getRHS(); 341 342 // Process lrhs, which is 'expr floordiv c'. 343 AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>(); 344 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv) 345 return nullptr; 346 347 auto llrhs = lrBinOpExpr.getLHS(); 348 auto rlrhs = lrBinOpExpr.getRHS(); 349 350 if (lhs == llrhs && rlrhs == -rrhs) { 351 return lhs % rlrhs; 352 } 353 return nullptr; 354 } 355 356 AffineExpr AffineExpr::operator+(int64_t v) const { 357 return *this + getAffineConstantExpr(v, getContext()); 358 } 359 AffineExpr AffineExpr::operator+(AffineExpr other) const { 360 if (auto simplified = simplifyAdd(*this, other)) 361 return simplified; 362 363 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 364 return uniquer.get<AffineBinaryOpExprStorage>( 365 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other); 366 } 367 368 /// Simplify a multiply expression. Return nullptr if it can't be simplified. 369 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { 370 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 371 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 372 373 if (lhsConst && rhsConst) 374 return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), 375 lhs.getContext()); 376 377 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant()); 378 379 // Canonicalize the mul expression so that the constant/symbolic term is the 380 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a 381 // constant. (Note that a constant is trivially symbolic). 382 if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) { 383 // At least one of them has to be symbolic. 384 return rhs * lhs; 385 } 386 387 // At this point, if there was a constant, it would be on the right. 388 389 // Multiplication with a one is a noop, return the other input. 390 if (rhsConst) { 391 if (rhsConst.getValue() == 1) 392 return lhs; 393 // Multiplication with zero. 394 if (rhsConst.getValue() == 0) 395 return rhsConst; 396 } 397 398 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6. 399 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 400 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) { 401 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) 402 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue()); 403 } 404 405 // When doing successive multiplication, bring constant to the right: turn (d0 406 // * 2) * d1 into (d0 * d1) * 2. 407 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 408 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 409 return (lBin.getLHS() * rhs) * lrhs; 410 } 411 } 412 413 return nullptr; 414 } 415 416 AffineExpr AffineExpr::operator*(int64_t v) const { 417 return *this * getAffineConstantExpr(v, getContext()); 418 } 419 AffineExpr AffineExpr::operator*(AffineExpr other) const { 420 if (auto simplified = simplifyMul(*this, other)) 421 return simplified; 422 423 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 424 return uniquer.get<AffineBinaryOpExprStorage>( 425 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other); 426 } 427 428 // Unary minus, delegate to operator*. 429 AffineExpr AffineExpr::operator-() const { 430 return *this * getAffineConstantExpr(-1, getContext()); 431 } 432 433 // Delegate to operator+. 434 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); } 435 AffineExpr AffineExpr::operator-(AffineExpr other) const { 436 return *this + (-other); 437 } 438 439 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { 440 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 441 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 442 443 if (!rhsConst || rhsConst.getValue() < 1) 444 return nullptr; 445 446 if (lhsConst) 447 return getAffineConstantExpr( 448 floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 449 450 // Fold floordiv of a multiply with a constant that is a multiple of the 451 // divisor. Eg: (i * 128) floordiv 64 = i * 2. 452 if (rhsConst.getValue() == 1) 453 return lhs; 454 455 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 456 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 457 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 458 // rhsConst is known to be positive if a constant. 459 if (lrhs.getValue() % rhsConst.getValue() == 0) 460 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 461 } 462 } 463 464 return nullptr; 465 } 466 467 AffineExpr AffineExpr::floorDiv(uint64_t v) const { 468 return floorDiv(getAffineConstantExpr(v, getContext())); 469 } 470 AffineExpr AffineExpr::floorDiv(AffineExpr other) const { 471 if (auto simplified = simplifyFloorDiv(*this, other)) 472 return simplified; 473 474 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 475 return uniquer.get<AffineBinaryOpExprStorage>( 476 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this, 477 other); 478 } 479 480 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { 481 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 482 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 483 484 if (!rhsConst || rhsConst.getValue() < 1) 485 return nullptr; 486 487 if (lhsConst) 488 return getAffineConstantExpr( 489 ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); 490 491 // Fold ceildiv of a multiply with a constant that is a multiple of the 492 // divisor. Eg: (i * 128) ceildiv 64 = i * 2. 493 if (rhsConst.getValue() == 1) 494 return lhs; 495 496 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); 497 if (lBin && lBin.getKind() == AffineExprKind::Mul) { 498 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { 499 // rhsConst is known to be positive if a constant. 500 if (lrhs.getValue() % rhsConst.getValue() == 0) 501 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); 502 } 503 } 504 505 return nullptr; 506 } 507 508 AffineExpr AffineExpr::ceilDiv(uint64_t v) const { 509 return ceilDiv(getAffineConstantExpr(v, getContext())); 510 } 511 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const { 512 if (auto simplified = simplifyCeilDiv(*this, other)) 513 return simplified; 514 515 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 516 return uniquer.get<AffineBinaryOpExprStorage>( 517 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this, 518 other); 519 } 520 521 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { 522 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); 523 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); 524 525 if (!rhsConst || rhsConst.getValue() < 1) 526 return nullptr; 527 528 if (lhsConst) 529 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), 530 lhs.getContext()); 531 532 // Fold modulo of an expression that is known to be a multiple of a constant 533 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) 534 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0. 535 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) 536 return getAffineConstantExpr(0, lhs.getContext()); 537 538 return nullptr; 539 // TODO(bondhugula): In general, this can be simplified more by using the GCD 540 // test, or in general using quantifier elimination (add two new variables q 541 // and r, and eliminate all variables from the linear system other than r. All 542 // of this can be done through mlir/Analysis/'s FlatAffineConstraints. 543 } 544 545 AffineExpr AffineExpr::operator%(uint64_t v) const { 546 return *this % getAffineConstantExpr(v, getContext()); 547 } 548 AffineExpr AffineExpr::operator%(AffineExpr other) const { 549 if (auto simplified = simplifyMod(*this, other)) 550 return simplified; 551 552 StorageUniquer &uniquer = getContext()->getAffineUniquer(); 553 return uniquer.get<AffineBinaryOpExprStorage>( 554 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other); 555 } 556 557 AffineExpr AffineExpr::compose(AffineMap map) const { 558 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(), 559 map.getResults().end()); 560 return replaceDimsAndSymbols(dimReplacements, {}); 561 } 562 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr &expr) { 563 expr.print(os); 564 return os; 565 } 566 567 /// Constructs an affine expression from a flat ArrayRef. If there are local 568 /// identifiers (neither dimensional nor symbolic) that appear in the sum of 569 /// products expression, 'localExprs' is expected to have the AffineExpr 570 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the 571 /// format [dims, symbols, locals, constant term]. 572 AffineExpr mlir::toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims, 573 unsigned numSymbols, 574 ArrayRef<AffineExpr> localExprs, 575 MLIRContext *context) { 576 // Assert expected numLocals = eq.size() - numDims - numSymbols - 1 577 assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() && 578 "unexpected number of local expressions"); 579 580 auto expr = getAffineConstantExpr(0, context); 581 // Dimensions and symbols. 582 for (unsigned j = 0; j < numDims + numSymbols; j++) { 583 if (eq[j] == 0) { 584 continue; 585 } 586 auto id = j < numDims ? getAffineDimExpr(j, context) 587 : getAffineSymbolExpr(j - numDims, context); 588 expr = expr + id * eq[j]; 589 } 590 591 // Local identifiers. 592 for (unsigned j = numDims + numSymbols, e = eq.size() - 1; j < e; j++) { 593 if (eq[j] == 0) { 594 continue; 595 } 596 auto term = localExprs[j - numDims - numSymbols] * eq[j]; 597 expr = expr + term; 598 } 599 600 // Constant term. 601 int64_t constTerm = eq[eq.size() - 1]; 602 if (constTerm != 0) 603 expr = expr + constTerm; 604 return expr; 605 } 606 607 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, 608 unsigned numSymbols) 609 : numDims(numDims), numSymbols(numSymbols), numLocals(0) { 610 operandExprStack.reserve(8); 611 } 612 613 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { 614 assert(operandExprStack.size() >= 2); 615 // This is a pure affine expr; the RHS will be a constant. 616 assert(expr.getRHS().isa<AffineConstantExpr>()); 617 // Get the RHS constant. 618 auto rhsConst = operandExprStack.back()[getConstantIndex()]; 619 operandExprStack.pop_back(); 620 // Update the LHS in place instead of pop and push. 621 auto &lhs = operandExprStack.back(); 622 for (unsigned i = 0, e = lhs.size(); i < e; i++) { 623 lhs[i] *= rhsConst; 624 } 625 } 626 627 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { 628 assert(operandExprStack.size() >= 2); 629 const auto &rhs = operandExprStack.back(); 630 auto &lhs = operandExprStack[operandExprStack.size() - 2]; 631 assert(lhs.size() == rhs.size()); 632 // Update the LHS in place. 633 for (unsigned i = 0, e = rhs.size(); i < e; i++) { 634 lhs[i] += rhs[i]; 635 } 636 // Pop off the RHS. 637 operandExprStack.pop_back(); 638 } 639 640 // 641 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 642 // 643 // A mod expression "expr mod c" is thus flattened by introducing a new local 644 // variable q (= expr floordiv c), such that expr mod c is replaced with 645 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. 646 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { 647 assert(operandExprStack.size() >= 2); 648 // This is a pure affine expr; the RHS will be a constant. 649 assert(expr.getRHS().isa<AffineConstantExpr>()); 650 auto rhsConst = operandExprStack.back()[getConstantIndex()]; 651 operandExprStack.pop_back(); 652 auto &lhs = operandExprStack.back(); 653 // TODO(bondhugula): handle modulo by zero case when this issue is fixed 654 // at the other places in the IR. 655 assert(rhsConst > 0 && "RHS constant has to be positive"); 656 657 // Check if the LHS expression is a multiple of modulo factor. 658 unsigned i, e; 659 for (i = 0, e = lhs.size(); i < e; i++) 660 if (lhs[i] % rhsConst != 0) 661 break; 662 // If yes, modulo expression here simplifies to zero. 663 if (i == lhs.size()) { 664 std::fill(lhs.begin(), lhs.end(), 0); 665 return; 666 } 667 668 // Add a local variable for the quotient, i.e., expr % c is replaced by 669 // (expr - q * c) where q = expr floordiv c. Do this while canceling out 670 // the GCD of expr and c. 671 SmallVector<int64_t, 8> floorDividend(lhs); 672 uint64_t gcd = rhsConst; 673 for (unsigned i = 0, e = lhs.size(); i < e; i++) 674 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 675 // Simplify the numerator and the denominator. 676 if (gcd != 1) { 677 for (unsigned i = 0, e = floorDividend.size(); i < e; i++) 678 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd); 679 } 680 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd); 681 682 // Construct the AffineExpr form of the floordiv to store in localExprs. 683 MLIRContext *context = expr.getContext(); 684 auto dividendExpr = 685 toAffineExpr(floorDividend, numDims, numSymbols, localExprs, context); 686 auto divisorExpr = getAffineConstantExpr(floorDivisor, context); 687 auto floorDivExpr = dividendExpr.floorDiv(divisorExpr); 688 int loc; 689 if ((loc = findLocalId(floorDivExpr)) == -1) { 690 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr); 691 // Set result at top of stack to "lhs - rhsConst * q". 692 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst; 693 } else { 694 // Reuse the existing local id. 695 lhs[getLocalVarStartIndex() + loc] = -rhsConst; 696 } 697 } 698 699 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { 700 visitDivExpr(expr, /*isCeil=*/true); 701 } 702 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { 703 visitDivExpr(expr, /*isCeil=*/false); 704 } 705 706 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { 707 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 708 auto &eq = operandExprStack.back(); 709 assert(expr.getPosition() < numDims && "Inconsistent number of dims"); 710 eq[getDimStartIndex() + expr.getPosition()] = 1; 711 } 712 713 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { 714 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 715 auto &eq = operandExprStack.back(); 716 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); 717 eq[getSymbolStartIndex() + expr.getPosition()] = 1; 718 } 719 720 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { 721 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); 722 auto &eq = operandExprStack.back(); 723 eq[getConstantIndex()] = expr.getValue(); 724 } 725 726 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1 727 // A floordiv is thus flattened by introducing a new local variable q, and 728 // replacing that expression with 'q' while adding the constraints 729 // c * q <= expr <= c * q + c - 1 to localVarCst (done by 730 // FlatAffineConstraints::addLocalFloorDiv). 731 // 732 // A ceildiv is similarly flattened: 733 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c 734 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, 735 bool isCeil) { 736 assert(operandExprStack.size() >= 2); 737 assert(expr.getRHS().isa<AffineConstantExpr>()); 738 739 // This is a pure affine expr; the RHS is a positive constant. 740 int64_t rhsConst = operandExprStack.back()[getConstantIndex()]; 741 // TODO(bondhugula): handle division by zero at the same time the issue is 742 // fixed at other places. 743 assert(rhsConst > 0 && "RHS constant has to be positive"); 744 operandExprStack.pop_back(); 745 auto &lhs = operandExprStack.back(); 746 747 // Simplify the floordiv, ceildiv if possible by canceling out the greatest 748 // common divisors of the numerator and denominator. 749 uint64_t gcd = std::abs(rhsConst); 750 for (unsigned i = 0, e = lhs.size(); i < e; i++) 751 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i])); 752 // Simplify the numerator and the denominator. 753 if (gcd != 1) { 754 for (unsigned i = 0, e = lhs.size(); i < e; i++) 755 lhs[i] = lhs[i] / static_cast<int64_t>(gcd); 756 } 757 int64_t divisor = rhsConst / static_cast<int64_t>(gcd); 758 // If the divisor becomes 1, the updated LHS is the result. (The 759 // divisor can't be negative since rhsConst is positive). 760 if (divisor == 1) 761 return; 762 763 // If the divisor cannot be simplified to one, we will have to retain 764 // the ceil/floor expr (simplified up until here). Add an existential 765 // quantifier to express its result, i.e., expr1 div expr2 is replaced 766 // by a new identifier, q. 767 MLIRContext *context = expr.getContext(); 768 auto a = toAffineExpr(lhs, numDims, numSymbols, localExprs, context); 769 auto b = getAffineConstantExpr(divisor, context); 770 771 int loc; 772 auto divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); 773 if ((loc = findLocalId(divExpr)) == -1) { 774 if (!isCeil) { 775 SmallVector<int64_t, 8> dividend(lhs); 776 addLocalFloorDivId(dividend, divisor, divExpr); 777 } else { 778 // lhs ceildiv c <=> (lhs + c - 1) floordiv c 779 SmallVector<int64_t, 8> dividend(lhs); 780 dividend.back() += divisor - 1; 781 addLocalFloorDivId(dividend, divisor, divExpr); 782 } 783 } 784 // Set the expression on stack to the local var introduced to capture the 785 // result of the division (floor or ceil). 786 std::fill(lhs.begin(), lhs.end(), 0); 787 if (loc == -1) 788 lhs[getLocalVarStartIndex() + numLocals - 1] = 1; 789 else 790 lhs[getLocalVarStartIndex() + loc] = 1; 791 } 792 793 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). 794 // The local identifier added is always a floordiv of a pure add/mul affine 795 // function of other identifiers, coefficients of which are specified in 796 // dividend and with respect to a positive constant divisor. localExpr is the 797 // simplified tree expression (AffineExpr) corresponding to the quantifier. 798 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend, 799 int64_t divisor, 800 AffineExpr localExpr) { 801 assert(divisor > 0 && "positive constant divisor expected"); 802 for (auto &subExpr : operandExprStack) 803 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0); 804 localExprs.push_back(localExpr); 805 numLocals++; 806 // dividend and divisor are not used here; an override of this method uses it. 807 } 808 809 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) { 810 SmallVectorImpl<AffineExpr>::iterator it; 811 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end()) 812 return -1; 813 return it - localExprs.begin(); 814 } 815 816 /// Simplify the affine expression by flattening it and reconstructing it. 817 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, 818 unsigned numSymbols) { 819 // TODO(bondhugula): only pure affine for now. The simplification here can 820 // be extended to semi-affine maps in the future. 821 if (!expr.isPureAffine()) 822 return expr; 823 824 SimpleAffineExprFlattener flattener(numDims, numSymbols); 825 flattener.walkPostOrder(expr); 826 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); 827 auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols, 828 flattener.localExprs, expr.getContext()); 829 flattener.operandExprStack.pop_back(); 830 assert(flattener.operandExprStack.empty()); 831 832 return simplifiedExpr; 833 } 834 835 // Flattens the expressions in map. Returns true on success or false 836 // if 'expr' was unable to be flattened (i.e., semi-affine expressions not 837 // handled yet). 838 static bool getFlattenedAffineExprs( 839 ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols, 840 std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) { 841 if (exprs.empty()) { 842 return true; 843 } 844 845 SimpleAffineExprFlattener flattener(numDims, numSymbols); 846 // Use the same flattener to simplify each expression successively. This way 847 // local identifiers / expressions are shared. 848 for (auto expr : exprs) { 849 if (!expr.isPureAffine()) 850 return false; 851 852 flattener.walkPostOrder(expr); 853 } 854 855 flattenedExprs->clear(); 856 assert(flattener.operandExprStack.size() == exprs.size()); 857 flattenedExprs->assign(flattener.operandExprStack.begin(), 858 flattener.operandExprStack.end()); 859 860 return true; 861 } 862 863 // Flattens 'expr' into 'flattenedExpr'. Returns true on success or false 864 // if 'expr' was unable to be flattened (semi-affine expressions not handled 865 // yet). 866 bool mlir::getFlattenedAffineExpr( 867 AffineExpr expr, unsigned numDims, unsigned numSymbols, 868 llvm::SmallVectorImpl<int64_t> *flattenedExpr) { 869 std::vector<SmallVector<int64_t, 8>> flattenedExprs; 870 bool ret = 871 ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs); 872 *flattenedExpr = flattenedExprs[0]; 873 return ret; 874 } 875 876 /// Flattens the expressions in map. Returns true on success or false 877 /// if 'expr' was unable to be flattened (i.e., semi-affine expressions not 878 /// handled yet). 879 bool mlir::getFlattenedAffineExprs( 880 AffineMap map, std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) { 881 if (map.getNumResults() == 0) { 882 return true; 883 } 884 return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), 885 map.getNumSymbols(), flattenedExprs); 886 } 887 888 bool mlir::getFlattenedAffineExprs( 889 IntegerSet set, 890 std::vector<llvm::SmallVector<int64_t, 8>> *flattenedExprs) { 891 if (set.getNumConstraints() == 0) { 892 return true; 893 } 894 return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), 895 set.getNumSymbols(), flattenedExprs); 896 }