github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/SDBM/SDBMExpr.cpp (about)

     1  //===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // A striped difference-bound matrix (SDBM) expression is a constant expression,
    19  // an identifier, a binary expression with constant RHS and +, stripe operators
    20  // or a difference expression between two identifiers.
    21  //
    22  //===----------------------------------------------------------------------===//
    23  
    24  #include "mlir/Dialect/SDBM/SDBMExpr.h"
    25  #include "SDBMExprDetail.h"
    26  #include "mlir/Dialect/SDBM/SDBMDialect.h"
    27  #include "mlir/IR/AffineExpr.h"
    28  #include "mlir/IR/AffineExprVisitor.h"
    29  
    30  #include "llvm/Support/raw_ostream.h"
    31  
    32  using namespace mlir;
    33  
    34  namespace {
    35  /// A simple compositional matcher for AffineExpr
    36  ///
    37  /// Example usage:
    38  ///
    39  /// ```c++
    40  ///    AffineExprMatcher x, C, m;
    41  ///    AffineExprMatcher pattern1 = ((x % C) * m) + x;
    42  ///    AffineExprMatcher pattern2 = x + ((x % C) * m);
    43  ///    if (pattern1.match(expr) || pattern2.match(expr)) {
    44  ///      ...
    45  ///    }
    46  /// ```
    47  class AffineExprMatcherStorage;
    48  class AffineExprMatcher {
    49  public:
    50    AffineExprMatcher();
    51    AffineExprMatcher(const AffineExprMatcher &other);
    52  
    53    AffineExprMatcher operator+(AffineExprMatcher other) {
    54      return AffineExprMatcher(AffineExprKind::Add, *this, other);
    55    }
    56    AffineExprMatcher operator*(AffineExprMatcher other) {
    57      return AffineExprMatcher(AffineExprKind::Mul, *this, other);
    58    }
    59    AffineExprMatcher floorDiv(AffineExprMatcher other) {
    60      return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
    61    }
    62    AffineExprMatcher ceilDiv(AffineExprMatcher other) {
    63      return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
    64    }
    65    AffineExprMatcher operator%(AffineExprMatcher other) {
    66      return AffineExprMatcher(AffineExprKind::Mod, *this, other);
    67    }
    68  
    69    AffineExpr match(AffineExpr expr);
    70    AffineExpr matched();
    71    Optional<int> getMatchedConstantValue();
    72  
    73  private:
    74    AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
    75    AffineExprKind kind; // only used to match in binary op cases.
    76    // A shared_ptr allows multiple references to same matcher storage without
    77    // worrying about ownership or dealing with an arena. To be cleaned up if we
    78    // go with this.
    79    std::shared_ptr<AffineExprMatcherStorage> storage;
    80  };
    81  
    82  class AffineExprMatcherStorage {
    83  public:
    84    AffineExprMatcherStorage() {}
    85    AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
    86        : subExprs(other.subExprs.begin(), other.subExprs.end()),
    87          matched(other.matched) {}
    88    AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
    89        : subExprs(exprs.begin(), exprs.end()) {}
    90    AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
    91        : subExprs({a, b}) {}
    92    llvm::SmallVector<AffineExprMatcher, 0> subExprs;
    93    AffineExpr matched;
    94  };
    95  } // namespace
    96  
    97  AffineExprMatcher::AffineExprMatcher()
    98      : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
    99  
   100  AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
   101      : kind(other.kind), storage(other.storage) {}
   102  
   103  Optional<int> AffineExprMatcher::getMatchedConstantValue() {
   104    if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
   105      return cst.getValue();
   106    return None;
   107  }
   108  
   109  AffineExpr AffineExprMatcher::match(AffineExpr expr) {
   110    if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
   111      if (storage->matched)
   112        if (storage->matched != expr)
   113          return AffineExpr();
   114      storage->matched = expr;
   115      return storage->matched;
   116    }
   117    if (kind != expr.getKind()) {
   118      return AffineExpr();
   119    }
   120    if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
   121      if (!storage->subExprs.empty() &&
   122          !storage->subExprs[0].match(bin.getLHS())) {
   123        return AffineExpr();
   124      }
   125      if (!storage->subExprs.empty() &&
   126          !storage->subExprs[1].match(bin.getRHS())) {
   127        return AffineExpr();
   128      }
   129      if (storage->matched)
   130        if (storage->matched != expr)
   131          return AffineExpr();
   132      storage->matched = expr;
   133      return storage->matched;
   134    }
   135    llvm_unreachable("binary expected");
   136  }
   137  
   138  AffineExpr AffineExprMatcher::matched() { return storage->matched; }
   139  
   140  AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
   141                                       AffineExprMatcher b)
   142      : kind(k), storage(new AffineExprMatcherStorage(a, b)) {
   143    storage->subExprs.push_back(a);
   144    storage->subExprs.push_back(b);
   145  }
   146  
   147  //===----------------------------------------------------------------------===//
   148  // SDBMExpr
   149  //===----------------------------------------------------------------------===//
   150  
   151  SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
   152  
   153  MLIRContext *SDBMExpr::getContext() const {
   154    return impl->dialect->getContext();
   155  }
   156  
   157  SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
   158  
   159  void SDBMExpr::print(raw_ostream &os) const {
   160    struct Printer : public SDBMVisitor<Printer> {
   161      Printer(raw_ostream &ostream) : prn(ostream) {}
   162  
   163      void visitSum(SDBMSumExpr expr) {
   164        visitVarying(expr.getLHS());
   165        prn << " + ";
   166        visitConstant(expr.getRHS());
   167      }
   168      void visitDiff(SDBMDiffExpr expr) {
   169        visitPositive(expr.getLHS());
   170        prn << " - ";
   171        visitPositive(expr.getRHS());
   172      }
   173      void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
   174      void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
   175      void visitStripe(SDBMStripeExpr expr) {
   176        visitPositive(expr.getVar());
   177        prn << " # ";
   178        visitConstant(expr.getStripeFactor());
   179      }
   180      void visitNeg(SDBMNegExpr expr) {
   181        prn << '-';
   182        visitPositive(expr.getVar());
   183      }
   184      void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
   185  
   186      raw_ostream &prn;
   187    };
   188    Printer printer(os);
   189    printer.visit(*this);
   190  }
   191  
   192  void SDBMExpr::dump() const {
   193    print(llvm::errs());
   194    llvm::errs() << '\n';
   195  }
   196  
   197  namespace {
   198  // Helper class to perform negation of an SDBM expression.
   199  struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
   200    // Any positive expression is wrapped into a negation expression.
   201    //  -(x) = -x
   202    SDBMExpr visitPositive(SDBMPositiveExpr expr) {
   203      return SDBMNegExpr::get(expr);
   204    }
   205    // A negation expression is unwrapped.
   206    //  -(-x) = x
   207    SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
   208    // The value of the constant is negated.
   209    SDBMExpr visitConstant(SDBMConstantExpr expr) {
   210      return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
   211    }
   212    // Both terms of the sum are negated recursively.
   213    SDBMExpr visitSum(SDBMSumExpr expr) {
   214      return SDBMSumExpr::get(visit(expr.getLHS()).cast<SDBMVaryingExpr>(),
   215                              visit(expr.getRHS()).cast<SDBMConstantExpr>());
   216    }
   217    // Terms of a difference are interchanged.
   218    //  -(x - y) = y - x
   219    SDBMExpr visitDiff(SDBMDiffExpr expr) {
   220      return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS());
   221    }
   222  };
   223  } // namespace
   224  
   225  SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
   226  
   227  //===----------------------------------------------------------------------===//
   228  // SDBMSumExpr
   229  //===----------------------------------------------------------------------===//
   230  
   231  SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
   232    assert(lhs && "expected SDBM variable expression");
   233    assert(rhs && "expected SDBM constant");
   234  
   235    // If LHS of a sum is another sum, fold the constant RHS parts.
   236    if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
   237      lhs = lhsSum.getLHS();
   238      rhs = SDBMConstantExpr::get(rhs.getDialect(),
   239                                  rhs.getValue() + lhsSum.getRHS().getValue());
   240    }
   241  
   242    StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
   243    return uniquer.get<detail::SDBMBinaryExprStorage>(
   244        /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
   245  }
   246  
   247  SDBMVaryingExpr SDBMSumExpr::getLHS() const {
   248    return static_cast<ImplType *>(impl)->lhs;
   249  }
   250  
   251  SDBMConstantExpr SDBMSumExpr::getRHS() const {
   252    return static_cast<ImplType *>(impl)->rhs;
   253  }
   254  
   255  AffineExpr SDBMExpr::getAsAffineExpr() const {
   256    struct Converter : public SDBMVisitor<Converter, AffineExpr> {
   257      AffineExpr visitSum(SDBMSumExpr expr) {
   258        AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
   259        return lhs + rhs;
   260      }
   261  
   262      AffineExpr visitStripe(SDBMStripeExpr expr) {
   263        AffineExpr lhs = visit(expr.getVar()),
   264                   rhs = visit(expr.getStripeFactor());
   265        return lhs - (lhs % rhs);
   266      }
   267  
   268      AffineExpr visitDiff(SDBMDiffExpr expr) {
   269        AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
   270        return lhs - rhs;
   271      }
   272  
   273      AffineExpr visitDim(SDBMDimExpr expr) {
   274        return getAffineDimExpr(expr.getPosition(), expr.getContext());
   275      }
   276  
   277      AffineExpr visitSymbol(SDBMSymbolExpr expr) {
   278        return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
   279      }
   280  
   281      AffineExpr visitNeg(SDBMNegExpr expr) {
   282        return getAffineBinaryOpExpr(AffineExprKind::Mul,
   283                                     getAffineConstantExpr(-1, expr.getContext()),
   284                                     visit(expr.getVar()));
   285      }
   286  
   287      AffineExpr visitConstant(SDBMConstantExpr expr) {
   288        return getAffineConstantExpr(expr.getValue(), expr.getContext());
   289      }
   290    } converter;
   291    return converter.visit(*this);
   292  }
   293  
   294  Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
   295    struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
   296      SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
   297        // Attempt to recover a stripe expression.  Because AffineExprs don't have
   298        // a first-class difference kind, we check for both x + -1 * (x mod C) and
   299        // -1 * (x mod C) + x cases.
   300        AffineExprMatcher x, C, m;
   301        AffineExprMatcher pattern1 = ((x % C) * m) + x;
   302        AffineExprMatcher pattern2 = x + ((x % C) * m);
   303        if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) ||
   304            (pattern2.match(expr) && m.getMatchedConstantValue() == -1)) {
   305          if (auto convertedLHS = visit(x.matched())) {
   306            // TODO(ntv): return convertedLHS.stripe(C);
   307            return SDBMStripeExpr::get(
   308                convertedLHS.cast<SDBMPositiveExpr>(),
   309                visit(C.matched()).cast<SDBMConstantExpr>());
   310          }
   311        }
   312        auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
   313        if (!lhs || !rhs)
   314          return {};
   315  
   316        // In a "add" AffineExpr, the constant always appears on the right.  If
   317        // there were two constants, they would have been folded away.
   318        assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
   319        auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
   320  
   321        // SDBM accepts LHS variables and RHS constants in a sum.
   322        auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
   323        auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
   324        if (rhsConstant && lhsVar)
   325          return SDBMSumExpr::get(lhsVar, rhsConstant);
   326  
   327        // The sum of a negated variable and a non-negated variable is a
   328        // difference, supported as a special kind in SDBM.  Because AffineExprs
   329        // don't have first-class difference kind, check both LHS and RHS for
   330        // negation.
   331        auto lhsPos = lhs.dyn_cast<SDBMPositiveExpr>();
   332        auto rhsPos = rhs.dyn_cast<SDBMPositiveExpr>();
   333        auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
   334        auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
   335        if (lhsNeg && rhsVar)
   336          return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
   337        if (rhsNeg && lhsVar)
   338          return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
   339  
   340        // Other cases don't fit into SDBM.
   341        return {};
   342      }
   343  
   344      SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
   345        // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
   346        AffineExprMatcher x, C;
   347        AffineExprMatcher pattern = (x.floorDiv(C)) * C;
   348        if (pattern.match(expr)) {
   349          if (SDBMExpr converted = visit(x.matched())) {
   350            if (auto varConverted = converted.dyn_cast<SDBMPositiveExpr>())
   351              // TODO(ntv): return varConverted.stripe(C.getConstantValue());
   352              return SDBMStripeExpr::get(
   353                  varConverted,
   354                  SDBMConstantExpr::get(dialect,
   355                                        C.getMatchedConstantValue().getValue()));
   356          }
   357        }
   358  
   359        auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
   360        if (!lhs || !rhs)
   361          return {};
   362  
   363        // In a "mul" AffineExpr, the constant always appears on the right.  If
   364        // there were two constants, they would have been folded away.
   365        assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
   366        auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
   367        if (!rhsConstant)
   368          return {};
   369  
   370        // The only supported "multiplication" expression is an SDBM is dimension
   371        // negation, that is a product of dimension and constant -1.
   372        auto lhsVar = lhs.dyn_cast<SDBMPositiveExpr>();
   373        if (lhsVar && rhsConstant.getValue() == -1)
   374          return SDBMNegExpr::get(lhsVar);
   375  
   376        // Other multiplications are not allowed in SDBM.
   377        return {};
   378      }
   379  
   380      SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
   381        auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
   382        if (!lhs || !rhs)
   383          return {};
   384  
   385        // 'mod' can only be converted to SDBM if its LHS is a variable
   386        // and its RHS is a constant.  Then it `x mod c = x - x stripe c`.
   387        auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
   388        auto lhsVar = rhs.dyn_cast<SDBMPositiveExpr>();
   389        if (!lhsVar || !rhsConstant)
   390          return {};
   391        return SDBMDiffExpr::get(lhsVar,
   392                                 SDBMStripeExpr::get(lhsVar, rhsConstant));
   393      }
   394  
   395      // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
   396      SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
   397      SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
   398  
   399      // Dimensions, symbols and constants are converted trivially.
   400      SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
   401        return SDBMConstantExpr::get(dialect, expr.getValue());
   402      }
   403      SDBMExpr visitDimExpr(AffineDimExpr expr) {
   404        return SDBMDimExpr::get(dialect, expr.getPosition());
   405      }
   406      SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
   407        return SDBMSymbolExpr::get(dialect, expr.getPosition());
   408      }
   409  
   410      SDBMDialect *dialect;
   411    } converter;
   412    converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
   413  
   414    if (auto result = converter.visit(affine))
   415      return result;
   416    return None;
   417  }
   418  
   419  //===----------------------------------------------------------------------===//
   420  // SDBMDiffExpr
   421  //===----------------------------------------------------------------------===//
   422  
   423  SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
   424    assert(lhs && "expected SDBM dimension");
   425    assert(rhs && "expected SDBM dimension");
   426  
   427    StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
   428    return uniquer.get<detail::SDBMDiffExprStorage>(
   429        /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
   430  }
   431  
   432  SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
   433    return static_cast<ImplType *>(impl)->lhs;
   434  }
   435  
   436  SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
   437    return static_cast<ImplType *>(impl)->rhs;
   438  }
   439  
   440  //===----------------------------------------------------------------------===//
   441  // SDBMStripeExpr
   442  //===----------------------------------------------------------------------===//
   443  
   444  SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
   445                                     SDBMConstantExpr stripeFactor) {
   446    assert(var && "expected SDBM variable expression");
   447    assert(stripeFactor && "expected non-null stripe factor");
   448    if (stripeFactor.getValue() <= 0)
   449      llvm::report_fatal_error("non-positive stripe factor");
   450  
   451    StorageUniquer &uniquer = var.getDialect()->getUniquer();
   452    return uniquer.get<detail::SDBMBinaryExprStorage>(
   453        /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
   454        stripeFactor);
   455  }
   456  
   457  SDBMPositiveExpr SDBMStripeExpr::getVar() const {
   458    if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
   459      return lhs.cast<SDBMPositiveExpr>();
   460    return {};
   461  }
   462  
   463  SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
   464    return static_cast<ImplType *>(impl)->rhs;
   465  }
   466  
   467  //===----------------------------------------------------------------------===//
   468  // SDBMInputExpr
   469  //===----------------------------------------------------------------------===//
   470  
   471  unsigned SDBMInputExpr::getPosition() const {
   472    return static_cast<ImplType *>(impl)->position;
   473  }
   474  
   475  //===----------------------------------------------------------------------===//
   476  // SDBMDimExpr
   477  //===----------------------------------------------------------------------===//
   478  
   479  SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
   480    assert(dialect && "expected non-null dialect");
   481  
   482    auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
   483      storage->dialect = dialect;
   484    };
   485  
   486    StorageUniquer &uniquer = dialect->getUniquer();
   487    return uniquer.get<detail::SDBMPositiveExprStorage>(
   488        assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
   489  }
   490  
   491  //===----------------------------------------------------------------------===//
   492  // SDBMSymbolExpr
   493  //===----------------------------------------------------------------------===//
   494  
   495  SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
   496    assert(dialect && "expected non-null dialect");
   497  
   498    auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
   499      storage->dialect = dialect;
   500    };
   501  
   502    StorageUniquer &uniquer = dialect->getUniquer();
   503    return uniquer.get<detail::SDBMPositiveExprStorage>(
   504        assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
   505  }
   506  
   507  //===----------------------------------------------------------------------===//
   508  // SDBMConstantExpr
   509  //===----------------------------------------------------------------------===//
   510  
   511  SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
   512    assert(dialect && "expected non-null dialect");
   513  
   514    auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
   515      storage->dialect = dialect;
   516    };
   517  
   518    StorageUniquer &uniquer = dialect->getUniquer();
   519    return uniquer.get<detail::SDBMConstantExprStorage>(
   520        assignCtx, static_cast<unsigned>(SDBMExprKind::Constant), value);
   521  }
   522  
   523  int64_t SDBMConstantExpr::getValue() const {
   524    return static_cast<ImplType *>(impl)->constant;
   525  }
   526  
   527  //===----------------------------------------------------------------------===//
   528  // SDBMNegExpr
   529  //===----------------------------------------------------------------------===//
   530  
   531  SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
   532    assert(var && "expected non-null SDBM variable expression");
   533  
   534    StorageUniquer &uniquer = var.getDialect()->getUniquer();
   535    return uniquer.get<detail::SDBMNegExprStorage>(
   536        /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
   537  }
   538  
   539  SDBMPositiveExpr SDBMNegExpr::getVar() const {
   540    return static_cast<ImplType *>(impl)->dim;
   541  }
   542  
   543  namespace mlir {
   544  namespace ops_assertions {
   545  
   546  SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
   547    // If one of the operands is a negation, take a difference rather than a sum.
   548    auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
   549    auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
   550    assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of "
   551                                  "a sum of variables and not a correct SDBM");
   552    if (lhsNeg)
   553      return rhs - lhsNeg.getVar();
   554    if (rhsNeg)
   555      return lhs - rhsNeg.getVar();
   556  
   557    // If LHS is a constant and RHS is not, swap the order to get into a supported
   558    // sum case.  From now on, RHS must be a constant.
   559    auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
   560    auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
   561    if (!rhsConstant && lhsConstant) {
   562      std::swap(lhs, rhs);
   563      std::swap(lhsConstant, rhsConstant);
   564    }
   565    assert(rhsConstant && "at least one operand must be a constant");
   566  
   567    // If LHS is another sum, first compute the sum of its variable
   568    // part with the other argument and then add the constant part to enable
   569    // constant folding (the variable part may, e.g., be a negation that requires
   570    // to enter this function again).
   571    auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
   572    if (lhsSum)
   573      return lhsSum.getLHS() +
   574             (lhsSum.getRHS().getValue() + rhsConstant.getValue());
   575  
   576    // Constant-fold if LHS is a constant.
   577    if (lhsConstant)
   578      return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
   579                                                         rhsConstant.getValue());
   580  
   581    // Fold x + 0 == x.
   582    if (rhsConstant.getValue() == 0)
   583      return lhs;
   584  
   585    return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(),
   586                            rhs.cast<SDBMConstantExpr>());
   587  }
   588  
   589  SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
   590    // Fold x - x == 0.
   591    if (lhs == rhs)
   592      return SDBMConstantExpr::get(lhs.getDialect(), 0);
   593  
   594    // LHS and RHS may be constants.
   595    auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
   596    auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
   597  
   598    // Constant fold if both LHS and RHS are constants.
   599    if (lhsConstant && rhsConstant)
   600      return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
   601                                                         rhsConstant.getValue());
   602  
   603    // Replace a difference with a sum with a negated value if one of LHS and RHS
   604    // is a constant:
   605    //   x - C == x + (-C);
   606    //   C - x == -x + C.
   607    // This calls into operator+ for further simplification.
   608    if (rhsConstant)
   609      return lhs + (-rhsConstant);
   610    if (lhsConstant)
   611      return -rhs + lhsConstant;
   612  
   613    // Hoist constant factors outside the difference if any of sides is a sum:
   614    //   (x + A) - (y - B) == x - y + (A - B).
   615    // If either LHS or RHS is a sum, collect the constant values separately and
   616    // update LHS and RHS to point to the variable part of the sum.
   617    auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
   618    auto rhsSum = rhs.dyn_cast<SDBMSumExpr>();
   619    int64_t value = 0;
   620    if (lhsSum) {
   621      value += lhsSum.getRHS().getValue();
   622      lhs = lhsSum.getLHS();
   623    }
   624    if (rhsSum) {
   625      value -= rhsSum.getRHS().getValue();
   626      rhs = rhsSum.getLHS();
   627    }
   628  
   629    // This calls into operator+ for futher simplification in case value == 0.
   630    return SDBMDiffExpr::get(lhs.cast<SDBMPositiveExpr>(),
   631                             rhs.cast<SDBMPositiveExpr>()) +
   632           value;
   633  }
   634  
   635  SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
   636    auto constantFactor = factor.cast<SDBMConstantExpr>();
   637    assert(constantFactor.getValue() > 0 && "non-positive stripe");
   638  
   639    // Fold x # 1 = x.
   640    if (constantFactor.getValue() == 1)
   641      return expr;
   642  
   643    return SDBMStripeExpr::get(expr.cast<SDBMPositiveExpr>(), constantFactor);
   644  }
   645  
   646  } // namespace ops_assertions
   647  } // namespace mlir