github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (about) 1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the linalg dialect Fusion pass. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/EDSC/Helpers.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/OpImplementation.h" 26 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 27 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 28 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 29 #include "mlir/Dialect/Linalg/Passes.h" 30 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" 31 #include "mlir/Dialect/Linalg/Utils/Utils.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Support/LLVM.h" 34 #include "mlir/Support/STLExtras.h" 35 #include "mlir/Transforms/FoldUtils.h" 36 37 #include "llvm/ADT/SetVector.h" 38 #include "llvm/Support/CommandLine.h" 39 #include "llvm/Support/Debug.h" 40 41 #define DEBUG_TYPE "linalg-fusion" 42 43 using namespace mlir; 44 using namespace mlir::edsc; 45 using namespace mlir::edsc::intrinsics; 46 using namespace mlir::linalg; 47 using namespace mlir::linalg::intrinsics; 48 49 using llvm::dbgs; 50 51 /// Implements a simple high-level fusion pass of linalg library operations. 52 /// 53 /// In each block, linalg ops are processed in reverse textual order. 54 /// Given a linalg op, fusion occurs by: 55 /// 1. tiling the op by a given multi-dimensional tile size; 56 /// 2. inspecting the linalg ops that write into the views read by the op in 57 /// step 1. This uses the SSA value of the views to determine producer- 58 /// consumer dependences: only identical SSA views are considered for 59 /// fusion at this point; 60 /// 3. greedily fuse the producing linalg ops into the consuming loop tiles; 61 /// 4. inspect the fused ops and determine whether they have other remaining 62 /// LinalgOp uses. If not, then erase the original producing linalg op. 63 /// 64 /// More advanced use cases, analyses as well as profitability heuristics are 65 /// left for future work. 66 67 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 68 static llvm::cl::list<unsigned> clTileSizes( 69 "linalg-fusion-tile-sizes", 70 llvm::cl::desc( 71 "Tile sizes by which to tile linalg operations during linalg fusion"), 72 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 73 llvm::cl::cat(clOptionsCategory)); 74 75 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 76 // a subset of the original loop ranges of `op`. 77 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 78 // to the `loopRanges` in order to obtain view ranges. 79 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 80 ArrayRef<SubViewOp::Range> loopRanges, 81 OperationFolder &state) { 82 ScopedContext scope(b, loc); 83 84 auto maps = loopToOperandRangesMaps(op); 85 SmallVector<Value *, 8> clonedViews; 86 clonedViews.reserve(op.getNumInputsAndOutputs()); 87 // Iterate over the inputs and outputs in order. 88 // Extract the subranges from the linearized ranges. 89 SmallVector<Value *, 8> ios(op.getInputsAndOutputs()); 90 for (auto en : llvm::enumerate(ios)) { 91 unsigned idx = en.index(); 92 auto map = maps[idx]; 93 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 94 Value *view = en.value(); 95 SmallVector<SubViewOp::Range, 8> viewRanges(map.getNumResults()); 96 for (auto en2 : llvm::enumerate(map.getResults())) { 97 unsigned d = en2.index(); 98 // loopToOperandRangesMaps are permutations-only. 99 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 100 viewRanges[d] = loopRanges[loopPos]; 101 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 102 << "\t" 103 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 104 } 105 // TODO(ntv) opportunities for folding/CSE here rather than build new IR. 106 clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges)); 107 } 108 auto operands = getAssumedNonViewOperands(op); 109 clonedViews.append(operands.begin(), operands.end()); 110 return op.create(b, loc, clonedViews, op.getAttrs()); 111 } 112 113 struct ViewDimension { 114 Value *view; 115 unsigned dimension; 116 }; 117 118 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 119 auto maps = loopToOperandRangesMaps(op); 120 SmallVector<Value *, 8> clonedViews; 121 clonedViews.reserve(op.getNumInputsAndOutputs()); 122 // Iterate over the inputs and outputs in order. 123 // Extract the subranges from the linearized ranges. 124 SmallVector<Value *, 8> ios(op.getInputsAndOutputs()); 125 for (auto en : llvm::enumerate(ios)) { 126 unsigned idx = en.index(); 127 auto map = maps[idx]; 128 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 129 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 130 Value *view = en.value(); 131 SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr); 132 for (auto en2 : llvm::enumerate(map.getResults())) { 133 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 134 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 135 << "\n"); 136 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view 137 << "\n"); 138 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 139 } 140 } 141 } 142 llvm_unreachable("Expect to be able to extract a view defining loop range"); 143 } 144 145 static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer, 146 LinalgOp consumer, LinalgOp tiledConsumer, 147 OperationFolder &state) { 148 auto maybeConsumerIdx = consumer.getIndexOfInput(producedView); 149 if (!maybeConsumerIdx.hasValue()) 150 return llvm::None; 151 unsigned consumerIdx = maybeConsumerIdx.getValue(); 152 153 auto maybeProducerIdx = producer.getIndexOfOutput(producedView); 154 if (!maybeProducerIdx.hasValue()) 155 return llvm::None; 156 unsigned producerIdx = maybeProducerIdx.getValue(); 157 158 // If the view is the same between consumer and tiledConsumer, this means we 159 // don't have loops and the producer cannot be fused at this level. 160 if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx)) 161 return llvm::None; 162 163 auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>( 164 tiledConsumer.getInput(consumerIdx)->getDefiningOp()); 165 166 // If we don't have a slice, this also means we don't have loops and the 167 // producer cannot be fused at this level. 168 if (!tiledConsumerSubView) 169 return llvm::None; 170 171 // loopToOperandRangesMaps are permutations-only by construction: 172 // we can always identify a data dimension with a (at least one) loop 173 // dimension. 174 AffineMap producerMap = 175 loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx]; 176 LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: " 177 << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n"); 178 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 179 << ", producer map: " << producerMap << "\n"); 180 181 unsigned nPar = producer.getNumParallelLoops(); 182 unsigned nRed = producer.getNumReductionLoops(); 183 unsigned nWin = producer.getNumWindowLoops(); 184 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 185 186 // Iterate over dimensions identified by the producer map for `producerIdx`. 187 // This defines a subset of the loop ranges that we need to complete later. 188 for (auto en : llvm::enumerate(producerMap.getResults())) { 189 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 190 loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index()); 191 } 192 193 OpBuilder b(tiledConsumer.getOperation()); 194 auto loc = tiledConsumer.getLoc(); 195 // Iterate over all dimensions. For the dimensions not identified by the 196 // producer map for `producerIdx`, we need to explicitly compute the view that 197 // defines the loop ranges using the `producer`. 198 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 199 if (loopRanges[i].min) 200 LLVM_DEBUG(llvm::dbgs() 201 << "existing LoopRange: " << loopRanges[i] << "\n"); 202 else { 203 auto viewDim = getViewDefiningLoopRange(producer, i); 204 loopRanges[i] = SubViewOp::Range{ 205 state.create<ConstantIndexOp>(b, loc, 0), 206 linalg::intrinsics::dim(viewDim.view, viewDim.dimension), 207 state.create<ConstantIndexOp>(b, loc, 1)}; 208 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 209 } 210 } 211 212 return cloneWithLoopRanges(b, loc, producer, loopRanges, state); 213 } 214 215 // Encode structural fusion safety preconditions. 216 // Some of these will be lifted in the future with better analysis. 217 static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, 218 LinalgOp consumer) { 219 // If a producer has multiple outputs, the analysis needs to take the tiling 220 // of other outputs into account. 221 if (producer.getNumOutputs() != 1) 222 return false; 223 // Until subview analysis is available, same SSA value is required for fusion. 224 if (producer.getOutput(0) != readView) 225 return false; 226 // No control-flow divergence supported. Only straightline op fusion allowed. 227 // TODO(ntv) allow fusion when a dominance relation exists. 228 if (producer.getOperation()->getBlock() != 229 consumer.getOperation()->getBlock()) 230 return false; 231 return true; 232 } 233 234 static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) { 235 OperationFolder state; 236 DenseSet<Operation *> eraseSet; 237 238 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 239 240 // 1. Record the linalg ops so we can traverse them in reverse order. 241 SmallVector<Operation *, 8> linalgOps; 242 f.walk([&](LinalgOp op) { linalgOps.push_back(op.getOperation()); }); 243 244 // 2. Setup the dependences graph, aliases are populated lazily. 245 Aliases aliases; 246 LinalgDependenceGraph G(aliases, linalgOps); 247 248 // 2. For each original linalg op (in reverse order to allow chained 249 // fusions). 250 for (auto *op : llvm::reverse(linalgOps)) { 251 auto consumer = cast<LinalgOp>(op); 252 LLVM_DEBUG(dbgs() << "\n******\nStart processing:\t" << *op); 253 // 3. If marked for erasure, it has already been fused. Skip fusing op. 254 if (eraseSet.count(op) > 0) { 255 LLVM_DEBUG(dbgs() << "\nAlready fused and marked for erasure, skip."); 256 continue; 257 } 258 259 // 4. Apply loop tiling to enable fusion. If unsuccessful, skip fusing op. 260 auto tiledOp = tileLinalgOp(op, tileSizes, state); 261 if (!tiledOp) { 262 LLVM_DEBUG(dbgs() << "\nTile sizes did not produce loops, skip."); 263 continue; 264 } 265 266 // 5. For now, we only fuse RAW dependences. 267 SmallVector<Operation *, 8> fusedProducers; 268 SmallVector<Value *, 8> fusedViews; 269 for (auto dependence : G.getDependencesInto( 270 consumer, LinalgDependenceGraph::DependenceType::RAW)) { 271 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 272 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 273 << *producer.getOperation() << "\n"); 274 275 // a. For now we require fusion on identical SSA values, this allows us to 276 // not worry about partial writes etc. 277 // TODO(ntv) support more elaborate fusion with non identical SSA values. 278 auto *view = dependence.indexingView; 279 if (view != dependence.dependentOpView.view) { 280 LLVM_DEBUG(dbgs() << "\nviews are different SSA values, skip."); 281 continue; 282 } 283 // b. Make some simple structural checks that alleviate the need for more 284 // complex analyses. 285 if (!isStructurallyFusableProducer(producer, view, op)) { 286 LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation()); 287 continue; 288 } 289 // c. Check for fusion-preventing write that would violate dependences. 290 // `view` is a producer write that cannot bypass any other write or read. 291 bool preventFusion = false; 292 for (auto *op : G.findCoveringDependences(producer, consumer)) 293 if (eraseSet.count(op) == 0) { 294 preventFusion = true; 295 LLVM_DEBUG(dbgs() << "\n***Found fusion preventing dep via: " << *op); 296 break; 297 } 298 if (preventFusion) 299 continue; 300 301 // 6. Try to fuse `producer` just before `tiledOp`. 302 LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n")); 303 304 auto tOp = tiledOp->op; 305 OpBuilder builder(tOp.getOperation()); 306 ScopedContext scope(builder, tOp.getLoc()); 307 LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n"); 308 auto maybeFusedProducer = fuse(view, producer, op, tOp, state); 309 if (!maybeFusedProducer) { 310 LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip."); 311 continue; 312 } 313 314 fusedProducers.push_back(producer.getOperation()); 315 fusedViews.push_back(view); 316 } 317 318 // 7. If no fusion occurred, or a drop the outer tiled loop which undoes 319 // everything we did. 320 if (fusedProducers.empty()) { 321 tiledOp->loops[0].erase(); 322 continue; 323 } 324 325 eraseSet.insert(op); 326 eraseSet.insert(fusedProducers.begin(), fusedProducers.end()); 327 } 328 329 for (auto *op : eraseSet) 330 op->erase(); 331 332 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 333 } 334 335 namespace { 336 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { 337 LinalgFusionPass() = default; 338 LinalgFusionPass(ArrayRef<int64_t> sizes); 339 340 void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); } 341 342 SmallVector<int64_t, 8> tileSizes; 343 }; 344 } // namespace 345 346 LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes) 347 : LinalgFusionPass() { 348 if (!sizes.empty()) 349 this->tileSizes.assign(sizes.begin(), sizes.end()); 350 } 351 352 std::unique_ptr<FunctionPassBase> 353 mlir::linalg::createLinalgFusionPass(ArrayRef<int64_t> tileSizes) { 354 return std::make_unique<LinalgFusionPass>(tileSizes); 355 } 356 357 static PassRegistration<LinalgFusionPass> 358 pass("linalg-fusion", "Fuse operations in the linalg dialect", [] { 359 auto pass = std::make_unique<LinalgFusionPass>(); 360 pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end()); 361 return pass; 362 });