github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LoopUnrollAndJam.cpp (about)

     1  //===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements loop unroll and jam. Unroll and jam is a transformation
    19  // that improves locality, in particular, register reuse, while also improving
    20  // operation level parallelism. The example below shows what it does in nearly
    21  // the general case. Loop unroll and jam currently works if the bounds of the
    22  // loops inner to the loop being unroll-jammed do not depend on the latter.
    23  //
    24  // Before      After unroll and jam of i by factor 2:
    25  //
    26  //             for i, step = 2
    27  // for i         S1(i);
    28  //   S1;         S2(i);
    29  //   S2;         S1(i+1);
    30  //   for j       S2(i+1);
    31  //     S3;       for j
    32  //     S4;         S3(i, j);
    33  //   S5;           S4(i, j);
    34  //   S6;           S3(i+1, j)
    35  //                 S4(i+1, j)
    36  //               S5(i);
    37  //               S6(i);
    38  //               S5(i+1);
    39  //               S6(i+1);
    40  //
    41  // Note: 'if/else' blocks are not jammed. So, if there are loops inside if
    42  // op's, bodies of those loops will not be jammed.
    43  //===----------------------------------------------------------------------===//
    44  #include "mlir/Transforms/Passes.h"
    45  
    46  #include "mlir/Analysis/LoopAnalysis.h"
    47  #include "mlir/Dialect/AffineOps/AffineOps.h"
    48  #include "mlir/IR/AffineExpr.h"
    49  #include "mlir/IR/AffineMap.h"
    50  #include "mlir/IR/BlockAndValueMapping.h"
    51  #include "mlir/IR/Builders.h"
    52  #include "mlir/Pass/Pass.h"
    53  #include "mlir/Transforms/LoopUtils.h"
    54  #include "llvm/ADT/DenseMap.h"
    55  #include "llvm/Support/CommandLine.h"
    56  
    57  using namespace mlir;
    58  
    59  #define DEBUG_TYPE "affine-loop-unroll-jam"
    60  
    61  static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
    62  
    63  // Loop unroll and jam factor.
    64  static llvm::cl::opt<unsigned>
    65      clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden,
    66                        llvm::cl::desc("Use this unroll jam factor for all loops"
    67                                       " (default 4)"),
    68                        llvm::cl::cat(clOptionsCategory));
    69  
    70  namespace {
    71  /// Loop unroll jam pass. Currently, this just unroll jams the first
    72  /// outer loop in a Function.
    73  struct LoopUnrollAndJam : public FunctionPass<LoopUnrollAndJam> {
    74    Optional<unsigned> unrollJamFactor;
    75    static const unsigned kDefaultUnrollJamFactor = 4;
    76  
    77    explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor = None)
    78        : unrollJamFactor(unrollJamFactor) {}
    79  
    80    void runOnFunction() override;
    81    LogicalResult runOnAffineForOp(AffineForOp forOp);
    82  };
    83  } // end anonymous namespace
    84  
    85  std::unique_ptr<FunctionPassBase>
    86  mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
    87    return std::make_unique<LoopUnrollAndJam>(
    88        unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor));
    89  }
    90  
    91  void LoopUnrollAndJam::runOnFunction() {
    92    // Currently, just the outermost loop from the first loop nest is
    93    // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
    94    // any for operation.
    95    auto &entryBlock = getFunction().front();
    96    if (auto forOp = dyn_cast<AffineForOp>(entryBlock.front()))
    97      runOnAffineForOp(forOp);
    98  }
    99  
   100  /// Unroll and jam a 'affine.for' op. Default unroll jam factor is
   101  /// kDefaultUnrollJamFactor. Return failure if nothing was done.
   102  LogicalResult LoopUnrollAndJam::runOnAffineForOp(AffineForOp forOp) {
   103    // Unroll and jam by the factor that was passed if any.
   104    if (unrollJamFactor.hasValue())
   105      return loopUnrollJamByFactor(forOp, unrollJamFactor.getValue());
   106    // Otherwise, unroll jam by the command-line factor if one was specified.
   107    if (clUnrollJamFactor.getNumOccurrences() > 0)
   108      return loopUnrollJamByFactor(forOp, clUnrollJamFactor);
   109  
   110    // Unroll and jam by four otherwise.
   111    return loopUnrollJamByFactor(forOp, kDefaultUnrollJamFactor);
   112  }
   113  
   114  LogicalResult mlir::loopUnrollJamUpToFactor(AffineForOp forOp,
   115                                              uint64_t unrollJamFactor) {
   116    Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   117  
   118    if (mayBeConstantTripCount.hasValue() &&
   119        mayBeConstantTripCount.getValue() < unrollJamFactor)
   120      return loopUnrollJamByFactor(forOp, mayBeConstantTripCount.getValue());
   121    return loopUnrollJamByFactor(forOp, unrollJamFactor);
   122  }
   123  
   124  /// Unrolls and jams this loop by the specified factor.
   125  LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
   126                                            uint64_t unrollJamFactor) {
   127    // Gathers all maximal sub-blocks of operations that do not themselves
   128    // include a for op (a operation could have a descendant for op though
   129    // in its tree).  Ignore the block terminators.
   130    struct JamBlockGatherer {
   131      // Store iterators to the first and last op of each sub-block found.
   132      std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
   133  
   134      // This is a linear time walk.
   135      void walk(Operation *op) {
   136        for (auto &region : op->getRegions())
   137          for (auto &block : region)
   138            walk(block);
   139      }
   140      void walk(Block &block) {
   141        for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
   142          auto subBlockStart = it;
   143          while (it != e && !isa<AffineForOp>(&*it))
   144            ++it;
   145          if (it != subBlockStart)
   146            subBlocks.push_back({subBlockStart, std::prev(it)});
   147          // Process all for insts that appear next.
   148          while (it != e && isa<AffineForOp>(&*it))
   149            walk(&*it++);
   150        }
   151      }
   152    };
   153  
   154    assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
   155  
   156    if (unrollJamFactor == 1)
   157      return promoteIfSingleIteration(forOp);
   158  
   159    if (forOp.getBody()->empty() ||
   160        forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
   161      return failure();
   162  
   163    // Loops where both lower and upper bounds are multi-result maps won't be
   164    // unrolled (since the trip can't be expressed as an affine function in
   165    // general).
   166    // TODO(mlir-team): this may not be common, but we could support the case
   167    // where the lower bound is a multi-result map and the ub is a single result
   168    // one.
   169    if (forOp.getLowerBoundMap().getNumResults() != 1)
   170      return failure();
   171  
   172    Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   173    // If the trip count is lower than the unroll jam factor, no unroll jam.
   174    if (mayBeConstantTripCount.hasValue() &&
   175        mayBeConstantTripCount.getValue() < unrollJamFactor)
   176      return failure();
   177  
   178    auto *forInst = forOp.getOperation();
   179  
   180    // Gather all sub-blocks to jam upon the loop being unrolled.
   181    JamBlockGatherer jbg;
   182    jbg.walk(forInst);
   183    auto &subBlocks = jbg.subBlocks;
   184  
   185    // Generate the cleanup loop if trip count isn't a multiple of
   186    // unrollJamFactor.
   187    if (getLargestDivisorOfTripCount(forOp) % unrollJamFactor != 0) {
   188      // Insert the cleanup loop right after 'forOp'.
   189      OpBuilder builder(forInst->getBlock(), std::next(Block::iterator(forInst)));
   190      auto cleanupAffineForOp = cast<AffineForOp>(builder.clone(*forInst));
   191      // Adjust the lower bound of the cleanup loop; its upper bound is the same
   192      // as the original loop's upper bound.
   193      AffineMap cleanupMap;
   194      SmallVector<Value *, 4> cleanupOperands;
   195      getCleanupLoopLowerBound(forOp, unrollJamFactor, &cleanupMap,
   196                               &cleanupOperands, builder);
   197      cleanupAffineForOp.setLowerBound(cleanupOperands, cleanupMap);
   198  
   199      // Promote the cleanup loop if it has turned into a single iteration loop.
   200      promoteIfSingleIteration(cleanupAffineForOp);
   201  
   202      // Adjust the upper bound of the original loop - it will be the same as the
   203      // cleanup loop's lower bound. Its lower bound remains unchanged.
   204      forOp.setUpperBound(cleanupOperands, cleanupMap);
   205    }
   206  
   207    // Scale the step of loop being unroll-jammed by the unroll-jam factor.
   208    int64_t step = forOp.getStep();
   209    forOp.setStep(step * unrollJamFactor);
   210  
   211    auto *forOpIV = forOp.getInductionVar();
   212    // Unroll and jam (appends unrollJamFactor-1 additional copies).
   213    for (unsigned i = 1; i < unrollJamFactor; i++) {
   214      // Operand map persists across all sub-blocks.
   215      BlockAndValueMapping operandMapping;
   216      for (auto &subBlock : subBlocks) {
   217        // Builder to insert unroll-jammed bodies. Insert right at the end of
   218        // sub-block.
   219        OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
   220  
   221        // If the induction variable is used, create a remapping to the value for
   222        // this unrolled instance.
   223        if (!forOpIV->use_empty()) {
   224          // iv' = iv + i, i = 1 to unrollJamFactor-1.
   225          auto d0 = builder.getAffineDimExpr(0);
   226          auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step});
   227          auto ivUnroll =
   228              builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forOpIV);
   229          operandMapping.map(forOpIV, ivUnroll);
   230        }
   231        // Clone the sub-block being unroll-jammed.
   232        for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
   233          builder.clone(*it, operandMapping);
   234        }
   235      }
   236    }
   237  
   238    // Promote the loop body up if this has turned into a single iteration loop.
   239    promoteIfSingleIteration(forOp);
   240    return success();
   241  }
   242  
   243  static PassRegistration<LoopUnrollAndJam> pass("affine-loop-unroll-jam",
   244                                                 "Unroll and jam loops");