github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp (about)

     1  //===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements view-based alias and dependence analyses.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
    23  #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
    24  
    25  #include "llvm/Support/CommandLine.h"
    26  #include "llvm/Support/Debug.h"
    27  
    28  #define DEBUG_TYPE "linalg-dependence-analysis"
    29  
    30  using namespace mlir;
    31  using namespace mlir::linalg;
    32  
    33  using llvm::dbgs;
    34  
    35  Value *Aliases::find(Value *v) {
    36    if (isa<BlockArgument>(v))
    37      return v;
    38  
    39    auto it = aliases.find(v);
    40    if (it != aliases.end()) {
    41      assert(((isa<BlockArgument>(it->getSecond()) &&
    42               it->getSecond()->getType().isa<ViewType>()) ||
    43              it->getSecond()->getType().isa<BufferType>()) &&
    44             "Buffer or block argument expected");
    45      return it->getSecond();
    46    }
    47  
    48    while (true) {
    49      if (isa<BlockArgument>(v))
    50        return v;
    51      if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
    52        auto it = aliases.insert(std::make_pair(v, find(slice.view())));
    53        return it.first->second;
    54      }
    55      if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
    56        auto it = aliases.insert(std::make_pair(v, view.buffer()));
    57        return it.first->second;
    58      }
    59      if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
    60        v = view.getView();
    61        continue;
    62      }
    63      llvm::errs() << "View alias analysis reduces to: " << *v << "\n";
    64      llvm_unreachable("unsupported view alias case");
    65    }
    66  }
    67  
    68  LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
    69                                               ArrayRef<Operation *> ops)
    70      : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
    71    for (auto en : llvm::enumerate(linalgOps)) {
    72      assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
    73      linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
    74    }
    75    for (unsigned i = 0, e = ops.size(); i < e; ++i) {
    76      for (unsigned j = i + 1; j < e; ++j) {
    77        addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
    78      }
    79    }
    80  }
    81  
    82  void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
    83                                                LinalgOpView indexingOpView,
    84                                                LinalgOpView dependentOpView) {
    85    LLVM_DEBUG(dbgs() << "\nAdd dep type " << dt << ":\t" << *indexingOpView.op
    86                      << " -> " << *dependentOpView.op);
    87    dependencesFromGraphs[dt][indexingOpView.op].push_back(
    88        LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
    89    dependencesIntoGraphs[dt][dependentOpView.op].push_back(
    90        LinalgDependenceGraphElem{indexingOpView, dependentOpView.view});
    91  }
    92  
    93  LinalgDependenceGraph::dependence_range
    94  LinalgDependenceGraph::getDependencesFrom(
    95      LinalgOp src, LinalgDependenceGraph::DependenceType dt) {
    96    return getDependencesFrom(src.getOperation(), dt);
    97  }
    98  
    99  LinalgDependenceGraph::dependence_range
   100  LinalgDependenceGraph::getDependencesFrom(
   101      Operation *src, LinalgDependenceGraph::DependenceType dt) {
   102    auto &vec = dependencesFromGraphs[dt][src];
   103    return llvm::make_range(vec.begin(), vec.end());
   104  }
   105  
   106  LinalgDependenceGraph::dependence_range
   107  LinalgDependenceGraph::getDependencesInto(
   108      LinalgOp dst, LinalgDependenceGraph::DependenceType dt) {
   109    return getDependencesInto(dst.getOperation(), dt);
   110  }
   111  
   112  LinalgDependenceGraph::dependence_range
   113  LinalgDependenceGraph::getDependencesInto(
   114      Operation *dst, LinalgDependenceGraph::DependenceType dt) {
   115    auto &vec = dependencesIntoGraphs[dt][dst];
   116    return llvm::make_range(vec.begin(), vec.end());
   117  }
   118  
   119  void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
   120    for (auto *srcView : src.getOutputs()) { // W
   121      // RAW graph
   122      for (auto *dstView : dst.getInputs()) {  // R
   123        if (aliases.alias(srcView, dstView)) { // if alias, fill RAW
   124          addDependenceElem(DependenceType::RAW,
   125                            LinalgOpView{src.getOperation(), srcView},
   126                            LinalgOpView{dst.getOperation(), dstView});
   127        }
   128      }
   129      // WAW graph
   130      for (auto *dstView : dst.getOutputs()) { // W
   131        if (aliases.alias(srcView, dstView)) { // if alias, fill WAW
   132          addDependenceElem(DependenceType::WAW,
   133                            LinalgOpView{src.getOperation(), srcView},
   134                            LinalgOpView{dst.getOperation(), dstView});
   135        }
   136      }
   137    }
   138    for (auto *srcView : src.getInputs()) { // R
   139      // RAR graph
   140      for (auto *dstView : dst.getInputs()) {  // R
   141        if (aliases.alias(srcView, dstView)) { // if alias, fill RAR
   142          addDependenceElem(DependenceType::RAR,
   143                            LinalgOpView{src.getOperation(), srcView},
   144                            LinalgOpView{dst.getOperation(), dstView});
   145        }
   146      }
   147      // WAR graph
   148      for (auto *dstView : dst.getOutputs()) { // W
   149        if (aliases.alias(srcView, dstView)) { // if alias, fill WAR
   150          addDependenceElem(DependenceType::WAR,
   151                            LinalgOpView{src.getOperation(), srcView},
   152                            LinalgOpView{dst.getOperation(), dstView});
   153        }
   154      }
   155    }
   156  }
   157  
   158  SmallVector<Operation *, 8>
   159  LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
   160                                                 LinalgOp dstLinalgOp) {
   161    return findOperationsWithCoveringDependences(
   162        srcLinalgOp, dstLinalgOp, nullptr,
   163        {DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
   164  }
   165  
   166  SmallVector<Operation *, 8>
   167  LinalgDependenceGraph::findCoveringWrites(LinalgOp srcLinalgOp,
   168                                            LinalgOp dstLinalgOp, Value *view) {
   169    return findOperationsWithCoveringDependences(
   170        srcLinalgOp, dstLinalgOp, view,
   171        {DependenceType::WAW, DependenceType::WAR});
   172  }
   173  
   174  SmallVector<Operation *, 8>
   175  LinalgDependenceGraph::findCoveringReads(LinalgOp srcLinalgOp,
   176                                           LinalgOp dstLinalgOp, Value *view) {
   177    return findOperationsWithCoveringDependences(
   178        srcLinalgOp, dstLinalgOp, view,
   179        {DependenceType::RAR, DependenceType::RAW});
   180  }
   181  
   182  SmallVector<Operation *, 8>
   183  LinalgDependenceGraph::findOperationsWithCoveringDependences(
   184      LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value *view,
   185      ArrayRef<DependenceType> types) {
   186    auto *src = srcLinalgOp.getOperation();
   187    auto *dst = dstLinalgOp.getOperation();
   188    auto srcPos = linalgOpPositions[src];
   189    auto dstPos = linalgOpPositions[dst];
   190    assert(srcPos < dstPos && "expected dst after src in IR traversal order");
   191  
   192    SmallVector<Operation *, 8> res;
   193    // Consider an intermediate interleaved `interim` op, look for any dependence
   194    // to an aliasing view on a src -> op -> dst path.
   195    // TODO(ntv) we are not considering paths yet, just interleaved positions.
   196    for (auto dt : types) {
   197      for (auto dependence : getDependencesFrom(src, dt)) {
   198        auto interimPos = linalgOpPositions[dependence.dependentOpView.op];
   199        // Skip if not interleaved.
   200        if (interimPos >= dstPos || interimPos <= srcPos)
   201          continue;
   202        if (view && !aliases.alias(view, dependence.indexingView))
   203          continue;
   204        auto *op = dependence.dependentOpView.op;
   205        LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " << dt
   206                          << ": " << *src << " -> " << *op << " on "
   207                          << *dependence.indexingView);
   208        res.push_back(op);
   209      }
   210    }
   211    return res;
   212  }