github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/CSE.cpp (about)

     1  //===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This transformation pass performs a simple common sub-expression elimination
    19  // algorithm on operations within a function.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Analysis/Dominance.h"
    24  #include "mlir/IR/Attributes.h"
    25  #include "mlir/IR/Builders.h"
    26  #include "mlir/IR/Function.h"
    27  #include "mlir/Pass/Pass.h"
    28  #include "mlir/Support/Functional.h"
    29  #include "mlir/Transforms/Passes.h"
    30  #include "mlir/Transforms/Utils.h"
    31  #include "llvm/ADT/DenseMapInfo.h"
    32  #include "llvm/ADT/Hashing.h"
    33  #include "llvm/ADT/ScopedHashTable.h"
    34  #include "llvm/Support/Allocator.h"
    35  #include "llvm/Support/RecyclingAllocator.h"
    36  #include <deque>
    37  using namespace mlir;
    38  
    39  namespace {
    40  // TODO(riverriddle) Handle commutative operations.
    41  struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
    42    static unsigned getHashValue(const Operation *opC) {
    43      auto *op = const_cast<Operation *>(opC);
    44      // Hash the operations based upon their:
    45      //   - Operation Name
    46      //   - Attributes
    47      //   - Result Types
    48      //   - Operands
    49      return hash_combine(
    50          op->getName(), op->getAttrs(),
    51          hash_combine_range(op->result_type_begin(), op->result_type_end()),
    52          hash_combine_range(op->operand_begin(), op->operand_end()));
    53    }
    54    static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
    55      auto *lhs = const_cast<Operation *>(lhsC);
    56      auto *rhs = const_cast<Operation *>(rhsC);
    57      if (lhs == rhs)
    58        return true;
    59      if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
    60          rhs == getTombstoneKey() || rhs == getEmptyKey())
    61        return false;
    62  
    63      // Compare the operation name.
    64      if (lhs->getName() != rhs->getName())
    65        return false;
    66      // Check operand and result type counts.
    67      if (lhs->getNumOperands() != rhs->getNumOperands() ||
    68          lhs->getNumResults() != rhs->getNumResults())
    69        return false;
    70      // Compare attributes.
    71      if (lhs->getAttrs() != rhs->getAttrs())
    72        return false;
    73      // Compare operands.
    74      if (!std::equal(lhs->operand_begin(), lhs->operand_end(),
    75                      rhs->operand_begin()))
    76        return false;
    77      // Compare result types.
    78      return std::equal(lhs->result_type_begin(), lhs->result_type_end(),
    79                        rhs->result_type_begin());
    80    }
    81  };
    82  } // end anonymous namespace
    83  
    84  namespace {
    85  /// Simple common sub-expression elimination.
    86  struct CSE : public FunctionPass<CSE> {
    87    CSE() = default;
    88    CSE(const CSE &) {}
    89  
    90    /// Shared implementation of operation elimination and scoped map definitions.
    91    using AllocatorTy = llvm::RecyclingAllocator<
    92        llvm::BumpPtrAllocator,
    93        llvm::ScopedHashTableVal<Operation *, Operation *>>;
    94    using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
    95                                              SimpleOperationInfo, AllocatorTy>;
    96  
    97    /// Represents a single entry in the depth first traversal of a CFG.
    98    struct CFGStackNode {
    99      CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
   100          : scope(knownValues), node(node), childIterator(node->begin()),
   101            processed(false) {}
   102  
   103      /// Scope for the known values.
   104      ScopedMapTy::ScopeTy scope;
   105  
   106      DominanceInfoNode *node;
   107      DominanceInfoNode::iterator childIterator;
   108  
   109      /// If this node has been fully processed yet or not.
   110      bool processed;
   111    };
   112  
   113    /// Attempt to eliminate a redundant operation. Returns success if the
   114    /// operation was marked for removal, failure otherwise.
   115    LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op);
   116  
   117    void simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
   118                       Block *bb);
   119    void simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
   120                        Region &region);
   121  
   122    void runOnFunction() override;
   123  
   124  private:
   125    /// Operations marked as dead and to be erased.
   126    std::vector<Operation *> opsToErase;
   127  };
   128  } // end anonymous namespace
   129  
   130  /// Attempt to eliminate a redundant operation.
   131  LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
   132    // Don't simplify operations with nested blocks. We don't currently model
   133    // equality comparisons correctly among other things. It is also unclear
   134    // whether we would want to CSE such operations.
   135    if (op->getNumRegions() != 0)
   136      return failure();
   137  
   138    // TODO(riverriddle) We currently only eliminate non side-effecting
   139    // operations.
   140    if (!op->hasNoSideEffect())
   141      return failure();
   142  
   143    // If the operation is already trivially dead just add it to the erase list.
   144    if (op->use_empty()) {
   145      opsToErase.push_back(op);
   146      return success();
   147    }
   148  
   149    // Look for an existing definition for the operation.
   150    if (auto *existing = knownValues.lookup(op)) {
   151      // If we find one then replace all uses of the current operation with the
   152      // existing one and mark it for deletion.
   153      op->replaceAllUsesWith(existing);
   154      opsToErase.push_back(op);
   155  
   156      // If the existing operation has an unknown location and the current
   157      // operation doesn't, then set the existing op's location to that of the
   158      // current op.
   159      if (existing->getLoc().isa<UnknownLoc>() &&
   160          !op->getLoc().isa<UnknownLoc>()) {
   161        existing->setLoc(op->getLoc());
   162      }
   163      return success();
   164    }
   165  
   166    // Otherwise, we add this operation to the known values map.
   167    knownValues.insert(op, op);
   168    return failure();
   169  }
   170  
   171  void CSE::simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
   172                          Block *bb) {
   173    for (auto &inst : *bb) {
   174      // If the operation is simplified, we don't process any held regions.
   175      if (succeeded(simplifyOperation(knownValues, &inst)))
   176        continue;
   177  
   178      // If this operation is isolated above, we can't process nested regions with
   179      // the given 'knownValues' map. This would cause the insertion of implicit
   180      // captures in explicit capture only regions.
   181      if (!inst.isRegistered() || inst.isKnownIsolatedFromAbove()) {
   182        ScopedMapTy nestedKnownValues;
   183        for (auto &region : inst.getRegions())
   184          simplifyRegion(nestedKnownValues, domInfo, region);
   185        continue;
   186      }
   187  
   188      // Otherwise, process nested regions normally.
   189      for (auto &region : inst.getRegions())
   190        simplifyRegion(knownValues, domInfo, region);
   191    }
   192  }
   193  
   194  void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
   195                           Region &region) {
   196    // If the region is empty there is nothing to do.
   197    if (region.empty())
   198      return;
   199  
   200    // If the region only contains one block, then simplify it directly.
   201    if (std::next(region.begin()) == region.end()) {
   202      ScopedMapTy::ScopeTy scope(knownValues);
   203      simplifyBlock(knownValues, domInfo, &region.front());
   204      return;
   205    }
   206  
   207    // Note, deque is being used here because there was significant performance
   208    // gains over vector when the container becomes very large due to the
   209    // specific access patterns. If/when these performance issues are no
   210    // longer a problem we can change this to vector. For more information see
   211    // the llvm mailing list discussion on this:
   212    // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
   213    std::deque<std::unique_ptr<CFGStackNode>> stack;
   214  
   215    // Process the nodes of the dom tree for this region.
   216    stack.emplace_back(std::make_unique<CFGStackNode>(
   217        knownValues, domInfo.getRootNode(&region)));
   218  
   219    while (!stack.empty()) {
   220      auto &currentNode = stack.back();
   221  
   222      // Check to see if we need to process this node.
   223      if (!currentNode->processed) {
   224        currentNode->processed = true;
   225        simplifyBlock(knownValues, domInfo, currentNode->node->getBlock());
   226      }
   227  
   228      // Otherwise, check to see if we need to process a child node.
   229      if (currentNode->childIterator != currentNode->node->end()) {
   230        auto *childNode = *(currentNode->childIterator++);
   231        stack.emplace_back(
   232            std::make_unique<CFGStackNode>(knownValues, childNode));
   233      } else {
   234        // Finally, if the node and all of its children have been processed
   235        // then we delete the node.
   236        stack.pop_back();
   237      }
   238    }
   239  }
   240  
   241  void CSE::runOnFunction() {
   242    /// A scoped hash table of defining operations within a function.
   243    ScopedMapTy knownValues;
   244    simplifyRegion(knownValues, getAnalysis<DominanceInfo>(),
   245                   getFunction().getBody());
   246  
   247    // If no operations were erased, then we mark all analyses as preserved.
   248    if (opsToErase.empty())
   249      return markAllAnalysesPreserved();
   250  
   251    /// Erase any operations that were marked as dead during simplification.
   252    for (auto *op : opsToErase)
   253      op->erase();
   254    opsToErase.clear();
   255  
   256    // We currently don't remove region operations, so mark dominance as
   257    // preserved.
   258    markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
   259  }
   260  
   261  std::unique_ptr<FunctionPassBase> mlir::createCSEPass() {
   262    return std::make_unique<CSE>();
   263  }
   264  
   265  static PassRegistration<CSE>
   266      pass("cse", "Eliminate common sub-expressions in functions");