github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LoopCoalescing.cpp (about)

     1  //===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/LoopOps/LoopOps.h"
    19  #include "mlir/Dialect/StandardOps/Ops.h"
    20  #include "mlir/Pass/Pass.h"
    21  #include "mlir/Transforms/LoopUtils.h"
    22  #include "mlir/Transforms/Passes.h"
    23  #include "mlir/Transforms/RegionUtils.h"
    24  #include "llvm/Support/Debug.h"
    25  
    26  #define PASS_NAME "loop-coalescing"
    27  #define DEBUG_TYPE PASS_NAME
    28  
    29  using namespace mlir;
    30  
    31  namespace {
    32  class LoopCoalescingPass : public FunctionPass<LoopCoalescingPass> {
    33  public:
    34    void runOnFunction() override {
    35      FuncOp func = getFunction();
    36  
    37      func.walk([](loop::ForOp op) {
    38        // Ignore nested loops.
    39        if (op.getParentOfType<loop::ForOp>())
    40          return;
    41  
    42        SmallVector<loop::ForOp, 4> loops;
    43        getPerfectlyNestedLoops(loops, op);
    44        LLVM_DEBUG(llvm::dbgs()
    45                   << "found a perfect nest of depth " << loops.size() << '\n');
    46  
    47        // Look for a band of loops that can be coalesced, i.e. perfectly nested
    48        // loops with bounds defined above some loop.
    49        // 1. For each loop, find above which parent loop its operands are
    50        // defined.
    51        SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
    52        for (unsigned i = 0, e = loops.size(); i < e; ++i) {
    53          operandsDefinedAbove[i] = i;
    54          for (unsigned j = 0; j < i; ++j) {
    55            if (areValuesDefinedAbove(loops[i].getOperands(),
    56                                      loops[j].region())) {
    57              operandsDefinedAbove[i] = j;
    58              break;
    59            }
    60          }
    61          LLVM_DEBUG(llvm::dbgs()
    62                     << "  bounds of loop " << i << " are known above depth "
    63                     << operandsDefinedAbove[i] << '\n');
    64        }
    65  
    66        // 2. Identify bands of loops such that the operands of all of them are
    67        // defined above the first loop in the band.  Traverse the nest bottom-up
    68        // so that modifications don't invalidate the inner loops.
    69        for (unsigned end = loops.size(); end > 0; --end) {
    70          unsigned start = 0;
    71          for (; start < end - 1; ++start) {
    72            auto maxPos =
    73                *std::max_element(std::next(operandsDefinedAbove.begin(), start),
    74                                  std::next(operandsDefinedAbove.begin(), end));
    75            if (maxPos > start)
    76              continue;
    77  
    78            assert(maxPos == start &&
    79                   "expected loop bounds to be known at the start of the band");
    80            LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
    81                                    << " to " << end << '\n');
    82  
    83            auto band =
    84                llvm::makeMutableArrayRef(loops.data() + start, end - start);
    85            coalesceLoops(band);
    86            break;
    87          }
    88          // If a band was found and transformed, keep looking at the loops above
    89          // the outermost transformed loop.
    90          if (start != end - 1)
    91            end = start + 1;
    92        }
    93      });
    94    }
    95  };
    96  
    97  } // namespace
    98  
    99  std::unique_ptr<FunctionPassBase> mlir::createLoopCoalescingPass() {
   100    return std::make_unique<LoopCoalescingPass>();
   101  }
   102  
   103  static PassRegistration<LoopCoalescingPass>
   104      reg(PASS_NAME,
   105          "coalesce nested loops with independent bounds into a single loop");