github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/PipelineDataTransfer.cpp (about)

     1  //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements a pass to pipeline data transfers.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Transforms/Passes.h"
    23  
    24  #include "mlir/Analysis/AffineAnalysis.h"
    25  #include "mlir/Analysis/LoopAnalysis.h"
    26  #include "mlir/Analysis/Utils.h"
    27  #include "mlir/Dialect/AffineOps/AffineOps.h"
    28  #include "mlir/Dialect/StandardOps/Ops.h"
    29  #include "mlir/IR/Builders.h"
    30  #include "mlir/Pass/Pass.h"
    31  #include "mlir/Transforms/LoopUtils.h"
    32  #include "mlir/Transforms/Utils.h"
    33  #include "llvm/ADT/DenseMap.h"
    34  #include "llvm/Support/Debug.h"
    35  #define DEBUG_TYPE "affine-pipeline-data-transfer"
    36  
    37  using namespace mlir;
    38  
    39  namespace {
    40  
    41  struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> {
    42    void runOnFunction() override;
    43    void runOnAffineForOp(AffineForOp forOp);
    44  
    45    std::vector<AffineForOp> forOps;
    46  };
    47  
    48  } // end anonymous namespace
    49  
    50  /// Creates a pass to pipeline explicit movement of data across levels of the
    51  /// memory hierarchy.
    52  std::unique_ptr<FunctionPassBase> mlir::createPipelineDataTransferPass() {
    53    return std::make_unique<PipelineDataTransfer>();
    54  }
    55  
    56  // Returns the position of the tag memref operand given a DMA operation.
    57  // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
    58  // added.  TODO(b/117228571)
    59  static unsigned getTagMemRefPos(Operation &dmaInst) {
    60    assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst));
    61    if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) {
    62      return dmaStartOp.getTagMemRefOperandIndex();
    63    }
    64    // First operand for a dma finish operation.
    65    return 0;
    66  }
    67  
    68  /// Doubles the buffer of the supplied memref on the specified 'affine.for'
    69  /// operation by adding a leading dimension of size two to the memref.
    70  /// Replaces all uses of the old memref by the new one while indexing the newly
    71  /// added dimension by the loop IV of the specified 'affine.for' operation
    72  /// modulo 2. Returns false if such a replacement cannot be performed.
    73  static bool doubleBuffer(Value *oldMemRef, AffineForOp forOp) {
    74    auto *forBody = forOp.getBody();
    75    OpBuilder bInner(forBody, forBody->begin());
    76    bInner.setInsertionPoint(forBody, forBody->begin());
    77  
    78    // Doubles the shape with a leading dimension extent of 2.
    79    auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
    80      // Add the leading dimension in the shape for the double buffer.
    81      ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
    82      SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
    83      newShape[0] = 2;
    84      std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
    85      auto newMemRefType =
    86          bInner.getMemRefType(newShape, oldMemRefType.getElementType(), {},
    87                               oldMemRefType.getMemorySpace());
    88      return newMemRefType;
    89    };
    90  
    91    auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
    92    auto newMemRefType = doubleShape(oldMemRefType);
    93  
    94    // The double buffer is allocated right before 'forInst'.
    95    auto *forInst = forOp.getOperation();
    96    OpBuilder bOuter(forInst);
    97    // Put together alloc operands for any dynamic dimensions of the memref.
    98    SmallVector<Value *, 4> allocOperands;
    99    unsigned dynamicDimCount = 0;
   100    for (auto dimSize : oldMemRefType.getShape()) {
   101      if (dimSize == -1)
   102        allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
   103                                                     dynamicDimCount++));
   104    }
   105  
   106    // Create and place the alloc right before the 'affine.for' operation.
   107    Value *newMemRef =
   108        bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
   109  
   110    // Create 'iv mod 2' value to index the leading dimension.
   111    auto d0 = bInner.getAffineDimExpr(0);
   112    int64_t step = forOp.getStep();
   113    auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0,
   114                                         {d0.floorDiv(step) % 2});
   115    auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
   116                                                   forOp.getInductionVar());
   117  
   118    // replaceAllMemRefUsesWith will succeed unless the forOp body has
   119    // non-dereferencing uses of the memref (dealloc's are fine though).
   120    if (failed(replaceAllMemRefUsesWith(
   121            oldMemRef, newMemRef,
   122            /*extraIndices=*/{ivModTwoOp},
   123            /*indexRemap=*/AffineMap(),
   124            /*extraOperands=*/{},
   125            /*domInstFilter=*/&*forOp.getBody()->begin()))) {
   126      LLVM_DEBUG(
   127          forOp.emitError("memref replacement for double buffering failed"));
   128      ivModTwoOp.erase();
   129      return false;
   130    }
   131    // Insert the dealloc op right after the for loop.
   132    bOuter.setInsertionPoint(forInst->getBlock(),
   133                             std::next(Block::iterator(forInst)));
   134    bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef);
   135  
   136    return true;
   137  }
   138  
   139  /// Returns success if the IR is in a valid state.
   140  void PipelineDataTransfer::runOnFunction() {
   141    // Do a post order walk so that inner loop DMAs are processed first. This is
   142    // necessary since 'affine.for' operations nested within would otherwise
   143    // become invalid (erased) when the outer loop is pipelined (the pipelined one
   144    // gets deleted and replaced by a prologue, a new steady-state loop and an
   145    // epilogue).
   146    forOps.clear();
   147    getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
   148    for (auto forOp : forOps)
   149      runOnAffineForOp(forOp);
   150  }
   151  
   152  // Check if tags of the dma start op and dma wait op match.
   153  static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
   154    if (startOp.getTagMemRef() != waitOp.getTagMemRef())
   155      return false;
   156    auto startIndices = startOp.getTagIndices();
   157    auto waitIndices = waitOp.getTagIndices();
   158    // Both of these have the same number of indices since they correspond to the
   159    // same tag memref.
   160    for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
   161              e = startIndices.end();
   162         it != e; ++it, ++wIt) {
   163      // Keep it simple for now, just checking if indices match.
   164      // TODO(mlir-team): this would in general need to check if there is no
   165      // intervening write writing to the same tag location, i.e., memory last
   166      // write/data flow analysis. This is however sufficient/powerful enough for
   167      // now since the DMA generation pass or the input for it will always have
   168      // start/wait with matching tags (same SSA operand indices).
   169      if (*it != *wIt)
   170        return false;
   171    }
   172    return true;
   173  }
   174  
   175  // Identify matching DMA start/finish operations to overlap computation with.
   176  static void findMatchingStartFinishInsts(
   177      AffineForOp forOp,
   178      SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
   179  
   180    // Collect outgoing DMA operations - needed to check for dependences below.
   181    SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
   182    for (auto &op : *forOp.getBody()) {
   183      auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
   184      if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
   185        outgoingDmaOps.push_back(dmaStartOp);
   186    }
   187  
   188    SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
   189    for (auto &op : *forOp.getBody()) {
   190      // Collect DMA finish operations.
   191      if (isa<AffineDmaWaitOp>(op)) {
   192        dmaFinishInsts.push_back(&op);
   193        continue;
   194      }
   195      auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
   196      if (!dmaStartOp)
   197        continue;
   198  
   199      // Only DMAs incoming into higher memory spaces are pipelined for now.
   200      // TODO(bondhugula): handle outgoing DMA pipelining.
   201      if (!dmaStartOp.isDestMemorySpaceFaster())
   202        continue;
   203  
   204      // Check for dependence with outgoing DMAs. Doing this conservatively.
   205      // TODO(andydavis,bondhugula): use the dependence analysis to check for
   206      // dependences between an incoming and outgoing DMA in the same iteration.
   207      auto it = outgoingDmaOps.begin();
   208      for (; it != outgoingDmaOps.end(); ++it) {
   209        if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
   210          break;
   211      }
   212      if (it != outgoingDmaOps.end())
   213        continue;
   214  
   215      // We only double buffer if the buffer is not live out of loop.
   216      auto *memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
   217      bool escapingUses = false;
   218      for (auto *user : memref->getUsers()) {
   219        // We can double buffer regardless of dealloc's outside the loop.
   220        if (isa<DeallocOp>(user))
   221          continue;
   222        if (!forOp.getBody()->findAncestorInstInBlock(*user)) {
   223          LLVM_DEBUG(llvm::dbgs()
   224                         << "can't pipeline: buffer is live out of loop\n";);
   225          escapingUses = true;
   226          break;
   227        }
   228      }
   229      if (!escapingUses)
   230        dmaStartInsts.push_back(&op);
   231    }
   232  
   233    // For each start operation, we look for a matching finish operation.
   234    for (auto *dmaStartInst : dmaStartInsts) {
   235      for (auto *dmaFinishInst : dmaFinishInsts) {
   236        if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst),
   237                          cast<AffineDmaWaitOp>(dmaFinishInst))) {
   238          startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
   239          break;
   240        }
   241      }
   242    }
   243  }
   244  
   245  /// Overlap DMA transfers with computation in this loop. If successful,
   246  /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
   247  /// inserted right before where it was.
   248  void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
   249    auto mayBeConstTripCount = getConstantTripCount(forOp);
   250    if (!mayBeConstTripCount.hasValue()) {
   251      LLVM_DEBUG(
   252          forOp.emitRemark("won't pipeline due to unknown trip count loop"));
   253      return;
   254    }
   255  
   256    SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
   257    findMatchingStartFinishInsts(forOp, startWaitPairs);
   258  
   259    if (startWaitPairs.empty()) {
   260      LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
   261      return;
   262    }
   263  
   264    // Double the buffers for the higher memory space memref's.
   265    // Identify memref's to replace by scanning through all DMA start
   266    // operations. A DMA start operation has two memref's - the one from the
   267    // higher level of memory hierarchy is the one to double buffer.
   268    // TODO(bondhugula): check whether double-buffering is even necessary.
   269    // TODO(bondhugula): make this work with different layouts: assuming here that
   270    // the dimension we are adding here for the double buffering is the outermost
   271    // dimension.
   272    for (auto &pair : startWaitPairs) {
   273      auto *dmaStartInst = pair.first;
   274      Value *oldMemRef = dmaStartInst->getOperand(
   275          cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos());
   276      if (!doubleBuffer(oldMemRef, forOp)) {
   277        // Normally, double buffering should not fail because we already checked
   278        // that there are no uses outside.
   279        LLVM_DEBUG(llvm::dbgs()
   280                       << "double buffering failed for" << dmaStartInst << "\n";);
   281        // IR still valid and semantically correct.
   282        return;
   283      }
   284      // If the old memref has no more uses, remove its 'dead' alloc if it was
   285      // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
   286      // operation could have been used on it if it was dynamically shaped in
   287      // order to create the double buffer above.)
   288      // '-canonicalize' does this in a more general way, but we'll anyway do the
   289      // simple/common case so that the output / test cases looks clear.
   290      if (auto *allocInst = oldMemRef->getDefiningOp()) {
   291        if (oldMemRef->use_empty()) {
   292          allocInst->erase();
   293        } else if (oldMemRef->hasOneUse()) {
   294          if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef->user_begin())) {
   295            dealloc.erase();
   296            oldMemRef->getDefiningOp()->erase();
   297          }
   298        }
   299      }
   300    }
   301  
   302    // Double the buffers for tag memrefs.
   303    for (auto &pair : startWaitPairs) {
   304      auto *dmaFinishInst = pair.second;
   305      Value *oldTagMemRef =
   306          dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
   307      if (!doubleBuffer(oldTagMemRef, forOp)) {
   308        LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
   309        return;
   310      }
   311      // If the old tag has no more uses, remove its 'dead' alloc if it was
   312      // alloc'ed.
   313      if (oldTagMemRef->use_empty())
   314        if (auto *allocInst = oldTagMemRef->getDefiningOp())
   315          allocInst->erase();
   316    }
   317  
   318    // Double buffering would have invalidated all the old DMA start/wait insts.
   319    startWaitPairs.clear();
   320    findMatchingStartFinishInsts(forOp, startWaitPairs);
   321  
   322    // Store shift for operation for later lookup for AffineApplyOp's.
   323    DenseMap<Operation *, unsigned> instShiftMap;
   324    for (auto &pair : startWaitPairs) {
   325      auto *dmaStartInst = pair.first;
   326      assert(isa<AffineDmaStartOp>(dmaStartInst));
   327      instShiftMap[dmaStartInst] = 0;
   328      // Set shifts for DMA start op's affine operand computation slices to 0.
   329      SmallVector<AffineApplyOp, 4> sliceOps;
   330      mlir::createAffineComputationSlice(dmaStartInst, &sliceOps);
   331      if (!sliceOps.empty()) {
   332        for (auto sliceOp : sliceOps) {
   333          instShiftMap[sliceOp.getOperation()] = 0;
   334        }
   335      } else {
   336        // If a slice wasn't created, the reachable affine.apply op's from its
   337        // operands are the ones that go with it.
   338        SmallVector<Operation *, 4> affineApplyInsts;
   339        SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
   340        getReachableAffineApplyOps(operands, affineApplyInsts);
   341        for (auto *op : affineApplyInsts) {
   342          instShiftMap[op] = 0;
   343        }
   344      }
   345    }
   346    // Everything else (including compute ops and dma finish) are shifted by one.
   347    for (auto &op : *forOp.getBody()) {
   348      if (instShiftMap.find(&op) == instShiftMap.end()) {
   349        instShiftMap[&op] = 1;
   350      }
   351    }
   352  
   353    // Get shifts stored in map.
   354    std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size());
   355    unsigned s = 0;
   356    for (auto &op : *forOp.getBody()) {
   357      assert(instShiftMap.find(&op) != instShiftMap.end());
   358      shifts[s++] = instShiftMap[&op];
   359  
   360      // Tagging operations with shifts for debugging purposes.
   361      LLVM_DEBUG({
   362        OpBuilder b(&op);
   363        op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
   364      });
   365    }
   366  
   367    if (!isInstwiseShiftValid(forOp, shifts)) {
   368      // Violates dependences.
   369      LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
   370      return;
   371    }
   372  
   373    if (failed(instBodySkew(forOp, shifts))) {
   374      LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
   375      return;
   376    }
   377  }
   378  
   379  static PassRegistration<PipelineDataTransfer> pass(
   380      "affine-pipeline-data-transfer",
   381      "Pipeline non-blocking data transfers between explicitly managed levels of "
   382      "the memory hierarchy");