github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LoopUnroll.cpp (about) 1 //===- LoopUnroll.cpp - Code to perform loop unrolling --------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements loop unrolling. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Transforms/Passes.h" 23 24 #include "mlir/Analysis/LoopAnalysis.h" 25 #include "mlir/Dialect/AffineOps/AffineOps.h" 26 #include "mlir/IR/AffineExpr.h" 27 #include "mlir/IR/AffineMap.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Transforms/LoopUtils.h" 31 #include "llvm/ADT/DenseMap.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 35 using namespace mlir; 36 37 #define DEBUG_TYPE "affine-loop-unroll" 38 39 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 40 41 // Loop unrolling factor. 42 static llvm::cl::opt<unsigned> clUnrollFactor( 43 "unroll-factor", 44 llvm::cl::desc("Use this unroll factor for all loops being unrolled"), 45 llvm::cl::cat(clOptionsCategory)); 46 47 static llvm::cl::opt<bool> clUnrollFull("unroll-full", 48 llvm::cl::desc("Fully unroll loops"), 49 llvm::cl::cat(clOptionsCategory)); 50 51 static llvm::cl::opt<unsigned> clUnrollNumRepetitions( 52 "unroll-num-reps", 53 llvm::cl::desc("Unroll innermost loops repeatedly this many times"), 54 llvm::cl::cat(clOptionsCategory)); 55 56 static llvm::cl::opt<unsigned> clUnrollFullThreshold( 57 "unroll-full-threshold", llvm::cl::Hidden, 58 llvm::cl::desc( 59 "Unroll all loops with trip count less than or equal to this"), 60 llvm::cl::cat(clOptionsCategory)); 61 62 namespace { 63 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a 64 /// full unroll threshold was specified, in which case, fully unrolls all loops 65 /// with trip count less than the specified threshold. The latter is for testing 66 /// purposes, especially for testing outer loop unrolling. 67 struct LoopUnroll : public FunctionPass<LoopUnroll> { 68 const Optional<unsigned> unrollFactor; 69 const Optional<bool> unrollFull; 70 // Callback to obtain unroll factors; if this has a callable target, takes 71 // precedence over command-line argument or passed argument. 72 const std::function<unsigned(AffineForOp)> getUnrollFactor; 73 74 explicit LoopUnroll( 75 Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None, 76 const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) 77 : unrollFactor(unrollFactor), unrollFull(unrollFull), 78 getUnrollFactor(getUnrollFactor) {} 79 80 void runOnFunction() override; 81 82 /// Unroll this for op. Returns failure if nothing was done. 83 LogicalResult runOnAffineForOp(AffineForOp forOp); 84 85 static const unsigned kDefaultUnrollFactor = 4; 86 }; 87 } // end anonymous namespace 88 89 void LoopUnroll::runOnFunction() { 90 // Gathers all innermost loops through a post order pruned walk. 91 struct InnermostLoopGatherer { 92 // Store innermost loops as we walk. 93 std::vector<AffineForOp> loops; 94 95 void walkPostOrder(FuncOp f) { 96 for (auto &b : f) 97 walkPostOrder(b.begin(), b.end()); 98 } 99 100 bool walkPostOrder(Block::iterator Start, Block::iterator End) { 101 bool hasInnerLoops = false; 102 // We need to walk all elements since all innermost loops need to be 103 // gathered as opposed to determining whether this list has any inner 104 // loops or not. 105 while (Start != End) 106 hasInnerLoops |= walkPostOrder(&(*Start++)); 107 return hasInnerLoops; 108 } 109 bool walkPostOrder(Operation *opInst) { 110 bool hasInnerLoops = false; 111 for (auto ®ion : opInst->getRegions()) 112 for (auto &block : region) 113 hasInnerLoops |= walkPostOrder(block.begin(), block.end()); 114 if (isa<AffineForOp>(opInst)) { 115 if (!hasInnerLoops) 116 loops.push_back(cast<AffineForOp>(opInst)); 117 return true; 118 } 119 return hasInnerLoops; 120 } 121 }; 122 123 if (clUnrollFull.getNumOccurrences() > 0 && 124 clUnrollFullThreshold.getNumOccurrences() > 0) { 125 // Store short loops as we walk. 126 std::vector<AffineForOp> loops; 127 128 // Gathers all loops with trip count <= minTripCount. Do a post order walk 129 // so that loops are gathered from innermost to outermost (or else unrolling 130 // an outer one may delete gathered inner ones). 131 getFunction().walk([&](AffineForOp forOp) { 132 Optional<uint64_t> tripCount = getConstantTripCount(forOp); 133 if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold) 134 loops.push_back(forOp); 135 }); 136 for (auto forOp : loops) 137 loopUnrollFull(forOp); 138 return; 139 } 140 141 unsigned numRepetitions = clUnrollNumRepetitions.getNumOccurrences() > 0 142 ? clUnrollNumRepetitions 143 : 1; 144 // If the call back is provided, we will recurse until no loops are found. 145 FuncOp func = getFunction(); 146 for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { 147 InnermostLoopGatherer ilg; 148 ilg.walkPostOrder(func); 149 auto &loops = ilg.loops; 150 if (loops.empty()) 151 break; 152 bool unrolled = false; 153 for (auto forOp : loops) 154 unrolled |= succeeded(runOnAffineForOp(forOp)); 155 if (!unrolled) 156 // Break out if nothing was unrolled. 157 break; 158 } 159 } 160 161 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled, 162 /// failure otherwise. The default unroll factor is 4. 163 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { 164 // Use the function callback if one was provided. 165 if (getUnrollFactor) { 166 return loopUnrollByFactor(forOp, getUnrollFactor(forOp)); 167 } 168 // Unroll by the factor passed, if any. 169 if (unrollFactor.hasValue()) 170 return loopUnrollByFactor(forOp, unrollFactor.getValue()); 171 // Unroll by the command line factor if one was specified. 172 if (clUnrollFactor.getNumOccurrences() > 0) 173 return loopUnrollByFactor(forOp, clUnrollFactor); 174 // Unroll completely if full loop unroll was specified. 175 if (clUnrollFull.getNumOccurrences() > 0 || 176 (unrollFull.hasValue() && unrollFull.getValue())) 177 return loopUnrollFull(forOp); 178 179 // Unroll by four otherwise. 180 return loopUnrollByFactor(forOp, kDefaultUnrollFactor); 181 } 182 183 std::unique_ptr<FunctionPassBase> mlir::createLoopUnrollPass( 184 int unrollFactor, int unrollFull, 185 const std::function<unsigned(AffineForOp)> &getUnrollFactor) { 186 return std::make_unique<LoopUnroll>( 187 unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor), 188 unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor); 189 } 190 191 static PassRegistration<LoopUnroll> pass("affine-loop-unroll", "Unroll loops");