github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (about)

     1  //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements the linalg dialect Fusion pass.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/EDSC/Helpers.h"
    23  #include "mlir/IR/AffineExpr.h"
    24  #include "mlir/IR/AffineMap.h"
    25  #include "mlir/IR/OpImplementation.h"
    26  #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
    27  #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
    28  #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
    29  #include "mlir/Dialect/Linalg/Passes.h"
    30  #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
    31  #include "mlir/Dialect/Linalg/Utils/Utils.h"
    32  #include "mlir/Pass/Pass.h"
    33  #include "mlir/Support/LLVM.h"
    34  #include "mlir/Support/STLExtras.h"
    35  #include "mlir/Transforms/FoldUtils.h"
    36  
    37  #include "llvm/ADT/SetVector.h"
    38  #include "llvm/Support/CommandLine.h"
    39  #include "llvm/Support/Debug.h"
    40  
    41  #define DEBUG_TYPE "linalg-fusion"
    42  
    43  using namespace mlir;
    44  using namespace mlir::edsc;
    45  using namespace mlir::edsc::intrinsics;
    46  using namespace mlir::linalg;
    47  using namespace mlir::linalg::intrinsics;
    48  
    49  using llvm::dbgs;
    50  
    51  /// Implements a simple high-level fusion pass of linalg library operations.
    52  ///
    53  /// In each block, linalg ops are processed in reverse textual order.
    54  /// Given a linalg op, fusion occurs by:
    55  ///   1. tiling the op by a given multi-dimensional tile size;
    56  ///   2. inspecting the linalg ops that write into the views read by the op in
    57  ///      step 1. This uses the SSA value of the views to determine producer-
    58  ///      consumer dependences: only identical SSA views are considered for
    59  ///      fusion at this point;
    60  ///   3. greedily fuse the producing linalg ops into the consuming loop tiles;
    61  ///   4. inspect the fused ops and determine whether they have other remaining
    62  ///      LinalgOp uses. If not, then erase the original producing linalg op.
    63  ///
    64  /// More advanced use cases, analyses as well as profitability heuristics are
    65  /// left for future work.
    66  
    67  static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
    68  static llvm::cl::list<unsigned> clTileSizes(
    69      "linalg-fusion-tile-sizes",
    70      llvm::cl::desc(
    71          "Tile sizes by which to tile linalg operations during linalg fusion"),
    72      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
    73      llvm::cl::cat(clOptionsCategory));
    74  
    75  // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
    76  // a subset of the original loop ranges of `op`.
    77  // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
    78  // to the `loopRanges` in order to obtain view ranges.
    79  static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
    80                                      ArrayRef<SubViewOp::Range> loopRanges,
    81                                      OperationFolder &state) {
    82    ScopedContext scope(b, loc);
    83  
    84    auto maps = loopToOperandRangesMaps(op);
    85    SmallVector<Value *, 8> clonedViews;
    86    clonedViews.reserve(op.getNumInputsAndOutputs());
    87    // Iterate over the inputs and outputs in order.
    88    // Extract the subranges from the linearized ranges.
    89    SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
    90    for (auto en : llvm::enumerate(ios)) {
    91      unsigned idx = en.index();
    92      auto map = maps[idx];
    93      LLVM_DEBUG(dbgs() << "map: " << map << "\n");
    94      Value *view = en.value();
    95      SmallVector<SubViewOp::Range, 8> viewRanges(map.getNumResults());
    96      for (auto en2 : llvm::enumerate(map.getResults())) {
    97        unsigned d = en2.index();
    98        // loopToOperandRangesMaps are permutations-only.
    99        unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
   100        viewRanges[d] = loopRanges[loopPos];
   101        LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
   102                          << "\t"
   103                          << "loopPos: " << loopPos << "\t" << viewRanges[d]);
   104      }
   105      // TODO(ntv) opportunities for folding/CSE here rather than build new IR.
   106      clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges));
   107    }
   108    auto operands = getAssumedNonViewOperands(op);
   109    clonedViews.append(operands.begin(), operands.end());
   110    return op.create(b, loc, clonedViews, op.getAttrs());
   111  }
   112  
   113  struct ViewDimension {
   114    Value *view;
   115    unsigned dimension;
   116  };
   117  
   118  static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
   119    auto maps = loopToOperandRangesMaps(op);
   120    SmallVector<Value *, 8> clonedViews;
   121    clonedViews.reserve(op.getNumInputsAndOutputs());
   122    // Iterate over the inputs and outputs in order.
   123    // Extract the subranges from the linearized ranges.
   124    SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
   125    for (auto en : llvm::enumerate(ios)) {
   126      unsigned idx = en.index();
   127      auto map = maps[idx];
   128      LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
   129      LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
   130      Value *view = en.value();
   131      SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr);
   132      for (auto en2 : llvm::enumerate(map.getResults())) {
   133        if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
   134          LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
   135                            << "\n");
   136          LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
   137                            << "\n");
   138          return ViewDimension{view, static_cast<unsigned>(en2.index())};
   139        }
   140      }
   141    }
   142    llvm_unreachable("Expect to be able to extract a view defining loop range");
   143  }
   144  
   145  static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer,
   146                                 LinalgOp consumer, LinalgOp tiledConsumer,
   147                                 OperationFolder &state) {
   148    auto maybeConsumerIdx = consumer.getIndexOfInput(producedView);
   149    if (!maybeConsumerIdx.hasValue())
   150      return llvm::None;
   151    unsigned consumerIdx = maybeConsumerIdx.getValue();
   152  
   153    auto maybeProducerIdx = producer.getIndexOfOutput(producedView);
   154    if (!maybeProducerIdx.hasValue())
   155      return llvm::None;
   156    unsigned producerIdx = maybeProducerIdx.getValue();
   157  
   158    // If the view is the same between consumer and tiledConsumer, this means we
   159    // don't have loops and the producer cannot be fused at this level.
   160    if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx))
   161      return llvm::None;
   162  
   163    auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>(
   164        tiledConsumer.getInput(consumerIdx)->getDefiningOp());
   165  
   166    // If we don't have a slice, this also means we don't have loops and the
   167    // producer cannot be fused at this level.
   168    if (!tiledConsumerSubView)
   169      return llvm::None;
   170  
   171    // loopToOperandRangesMaps are permutations-only by construction:
   172    //   we can always identify a data dimension with a (at least one) loop
   173    //   dimension.
   174    AffineMap producerMap =
   175        loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
   176    LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: "
   177                      << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n");
   178    LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
   179                      << ", producer map: " << producerMap << "\n");
   180  
   181    unsigned nPar = producer.getNumParallelLoops();
   182    unsigned nRed = producer.getNumReductionLoops();
   183    unsigned nWin = producer.getNumWindowLoops();
   184    SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
   185  
   186    // Iterate over dimensions identified by the producer map for `producerIdx`.
   187    // This defines a subset of the loop ranges that we need to complete later.
   188    for (auto en : llvm::enumerate(producerMap.getResults())) {
   189      unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
   190      loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index());
   191    }
   192  
   193    OpBuilder b(tiledConsumer.getOperation());
   194    auto loc = tiledConsumer.getLoc();
   195    // Iterate over all dimensions. For the dimensions not identified by the
   196    // producer map for `producerIdx`, we need to explicitly compute the view that
   197    // defines the loop ranges using the `producer`.
   198    for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
   199      if (loopRanges[i].min)
   200        LLVM_DEBUG(llvm::dbgs()
   201                   << "existing LoopRange: " << loopRanges[i] << "\n");
   202      else {
   203        auto viewDim = getViewDefiningLoopRange(producer, i);
   204        loopRanges[i] = SubViewOp::Range{
   205            state.create<ConstantIndexOp>(b, loc, 0),
   206            linalg::intrinsics::dim(viewDim.view, viewDim.dimension),
   207            state.create<ConstantIndexOp>(b, loc, 1)};
   208        LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
   209      }
   210    }
   211  
   212    return cloneWithLoopRanges(b, loc, producer, loopRanges, state);
   213  }
   214  
   215  // Encode structural fusion safety preconditions.
   216  // Some of these will be lifted in the future with better analysis.
   217  static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
   218                                            LinalgOp consumer) {
   219    // If a producer has multiple outputs, the analysis needs to take the tiling
   220    // of other outputs into account.
   221    if (producer.getNumOutputs() != 1)
   222      return false;
   223    // Until subview analysis is available, same SSA value is required for fusion.
   224    if (producer.getOutput(0) != readView)
   225      return false;
   226    // No control-flow divergence supported. Only straightline op fusion allowed.
   227    // TODO(ntv) allow fusion when a dominance relation exists.
   228    if (producer.getOperation()->getBlock() !=
   229        consumer.getOperation()->getBlock())
   230      return false;
   231    return true;
   232  }
   233  
   234  static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
   235    OperationFolder state;
   236    DenseSet<Operation *> eraseSet;
   237  
   238    LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
   239  
   240    // 1. Record the linalg ops so we can traverse them in reverse order.
   241    SmallVector<Operation *, 8> linalgOps;
   242    f.walk([&](LinalgOp op) { linalgOps.push_back(op.getOperation()); });
   243  
   244    // 2. Setup the dependences graph, aliases are populated lazily.
   245    Aliases aliases;
   246    LinalgDependenceGraph G(aliases, linalgOps);
   247  
   248    // 2. For each original linalg op (in reverse order to allow chained
   249    // fusions).
   250    for (auto *op : llvm::reverse(linalgOps)) {
   251      auto consumer = cast<LinalgOp>(op);
   252      LLVM_DEBUG(dbgs() << "\n******\nStart processing:\t" << *op);
   253      // 3. If marked for erasure, it has already been fused. Skip fusing op.
   254      if (eraseSet.count(op) > 0) {
   255        LLVM_DEBUG(dbgs() << "\nAlready fused and marked for erasure, skip.");
   256        continue;
   257      }
   258  
   259      // 4. Apply loop tiling to enable fusion. If unsuccessful, skip fusing op.
   260      auto tiledOp = tileLinalgOp(op, tileSizes, state);
   261      if (!tiledOp) {
   262        LLVM_DEBUG(dbgs() << "\nTile sizes did not produce loops, skip.");
   263        continue;
   264      }
   265  
   266      // 5. For now, we only fuse RAW dependences.
   267      SmallVector<Operation *, 8> fusedProducers;
   268      SmallVector<Value *, 8> fusedViews;
   269      for (auto dependence : G.getDependencesInto(
   270               consumer, LinalgDependenceGraph::DependenceType::RAW)) {
   271        auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
   272        LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
   273                          << *producer.getOperation() << "\n");
   274  
   275        // a. For now we require fusion on identical SSA values, this allows us to
   276        // not worry about partial writes etc.
   277        // TODO(ntv) support more elaborate fusion with non identical SSA values.
   278        auto *view = dependence.indexingView;
   279        if (view != dependence.dependentOpView.view) {
   280          LLVM_DEBUG(dbgs() << "\nviews are different SSA values, skip.");
   281          continue;
   282        }
   283        // b. Make some simple structural checks that alleviate the need for more
   284        // complex analyses.
   285        if (!isStructurallyFusableProducer(producer, view, op)) {
   286          LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation());
   287          continue;
   288        }
   289        // c. Check for fusion-preventing write that would violate dependences.
   290        // `view` is a producer write that cannot bypass any other write or read.
   291        bool preventFusion = false;
   292        for (auto *op : G.findCoveringDependences(producer, consumer))
   293          if (eraseSet.count(op) == 0) {
   294            preventFusion = true;
   295            LLVM_DEBUG(dbgs() << "\n***Found fusion preventing dep via: " << *op);
   296            break;
   297          }
   298        if (preventFusion)
   299          continue;
   300  
   301        // 6. Try to fuse `producer` just before `tiledOp`.
   302        LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n"));
   303  
   304        auto tOp = tiledOp->op;
   305        OpBuilder builder(tOp.getOperation());
   306        ScopedContext scope(builder, tOp.getLoc());
   307        LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n");
   308        auto maybeFusedProducer = fuse(view, producer, op, tOp, state);
   309        if (!maybeFusedProducer) {
   310          LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip.");
   311          continue;
   312        }
   313  
   314        fusedProducers.push_back(producer.getOperation());
   315        fusedViews.push_back(view);
   316      }
   317  
   318      // 7. If no fusion occurred, or a drop the outer tiled loop which undoes
   319      // everything we did.
   320      if (fusedProducers.empty()) {
   321        tiledOp->loops[0].erase();
   322        continue;
   323      }
   324  
   325      eraseSet.insert(op);
   326      eraseSet.insert(fusedProducers.begin(), fusedProducers.end());
   327    }
   328  
   329    for (auto *op : eraseSet)
   330      op->erase();
   331  
   332    LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
   333  }
   334  
   335  namespace {
   336  struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
   337    LinalgFusionPass() = default;
   338    LinalgFusionPass(ArrayRef<int64_t> sizes);
   339  
   340    void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); }
   341  
   342    SmallVector<int64_t, 8> tileSizes;
   343  };
   344  } // namespace
   345  
   346  LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes)
   347      : LinalgFusionPass() {
   348    if (!sizes.empty())
   349      this->tileSizes.assign(sizes.begin(), sizes.end());
   350  }
   351  
   352  std::unique_ptr<FunctionPassBase>
   353  mlir::linalg::createLinalgFusionPass(ArrayRef<int64_t> tileSizes) {
   354    return std::make_unique<LinalgFusionPass>(tileSizes);
   355  }
   356  
   357  static PassRegistration<LinalgFusionPass>
   358      pass("linalg-fusion", "Fuse operations in the linalg dialect", [] {
   359        auto pass = std::make_unique<LinalgFusionPass>();
   360        pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
   361        return pass;
   362      });