github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LowerVectorTransfers.cpp (about)

     1  //===- LowerVectorTransfers.cpp - LowerVectorTransfers Pass Impl ----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements target-dependent lowering of vector transfer operations.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include <type_traits>
    23  
    24  #include "mlir/Analysis/AffineAnalysis.h"
    25  #include "mlir/Analysis/NestedMatcher.h"
    26  #include "mlir/Analysis/Utils.h"
    27  #include "mlir/Analysis/VectorAnalysis.h"
    28  #include "mlir/Dialect/StandardOps/Ops.h"
    29  #include "mlir/Dialect/VectorOps/VectorOps.h"
    30  #include "mlir/EDSC/Builders.h"
    31  #include "mlir/EDSC/Helpers.h"
    32  #include "mlir/IR/AffineExpr.h"
    33  #include "mlir/IR/AffineMap.h"
    34  #include "mlir/IR/Attributes.h"
    35  #include "mlir/IR/Builders.h"
    36  #include "mlir/IR/Location.h"
    37  #include "mlir/IR/Matchers.h"
    38  #include "mlir/IR/OperationSupport.h"
    39  #include "mlir/IR/PatternMatch.h"
    40  #include "mlir/IR/Types.h"
    41  #include "mlir/Pass/Pass.h"
    42  #include "mlir/Support/Functional.h"
    43  #include "mlir/Transforms/Passes.h"
    44  
    45  /// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
    46  /// proper abstraction for the hardware.
    47  ///
    48  /// For now, we only emit a simple loop nest that performs clipped pointwise
    49  /// copies from a remote to a locally allocated memory.
    50  ///
    51  /// Consider the case:
    52  ///
    53  /// ```mlir {.mlir}
    54  ///    // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into
    55  ///    // vector<32x256xf32> and pad with %f0 to handle the boundary case:
    56  ///    %f0 = constant 0.0f : f32
    57  ///    affine.for %i0 = 0 to %0 {
    58  ///      affine.for %i1 = 0 to %1 step 256 {
    59  ///        affine.for %i2 = 0 to %2 step 32 {
    60  ///          %v = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
    61  ///               {permutation_map: (d0, d1, d2) -> (d2, d1)} :
    62  ///               memref<?x?x?xf32>, vector<32x256xf32>
    63  ///    }}}
    64  /// ```
    65  ///
    66  /// The rewriters construct loop and indices that access MemRef A in a pattern
    67  /// resembling the following (while guaranteeing an always full-tile
    68  /// abstraction):
    69  ///
    70  /// ```mlir {.mlir}
    71  ///    affine.for %d2 = 0 to 256 {
    72  ///      affine.for %d1 = 0 to 32 {
    73  ///        %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
    74  ///        %tmp[%d2, %d1] = %s
    75  ///      }
    76  ///    }
    77  /// ```
    78  ///
    79  /// In the current state, only a clipping transfer is implemented by `clip`,
    80  /// which creates individual indexing expressions of the form:
    81  ///
    82  /// ```mlir-dsc
    83  ///    SELECT(i + ii < zero, zero, SELECT(i + ii < N, i + ii, N - one))
    84  /// ```
    85  
    86  using namespace mlir;
    87  using vector::VectorTransferReadOp;
    88  using vector::VectorTransferWriteOp;
    89  
    90  #define DEBUG_TYPE "affine-lower-vector-transfers"
    91  
    92  namespace {
    93  
    94  /// Lowers VectorTransferOp into a combination of:
    95  ///   1. local memory allocation;
    96  ///   2. perfect loop nest over:
    97  ///      a. scalar load/stores from local buffers (viewed as a scalar memref);
    98  ///      a. scalar store/load to original memref (with clipping).
    99  ///   3. vector_load/store
   100  ///   4. local memory deallocation.
   101  /// Minor variations occur depending on whether a VectorTransferReadOp or
   102  /// a VectorTransferWriteOp is rewritten.
   103  template <typename VectorTransferOpTy>
   104  struct VectorTransferRewriter : public RewritePattern {
   105    explicit VectorTransferRewriter(MLIRContext *context)
   106        : RewritePattern(VectorTransferOpTy::getOperationName(), 1, context) {}
   107  
   108    /// Used for staging the transfer in a local scalar buffer.
   109    MemRefType tmpMemRefType(VectorTransferOpTy transfer) const {
   110      auto vectorType = transfer.getVectorType();
   111      return MemRefType::get(vectorType.getShape(), vectorType.getElementType(),
   112                             {}, 0);
   113    }
   114  
   115    /// View of tmpMemRefType as one vector, used in vector load/store to tmp
   116    /// buffer.
   117    MemRefType vectorMemRefType(VectorTransferOpTy transfer) const {
   118      return MemRefType::get({1}, transfer.getVectorType(), {}, 0);
   119    }
   120  
   121    /// Performs the rewrite.
   122    PatternMatchResult matchAndRewrite(Operation *op,
   123                                       PatternRewriter &rewriter) const override;
   124  };
   125  
   126  /// Analyzes the `transfer` to find an access dimension along the fastest remote
   127  /// MemRef dimension. If such a dimension with coalescing properties is found,
   128  /// `pivs` and `vectorView` are swapped so that the invocation of
   129  /// LoopNestBuilder captures it in the innermost loop.
   130  template <typename VectorTransferOpTy>
   131  void coalesceCopy(VectorTransferOpTy transfer,
   132                    SmallVectorImpl<edsc::ValueHandle *> *pivs,
   133                    edsc::VectorView *vectorView) {
   134    // rank of the remote memory access, coalescing behavior occurs on the
   135    // innermost memory dimension.
   136    auto remoteRank = transfer.getMemRefType().getRank();
   137    // Iterate over the results expressions of the permutation map to determine
   138    // the loop order for creating pointwise copies between remote and local
   139    // memories.
   140    int coalescedIdx = -1;
   141    auto exprs = transfer.getPermutationMap().getResults();
   142    for (auto en : llvm::enumerate(exprs)) {
   143      auto dim = en.value().template dyn_cast<AffineDimExpr>();
   144      if (!dim) {
   145        continue;
   146      }
   147      auto memRefDim = dim.getPosition();
   148      if (memRefDim == remoteRank - 1) {
   149        // memRefDim has coalescing properties, it should be swapped in the last
   150        // position.
   151        assert(coalescedIdx == -1 && "Unexpected > 1 coalesced indices");
   152        coalescedIdx = en.index();
   153      }
   154    }
   155    if (coalescedIdx >= 0) {
   156      std::swap(pivs->back(), (*pivs)[coalescedIdx]);
   157      vectorView->swapRanges(pivs->size() - 1, coalescedIdx);
   158    }
   159  }
   160  
   161  /// Emits remote memory accesses that are clipped to the boundaries of the
   162  /// MemRef.
   163  template <typename VectorTransferOpTy>
   164  llvm::SmallVector<edsc::ValueHandle, 8> clip(VectorTransferOpTy transfer,
   165                                               edsc::MemRefView &view,
   166                                               ArrayRef<edsc::IndexHandle> ivs) {
   167    using namespace mlir::edsc;
   168    using namespace edsc::op;
   169    using edsc::intrinsics::select;
   170  
   171    IndexHandle zero(index_t(0)), one(index_t(1));
   172    llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.getIndices());
   173    llvm::SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs(
   174        memRefAccess.size(), edsc::IndexHandle());
   175  
   176    // Indices accessing to remote memory are clipped and their expressions are
   177    // returned in clippedScalarAccessExprs.
   178    for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size();
   179         ++memRefDim) {
   180      // Linear search on a small number of entries.
   181      int loopIndex = -1;
   182      auto exprs = transfer.getPermutationMap().getResults();
   183      for (auto en : llvm::enumerate(exprs)) {
   184        auto expr = en.value();
   185        auto dim = expr.template dyn_cast<AffineDimExpr>();
   186        // Sanity check.
   187        assert(
   188            (dim || expr.template cast<AffineConstantExpr>().getValue() == 0) &&
   189            "Expected dim or 0 in permutationMap");
   190        if (dim && memRefDim == dim.getPosition()) {
   191          loopIndex = en.index();
   192          break;
   193        }
   194      }
   195  
   196      // We cannot distinguish atm between unrolled dimensions that implement
   197      // the "always full" tile abstraction and need clipping from the other
   198      // ones. So we conservatively clip everything.
   199      auto N = view.ub(memRefDim);
   200      auto i = memRefAccess[memRefDim];
   201      if (loopIndex < 0) {
   202        auto N_minus_1 = N - one;
   203        auto select_1 = select(i < N, i, N_minus_1);
   204        clippedScalarAccessExprs[memRefDim] = select(i < zero, zero, select_1);
   205      } else {
   206        auto ii = ivs[loopIndex];
   207        auto i_plus_ii = i + ii;
   208        auto N_minus_1 = N - one;
   209        auto select_1 = select(i_plus_ii < N, i_plus_ii, N_minus_1);
   210        clippedScalarAccessExprs[memRefDim] =
   211            select(i_plus_ii < zero, zero, select_1);
   212      }
   213    }
   214  
   215    return clippedScalarAccessExprs;
   216  }
   217  
   218  /// Lowers VectorTransferReadOp into a combination of:
   219  ///   1. local memory allocation;
   220  ///   2. perfect loop nest over:
   221  ///      a. scalar load from local buffers (viewed as a scalar memref);
   222  ///      a. scalar store to original memref (with clipping).
   223  ///   3. vector_load from local buffer (viewed as a memref<1 x vector>);
   224  ///   4. local memory deallocation.
   225  ///
   226  /// Lowers the data transfer part of a VectorTransferReadOp while ensuring no
   227  /// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
   228  /// clipping. This means that a given value in memory can be read multiple
   229  /// times and concurrently.
   230  ///
   231  /// Important notes about clipping and "full-tiles only" abstraction:
   232  /// =================================================================
   233  /// When using clipping for dealing with boundary conditions, the same edge
   234  /// value will appear multiple times (a.k.a edge padding). This is fine if the
   235  /// subsequent vector operations are all data-parallel but **is generally
   236  /// incorrect** in the presence of reductions or extract operations.
   237  ///
   238  /// More generally, clipping is a scalar abstraction that is expected to work
   239  /// fine as a baseline for CPUs and GPUs but not for vector_load and DMAs.
   240  /// To deal with real vector_load and DMAs, a "padded allocation + view"
   241  /// abstraction with the ability to read out-of-memref-bounds (but still within
   242  /// the allocated region) is necessary.
   243  ///
   244  /// Whether using scalar loops or vector_load/DMAs to perform the transfer,
   245  /// junk values will be materialized in the vectors and generally need to be
   246  /// filtered out and replaced by the "neutral element". This neutral element is
   247  /// op-dependent so, in the future, we expect to create a vector filter and
   248  /// apply it to a splatted constant vector with the proper neutral element at
   249  /// each ssa-use. This filtering is not necessary for pure data-parallel
   250  /// operations.
   251  ///
   252  /// In the case of vector_store/DMAs, Read-Modify-Write will be required, which
   253  /// also have concurrency implications. Note that by using clipped scalar stores
   254  /// in the presence of data-parallel only operations, we generate code that
   255  /// writes the same value multiple time on the edge locations.
   256  ///
   257  /// TODO(ntv): implement alternatives to clipping.
   258  /// TODO(ntv): support non-data-parallel operations.
   259  
   260  /// Performs the rewrite.
   261  template <>
   262  PatternMatchResult
   263  VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
   264      Operation *op, PatternRewriter &rewriter) const {
   265    using namespace mlir::edsc;
   266    using namespace mlir::edsc::op;
   267    using namespace mlir::edsc::intrinsics;
   268    using IndexedValue =
   269        TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
   270  
   271    VectorTransferReadOp transfer = cast<VectorTransferReadOp>(op);
   272  
   273    // 1. Setup all the captures.
   274    ScopedContext scope(rewriter, transfer.getLoc());
   275    IndexedValue remote(transfer.getMemRef());
   276    MemRefView view(transfer.getMemRef());
   277    VectorView vectorView(transfer.getVector());
   278    SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
   279    SmallVector<ValueHandle *, 8> pivs =
   280        makeIndexHandlePointers(MutableArrayRef<IndexHandle>(ivs));
   281    coalesceCopy(transfer, &pivs, &vectorView);
   282  
   283    auto lbs = vectorView.getLbs();
   284    auto ubs = vectorView.getUbs();
   285    auto steps = vectorView.getSteps();
   286  
   287    // 2. Emit alloc-copy-load-dealloc.
   288    ValueHandle tmp = alloc(tmpMemRefType(transfer));
   289    IndexedValue local(tmp);
   290    ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
   291    LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
   292      // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
   293      local(ivs) = remote(clip(transfer, view, ivs));
   294    });
   295    ValueHandle vectorValue = std_load(vec, {constant_index(0)});
   296    (dealloc(tmp)); // vexing parse
   297  
   298    // 3. Propagate.
   299    rewriter.replaceOp(op, vectorValue.getValue());
   300    return matchSuccess();
   301  }
   302  
   303  /// Lowers VectorTransferWriteOp into a combination of:
   304  ///   1. local memory allocation;
   305  ///   2. vector_store to local buffer (viewed as a memref<1 x vector>);
   306  ///   3. perfect loop nest over:
   307  ///      a. scalar load from local buffers (viewed as a scalar memref);
   308  ///      a. scalar store to original memref (with clipping).
   309  ///   4. local memory deallocation.
   310  ///
   311  /// More specifically, lowers the data transfer part while ensuring no
   312  /// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
   313  /// clipping. This means that a given value in memory can be written to multiple
   314  /// times and concurrently.
   315  ///
   316  /// See `Important notes about clipping and full-tiles only abstraction` in the
   317  /// description of `readClipped` above.
   318  ///
   319  /// TODO(ntv): implement alternatives to clipping.
   320  /// TODO(ntv): support non-data-parallel operations.
   321  template <>
   322  PatternMatchResult
   323  VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
   324      Operation *op, PatternRewriter &rewriter) const {
   325    using namespace mlir::edsc;
   326    using namespace mlir::edsc::op;
   327    using namespace mlir::edsc::intrinsics;
   328    using IndexedValue =
   329        TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
   330  
   331    VectorTransferWriteOp transfer = cast<VectorTransferWriteOp>(op);
   332  
   333    // 1. Setup all the captures.
   334    ScopedContext scope(rewriter, transfer.getLoc());
   335    IndexedValue remote(transfer.getMemRef());
   336    MemRefView view(transfer.getMemRef());
   337    ValueHandle vectorValue(transfer.getVector());
   338    VectorView vectorView(transfer.getVector());
   339    SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
   340    SmallVector<ValueHandle *, 8> pivs = makeIndexHandlePointers(ivs);
   341    coalesceCopy(transfer, &pivs, &vectorView);
   342  
   343    auto lbs = vectorView.getLbs();
   344    auto ubs = vectorView.getUbs();
   345    auto steps = vectorView.getSteps();
   346  
   347    // 2. Emit alloc-store-copy-dealloc.
   348    ValueHandle tmp = alloc(tmpMemRefType(transfer));
   349    IndexedValue local(tmp);
   350    ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
   351    std_store(vectorValue, vec, {constant_index(0)});
   352    LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
   353      // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
   354      remote(clip(transfer, view, ivs)) = local(ivs);
   355    });
   356    (dealloc(tmp)); // vexing parse...
   357  
   358    rewriter.replaceOp(op, llvm::None);
   359    return matchSuccess();
   360  }
   361  
   362  struct LowerVectorTransfersPass
   363      : public FunctionPass<LowerVectorTransfersPass> {
   364    void runOnFunction() {
   365      OwningRewritePatternList patterns;
   366      auto *context = &getContext();
   367      patterns.insert<VectorTransferRewriter<vector::VectorTransferReadOp>,
   368                      VectorTransferRewriter<vector::VectorTransferWriteOp>>(
   369          context);
   370      applyPatternsGreedily(getFunction(), patterns);
   371    }
   372  };
   373  
   374  } // end anonymous namespace
   375  
   376  std::unique_ptr<FunctionPassBase> mlir::createLowerVectorTransfersPass() {
   377    return std::make_unique<LowerVectorTransfersPass>();
   378  }
   379  
   380  static PassRegistration<LowerVectorTransfersPass>
   381      pass("affine-lower-vector-transfers",
   382           "Materializes vector transfer ops to a "
   383           "proper abstraction for the hardware");