github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp (about)

     1  //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file defines various operation fold utilities. These utilities are
    19  // intended to be used by passes to unify and simply their logic.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Transforms/FoldUtils.h"
    24  
    25  #include "mlir/Dialect/StandardOps/Ops.h"
    26  #include "mlir/IR/Builders.h"
    27  #include "mlir/IR/Matchers.h"
    28  #include "mlir/IR/Operation.h"
    29  
    30  using namespace mlir;
    31  
    32  /// Given an operation, find the parent region that folded constants should be
    33  /// inserted into.
    34  static Region *getInsertionRegion(Operation *op) {
    35    while (Region *region = op->getParentRegion()) {
    36      // Insert in this region for any of the following scenarios:
    37      //  * The parent is unregistered, or is known to be isolated from above.
    38      //  * The parent is a top-level operation.
    39      auto *parentOp = region->getParentOp();
    40      if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() ||
    41          !parentOp->getBlock())
    42        return region;
    43      // Traverse up the parent looking for an insertion region.
    44      op = parentOp;
    45    }
    46    llvm_unreachable("expected valid insertion region");
    47  }
    48  
    49  /// A utility function used to materialize a constant for a given attribute and
    50  /// type. On success, a valid constant value is returned. Otherwise, null is
    51  /// returned
    52  static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
    53                                        Attribute value, Type type,
    54                                        Location loc) {
    55    auto insertPt = builder.getInsertionPoint();
    56    (void)insertPt;
    57  
    58    // Ask the dialect to materialize a constant operation for this value.
    59    if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
    60      assert(insertPt == builder.getInsertionPoint());
    61      assert(matchPattern(constOp, m_Constant(&value)));
    62      return constOp;
    63    }
    64  
    65    // If the dialect is unable to materialize a constant, check to see if the
    66    // standard constant can be used.
    67    if (ConstantOp::isBuildableWith(value, type))
    68      return builder.create<ConstantOp>(loc, type, value);
    69    return nullptr;
    70  }
    71  
    72  //===----------------------------------------------------------------------===//
    73  // OperationFolder
    74  //===----------------------------------------------------------------------===//
    75  
    76  LogicalResult OperationFolder::tryToFold(
    77      Operation *op,
    78      llvm::function_ref<void(Operation *)> processGeneratedConstants,
    79      llvm::function_ref<void(Operation *)> preReplaceAction) {
    80    // If this is a unique'd constant, return failure as we know that it has
    81    // already been folded.
    82    if (referencedDialects.count(op))
    83      return failure();
    84  
    85    // Try to fold the operation.
    86    SmallVector<Value *, 8> results;
    87    if (failed(tryToFold(op, results, processGeneratedConstants)))
    88      return failure();
    89  
    90    // Constant folding succeeded. We will start replacing this op's uses and
    91    // eventually erase this op. Invoke the callback provided by the caller to
    92    // perform any pre-replacement action.
    93    if (preReplaceAction)
    94      preReplaceAction(op);
    95  
    96    // Check to see if the operation was just updated in place.
    97    if (results.empty())
    98      return success();
    99  
   100    // Otherwise, replace all of the result values and erase the operation.
   101    for (unsigned i = 0, e = results.size(); i != e; ++i)
   102      op->getResult(i)->replaceAllUsesWith(results[i]);
   103    op->erase();
   104    return success();
   105  }
   106  
   107  /// Notifies that the given constant `op` should be remove from this
   108  /// OperationFolder's internal bookkeeping.
   109  void OperationFolder::notifyRemoval(Operation *op) {
   110    // Check to see if this operation is uniqued within the folder.
   111    auto it = referencedDialects.find(op);
   112    if (it == referencedDialects.end())
   113      return;
   114  
   115    // Get the constant value for this operation, this is the value that was used
   116    // to unique the operation internally.
   117    Attribute constValue;
   118    matchPattern(op, m_Constant(&constValue));
   119    assert(constValue);
   120  
   121    // Get the constant map that this operation was uniqued in.
   122    auto &uniquedConstants = foldScopes[getInsertionRegion(op)];
   123  
   124    // Erase all of the references to this operation.
   125    auto type = op->getResult(0)->getType();
   126    for (auto *dialect : it->second)
   127      uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
   128    referencedDialects.erase(it);
   129  }
   130  
   131  /// Tries to perform folding on the given `op`. If successful, populates
   132  /// `results` with the results of the folding.
   133  LogicalResult OperationFolder::tryToFold(
   134      Operation *op, SmallVectorImpl<Value *> &results,
   135      llvm::function_ref<void(Operation *)> processGeneratedConstants) {
   136    SmallVector<Attribute, 8> operandConstants;
   137    SmallVector<OpFoldResult, 8> foldResults;
   138  
   139    // Check to see if any operands to the operation is constant and whether
   140    // the operation knows how to constant fold itself.
   141    operandConstants.assign(op->getNumOperands(), Attribute());
   142    for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
   143      matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
   144  
   145    // If this is a commutative binary operation with a constant on the left
   146    // side move it to the right side.
   147    if (operandConstants.size() == 2 && operandConstants[0] &&
   148        !operandConstants[1] && op->isCommutative()) {
   149      std::swap(op->getOpOperand(0), op->getOpOperand(1));
   150      std::swap(operandConstants[0], operandConstants[1]);
   151    }
   152  
   153    // Attempt to constant fold the operation.
   154    if (failed(op->fold(operandConstants, foldResults)))
   155      return failure();
   156  
   157    // Check to see if the operation was just updated in place.
   158    if (foldResults.empty())
   159      return success();
   160    assert(foldResults.size() == op->getNumResults());
   161  
   162    // Create a builder to insert new operations into the entry block of the
   163    // insertion region.
   164    auto *insertionRegion = getInsertionRegion(op);
   165    auto &entry = insertionRegion->front();
   166    OpBuilder builder(&entry, entry.begin());
   167  
   168    // Get the constant map for the insertion region of this operation.
   169    auto &uniquedConstants = foldScopes[insertionRegion];
   170  
   171    // Create the result constants and replace the results.
   172    auto *dialect = op->getDialect();
   173    for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
   174      assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
   175  
   176      // Check if the result was an SSA value.
   177      if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
   178        results.emplace_back(repl);
   179        continue;
   180      }
   181  
   182      // Check to see if there is a canonicalized version of this constant.
   183      auto *res = op->getResult(i);
   184      Attribute attrRepl = foldResults[i].get<Attribute>();
   185      if (auto *constOp =
   186              tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
   187                                     res->getType(), op->getLoc())) {
   188        results.push_back(constOp->getResult(0));
   189        continue;
   190      }
   191      // If materialization fails, cleanup any operations generated for the
   192      // previous results and return failure.
   193      for (Operation &op : llvm::make_early_inc_range(
   194               llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
   195        notifyRemoval(&op);
   196        op.erase();
   197      }
   198      return failure();
   199    }
   200  
   201    // Process any newly generated operations.
   202    if (processGeneratedConstants) {
   203      for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
   204        processGeneratedConstants(&*i);
   205    }
   206  
   207    return success();
   208  }
   209  
   210  /// Try to get or create a new constant entry. On success this returns the
   211  /// constant operation value, nullptr otherwise.
   212  Operation *OperationFolder::tryGetOrCreateConstant(
   213      ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
   214      Attribute value, Type type, Location loc) {
   215    // Check if an existing mapping already exists.
   216    auto constKey = std::make_tuple(dialect, value, type);
   217    auto *&constInst = uniquedConstants[constKey];
   218    if (constInst)
   219      return constInst;
   220  
   221    // If one doesn't exist, try to materialize one.
   222    if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
   223      return nullptr;
   224  
   225    // Check to see if the generated constant is in the expected dialect.
   226    auto *newDialect = constInst->getDialect();
   227    if (newDialect == dialect) {
   228      referencedDialects[constInst].push_back(dialect);
   229      return constInst;
   230    }
   231  
   232    // If it isn't, then we also need to make sure that the mapping for the new
   233    // dialect is valid.
   234    auto newKey = std::make_tuple(newDialect, value, type);
   235  
   236    // If an existing operation in the new dialect already exists, delete the
   237    // materialized operation in favor of the existing one.
   238    if (auto *existingOp = uniquedConstants.lookup(newKey)) {
   239      constInst->erase();
   240      referencedDialects[existingOp].push_back(dialect);
   241      return constInst = existingOp;
   242    }
   243  
   244    // Otherwise, update the new dialect to the materialized operation.
   245    referencedDialects[constInst].assign({dialect, newDialect});
   246    auto newIt = uniquedConstants.insert({newKey, constInst});
   247    return newIt.first->second;
   248  }