github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp (about) 1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a pass to pipeline data transfers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Transforms/Passes.h" 23 24 #include "mlir/Analysis/AffineAnalysis.h" 25 #include "mlir/Analysis/LoopAnalysis.h" 26 #include "mlir/Analysis/Utils.h" 27 #include "mlir/Dialect/AffineOps/AffineOps.h" 28 #include "mlir/Dialect/StandardOps/Ops.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Transforms/LoopUtils.h" 32 #include "mlir/Transforms/Utils.h" 33 #include "llvm/ADT/DenseMap.h" 34 #include "llvm/Support/Debug.h" 35 #define DEBUG_TYPE "affine-pipeline-data-transfer" 36 37 using namespace mlir; 38 39 namespace { 40 41 struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> { 42 void runOnFunction() override; 43 void runOnAffineForOp(AffineForOp forOp); 44 45 std::vector<AffineForOp> forOps; 46 }; 47 48 } // end anonymous namespace 49 50 /// Creates a pass to pipeline explicit movement of data across levels of the 51 /// memory hierarchy. 52 std::unique_ptr<FunctionPassBase> mlir::createPipelineDataTransferPass() { 53 return std::make_unique<PipelineDataTransfer>(); 54 } 55 56 // Returns the position of the tag memref operand given a DMA operation. 57 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are 58 // added. TODO(b/117228571) 59 static unsigned getTagMemRefPos(Operation &dmaInst) { 60 assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst)); 61 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) { 62 return dmaStartOp.getTagMemRefOperandIndex(); 63 } 64 // First operand for a dma finish operation. 65 return 0; 66 } 67 68 /// Doubles the buffer of the supplied memref on the specified 'affine.for' 69 /// operation by adding a leading dimension of size two to the memref. 70 /// Replaces all uses of the old memref by the new one while indexing the newly 71 /// added dimension by the loop IV of the specified 'affine.for' operation 72 /// modulo 2. Returns false if such a replacement cannot be performed. 73 static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) { 74 auto *forBody = forOp.getBody(); 75 OpBuilder bInner(forBody, forBody->begin()); 76 bInner.setInsertionPoint(forBody, forBody->begin()); 77 78 // Doubles the shape with a leading dimension extent of 2. 79 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType { 80 // Add the leading dimension in the shape for the double buffer. 81 ArrayRef<int64_t> oldShape = oldMemRefType.getShape(); 82 SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank()); 83 newShape[0] = 2; 84 std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1); 85 auto newMemRefType = 86 bInner.getMemRefType(newShape, oldMemRefType.getElementType(), {}, 87 oldMemRefType.getMemorySpace()); 88 return newMemRefType; 89 }; 90 91 auto oldMemRefType = oldMemRef->getType().cast<MemRefType>(); 92 auto newMemRefType = doubleShape(oldMemRefType); 93 94 // The double buffer is allocated right before 'forInst'. 95 auto *forInst = forOp.getOperation(); 96 OpBuilder bOuter(forInst); 97 // Put together alloc operands for any dynamic dimensions of the memref. 98 SmallVector<Value *, 4> allocOperands; 99 unsigned dynamicDimCount = 0; 100 for (auto dimSize : oldMemRefType.getShape()) { 101 if (dimSize == -1) 102 allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef, 103 dynamicDimCount++)); 104 } 105 106 // Create and place the alloc right before the 'affine.for' operation. 107 Value *newMemRef = 108 bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands); 109 110 // Create 'iv mod 2' value to index the leading dimension. 111 auto d0 = bInner.getAffineDimExpr(0); 112 int64_t step = forOp.getStep(); 113 auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, 114 {d0.floorDiv(step) % 2}); 115 auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, 116 forOp.getInductionVar()); 117 118 // replaceAllMemRefUsesWith will succeed unless the forOp body has 119 // non-dereferencing uses of the memref (dealloc's are fine though). 120 if (failed(replaceAllMemRefUsesWith( 121 oldMemRef, newMemRef, 122 /*extraIndices=*/{ivModTwoOp}, 123 /*indexRemap=*/AffineMap(), 124 /*extraOperands=*/{}, 125 /*domInstFilter=*/&*forOp.getBody()->begin()))) { 126 LLVM_DEBUG( 127 forOp.emitError("memref replacement for double buffering failed")); 128 ivModTwoOp.erase(); 129 return false; 130 } 131 // Insert the dealloc op right after the for loop. 132 bOuter.setInsertionPoint(forInst->getBlock(), 133 std::next(Block::iterator(forInst))); 134 bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef); 135 136 return true; 137 } 138 139 /// Returns success if the IR is in a valid state. 140 void PipelineDataTransfer::runOnFunction() { 141 // Do a post order walk so that inner loop DMAs are processed first. This is 142 // necessary since 'affine.for' operations nested within would otherwise 143 // become invalid (erased) when the outer loop is pipelined (the pipelined one 144 // gets deleted and replaced by a prologue, a new steady-state loop and an 145 // epilogue). 146 forOps.clear(); 147 getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); }); 148 for (auto forOp : forOps) 149 runOnAffineForOp(forOp); 150 } 151 152 // Check if tags of the dma start op and dma wait op match. 153 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) { 154 if (startOp.getTagMemRef() != waitOp.getTagMemRef()) 155 return false; 156 auto startIndices = startOp.getTagIndices(); 157 auto waitIndices = waitOp.getTagIndices(); 158 // Both of these have the same number of indices since they correspond to the 159 // same tag memref. 160 for (auto it = startIndices.begin(), wIt = waitIndices.begin(), 161 e = startIndices.end(); 162 it != e; ++it, ++wIt) { 163 // Keep it simple for now, just checking if indices match. 164 // TODO(mlir-team): this would in general need to check if there is no 165 // intervening write writing to the same tag location, i.e., memory last 166 // write/data flow analysis. This is however sufficient/powerful enough for 167 // now since the DMA generation pass or the input for it will always have 168 // start/wait with matching tags (same SSA operand indices). 169 if (*it != *wIt) 170 return false; 171 } 172 return true; 173 } 174 175 // Identify matching DMA start/finish operations to overlap computation with. 176 static void findMatchingStartFinishInsts( 177 AffineForOp forOp, 178 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) { 179 180 // Collect outgoing DMA operations - needed to check for dependences below. 181 SmallVector<AffineDmaStartOp, 4> outgoingDmaOps; 182 for (auto &op : *forOp.getBody()) { 183 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 184 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster()) 185 outgoingDmaOps.push_back(dmaStartOp); 186 } 187 188 SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts; 189 for (auto &op : *forOp.getBody()) { 190 // Collect DMA finish operations. 191 if (isa<AffineDmaWaitOp>(op)) { 192 dmaFinishInsts.push_back(&op); 193 continue; 194 } 195 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op); 196 if (!dmaStartOp) 197 continue; 198 199 // Only DMAs incoming into higher memory spaces are pipelined for now. 200 // TODO(bondhugula): handle outgoing DMA pipelining. 201 if (!dmaStartOp.isDestMemorySpaceFaster()) 202 continue; 203 204 // Check for dependence with outgoing DMAs. Doing this conservatively. 205 // TODO(andydavis,bondhugula): use the dependence analysis to check for 206 // dependences between an incoming and outgoing DMA in the same iteration. 207 auto it = outgoingDmaOps.begin(); 208 for (; it != outgoingDmaOps.end(); ++it) { 209 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef()) 210 break; 211 } 212 if (it != outgoingDmaOps.end()) 213 continue; 214 215 // We only double buffer if the buffer is not live out of loop. 216 auto *memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos()); 217 bool escapingUses = false; 218 for (auto *user : memref->getUsers()) { 219 // We can double buffer regardless of dealloc's outside the loop. 220 if (isa<DeallocOp>(user)) 221 continue; 222 if (!forOp.getBody()->findAncestorInstInBlock(*user)) { 223 LLVM_DEBUG(llvm::dbgs() 224 << "can't pipeline: buffer is live out of loop\n";); 225 escapingUses = true; 226 break; 227 } 228 } 229 if (!escapingUses) 230 dmaStartInsts.push_back(&op); 231 } 232 233 // For each start operation, we look for a matching finish operation. 234 for (auto *dmaStartInst : dmaStartInsts) { 235 for (auto *dmaFinishInst : dmaFinishInsts) { 236 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst), 237 cast<AffineDmaWaitOp>(dmaFinishInst))) { 238 startWaitPairs.push_back({dmaStartInst, dmaFinishInst}); 239 break; 240 } 241 } 242 } 243 } 244 245 /// Overlap DMA transfers with computation in this loop. If successful, 246 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are 247 /// inserted right before where it was. 248 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) { 249 auto mayBeConstTripCount = getConstantTripCount(forOp); 250 if (!mayBeConstTripCount.hasValue()) { 251 LLVM_DEBUG( 252 forOp.emitRemark("won't pipeline due to unknown trip count loop")); 253 return; 254 } 255 256 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs; 257 findMatchingStartFinishInsts(forOp, startWaitPairs); 258 259 if (startWaitPairs.empty()) { 260 LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n")); 261 return; 262 } 263 264 // Double the buffers for the higher memory space memref's. 265 // Identify memref's to replace by scanning through all DMA start 266 // operations. A DMA start operation has two memref's - the one from the 267 // higher level of memory hierarchy is the one to double buffer. 268 // TODO(bondhugula): check whether double-buffering is even necessary. 269 // TODO(bondhugula): make this work with different layouts: assuming here that 270 // the dimension we are adding here for the double buffering is the outermost 271 // dimension. 272 for (auto &pair : startWaitPairs) { 273 auto *dmaStartInst = pair.first; 274 Value *oldMemRef = dmaStartInst->getOperand( 275 cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos()); 276 if (!doubleBuffer(oldMemRef, forOp)) { 277 // Normally, double buffering should not fail because we already checked 278 // that there are no uses outside. 279 LLVM_DEBUG(llvm::dbgs() 280 << "double buffering failed for" << dmaStartInst << "\n";); 281 // IR still valid and semantically correct. 282 return; 283 } 284 // If the old memref has no more uses, remove its 'dead' alloc if it was 285 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim' 286 // operation could have been used on it if it was dynamically shaped in 287 // order to create the double buffer above.) 288 // '-canonicalize' does this in a more general way, but we'll anyway do the 289 // simple/common case so that the output / test cases looks clear. 290 if (auto *allocInst = oldMemRef->getDefiningOp()) { 291 if (oldMemRef->use_empty()) { 292 allocInst->erase(); 293 } else if (oldMemRef->hasOneUse()) { 294 if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef->user_begin())) { 295 dealloc.erase(); 296 oldMemRef->getDefiningOp()->erase(); 297 } 298 } 299 } 300 } 301 302 // Double the buffers for tag memrefs. 303 for (auto &pair : startWaitPairs) { 304 auto *dmaFinishInst = pair.second; 305 Value *oldTagMemRef = 306 dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); 307 if (!doubleBuffer(oldTagMemRef, forOp)) { 308 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); 309 return; 310 } 311 // If the old tag has no more uses, remove its 'dead' alloc if it was 312 // alloc'ed. 313 if (oldTagMemRef->use_empty()) 314 if (auto *allocInst = oldTagMemRef->getDefiningOp()) 315 allocInst->erase(); 316 } 317 318 // Double buffering would have invalidated all the old DMA start/wait insts. 319 startWaitPairs.clear(); 320 findMatchingStartFinishInsts(forOp, startWaitPairs); 321 322 // Store shift for operation for later lookup for AffineApplyOp's. 323 DenseMap<Operation *, unsigned> instShiftMap; 324 for (auto &pair : startWaitPairs) { 325 auto *dmaStartInst = pair.first; 326 assert(isa<AffineDmaStartOp>(dmaStartInst)); 327 instShiftMap[dmaStartInst] = 0; 328 // Set shifts for DMA start op's affine operand computation slices to 0. 329 SmallVector<AffineApplyOp, 4> sliceOps; 330 mlir::createAffineComputationSlice(dmaStartInst, &sliceOps); 331 if (!sliceOps.empty()) { 332 for (auto sliceOp : sliceOps) { 333 instShiftMap[sliceOp.getOperation()] = 0; 334 } 335 } else { 336 // If a slice wasn't created, the reachable affine.apply op's from its 337 // operands are the ones that go with it. 338 SmallVector<Operation *, 4> affineApplyInsts; 339 SmallVector<Value *, 4> operands(dmaStartInst->getOperands()); 340 getReachableAffineApplyOps(operands, affineApplyInsts); 341 for (auto *op : affineApplyInsts) { 342 instShiftMap[op] = 0; 343 } 344 } 345 } 346 // Everything else (including compute ops and dma finish) are shifted by one. 347 for (auto &op : *forOp.getBody()) { 348 if (instShiftMap.find(&op) == instShiftMap.end()) { 349 instShiftMap[&op] = 1; 350 } 351 } 352 353 // Get shifts stored in map. 354 std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size()); 355 unsigned s = 0; 356 for (auto &op : *forOp.getBody()) { 357 assert(instShiftMap.find(&op) != instShiftMap.end()); 358 shifts[s++] = instShiftMap[&op]; 359 360 // Tagging operations with shifts for debugging purposes. 361 LLVM_DEBUG({ 362 OpBuilder b(&op); 363 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1])); 364 }); 365 } 366 367 if (!isInstwiseShiftValid(forOp, shifts)) { 368 // Violates dependences. 369 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); 370 return; 371 } 372 373 if (failed(instBodySkew(forOp, shifts))) { 374 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";); 375 return; 376 } 377 } 378 379 static PassRegistration<PipelineDataTransfer> pass( 380 "affine-pipeline-data-transfer", 381 "Pipeline non-blocking data transfers between explicitly managed levels of " 382 "the memory hierarchy");