github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp (about)

     1  //===- LoopFusionUtils.cpp ---- Utilities for loop fusion ----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements loop fusion transformation utility functions.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Transforms/LoopFusionUtils.h"
    23  
    24  #include "mlir/Analysis/AffineAnalysis.h"
    25  #include "mlir/Analysis/AffineStructures.h"
    26  #include "mlir/Analysis/LoopAnalysis.h"
    27  #include "mlir/Analysis/Utils.h"
    28  #include "mlir/Dialect/AffineOps/AffineOps.h"
    29  #include "mlir/Dialect/StandardOps/Ops.h"
    30  #include "mlir/IR/AffineExpr.h"
    31  #include "mlir/IR/AffineMap.h"
    32  #include "mlir/IR/BlockAndValueMapping.h"
    33  #include "mlir/IR/Builders.h"
    34  #include "mlir/IR/Function.h"
    35  #include "mlir/IR/Operation.h"
    36  #include "llvm/ADT/DenseMap.h"
    37  #include "llvm/ADT/SmallVector.h"
    38  #include "llvm/Support/Debug.h"
    39  #include "llvm/Support/raw_ostream.h"
    40  
    41  #define DEBUG_TYPE "loop-fusion-utils"
    42  
    43  using namespace mlir;
    44  
    45  // Gathers all load and store memref accesses in 'opA' into 'values', where
    46  // 'values[memref] == true' for each store operation.
    47  static void getLoadAndStoreMemRefAccesses(Operation *opA,
    48                                            DenseMap<Value *, bool> &values) {
    49    opA->walk([&](Operation *op) {
    50      if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
    51        if (values.count(loadOp.getMemRef()) == 0)
    52          values[loadOp.getMemRef()] = false;
    53      } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
    54        values[storeOp.getMemRef()] = true;
    55      }
    56    });
    57  }
    58  
    59  // Returns true if 'op' is a load or store operation which access an memref
    60  // accessed 'values' and at least one of the access is a store operation.
    61  // Returns false otherwise.
    62  static bool isDependentLoadOrStoreOp(Operation *op,
    63                                       DenseMap<Value *, bool> &values) {
    64    if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
    65      return values.count(loadOp.getMemRef()) > 0 &&
    66             values[loadOp.getMemRef()] == true;
    67    } else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
    68      return values.count(storeOp.getMemRef()) > 0;
    69    }
    70    return false;
    71  }
    72  
    73  // Returns the first operation in range ('opA', 'opB') which has a data
    74  // dependence on 'opA'. Returns 'nullptr' of no dependence exists.
    75  static Operation *getFirstDependentOpInRange(Operation *opA, Operation *opB) {
    76    // Record memref values from all loads/store in loop nest rooted at 'opA'.
    77    // Map from memref value to bool which is true if store, false otherwise.
    78    DenseMap<Value *, bool> values;
    79    getLoadAndStoreMemRefAccesses(opA, values);
    80  
    81    // For each 'opX' in block in range ('opA', 'opB'), check if there is a data
    82    // dependence from 'opA' to 'opX' ('opA' and 'opX' access the same memref
    83    // and at least one of the accesses is a store).
    84    Operation *firstDepOp = nullptr;
    85    for (Block::iterator it = std::next(Block::iterator(opA));
    86         it != Block::iterator(opB); ++it) {
    87      Operation *opX = &(*it);
    88      opX->walk([&](Operation *op) {
    89        if (!firstDepOp && isDependentLoadOrStoreOp(op, values))
    90          firstDepOp = opX;
    91      });
    92      if (firstDepOp)
    93        break;
    94    }
    95    return firstDepOp;
    96  }
    97  
    98  // Returns the last operation 'opX' in range ('opA', 'opB'), for which there
    99  // exists a data dependence from 'opX' to 'opB'.
   100  // Returns 'nullptr' of no dependence exists.
   101  static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
   102    // Record memref values from all loads/store in loop nest rooted at 'opB'.
   103    // Map from memref value to bool which is true if store, false otherwise.
   104    DenseMap<Value *, bool> values;
   105    getLoadAndStoreMemRefAccesses(opB, values);
   106  
   107    // For each 'opX' in block in range ('opA', 'opB') in reverse order,
   108    // check if there is a data dependence from 'opX' to 'opB':
   109    // *) 'opX' and 'opB' access the same memref and at least one of the accesses
   110    //    is a store.
   111    // *) 'opX' produces an SSA Value which is used by 'opB'.
   112    Operation *lastDepOp = nullptr;
   113    for (Block::reverse_iterator it = std::next(Block::reverse_iterator(opB));
   114         it != Block::reverse_iterator(opA); ++it) {
   115      Operation *opX = &(*it);
   116      opX->walk([&](Operation *op) {
   117        if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) {
   118          if (isDependentLoadOrStoreOp(op, values)) {
   119            lastDepOp = opX;
   120            return WalkResult::interrupt();
   121          }
   122          return WalkResult::advance();
   123        }
   124        for (auto *value : op->getResults()) {
   125          for (auto *user : value->getUsers()) {
   126            SmallVector<AffineForOp, 4> loops;
   127            // Check if any loop in loop nest surrounding 'user' is 'opB'.
   128            getLoopIVs(*user, &loops);
   129            if (llvm::is_contained(loops, cast<AffineForOp>(opB))) {
   130              lastDepOp = opX;
   131              return WalkResult::interrupt();
   132            }
   133          }
   134        }
   135        return WalkResult::advance();
   136      });
   137      if (lastDepOp)
   138        break;
   139    }
   140    return lastDepOp;
   141  }
   142  
   143  // Computes and returns an insertion point operation, before which the
   144  // the fused <srcForOp, dstForOp> loop nest can be inserted while preserving
   145  // dependences. Returns nullptr if no such insertion point is found.
   146  static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
   147                                                   AffineForOp dstForOp) {
   148    bool isSrcForOpBeforeDstForOp =
   149        srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
   150    auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
   151    auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
   152  
   153    auto *firstDepOpA =
   154        getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
   155    auto *lastDepOpB =
   156        getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
   157    // Block:
   158    //      ...
   159    //  |-- opA
   160    //  |   ...
   161    //  |   lastDepOpB --|
   162    //  |   ...          |
   163    //  |-> firstDepOpA  |
   164    //      ...          |
   165    //      opB <---------
   166    //
   167    // Valid insertion point range: (lastDepOpB, firstDepOpA)
   168    //
   169    if (firstDepOpA != nullptr) {
   170      if (lastDepOpB != nullptr) {
   171        if (firstDepOpA->isBeforeInBlock(lastDepOpB) || firstDepOpA == lastDepOpB)
   172          // No valid insertion point exists which preserves dependences.
   173          return nullptr;
   174      }
   175      // Return insertion point in valid range closest to 'opB'.
   176      // TODO(andydavis) Consider other insertion points in valid range.
   177      return firstDepOpA;
   178    }
   179    // No dependences from 'opA' to operation in range ('opA', 'opB'), return
   180    // 'opB' insertion point.
   181    return forOpB.getOperation();
   182  }
   183  
   184  // Gathers all load and store ops in loop nest rooted at 'forOp' into
   185  // 'loadAndStoreOps'.
   186  static bool
   187  gatherLoadsAndStores(AffineForOp forOp,
   188                       SmallVectorImpl<Operation *> &loadAndStoreOps) {
   189    bool hasIfOp = false;
   190    forOp.getOperation()->walk([&](Operation *op) {
   191      if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
   192        loadAndStoreOps.push_back(op);
   193      else if (isa<AffineIfOp>(op))
   194        hasIfOp = true;
   195    });
   196    return !hasIfOp;
   197  }
   198  
   199  // TODO(andydavis) Prevent fusion of loop nests with side-effecting operations.
   200  FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   201                                  unsigned dstLoopDepth,
   202                                  ComputationSliceState *srcSlice) {
   203    // Return 'failure' if 'dstLoopDepth == 0'.
   204    if (dstLoopDepth == 0) {
   205      LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n.");
   206      return FusionResult::FailPrecondition;
   207    }
   208    // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block.
   209    auto *block = srcForOp.getOperation()->getBlock();
   210    if (block != dstForOp.getOperation()->getBlock()) {
   211      LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n.");
   212      return FusionResult::FailPrecondition;
   213    }
   214  
   215    // Return 'failure' if no valid insertion point for fused loop nest in 'block'
   216    // exists which would preserve dependences.
   217    if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) {
   218      LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n.");
   219      return FusionResult::FailBlockDependence;
   220    }
   221  
   222    // Check if 'srcForOp' precedeces 'dstForOp' in 'block'.
   223    bool isSrcForOpBeforeDstForOp =
   224        srcForOp.getOperation()->isBeforeInBlock(dstForOp.getOperation());
   225    // 'forOpA' executes before 'forOpB' in 'block'.
   226    auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
   227    auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
   228  
   229    // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'.
   230    SmallVector<Operation *, 4> opsA;
   231    if (!gatherLoadsAndStores(forOpA, opsA)) {
   232      LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
   233      return FusionResult::FailPrecondition;
   234    }
   235  
   236    // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'.
   237    SmallVector<Operation *, 4> opsB;
   238    if (!gatherLoadsAndStores(forOpB, opsB)) {
   239      LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n.");
   240      return FusionResult::FailPrecondition;
   241    }
   242  
   243    // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
   244    unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
   245        *srcForOp.getOperation(), *dstForOp.getOperation());
   246  
   247    // Compute union of computation slices computed between all pairs of ops
   248    // from 'forOpA' and 'forOpB'.
   249    if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops,
   250                                       isSrcForOpBeforeDstForOp, srcSlice))) {
   251      LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n");
   252      return FusionResult::FailPrecondition;
   253    }
   254  
   255    return FusionResult::Success;
   256  }
   257  
   258  /// Collect loop nest statistics (eg. loop trip count and operation count)
   259  /// in 'stats' for loop nest rooted at 'forOp'. Returns true on success,
   260  /// returns false otherwise.
   261  bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
   262    auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
   263      auto *childForOp = forOp.getOperation();
   264      auto *parentForOp = forOp.getOperation()->getParentOp();
   265      if (!llvm::isa<FuncOp>(parentForOp)) {
   266        if (!isa<AffineForOp>(parentForOp)) {
   267          LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp");
   268          return WalkResult::interrupt();
   269        }
   270        // Add mapping to 'forOp' from its parent AffineForOp.
   271        stats->loopMap[parentForOp].push_back(forOp);
   272      }
   273  
   274      // Record the number of op operations in the body of 'forOp'.
   275      unsigned count = 0;
   276      stats->opCountMap[childForOp] = 0;
   277      for (auto &op : *forOp.getBody()) {
   278        if (!isa<AffineForOp>(op) && !isa<AffineIfOp>(op))
   279          ++count;
   280      }
   281      stats->opCountMap[childForOp] = count;
   282  
   283      // Record trip count for 'forOp'. Set flag if trip count is not
   284      // constant.
   285      Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
   286      if (!maybeConstTripCount.hasValue()) {
   287        // Currently only constant trip count loop nests are supported.
   288        LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported");
   289        return WalkResult::interrupt();
   290      }
   291  
   292      stats->tripCountMap[childForOp] = maybeConstTripCount.getValue();
   293      return WalkResult::advance();
   294    });
   295    return !walkResult.wasInterrupted();
   296  }
   297  
   298  // Computes the total cost of the loop nest rooted at 'forOp'.
   299  // Currently, the total cost is computed by counting the total operation
   300  // instance count (i.e. total number of operations in the loop bodyloop
   301  // operation count * loop trip count) for the entire loop nest.
   302  // If 'tripCountOverrideMap' is non-null, overrides the trip count for loops
   303  // specified in the map when computing the total op instance count.
   304  // NOTEs: 1) This is used to compute the cost of computation slices, which are
   305  // sliced along the iteration dimension, and thus reduce the trip count.
   306  // If 'computeCostMap' is non-null, the total op count for forOps specified
   307  // in the map is increased (not overridden) by adding the op count from the
   308  // map to the existing op count for the for loop. This is done before
   309  // multiplying by the loop's trip count, and is used to model the cost of
   310  // inserting a sliced loop nest of known cost into the loop's body.
   311  // 2) This is also used to compute the cost of fusing a slice of some loop nest
   312  // within another loop.
   313  static int64_t getComputeCostHelper(
   314      Operation *forOp, LoopNestStats &stats,
   315      llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountOverrideMap,
   316      DenseMap<Operation *, int64_t> *computeCostMap) {
   317    // 'opCount' is the total number operations in one iteration of 'forOp' body,
   318    // minus terminator op which is a no-op.
   319    int64_t opCount = stats.opCountMap[forOp] - 1;
   320    if (stats.loopMap.count(forOp) > 0) {
   321      for (auto childForOp : stats.loopMap[forOp]) {
   322        opCount += getComputeCostHelper(childForOp.getOperation(), stats,
   323                                        tripCountOverrideMap, computeCostMap);
   324      }
   325    }
   326    // Add in additional op instances from slice (if specified in map).
   327    if (computeCostMap != nullptr) {
   328      auto it = computeCostMap->find(forOp);
   329      if (it != computeCostMap->end()) {
   330        opCount += it->second;
   331      }
   332    }
   333    // Override trip count (if specified in map).
   334    int64_t tripCount = stats.tripCountMap[forOp];
   335    if (tripCountOverrideMap != nullptr) {
   336      auto it = tripCountOverrideMap->find(forOp);
   337      if (it != tripCountOverrideMap->end()) {
   338        tripCount = it->second;
   339      }
   340    }
   341    // Returns the total number of dynamic instances of operations in loop body.
   342    return tripCount * opCount;
   343  }
   344  
   345  // TODO(andydavis,b/126426796): extend this to handle multiple result maps.
   346  static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
   347    assert(lbMap.getNumResults() == 1 && "expected single result bound map");
   348    assert(ubMap.getNumResults() == 1 && "expected single result bound map");
   349    assert(lbMap.getNumDims() == ubMap.getNumDims());
   350    assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
   351    AffineExpr lbExpr(lbMap.getResult(0));
   352    AffineExpr ubExpr(ubMap.getResult(0));
   353    auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
   354                                           lbMap.getNumSymbols());
   355    auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
   356    if (!cExpr)
   357      return None;
   358    return cExpr.getValue();
   359  }
   360  
   361  // Return the number of iterations in the given slice.
   362  static uint64_t getSliceIterationCount(
   363      const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
   364    uint64_t iterCount = 1;
   365    for (const auto &count : sliceTripCountMap) {
   366      iterCount *= count.second;
   367    }
   368    return iterCount;
   369  }
   370  
   371  // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
   372  // nest surrounding represented by slice loop bounds in 'slice'.
   373  // Returns true on success, false otherwise (if a non-constant trip count
   374  // was encountered).
   375  // TODO(andydavis) Make this work with non-unit step loops.
   376  static bool buildSliceTripCountMap(
   377      ComputationSliceState *slice,
   378      llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
   379    unsigned numSrcLoopIVs = slice->ivs.size();
   380    // Populate map from AffineForOp -> trip count
   381    for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
   382      AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]);
   383      auto *op = forOp.getOperation();
   384      AffineMap lbMap = slice->lbs[i];
   385      AffineMap ubMap = slice->ubs[i];
   386      if (lbMap == AffineMap() || ubMap == AffineMap()) {
   387        // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
   388        if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
   389          (*tripCountMap)[op] =
   390              forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
   391          continue;
   392        }
   393        Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
   394        if (maybeConstTripCount.hasValue()) {
   395          (*tripCountMap)[op] = maybeConstTripCount.getValue();
   396          continue;
   397        }
   398        return false;
   399      }
   400      Optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
   401      // Slice bounds are created with a constant ub - lb difference.
   402      if (!tripCount.hasValue())
   403        return false;
   404      (*tripCountMap)[op] = tripCount.getValue();
   405    }
   406    return true;
   407  }
   408  
   409  /// Computes the total cost of the loop nest rooted at 'forOp' using 'stats'.
   410  /// Currently, the total cost is computed by counting the total operation
   411  /// instance count (i.e. total number of operations in the loop body * loop
   412  /// trip count) for the entire loop nest.
   413  int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
   414    return getComputeCostHelper(forOp.getOperation(), stats,
   415                                /*tripCountOverrideMap=*/nullptr,
   416                                /*computeCostMap=*/nullptr);
   417  }
   418  
   419  /// Computes and returns in 'computeCost', the total compute cost of fusing the
   420  /// 'slice' of the loop nest rooted at 'srcForOp' into 'dstForOp'. Currently,
   421  /// the total cost is computed by counting the total operation instance count
   422  /// (i.e. total number of operations in the loop body * loop trip count) for
   423  /// the entire loop nest.
   424  bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
   425                                  AffineForOp dstForOp, LoopNestStats &dstStats,
   426                                  ComputationSliceState *slice,
   427                                  int64_t *computeCost) {
   428    llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
   429    DenseMap<Operation *, int64_t> computeCostMap;
   430  
   431    // Build trip count map for computation slice.
   432    if (!buildSliceTripCountMap(slice, &sliceTripCountMap))
   433      return false;
   434    // Checks whether a store to load forwarding will happen.
   435    int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap);
   436    assert(sliceIterationCount > 0);
   437    bool storeLoadFwdGuaranteed = (sliceIterationCount == 1);
   438    auto *insertPointParent = slice->insertPoint->getParentOp();
   439  
   440    // The store and loads to this memref will disappear.
   441    // TODO(andydavis) Add load coalescing to memref data flow opt pass.
   442    if (storeLoadFwdGuaranteed) {
   443      // Subtract from operation count the loads/store we expect load/store
   444      // forwarding to remove.
   445      unsigned storeCount = 0;
   446      llvm::SmallDenseSet<Value *, 4> storeMemrefs;
   447      srcForOp.getOperation()->walk([&](Operation *op) {
   448        if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
   449          storeMemrefs.insert(storeOp.getMemRef());
   450          ++storeCount;
   451        }
   452      });
   453      // Subtract out any store ops in single-iteration src slice loop nest.
   454      if (storeCount > 0)
   455        computeCostMap[insertPointParent] = -storeCount;
   456      // Subtract out any load users of 'storeMemrefs' nested below
   457      // 'insertPointParent'.
   458      for (auto *value : storeMemrefs) {
   459        for (auto *user : value->getUsers()) {
   460          if (auto loadOp = dyn_cast<AffineLoadOp>(user)) {
   461            SmallVector<AffineForOp, 4> loops;
   462            // Check if any loop in loop nest surrounding 'user' is
   463            // 'insertPointParent'.
   464            getLoopIVs(*user, &loops);
   465            if (llvm::is_contained(loops, cast<AffineForOp>(insertPointParent))) {
   466              if (auto forOp =
   467                      dyn_cast_or_null<AffineForOp>(user->getParentOp())) {
   468                if (computeCostMap.count(forOp) == 0)
   469                  computeCostMap[forOp] = 0;
   470                computeCostMap[forOp] -= 1;
   471              }
   472            }
   473          }
   474        }
   475      }
   476    }
   477  
   478    // Compute op instance count for the src loop nest with iteration slicing.
   479    int64_t sliceComputeCost = getComputeCostHelper(
   480        srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
   481  
   482    // Compute cost of fusion for this depth.
   483    computeCostMap[insertPointParent] = sliceComputeCost;
   484  
   485    *computeCost =
   486        getComputeCostHelper(dstForOp.getOperation(), dstStats,
   487                             /*tripCountOverrideMap=*/nullptr, &computeCostMap);
   488    return true;
   489  }