github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/AffineMap.cpp (about)

     1  //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/IR/AffineMap.h"
    19  #include "AffineMapDetail.h"
    20  #include "mlir/IR/AffineExpr.h"
    21  #include "mlir/IR/Attributes.h"
    22  #include "mlir/IR/StandardTypes.h"
    23  #include "mlir/Support/Functional.h"
    24  #include "mlir/Support/LogicalResult.h"
    25  #include "mlir/Support/MathExtras.h"
    26  #include "llvm/ADT/StringRef.h"
    27  #include "llvm/Support/raw_ostream.h"
    28  
    29  using namespace mlir;
    30  
    31  namespace {
    32  
    33  // AffineExprConstantFolder evaluates an affine expression using constant
    34  // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
    35  // representing the constant value of the affine expression evaluated on
    36  // constant 'operandConsts', or nullptr if it can't be folded.
    37  class AffineExprConstantFolder {
    38  public:
    39    AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
    40        : numDims(numDims), operandConsts(operandConsts) {}
    41  
    42    /// Attempt to constant fold the specified affine expr, or return null on
    43    /// failure.
    44    IntegerAttr constantFold(AffineExpr expr) {
    45      if (auto result = constantFoldImpl(expr))
    46        return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
    47      return nullptr;
    48    }
    49  
    50  private:
    51    llvm::Optional<int64_t> constantFoldImpl(AffineExpr expr) {
    52      switch (expr.getKind()) {
    53      case AffineExprKind::Add:
    54        return constantFoldBinExpr(
    55            expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
    56      case AffineExprKind::Mul:
    57        return constantFoldBinExpr(
    58            expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
    59      case AffineExprKind::Mod:
    60        return constantFoldBinExpr(
    61            expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
    62      case AffineExprKind::FloorDiv:
    63        return constantFoldBinExpr(
    64            expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
    65      case AffineExprKind::CeilDiv:
    66        return constantFoldBinExpr(
    67            expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
    68      case AffineExprKind::Constant:
    69        return expr.cast<AffineConstantExpr>().getValue();
    70      case AffineExprKind::DimId:
    71        if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
    72                            .dyn_cast_or_null<IntegerAttr>())
    73          return attr.getInt();
    74        return llvm::None;
    75      case AffineExprKind::SymbolId:
    76        if (auto attr = operandConsts[numDims +
    77                                      expr.cast<AffineSymbolExpr>().getPosition()]
    78                            .dyn_cast_or_null<IntegerAttr>())
    79          return attr.getInt();
    80        return llvm::None;
    81      }
    82      llvm_unreachable("Unknown AffineExpr");
    83    }
    84  
    85    // TODO: Change these to operate on APInts too.
    86    llvm::Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
    87                                                int64_t (*op)(int64_t, int64_t)) {
    88      auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
    89      if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
    90        if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
    91          return op(*lhs, *rhs);
    92      return llvm::None;
    93    }
    94  
    95    // The number of dimension operands in AffineMap containing this expression.
    96    unsigned numDims;
    97    // The constant valued operands used to evaluate this AffineExpr.
    98    ArrayRef<Attribute> operandConsts;
    99  };
   100  
   101  } // end anonymous namespace
   102  
   103  /// Returns a single constant result affine map.
   104  AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
   105    return get(/*dimCount=*/0, /*symbolCount=*/0,
   106               {getAffineConstantExpr(val, context)});
   107  }
   108  
   109  AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
   110                                              MLIRContext *context) {
   111    SmallVector<AffineExpr, 4> dimExprs;
   112    dimExprs.reserve(numDims);
   113    for (unsigned i = 0; i < numDims; ++i)
   114      dimExprs.push_back(mlir::getAffineDimExpr(i, context));
   115    return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs);
   116  }
   117  
   118  MLIRContext *AffineMap::getContext() const { return map->context; }
   119  
   120  bool AffineMap::isIdentity() const {
   121    if (getNumDims() != getNumResults())
   122      return false;
   123    ArrayRef<AffineExpr> results = getResults();
   124    for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
   125      auto expr = results[i].dyn_cast<AffineDimExpr>();
   126      if (!expr || expr.getPosition() != i)
   127        return false;
   128    }
   129    return true;
   130  }
   131  
   132  bool AffineMap::isEmpty() const {
   133    return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
   134  }
   135  
   136  bool AffineMap::isSingleConstant() const {
   137    return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
   138  }
   139  
   140  int64_t AffineMap::getSingleConstantResult() const {
   141    assert(isSingleConstant() && "map must have a single constant result");
   142    return getResult(0).cast<AffineConstantExpr>().getValue();
   143  }
   144  
   145  unsigned AffineMap::getNumDims() const {
   146    assert(map && "uninitialized map storage");
   147    return map->numDims;
   148  }
   149  unsigned AffineMap::getNumSymbols() const {
   150    assert(map && "uninitialized map storage");
   151    return map->numSymbols;
   152  }
   153  unsigned AffineMap::getNumResults() const {
   154    assert(map && "uninitialized map storage");
   155    return map->results.size();
   156  }
   157  unsigned AffineMap::getNumInputs() const {
   158    assert(map && "uninitialized map storage");
   159    return map->numDims + map->numSymbols;
   160  }
   161  
   162  ArrayRef<AffineExpr> AffineMap::getResults() const {
   163    assert(map && "uninitialized map storage");
   164    return map->results;
   165  }
   166  AffineExpr AffineMap::getResult(unsigned idx) const {
   167    assert(map && "uninitialized map storage");
   168    return map->results[idx];
   169  }
   170  
   171  /// Folds the results of the application of an affine map on the provided
   172  /// operands to a constant if possible. Returns false if the folding happens,
   173  /// true otherwise.
   174  LogicalResult
   175  AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
   176                          SmallVectorImpl<Attribute> &results) const {
   177    assert(getNumInputs() == operandConstants.size());
   178  
   179    // Fold each of the result expressions.
   180    AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
   181    // Constant fold each AffineExpr in AffineMap and add to 'results'.
   182    for (auto expr : getResults()) {
   183      auto folded = exprFolder.constantFold(expr);
   184      // If we didn't fold to a constant, then folding fails.
   185      if (!folded)
   186        return failure();
   187  
   188      results.push_back(folded);
   189    }
   190    assert(results.size() == getNumResults() &&
   191           "constant folding produced the wrong number of results");
   192    return success();
   193  }
   194  
   195  /// Walk all of the AffineExpr's in this mapping. Each node in an expression
   196  /// tree is visited in postorder.
   197  void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const {
   198    for (auto expr : getResults())
   199      expr.walk(callback);
   200  }
   201  
   202  /// This method substitutes any uses of dimensions and symbols (e.g.
   203  /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
   204  /// expression mapping.  Because this can be used to eliminate dims and
   205  /// symbols, the client needs to specify the number of dims and symbols in
   206  /// the result.  The returned map always has the same number of results.
   207  AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
   208                                             ArrayRef<AffineExpr> symReplacements,
   209                                             unsigned numResultDims,
   210                                             unsigned numResultSyms) {
   211    SmallVector<AffineExpr, 8> results;
   212    results.reserve(getNumResults());
   213    for (auto expr : getResults())
   214      results.push_back(
   215          expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
   216  
   217    return get(numResultDims, numResultSyms, results);
   218  }
   219  
   220  AffineMap AffineMap::compose(AffineMap map) {
   221    assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
   222    // Prepare `map` by concatenating the symbols and rewriting its exprs.
   223    unsigned numDims = map.getNumDims();
   224    unsigned numSymbolsThisMap = getNumSymbols();
   225    unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
   226    SmallVector<AffineExpr, 8> newDims(numDims);
   227    for (unsigned idx = 0; idx < numDims; ++idx) {
   228      newDims[idx] = getAffineDimExpr(idx, getContext());
   229    }
   230    SmallVector<AffineExpr, 8> newSymbols(numSymbols);
   231    for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
   232      newSymbols[idx - numSymbolsThisMap] =
   233          getAffineSymbolExpr(idx, getContext());
   234    }
   235    auto newMap =
   236        map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
   237    SmallVector<AffineExpr, 8> exprs;
   238    exprs.reserve(getResults().size());
   239    for (auto expr : getResults())
   240      exprs.push_back(expr.compose(newMap));
   241    return AffineMap::get(numDims, numSymbols, exprs);
   242  }
   243  
   244  bool AffineMap::isProjectedPermutation() {
   245    if (getNumSymbols() > 0)
   246      return false;
   247    SmallVector<bool, 8> seen(getNumInputs(), false);
   248    for (auto expr : getResults()) {
   249      if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
   250        if (seen[dim.getPosition()])
   251          return false;
   252        seen[dim.getPosition()] = true;
   253        continue;
   254      }
   255      return false;
   256    }
   257    return true;
   258  }
   259  
   260  bool AffineMap::isPermutation() {
   261    if (getNumDims() != getNumResults())
   262      return false;
   263    return isProjectedPermutation();
   264  }
   265  
   266  AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
   267    SmallVector<AffineExpr, 4> exprs;
   268    exprs.reserve(resultPos.size());
   269    for (auto idx : resultPos) {
   270      exprs.push_back(getResult(idx));
   271    }
   272    return AffineMap::get(getNumDims(), getNumSymbols(), exprs);
   273  }
   274  
   275  AffineMap mlir::simplifyAffineMap(AffineMap map) {
   276    SmallVector<AffineExpr, 8> exprs;
   277    for (auto e : map.getResults()) {
   278      exprs.push_back(
   279          simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
   280    }
   281    return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs);
   282  }
   283  
   284  AffineMap mlir::inversePermutation(AffineMap map) {
   285    if (!map)
   286      return map;
   287    assert(map.getNumSymbols() == 0 && "expected map without symbols");
   288    SmallVector<AffineExpr, 4> exprs(map.getNumDims());
   289    for (auto en : llvm::enumerate(map.getResults())) {
   290      auto expr = en.value();
   291      // Skip non-permutations.
   292      if (auto d = expr.dyn_cast<AffineDimExpr>()) {
   293        if (exprs[d.getPosition()])
   294          continue;
   295        exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
   296      }
   297    }
   298    SmallVector<AffineExpr, 4> seenExprs;
   299    seenExprs.reserve(map.getNumDims());
   300    for (auto expr : exprs)
   301      if (expr)
   302        seenExprs.push_back(expr);
   303    if (seenExprs.size() != map.getNumInputs())
   304      return AffineMap();
   305    return AffineMap::get(map.getNumResults(), 0, seenExprs);
   306  }
   307  
   308  AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
   309    unsigned numResults = 0;
   310    for (auto m : maps)
   311      numResults += m ? m.getNumResults() : 0;
   312    unsigned numDims = 0;
   313    llvm::SmallVector<AffineExpr, 8> results;
   314    results.reserve(numResults);
   315    for (auto m : maps) {
   316      if (!m)
   317        continue;
   318      assert(m.getNumSymbols() == 0 && "expected map without symbols");
   319      results.append(m.getResults().begin(), m.getResults().end());
   320      numDims = std::max(m.getNumDims(), numDims);
   321    }
   322    return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results);
   323  }