github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LoopUnroll.cpp (about)

     1  //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements loop unrolling.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Transforms/Passes.h"
    23  
    24  #include "mlir/Analysis/LoopAnalysis.h"
    25  #include "mlir/Dialect/AffineOps/AffineOps.h"
    26  #include "mlir/IR/AffineExpr.h"
    27  #include "mlir/IR/AffineMap.h"
    28  #include "mlir/IR/Builders.h"
    29  #include "mlir/Pass/Pass.h"
    30  #include "mlir/Transforms/LoopUtils.h"
    31  #include "llvm/ADT/DenseMap.h"
    32  #include "llvm/Support/CommandLine.h"
    33  #include "llvm/Support/Debug.h"
    34  
    35  using namespace mlir;
    36  
    37  #define DEBUG_TYPE "affine-loop-unroll"
    38  
    39  static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
    40  
    41  // Loop unrolling factor.
    42  static llvm::cl::opt<unsigned> clUnrollFactor(
    43      "unroll-factor",
    44      llvm::cl::desc("Use this unroll factor for all loops being unrolled"),
    45      llvm::cl::cat(clOptionsCategory));
    46  
    47  static llvm::cl::opt<bool> clUnrollFull("unroll-full",
    48                                          llvm::cl::desc("Fully unroll loops"),
    49                                          llvm::cl::cat(clOptionsCategory));
    50  
    51  static llvm::cl::opt<unsigned> clUnrollNumRepetitions(
    52      "unroll-num-reps",
    53      llvm::cl::desc("Unroll innermost loops repeatedly this many times"),
    54      llvm::cl::cat(clOptionsCategory));
    55  
    56  static llvm::cl::opt<unsigned> clUnrollFullThreshold(
    57      "unroll-full-threshold", llvm::cl::Hidden,
    58      llvm::cl::desc(
    59          "Unroll all loops with trip count less than or equal to this"),
    60      llvm::cl::cat(clOptionsCategory));
    61  
    62  namespace {
    63  /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
    64  /// full unroll threshold was specified, in which case, fully unrolls all loops
    65  /// with trip count less than the specified threshold. The latter is for testing
    66  /// purposes, especially for testing outer loop unrolling.
    67  struct LoopUnroll : public FunctionPass<LoopUnroll> {
    68    const Optional<unsigned> unrollFactor;
    69    const Optional<bool> unrollFull;
    70    // Callback to obtain unroll factors; if this has a callable target, takes
    71    // precedence over command-line argument or passed argument.
    72    const std::function<unsigned(AffineForOp)> getUnrollFactor;
    73  
    74    explicit LoopUnroll(
    75        Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
    76        const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
    77        : unrollFactor(unrollFactor), unrollFull(unrollFull),
    78          getUnrollFactor(getUnrollFactor) {}
    79  
    80    void runOnFunction() override;
    81  
    82    /// Unroll this for op. Returns failure if nothing was done.
    83    LogicalResult runOnAffineForOp(AffineForOp forOp);
    84  
    85    static const unsigned kDefaultUnrollFactor = 4;
    86  };
    87  } // end anonymous namespace
    88  
    89  void LoopUnroll::runOnFunction() {
    90    // Gathers all innermost loops through a post order pruned walk.
    91    struct InnermostLoopGatherer {
    92      // Store innermost loops as we walk.
    93      std::vector<AffineForOp> loops;
    94  
    95      void walkPostOrder(FuncOp f) {
    96        for (auto &b : f)
    97          walkPostOrder(b.begin(), b.end());
    98      }
    99  
   100      bool walkPostOrder(Block::iterator Start, Block::iterator End) {
   101        bool hasInnerLoops = false;
   102        // We need to walk all elements since all innermost loops need to be
   103        // gathered as opposed to determining whether this list has any inner
   104        // loops or not.
   105        while (Start != End)
   106          hasInnerLoops |= walkPostOrder(&(*Start++));
   107        return hasInnerLoops;
   108      }
   109      bool walkPostOrder(Operation *opInst) {
   110        bool hasInnerLoops = false;
   111        for (auto &region : opInst->getRegions())
   112          for (auto &block : region)
   113            hasInnerLoops |= walkPostOrder(block.begin(), block.end());
   114        if (isa<AffineForOp>(opInst)) {
   115          if (!hasInnerLoops)
   116            loops.push_back(cast<AffineForOp>(opInst));
   117          return true;
   118        }
   119        return hasInnerLoops;
   120      }
   121    };
   122  
   123    if (clUnrollFull.getNumOccurrences() > 0 &&
   124        clUnrollFullThreshold.getNumOccurrences() > 0) {
   125      // Store short loops as we walk.
   126      std::vector<AffineForOp> loops;
   127  
   128      // Gathers all loops with trip count <= minTripCount. Do a post order walk
   129      // so that loops are gathered from innermost to outermost (or else unrolling
   130      // an outer one may delete gathered inner ones).
   131      getFunction().walk([&](AffineForOp forOp) {
   132        Optional<uint64_t> tripCount = getConstantTripCount(forOp);
   133        if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
   134          loops.push_back(forOp);
   135      });
   136      for (auto forOp : loops)
   137        loopUnrollFull(forOp);
   138      return;
   139    }
   140  
   141    unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0
   142                                  ? clUnrollNumRepetitions
   143                                  : 1;
   144    // If the call back is provided, we will recurse until no loops are found.
   145    FuncOp func = getFunction();
   146    for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
   147      InnermostLoopGatherer ilg;
   148      ilg.walkPostOrder(func);
   149      auto &loops = ilg.loops;
   150      if (loops.empty())
   151        break;
   152      bool unrolled = false;
   153      for (auto forOp : loops)
   154        unrolled |= succeeded(runOnAffineForOp(forOp));
   155      if (!unrolled)
   156        // Break out if nothing was unrolled.
   157        break;
   158    }
   159  }
   160  
   161  /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
   162  /// failure otherwise. The default unroll factor is 4.
   163  LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
   164    // Use the function callback if one was provided.
   165    if (getUnrollFactor) {
   166      return loopUnrollByFactor(forOp, getUnrollFactor(forOp));
   167    }
   168    // Unroll by the factor passed, if any.
   169    if (unrollFactor.hasValue())
   170      return loopUnrollByFactor(forOp, unrollFactor.getValue());
   171    // Unroll by the command line factor if one was specified.
   172    if (clUnrollFactor.getNumOccurrences() > 0)
   173      return loopUnrollByFactor(forOp, clUnrollFactor);
   174    // Unroll completely if full loop unroll was specified.
   175    if (clUnrollFull.getNumOccurrences() > 0 ||
   176        (unrollFull.hasValue() && unrollFull.getValue()))
   177      return loopUnrollFull(forOp);
   178  
   179    // Unroll by four otherwise.
   180    return loopUnrollByFactor(forOp, kDefaultUnrollFactor);
   181  }
   182  
   183  std::unique_ptr<FunctionPassBase> mlir::createLoopUnrollPass(
   184      int unrollFactor, int unrollFull,
   185      const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
   186    return std::make_unique<LoopUnroll>(
   187        unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
   188        unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
   189  }
   190  
   191  static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops");