github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/Utils/Utils.cpp (about)

     1  //===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements miscellaneous transformation routines for non-loop IR
    19  // structures.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Transforms/Utils.h"
    24  
    25  #include "mlir/Analysis/AffineAnalysis.h"
    26  #include "mlir/Analysis/AffineStructures.h"
    27  #include "mlir/Analysis/Dominance.h"
    28  #include "mlir/Analysis/Utils.h"
    29  #include "mlir/Dialect/AffineOps/AffineOps.h"
    30  #include "mlir/Dialect/StandardOps/Ops.h"
    31  #include "mlir/IR/Builders.h"
    32  #include "mlir/IR/Function.h"
    33  #include "mlir/IR/Module.h"
    34  #include "mlir/Support/MathExtras.h"
    35  #include "llvm/ADT/DenseMap.h"
    36  using namespace mlir;
    37  
    38  /// Return true if this operation dereferences one or more memref's.
    39  // Temporary utility: will be replaced when this is modeled through
    40  // side-effects/op traits. TODO(b/117228571)
    41  static bool isMemRefDereferencingOp(Operation &op) {
    42    if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) ||
    43        isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op))
    44      return true;
    45    return false;
    46  }
    47  
    48  /// Return the AffineMapAttr associated with memory 'op' on 'memref'.
    49  static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) {
    50    if (auto loadOp = dyn_cast<AffineLoadOp>(op))
    51      return loadOp.getAffineMapAttrForMemRef(memref);
    52    else if (auto storeOp = dyn_cast<AffineStoreOp>(op))
    53      return storeOp.getAffineMapAttrForMemRef(memref);
    54    else if (auto dmaStart = dyn_cast<AffineDmaStartOp>(op))
    55      return dmaStart.getAffineMapAttrForMemRef(memref);
    56    assert(isa<AffineDmaWaitOp>(op));
    57    return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref);
    58  }
    59  
    60  // Perform the replacement in `op`.
    61  LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
    62                                               Operation *op,
    63                                               ArrayRef<Value *> extraIndices,
    64                                               AffineMap indexRemap,
    65                                               ArrayRef<Value *> extraOperands) {
    66    unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
    67    (void)newMemRefRank; // unused in opt mode
    68    unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
    69    (void)oldMemRefRank;
    70    if (indexRemap) {
    71      assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
    72      assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
    73      assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
    74    } else {
    75      assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
    76    }
    77  
    78    // Assert same elemental type.
    79    assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
    80           newMemRef->getType().cast<MemRefType>().getElementType());
    81  
    82    if (!isMemRefDereferencingOp(*op))
    83      // Failure: memref used in a non-dereferencing context (potentially
    84      // escapes); no replacement in these cases.
    85      return failure();
    86  
    87    SmallVector<unsigned, 2> usePositions;
    88    for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
    89      if (opEntry.value() == oldMemRef)
    90        usePositions.push_back(opEntry.index());
    91    }
    92  
    93    // If memref doesn't appear, nothing to do.
    94    if (usePositions.empty())
    95      return success();
    96  
    97    if (usePositions.size() > 1) {
    98      // TODO(mlir-team): extend it for this case when needed (rare).
    99      assert(false && "multiple dereferencing uses in a single op not supported");
   100      return failure();
   101    }
   102  
   103    unsigned memRefOperandPos = usePositions.front();
   104  
   105    OpBuilder builder(op);
   106    NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef);
   107    AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue();
   108    unsigned oldMapNumInputs = oldMap.getNumInputs();
   109    SmallVector<Value *, 4> oldMapOperands(
   110        op->operand_begin() + memRefOperandPos + 1,
   111        op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs);
   112  
   113    // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'.
   114    SmallVector<Value *, 4> oldMemRefOperands;
   115    SmallVector<Value *, 4> affineApplyOps;
   116    oldMemRefOperands.reserve(oldMemRefRank);
   117    if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) {
   118      for (auto resultExpr : oldMap.getResults()) {
   119        auto singleResMap = builder.getAffineMap(
   120            oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr);
   121        auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
   122                                                  oldMapOperands);
   123        oldMemRefOperands.push_back(afOp);
   124        affineApplyOps.push_back(afOp);
   125      }
   126    } else {
   127      oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end());
   128    }
   129  
   130    // Construct new indices as a remap of the old ones if a remapping has been
   131    // provided. The indices of a memref come right after it, i.e.,
   132    // at position memRefOperandPos + 1.
   133    SmallVector<Value *, 4> remapOperands;
   134    remapOperands.reserve(extraOperands.size() + oldMemRefRank);
   135    remapOperands.append(extraOperands.begin(), extraOperands.end());
   136    remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end());
   137  
   138    SmallVector<Value *, 4> remapOutputs;
   139    remapOutputs.reserve(oldMemRefRank);
   140  
   141    if (indexRemap &&
   142        indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) {
   143      // Remapped indices.
   144      for (auto resultExpr : indexRemap.getResults()) {
   145        auto singleResMap = builder.getAffineMap(
   146            indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr);
   147        auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
   148                                                  remapOperands);
   149        remapOutputs.push_back(afOp);
   150        affineApplyOps.push_back(afOp);
   151      }
   152    } else {
   153      // No remapping specified.
   154      remapOutputs.append(remapOperands.begin(), remapOperands.end());
   155    }
   156  
   157    SmallVector<Value *, 4> newMapOperands;
   158    newMapOperands.reserve(newMemRefRank);
   159  
   160    // Prepend 'extraIndices' in 'newMapOperands'.
   161    for (auto *extraIndex : extraIndices) {
   162      assert(extraIndex->getDefiningOp()->getNumResults() == 1 &&
   163             "single result op's expected to generate these indices");
   164      assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) &&
   165             "invalid memory op index");
   166      newMapOperands.push_back(extraIndex);
   167    }
   168  
   169    // Append 'remapOutputs' to 'newMapOperands'.
   170    newMapOperands.append(remapOutputs.begin(), remapOutputs.end());
   171  
   172    // Create new fully composed AffineMap for new op to be created.
   173    assert(newMapOperands.size() == newMemRefRank);
   174    auto newMap = builder.getMultiDimIdentityMap(newMemRefRank);
   175    // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here.
   176    fullyComposeAffineMapAndOperands(&newMap, &newMapOperands);
   177    newMap = simplifyAffineMap(newMap);
   178    canonicalizeMapAndOperands(&newMap, &newMapOperands);
   179    // Remove any affine.apply's that became dead as a result of composition.
   180    for (auto *value : affineApplyOps)
   181      if (value->use_empty())
   182        value->getDefiningOp()->erase();
   183  
   184    // Construct the new operation using this memref.
   185    OperationState state(op->getLoc(), op->getName());
   186    state.setOperandListToResizable(op->hasResizableOperandsList());
   187    state.operands.reserve(op->getNumOperands() + extraIndices.size());
   188    // Insert the non-memref operands.
   189    state.operands.append(op->operand_begin(),
   190                          op->operand_begin() + memRefOperandPos);
   191    // Insert the new memref value.
   192    state.operands.push_back(newMemRef);
   193  
   194    // Insert the new memref map operands.
   195    state.operands.append(newMapOperands.begin(), newMapOperands.end());
   196  
   197    // Insert the remaining operands unmodified.
   198    state.operands.append(op->operand_begin() + memRefOperandPos + 1 +
   199                              oldMapNumInputs,
   200                          op->operand_end());
   201  
   202    // Result types don't change. Both memref's are of the same elemental type.
   203    state.types.reserve(op->getNumResults());
   204    for (auto *result : op->getResults())
   205      state.types.push_back(result->getType());
   206  
   207    // Add attribute for 'newMap', other Attributes do not change.
   208    auto newMapAttr = builder.getAffineMapAttr(newMap);
   209    for (auto namedAttr : op->getAttrs()) {
   210      if (namedAttr.first == oldMapAttrPair.first) {
   211        state.attributes.push_back({namedAttr.first, newMapAttr});
   212      } else {
   213        state.attributes.push_back(namedAttr);
   214      }
   215    }
   216  
   217    // Create the new operation.
   218    auto *repOp = builder.createOperation(state);
   219    op->replaceAllUsesWith(repOp);
   220    op->erase();
   221  
   222    return success();
   223  }
   224  
   225  LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
   226                                               ArrayRef<Value *> extraIndices,
   227                                               AffineMap indexRemap,
   228                                               ArrayRef<Value *> extraOperands,
   229                                               Operation *domInstFilter,
   230                                               Operation *postDomInstFilter) {
   231    unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
   232    (void)newMemRefRank; // unused in opt mode
   233    unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
   234    (void)oldMemRefRank;
   235    if (indexRemap) {
   236      assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected");
   237      assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank);
   238      assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
   239    } else {
   240      assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
   241    }
   242  
   243    // Assert same elemental type.
   244    assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
   245           newMemRef->getType().cast<MemRefType>().getElementType());
   246  
   247    std::unique_ptr<DominanceInfo> domInfo;
   248    std::unique_ptr<PostDominanceInfo> postDomInfo;
   249    if (domInstFilter)
   250      domInfo = std::make_unique<DominanceInfo>(
   251          domInstFilter->getParentOfType<FuncOp>());
   252  
   253    if (postDomInstFilter)
   254      postDomInfo = std::make_unique<PostDominanceInfo>(
   255          postDomInstFilter->getParentOfType<FuncOp>());
   256  
   257    // Walk all uses of old memref; collect ops to perform replacement. We use a
   258    // DenseSet since an operation could potentially have multiple uses of a
   259    // memref (although rare), and the replacement later is going to erase ops.
   260    DenseSet<Operation *> opsToReplace;
   261    for (auto *op : oldMemRef->getUsers()) {
   262      // Skip this use if it's not dominated by domInstFilter.
   263      if (domInstFilter && !domInfo->dominates(domInstFilter, op))
   264        continue;
   265  
   266      // Skip this use if it's not post-dominated by postDomInstFilter.
   267      if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op))
   268        continue;
   269  
   270      // Skip dealloc's - no replacement is necessary, and a memref replacement
   271      // at other uses doesn't hurt these dealloc's.
   272      if (isa<DeallocOp>(op))
   273        continue;
   274  
   275      // Check if the memref was used in a non-dereferencing context. It is fine
   276      // for the memref to be used in a non-dereferencing way outside of the
   277      // region where this replacement is happening.
   278      if (!isMemRefDereferencingOp(*op))
   279        // Failure: memref used in a non-dereferencing op (potentially escapes);
   280        // no replacement in these cases.
   281        return failure();
   282  
   283      // We'll first collect and then replace --- since replacement erases the op
   284      // that has the use, and that op could be postDomFilter or domFilter itself!
   285      opsToReplace.insert(op);
   286    }
   287  
   288    for (auto *op : opsToReplace) {
   289      if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices,
   290                                          indexRemap, extraOperands)))
   291        assert(false && "memref replacement guaranteed to succeed here");
   292    }
   293  
   294    return success();
   295  }
   296  
   297  /// Given an operation, inserts one or more single result affine
   298  /// apply operations, results of which are exclusively used by this operation
   299  /// operation. The operands of these newly created affine apply ops are
   300  /// guaranteed to be loop iterators or terminal symbols of a function.
   301  ///
   302  /// Before
   303  ///
   304  /// affine.for %i = 0 to #map(%N)
   305  ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
   306  ///   "send"(%idx, %A, ...)
   307  ///   "compute"(%idx)
   308  ///
   309  /// After
   310  ///
   311  /// affine.for %i = 0 to #map(%N)
   312  ///   %idx = affine.apply (d0) -> (d0 mod 2) (%i)
   313  ///   "send"(%idx, %A, ...)
   314  ///   %idx_ = affine.apply (d0) -> (d0 mod 2) (%i)
   315  ///   "compute"(%idx_)
   316  ///
   317  /// This allows applying different transformations on send and compute (for eg.
   318  /// different shifts/delays).
   319  ///
   320  /// Returns nullptr either if none of opInst's operands were the result of an
   321  /// affine.apply and thus there was no affine computation slice to create, or if
   322  /// all the affine.apply op's supplying operands to this opInst did not have any
   323  /// uses besides this opInst; otherwise returns the list of affine.apply
   324  /// operations created in output argument `sliceOps`.
   325  void mlir::createAffineComputationSlice(
   326      Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) {
   327    // Collect all operands that are results of affine apply ops.
   328    SmallVector<Value *, 4> subOperands;
   329    subOperands.reserve(opInst->getNumOperands());
   330    for (auto *operand : opInst->getOperands())
   331      if (isa_and_nonnull<AffineApplyOp>(operand->getDefiningOp()))
   332        subOperands.push_back(operand);
   333  
   334    // Gather sequence of AffineApplyOps reachable from 'subOperands'.
   335    SmallVector<Operation *, 4> affineApplyOps;
   336    getReachableAffineApplyOps(subOperands, affineApplyOps);
   337    // Skip transforming if there are no affine maps to compose.
   338    if (affineApplyOps.empty())
   339      return;
   340  
   341    // Check if all uses of the affine apply op's lie only in this op op, in
   342    // which case there would be nothing to do.
   343    bool localized = true;
   344    for (auto *op : affineApplyOps) {
   345      for (auto *result : op->getResults()) {
   346        for (auto *user : result->getUsers()) {
   347          if (user != opInst) {
   348            localized = false;
   349            break;
   350          }
   351        }
   352      }
   353    }
   354    if (localized)
   355      return;
   356  
   357    OpBuilder builder(opInst);
   358    SmallVector<Value *, 4> composedOpOperands(subOperands);
   359    auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size());
   360    fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands);
   361  
   362    // Create an affine.apply for each of the map results.
   363    sliceOps->reserve(composedMap.getNumResults());
   364    for (auto resultExpr : composedMap.getResults()) {
   365      auto singleResMap = builder.getAffineMap(
   366          composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr);
   367      sliceOps->push_back(builder.create<AffineApplyOp>(
   368          opInst->getLoc(), singleResMap, composedOpOperands));
   369    }
   370  
   371    // Construct the new operands that include the results from the composed
   372    // affine apply op above instead of existing ones (subOperands). So, they
   373    // differ from opInst's operands only for those operands in 'subOperands', for
   374    // which they will be replaced by the corresponding one from 'sliceOps'.
   375    SmallVector<Value *, 4> newOperands(opInst->getOperands());
   376    for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
   377      // Replace the subOperands from among the new operands.
   378      unsigned j, f;
   379      for (j = 0, f = subOperands.size(); j < f; j++) {
   380        if (newOperands[i] == subOperands[j])
   381          break;
   382      }
   383      if (j < subOperands.size()) {
   384        newOperands[i] = (*sliceOps)[j];
   385      }
   386    }
   387    for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
   388      opInst->setOperand(idx, newOperands[idx]);
   389    }
   390  }