github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/SDBM/SDBMExpr.cpp (about) 1 //===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // A striped difference-bound matrix (SDBM) expression is a constant expression, 19 // an identifier, a binary expression with constant RHS and +, stripe operators 20 // or a difference expression between two identifiers. 21 // 22 //===----------------------------------------------------------------------===// 23 24 #include "mlir/Dialect/SDBM/SDBMExpr.h" 25 #include "SDBMExprDetail.h" 26 #include "mlir/Dialect/SDBM/SDBMDialect.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineExprVisitor.h" 29 30 #include "llvm/Support/raw_ostream.h" 31 32 using namespace mlir; 33 34 namespace { 35 /// A simple compositional matcher for AffineExpr 36 /// 37 /// Example usage: 38 /// 39 /// ```c++ 40 /// AffineExprMatcher x, C, m; 41 /// AffineExprMatcher pattern1 = ((x % C) * m) + x; 42 /// AffineExprMatcher pattern2 = x + ((x % C) * m); 43 /// if (pattern1.match(expr) || pattern2.match(expr)) { 44 /// ... 45 /// } 46 /// ``` 47 class AffineExprMatcherStorage; 48 class AffineExprMatcher { 49 public: 50 AffineExprMatcher(); 51 AffineExprMatcher(const AffineExprMatcher &other); 52 53 AffineExprMatcher operator+(AffineExprMatcher other) { 54 return AffineExprMatcher(AffineExprKind::Add, *this, other); 55 } 56 AffineExprMatcher operator*(AffineExprMatcher other) { 57 return AffineExprMatcher(AffineExprKind::Mul, *this, other); 58 } 59 AffineExprMatcher floorDiv(AffineExprMatcher other) { 60 return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other); 61 } 62 AffineExprMatcher ceilDiv(AffineExprMatcher other) { 63 return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other); 64 } 65 AffineExprMatcher operator%(AffineExprMatcher other) { 66 return AffineExprMatcher(AffineExprKind::Mod, *this, other); 67 } 68 69 AffineExpr match(AffineExpr expr); 70 AffineExpr matched(); 71 Optional<int> getMatchedConstantValue(); 72 73 private: 74 AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b); 75 AffineExprKind kind; // only used to match in binary op cases. 76 // A shared_ptr allows multiple references to same matcher storage without 77 // worrying about ownership or dealing with an arena. To be cleaned up if we 78 // go with this. 79 std::shared_ptr<AffineExprMatcherStorage> storage; 80 }; 81 82 class AffineExprMatcherStorage { 83 public: 84 AffineExprMatcherStorage() {} 85 AffineExprMatcherStorage(const AffineExprMatcherStorage &other) 86 : subExprs(other.subExprs.begin(), other.subExprs.end()), 87 matched(other.matched) {} 88 AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs) 89 : subExprs(exprs.begin(), exprs.end()) {} 90 AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b) 91 : subExprs({a, b}) {} 92 llvm::SmallVector<AffineExprMatcher, 0> subExprs; 93 AffineExpr matched; 94 }; 95 } // namespace 96 97 AffineExprMatcher::AffineExprMatcher() 98 : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {} 99 100 AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other) 101 : kind(other.kind), storage(other.storage) {} 102 103 Optional<int> AffineExprMatcher::getMatchedConstantValue() { 104 if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>()) 105 return cst.getValue(); 106 return None; 107 } 108 109 AffineExpr AffineExprMatcher::match(AffineExpr expr) { 110 if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) { 111 if (storage->matched) 112 if (storage->matched != expr) 113 return AffineExpr(); 114 storage->matched = expr; 115 return storage->matched; 116 } 117 if (kind != expr.getKind()) { 118 return AffineExpr(); 119 } 120 if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) { 121 if (!storage->subExprs.empty() && 122 !storage->subExprs[0].match(bin.getLHS())) { 123 return AffineExpr(); 124 } 125 if (!storage->subExprs.empty() && 126 !storage->subExprs[1].match(bin.getRHS())) { 127 return AffineExpr(); 128 } 129 if (storage->matched) 130 if (storage->matched != expr) 131 return AffineExpr(); 132 storage->matched = expr; 133 return storage->matched; 134 } 135 llvm_unreachable("binary expected"); 136 } 137 138 AffineExpr AffineExprMatcher::matched() { return storage->matched; } 139 140 AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, 141 AffineExprMatcher b) 142 : kind(k), storage(new AffineExprMatcherStorage(a, b)) { 143 storage->subExprs.push_back(a); 144 storage->subExprs.push_back(b); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // SDBMExpr 149 //===----------------------------------------------------------------------===// 150 151 SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); } 152 153 MLIRContext *SDBMExpr::getContext() const { 154 return impl->dialect->getContext(); 155 } 156 157 SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; } 158 159 void SDBMExpr::print(raw_ostream &os) const { 160 struct Printer : public SDBMVisitor<Printer> { 161 Printer(raw_ostream &ostream) : prn(ostream) {} 162 163 void visitSum(SDBMSumExpr expr) { 164 visitVarying(expr.getLHS()); 165 prn << " + "; 166 visitConstant(expr.getRHS()); 167 } 168 void visitDiff(SDBMDiffExpr expr) { 169 visitPositive(expr.getLHS()); 170 prn << " - "; 171 visitPositive(expr.getRHS()); 172 } 173 void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } 174 void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } 175 void visitStripe(SDBMStripeExpr expr) { 176 visitPositive(expr.getVar()); 177 prn << " # "; 178 visitConstant(expr.getStripeFactor()); 179 } 180 void visitNeg(SDBMNegExpr expr) { 181 prn << '-'; 182 visitPositive(expr.getVar()); 183 } 184 void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } 185 186 raw_ostream &prn; 187 }; 188 Printer printer(os); 189 printer.visit(*this); 190 } 191 192 void SDBMExpr::dump() const { 193 print(llvm::errs()); 194 llvm::errs() << '\n'; 195 } 196 197 namespace { 198 // Helper class to perform negation of an SDBM expression. 199 struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> { 200 // Any positive expression is wrapped into a negation expression. 201 // -(x) = -x 202 SDBMExpr visitPositive(SDBMPositiveExpr expr) { 203 return SDBMNegExpr::get(expr); 204 } 205 // A negation expression is unwrapped. 206 // -(-x) = x 207 SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); } 208 // The value of the constant is negated. 209 SDBMExpr visitConstant(SDBMConstantExpr expr) { 210 return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue()); 211 } 212 // Both terms of the sum are negated recursively. 213 SDBMExpr visitSum(SDBMSumExpr expr) { 214 return SDBMSumExpr::get(visit(expr.getLHS()).cast<SDBMVaryingExpr>(), 215 visit(expr.getRHS()).cast<SDBMConstantExpr>()); 216 } 217 // Terms of a difference are interchanged. 218 // -(x - y) = y - x 219 SDBMExpr visitDiff(SDBMDiffExpr expr) { 220 return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS()); 221 } 222 }; 223 } // namespace 224 225 SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); } 226 227 //===----------------------------------------------------------------------===// 228 // SDBMSumExpr 229 //===----------------------------------------------------------------------===// 230 231 SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) { 232 assert(lhs && "expected SDBM variable expression"); 233 assert(rhs && "expected SDBM constant"); 234 235 // If LHS of a sum is another sum, fold the constant RHS parts. 236 if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) { 237 lhs = lhsSum.getLHS(); 238 rhs = SDBMConstantExpr::get(rhs.getDialect(), 239 rhs.getValue() + lhsSum.getRHS().getValue()); 240 } 241 242 StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); 243 return uniquer.get<detail::SDBMBinaryExprStorage>( 244 /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs); 245 } 246 247 SDBMVaryingExpr SDBMSumExpr::getLHS() const { 248 return static_cast<ImplType *>(impl)->lhs; 249 } 250 251 SDBMConstantExpr SDBMSumExpr::getRHS() const { 252 return static_cast<ImplType *>(impl)->rhs; 253 } 254 255 AffineExpr SDBMExpr::getAsAffineExpr() const { 256 struct Converter : public SDBMVisitor<Converter, AffineExpr> { 257 AffineExpr visitSum(SDBMSumExpr expr) { 258 AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); 259 return lhs + rhs; 260 } 261 262 AffineExpr visitStripe(SDBMStripeExpr expr) { 263 AffineExpr lhs = visit(expr.getVar()), 264 rhs = visit(expr.getStripeFactor()); 265 return lhs - (lhs % rhs); 266 } 267 268 AffineExpr visitDiff(SDBMDiffExpr expr) { 269 AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); 270 return lhs - rhs; 271 } 272 273 AffineExpr visitDim(SDBMDimExpr expr) { 274 return getAffineDimExpr(expr.getPosition(), expr.getContext()); 275 } 276 277 AffineExpr visitSymbol(SDBMSymbolExpr expr) { 278 return getAffineSymbolExpr(expr.getPosition(), expr.getContext()); 279 } 280 281 AffineExpr visitNeg(SDBMNegExpr expr) { 282 return getAffineBinaryOpExpr(AffineExprKind::Mul, 283 getAffineConstantExpr(-1, expr.getContext()), 284 visit(expr.getVar())); 285 } 286 287 AffineExpr visitConstant(SDBMConstantExpr expr) { 288 return getAffineConstantExpr(expr.getValue(), expr.getContext()); 289 } 290 } converter; 291 return converter.visit(*this); 292 } 293 294 Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) { 295 struct Converter : public AffineExprVisitor<Converter, SDBMExpr> { 296 SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) { 297 // Attempt to recover a stripe expression. Because AffineExprs don't have 298 // a first-class difference kind, we check for both x + -1 * (x mod C) and 299 // -1 * (x mod C) + x cases. 300 AffineExprMatcher x, C, m; 301 AffineExprMatcher pattern1 = ((x % C) * m) + x; 302 AffineExprMatcher pattern2 = x + ((x % C) * m); 303 if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) || 304 (pattern2.match(expr) && m.getMatchedConstantValue() == -1)) { 305 if (auto convertedLHS = visit(x.matched())) { 306 // TODO(ntv): return convertedLHS.stripe(C); 307 return SDBMStripeExpr::get( 308 convertedLHS.cast<SDBMPositiveExpr>(), 309 visit(C.matched()).cast<SDBMConstantExpr>()); 310 } 311 } 312 auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); 313 if (!lhs || !rhs) 314 return {}; 315 316 // In a "add" AffineExpr, the constant always appears on the right. If 317 // there were two constants, they would have been folded away. 318 assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression"); 319 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>(); 320 321 // SDBM accepts LHS variables and RHS constants in a sum. 322 auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>(); 323 auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>(); 324 if (rhsConstant && lhsVar) 325 return SDBMSumExpr::get(lhsVar, rhsConstant); 326 327 // The sum of a negated variable and a non-negated variable is a 328 // difference, supported as a special kind in SDBM. Because AffineExprs 329 // don't have first-class difference kind, check both LHS and RHS for 330 // negation. 331 auto lhsPos = lhs.dyn_cast<SDBMPositiveExpr>(); 332 auto rhsPos = rhs.dyn_cast<SDBMPositiveExpr>(); 333 auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>(); 334 auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>(); 335 if (lhsNeg && rhsVar) 336 return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar()); 337 if (rhsNeg && lhsVar) 338 return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar()); 339 340 // Other cases don't fit into SDBM. 341 return {}; 342 } 343 344 SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) { 345 // Attempt to recover a stripe expression "x # C = (x floordiv C) * C". 346 AffineExprMatcher x, C; 347 AffineExprMatcher pattern = (x.floorDiv(C)) * C; 348 if (pattern.match(expr)) { 349 if (SDBMExpr converted = visit(x.matched())) { 350 if (auto varConverted = converted.dyn_cast<SDBMPositiveExpr>()) 351 // TODO(ntv): return varConverted.stripe(C.getConstantValue()); 352 return SDBMStripeExpr::get( 353 varConverted, 354 SDBMConstantExpr::get(dialect, 355 C.getMatchedConstantValue().getValue())); 356 } 357 } 358 359 auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); 360 if (!lhs || !rhs) 361 return {}; 362 363 // In a "mul" AffineExpr, the constant always appears on the right. If 364 // there were two constants, they would have been folded away. 365 assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression"); 366 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>(); 367 if (!rhsConstant) 368 return {}; 369 370 // The only supported "multiplication" expression is an SDBM is dimension 371 // negation, that is a product of dimension and constant -1. 372 auto lhsVar = lhs.dyn_cast<SDBMPositiveExpr>(); 373 if (lhsVar && rhsConstant.getValue() == -1) 374 return SDBMNegExpr::get(lhsVar); 375 376 // Other multiplications are not allowed in SDBM. 377 return {}; 378 } 379 380 SDBMExpr visitModExpr(AffineBinaryOpExpr expr) { 381 auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS()); 382 if (!lhs || !rhs) 383 return {}; 384 385 // 'mod' can only be converted to SDBM if its LHS is a variable 386 // and its RHS is a constant. Then it `x mod c = x - x stripe c`. 387 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>(); 388 auto lhsVar = rhs.dyn_cast<SDBMPositiveExpr>(); 389 if (!lhsVar || !rhsConstant) 390 return {}; 391 return SDBMDiffExpr::get(lhsVar, 392 SDBMStripeExpr::get(lhsVar, rhsConstant)); 393 } 394 395 // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM 396 SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; } 397 SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; } 398 399 // Dimensions, symbols and constants are converted trivially. 400 SDBMExpr visitConstantExpr(AffineConstantExpr expr) { 401 return SDBMConstantExpr::get(dialect, expr.getValue()); 402 } 403 SDBMExpr visitDimExpr(AffineDimExpr expr) { 404 return SDBMDimExpr::get(dialect, expr.getPosition()); 405 } 406 SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) { 407 return SDBMSymbolExpr::get(dialect, expr.getPosition()); 408 } 409 410 SDBMDialect *dialect; 411 } converter; 412 converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>(); 413 414 if (auto result = converter.visit(affine)) 415 return result; 416 return None; 417 } 418 419 //===----------------------------------------------------------------------===// 420 // SDBMDiffExpr 421 //===----------------------------------------------------------------------===// 422 423 SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) { 424 assert(lhs && "expected SDBM dimension"); 425 assert(rhs && "expected SDBM dimension"); 426 427 StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); 428 return uniquer.get<detail::SDBMDiffExprStorage>( 429 /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs); 430 } 431 432 SDBMPositiveExpr SDBMDiffExpr::getLHS() const { 433 return static_cast<ImplType *>(impl)->lhs; 434 } 435 436 SDBMPositiveExpr SDBMDiffExpr::getRHS() const { 437 return static_cast<ImplType *>(impl)->rhs; 438 } 439 440 //===----------------------------------------------------------------------===// 441 // SDBMStripeExpr 442 //===----------------------------------------------------------------------===// 443 444 SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var, 445 SDBMConstantExpr stripeFactor) { 446 assert(var && "expected SDBM variable expression"); 447 assert(stripeFactor && "expected non-null stripe factor"); 448 if (stripeFactor.getValue() <= 0) 449 llvm::report_fatal_error("non-positive stripe factor"); 450 451 StorageUniquer &uniquer = var.getDialect()->getUniquer(); 452 return uniquer.get<detail::SDBMBinaryExprStorage>( 453 /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var, 454 stripeFactor); 455 } 456 457 SDBMPositiveExpr SDBMStripeExpr::getVar() const { 458 if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs) 459 return lhs.cast<SDBMPositiveExpr>(); 460 return {}; 461 } 462 463 SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const { 464 return static_cast<ImplType *>(impl)->rhs; 465 } 466 467 //===----------------------------------------------------------------------===// 468 // SDBMInputExpr 469 //===----------------------------------------------------------------------===// 470 471 unsigned SDBMInputExpr::getPosition() const { 472 return static_cast<ImplType *>(impl)->position; 473 } 474 475 //===----------------------------------------------------------------------===// 476 // SDBMDimExpr 477 //===----------------------------------------------------------------------===// 478 479 SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) { 480 assert(dialect && "expected non-null dialect"); 481 482 auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { 483 storage->dialect = dialect; 484 }; 485 486 StorageUniquer &uniquer = dialect->getUniquer(); 487 return uniquer.get<detail::SDBMPositiveExprStorage>( 488 assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position); 489 } 490 491 //===----------------------------------------------------------------------===// 492 // SDBMSymbolExpr 493 //===----------------------------------------------------------------------===// 494 495 SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) { 496 assert(dialect && "expected non-null dialect"); 497 498 auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) { 499 storage->dialect = dialect; 500 }; 501 502 StorageUniquer &uniquer = dialect->getUniquer(); 503 return uniquer.get<detail::SDBMPositiveExprStorage>( 504 assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position); 505 } 506 507 //===----------------------------------------------------------------------===// 508 // SDBMConstantExpr 509 //===----------------------------------------------------------------------===// 510 511 SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) { 512 assert(dialect && "expected non-null dialect"); 513 514 auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) { 515 storage->dialect = dialect; 516 }; 517 518 StorageUniquer &uniquer = dialect->getUniquer(); 519 return uniquer.get<detail::SDBMConstantExprStorage>( 520 assignCtx, static_cast<unsigned>(SDBMExprKind::Constant), value); 521 } 522 523 int64_t SDBMConstantExpr::getValue() const { 524 return static_cast<ImplType *>(impl)->constant; 525 } 526 527 //===----------------------------------------------------------------------===// 528 // SDBMNegExpr 529 //===----------------------------------------------------------------------===// 530 531 SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) { 532 assert(var && "expected non-null SDBM variable expression"); 533 534 StorageUniquer &uniquer = var.getDialect()->getUniquer(); 535 return uniquer.get<detail::SDBMNegExprStorage>( 536 /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var); 537 } 538 539 SDBMPositiveExpr SDBMNegExpr::getVar() const { 540 return static_cast<ImplType *>(impl)->dim; 541 } 542 543 namespace mlir { 544 namespace ops_assertions { 545 546 SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { 547 // If one of the operands is a negation, take a difference rather than a sum. 548 auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>(); 549 auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>(); 550 assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of " 551 "a sum of variables and not a correct SDBM"); 552 if (lhsNeg) 553 return rhs - lhsNeg.getVar(); 554 if (rhsNeg) 555 return lhs - rhsNeg.getVar(); 556 557 // If LHS is a constant and RHS is not, swap the order to get into a supported 558 // sum case. From now on, RHS must be a constant. 559 auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>(); 560 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>(); 561 if (!rhsConstant && lhsConstant) { 562 std::swap(lhs, rhs); 563 std::swap(lhsConstant, rhsConstant); 564 } 565 assert(rhsConstant && "at least one operand must be a constant"); 566 567 // If LHS is another sum, first compute the sum of its variable 568 // part with the other argument and then add the constant part to enable 569 // constant folding (the variable part may, e.g., be a negation that requires 570 // to enter this function again). 571 auto lhsSum = lhs.dyn_cast<SDBMSumExpr>(); 572 if (lhsSum) 573 return lhsSum.getLHS() + 574 (lhsSum.getRHS().getValue() + rhsConstant.getValue()); 575 576 // Constant-fold if LHS is a constant. 577 if (lhsConstant) 578 return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() + 579 rhsConstant.getValue()); 580 581 // Fold x + 0 == x. 582 if (rhsConstant.getValue() == 0) 583 return lhs; 584 585 return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(), 586 rhs.cast<SDBMConstantExpr>()); 587 } 588 589 SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { 590 // Fold x - x == 0. 591 if (lhs == rhs) 592 return SDBMConstantExpr::get(lhs.getDialect(), 0); 593 594 // LHS and RHS may be constants. 595 auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>(); 596 auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>(); 597 598 // Constant fold if both LHS and RHS are constants. 599 if (lhsConstant && rhsConstant) 600 return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() - 601 rhsConstant.getValue()); 602 603 // Replace a difference with a sum with a negated value if one of LHS and RHS 604 // is a constant: 605 // x - C == x + (-C); 606 // C - x == -x + C. 607 // This calls into operator+ for further simplification. 608 if (rhsConstant) 609 return lhs + (-rhsConstant); 610 if (lhsConstant) 611 return -rhs + lhsConstant; 612 613 // Hoist constant factors outside the difference if any of sides is a sum: 614 // (x + A) - (y - B) == x - y + (A - B). 615 // If either LHS or RHS is a sum, collect the constant values separately and 616 // update LHS and RHS to point to the variable part of the sum. 617 auto lhsSum = lhs.dyn_cast<SDBMSumExpr>(); 618 auto rhsSum = rhs.dyn_cast<SDBMSumExpr>(); 619 int64_t value = 0; 620 if (lhsSum) { 621 value += lhsSum.getRHS().getValue(); 622 lhs = lhsSum.getLHS(); 623 } 624 if (rhsSum) { 625 value -= rhsSum.getRHS().getValue(); 626 rhs = rhsSum.getLHS(); 627 } 628 629 // This calls into operator+ for futher simplification in case value == 0. 630 return SDBMDiffExpr::get(lhs.cast<SDBMPositiveExpr>(), 631 rhs.cast<SDBMPositiveExpr>()) + 632 value; 633 } 634 635 SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { 636 auto constantFactor = factor.cast<SDBMConstantExpr>(); 637 assert(constantFactor.getValue() > 0 && "non-positive stripe"); 638 639 // Fold x # 1 = x. 640 if (constantFactor.getValue() == 1) 641 return expr; 642 643 return SDBMStripeExpr::get(expr.cast<SDBMPositiveExpr>(), constantFactor); 644 } 645 646 } // namespace ops_assertions 647 } // namespace mlir