github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Analysis/SliceAnalysis.cpp (about)

     1  //===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements Analysis functions specific to slicing in Function.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Analysis/SliceAnalysis.h"
    23  #include "mlir/Analysis/VectorAnalysis.h"
    24  #include "mlir/Dialect/AffineOps/AffineOps.h"
    25  #include "mlir/Dialect/LoopOps/LoopOps.h"
    26  #include "mlir/IR/Function.h"
    27  #include "mlir/IR/Operation.h"
    28  #include "mlir/Support/Functional.h"
    29  #include "mlir/Support/LLVM.h"
    30  #include "mlir/Support/STLExtras.h"
    31  #include "llvm/ADT/SetVector.h"
    32  
    33  ///
    34  /// Implements Analysis functions specific to slicing in Function.
    35  ///
    36  
    37  using namespace mlir;
    38  
    39  using llvm::SetVector;
    40  
    41  static void getForwardSliceImpl(Operation *op,
    42                                  SetVector<Operation *> *forwardSlice,
    43                                  TransitiveFilter filter) {
    44    if (!op) {
    45      return;
    46    }
    47  
    48    // Evaluate whether we should keep this use.
    49    // This is useful in particular to implement scoping; i.e. return the
    50    // transitive forwardSlice in the current scope.
    51    if (!filter(op)) {
    52      return;
    53    }
    54  
    55    if (auto forOp = dyn_cast<AffineForOp>(op)) {
    56      for (auto *ownerInst : forOp.getInductionVar()->getUsers())
    57        if (forwardSlice->count(ownerInst) == 0)
    58          getForwardSliceImpl(ownerInst, forwardSlice, filter);
    59    } else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
    60      for (auto *ownerInst : forOp.getInductionVar()->getUsers())
    61        if (forwardSlice->count(ownerInst) == 0)
    62          getForwardSliceImpl(ownerInst, forwardSlice, filter);
    63    } else {
    64      assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
    65      assert(op->getNumResults() <= 1 && "unexpected multiple results");
    66      if (op->getNumResults() > 0) {
    67        for (auto *ownerInst : op->getResult(0)->getUsers())
    68          if (forwardSlice->count(ownerInst) == 0)
    69            getForwardSliceImpl(ownerInst, forwardSlice, filter);
    70      }
    71    }
    72  
    73    forwardSlice->insert(op);
    74  }
    75  
    76  void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
    77                             TransitiveFilter filter) {
    78    getForwardSliceImpl(op, forwardSlice, filter);
    79    // Don't insert the top level operation, we just queried on it and don't
    80    // want it in the results.
    81    forwardSlice->remove(op);
    82  
    83    // Reverse to get back the actual topological order.
    84    // std::reverse does not work out of the box on SetVector and I want an
    85    // in-place swap based thing (the real std::reverse, not the LLVM adapter).
    86    std::vector<Operation *> v(forwardSlice->takeVector());
    87    forwardSlice->insert(v.rbegin(), v.rend());
    88  }
    89  
    90  static void getBackwardSliceImpl(Operation *op,
    91                                   SetVector<Operation *> *backwardSlice,
    92                                   TransitiveFilter filter) {
    93    if (!op)
    94      return;
    95  
    96    assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
    97            isa<loop::ForOp>(op)) &&
    98           "unexpected generic op with regions");
    99  
   100    // Evaluate whether we should keep this def.
   101    // This is useful in particular to implement scoping; i.e. return the
   102    // transitive forwardSlice in the current scope.
   103    if (!filter(op)) {
   104      return;
   105    }
   106  
   107    for (auto en : llvm::enumerate(op->getOperands())) {
   108      auto *operand = en.value();
   109      if (auto *blockArg = dyn_cast<BlockArgument>(operand)) {
   110        if (auto affIv = getForInductionVarOwner(operand)) {
   111          auto *affOp = affIv.getOperation();
   112          if (backwardSlice->count(affOp) == 0)
   113            getBackwardSliceImpl(affOp, backwardSlice, filter);
   114        } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
   115          auto *loopOp = loopIv.getOperation();
   116          if (backwardSlice->count(loopOp) == 0)
   117            getBackwardSliceImpl(loopOp, backwardSlice, filter);
   118        } else if (blockArg->getOwner() !=
   119                   &op->getParentOfType<FuncOp>().getBody().front()) {
   120          op->emitError("Unsupported CF for operand ") << en.index();
   121          llvm_unreachable("Unsupported control flow");
   122        }
   123        continue;
   124      }
   125      auto *op = operand->getDefiningOp();
   126      if (backwardSlice->count(op) == 0) {
   127        getBackwardSliceImpl(op, backwardSlice, filter);
   128      }
   129    }
   130  
   131    backwardSlice->insert(op);
   132  }
   133  
   134  void mlir::getBackwardSlice(Operation *op,
   135                              SetVector<Operation *> *backwardSlice,
   136                              TransitiveFilter filter) {
   137    getBackwardSliceImpl(op, backwardSlice, filter);
   138  
   139    // Don't insert the top level operation, we just queried on it and don't
   140    // want it in the results.
   141    backwardSlice->remove(op);
   142  }
   143  
   144  SetVector<Operation *> mlir::getSlice(Operation *op,
   145                                        TransitiveFilter backwardFilter,
   146                                        TransitiveFilter forwardFilter) {
   147    SetVector<Operation *> slice;
   148    slice.insert(op);
   149  
   150    unsigned currentIndex = 0;
   151    SetVector<Operation *> backwardSlice;
   152    SetVector<Operation *> forwardSlice;
   153    while (currentIndex != slice.size()) {
   154      auto *currentInst = (slice)[currentIndex];
   155      // Compute and insert the backwardSlice starting from currentInst.
   156      backwardSlice.clear();
   157      getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
   158      slice.insert(backwardSlice.begin(), backwardSlice.end());
   159  
   160      // Compute and insert the forwardSlice starting from currentInst.
   161      forwardSlice.clear();
   162      getForwardSlice(currentInst, &forwardSlice, forwardFilter);
   163      slice.insert(forwardSlice.begin(), forwardSlice.end());
   164      ++currentIndex;
   165    }
   166    return topologicalSort(slice);
   167  }
   168  
   169  namespace {
   170  /// DFS post-order implementation that maintains a global count to work across
   171  /// multiple invocations, to help implement topological sort on multi-root DAGs.
   172  /// We traverse all operations but only record the ones that appear in
   173  /// `toSort` for the final result.
   174  struct DFSState {
   175    DFSState(const SetVector<Operation *> &set)
   176        : toSort(set), topologicalCounts(), seen() {}
   177    const SetVector<Operation *> &toSort;
   178    SmallVector<Operation *, 16> topologicalCounts;
   179    DenseSet<Operation *> seen;
   180  };
   181  } // namespace
   182  
   183  static void DFSPostorder(Operation *current, DFSState *state) {
   184    assert(current->getNumResults() <= 1 && "NYI: multi-result");
   185    if (current->getNumResults() > 0) {
   186      for (auto &u : current->getResult(0)->getUses()) {
   187        auto *op = u.getOwner();
   188        DFSPostorder(op, state);
   189      }
   190    }
   191    bool inserted;
   192    using IterTy = decltype(state->seen.begin());
   193    IterTy iter;
   194    std::tie(iter, inserted) = state->seen.insert(current);
   195    if (inserted) {
   196      if (state->toSort.count(current) > 0) {
   197        state->topologicalCounts.push_back(current);
   198      }
   199    }
   200  }
   201  
   202  SetVector<Operation *>
   203  mlir::topologicalSort(const SetVector<Operation *> &toSort) {
   204    if (toSort.empty()) {
   205      return toSort;
   206    }
   207  
   208    // Run from each root with global count and `seen` set.
   209    DFSState state(toSort);
   210    for (auto *s : toSort) {
   211      assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
   212      DFSPostorder(s, &state);
   213    }
   214  
   215    // Reorder and return.
   216    SetVector<Operation *> res;
   217    for (auto it = state.topologicalCounts.rbegin(),
   218              eit = state.topologicalCounts.rend();
   219         it != eit; ++it) {
   220      res.insert(*it);
   221    }
   222    return res;
   223  }