github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/MaterializeVectors.cpp (about)

     1  //===- MaterializeVectors.cpp - MaterializeVectors Pass Impl --------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements target-dependent materialization of super-vectors to
    19  // vectors of the proper size for the hardware.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Analysis/AffineAnalysis.h"
    24  #include "mlir/Analysis/Dominance.h"
    25  #include "mlir/Analysis/LoopAnalysis.h"
    26  #include "mlir/Analysis/NestedMatcher.h"
    27  #include "mlir/Analysis/SliceAnalysis.h"
    28  #include "mlir/Analysis/Utils.h"
    29  #include "mlir/Analysis/VectorAnalysis.h"
    30  #include "mlir/Dialect/AffineOps/AffineOps.h"
    31  #include "mlir/Dialect/StandardOps/Ops.h"
    32  #include "mlir/Dialect/VectorOps/VectorOps.h"
    33  #include "mlir/IR/AffineExpr.h"
    34  #include "mlir/IR/AffineMap.h"
    35  #include "mlir/IR/Attributes.h"
    36  #include "mlir/IR/Builders.h"
    37  #include "mlir/IR/Location.h"
    38  #include "mlir/IR/OperationSupport.h"
    39  #include "mlir/IR/Types.h"
    40  #include "mlir/Pass/Pass.h"
    41  #include "mlir/Support/Functional.h"
    42  #include "mlir/Support/LLVM.h"
    43  #include "mlir/Transforms/Passes.h"
    44  
    45  #include "llvm/Support/CommandLine.h"
    46  #include "llvm/Support/Debug.h"
    47  #include "llvm/Support/raw_ostream.h"
    48  
    49  ///
    50  /// Implements target-dependent materialization of virtual super-vectors to
    51  /// vectors of the proper size for the hardware.
    52  ///
    53  /// While the physical vector size is target-dependent, the pass is written in
    54  /// a target-independent way: the target vector size is specified as a parameter
    55  /// to the pass. This pass is thus a partial lowering that opens the "greybox"
    56  /// that is the super-vector abstraction. In particular, this pass can turn the
    57  /// vector.transfer_read and vector.transfer_write ops in either:
    58  ///   1. a loop nest with either scalar and vector load/store operations; or
    59  ///   2. a loop-nest with DmaStartOp / DmaWaitOp; or
    60  ///   3. a pre-existing blackbox library call that can be written manually or
    61  ///      synthesized using search and superoptimization.
    62  /// An important feature that either of these 3 target lowering abstractions
    63  /// must handle is the handling of "non-effecting" padding with the proper
    64  /// neutral element in order to guarantee that all "partial tiles" are actually
    65  /// "full tiles" in practice.
    66  ///
    67  /// In particular this pass is a MLIR-MLIR rewriting and does not concern itself
    68  /// with target-specific instruction-selection and register allocation. These
    69  /// will happen downstream in LLVM.
    70  ///
    71  /// In this sense, despite performing lowering to a target-dependent size, this
    72  /// pass is still target-agnostic.
    73  ///
    74  /// Implementation details
    75  /// ======================
    76  /// The current decisions made by the super-vectorization pass guarantee that
    77  /// use-def chains do not escape an enclosing vectorized AffineForOp. In other
    78  /// words, this pass operates on a scoped program slice. Furthermore, since we
    79  /// do not vectorize in the presence of conditionals for now, sliced chains are
    80  /// guaranteed not to escape the innermost scope, which has to be either the top
    81  /// Function scope or the innermost loop scope, by construction. As a
    82  /// consequence, the implementation just starts from vector.transfer_write
    83  /// operations and builds the slice scoped the innermost loop enclosing the
    84  /// current vector.transfer_write. These assumptions and the implementation
    85  /// details are subject to revision in the future.
    86  ///
    87  /// Example
    88  /// ========
    89  /// In the following, the single vector.transfer_write op operates on a
    90  /// vector<4x4x4xf32>. Let's assume the HW supports vector<4x4xf32>.
    91  /// Materialization is achieved by instantiating each occurrence of the leading
    92  /// dimension of vector<4x4x4xf32> into a vector<4x4xf32>.
    93  /// The program transformation that implements this instantiation is a
    94  /// multi-loop unroll-and-jam (it can be partial or full depending on the ratio
    95  /// of super-vector shape to HW-vector shape).
    96  ///
    97  /// As a simple case, the following:
    98  ///
    99  /// ```mlir
   100  ///    mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) {
   101  ///      %A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32>
   102  ///      %f1 = constant dense<vector<4x4x4xf32>, 1.000000e+00> :
   103  ///      vector<4x4x4xf32> affine.for %i0 = 0 to %M step 4 {
   104  ///        affine.for %i1 = 0 to %N step 4 {
   105  ///          affine.for %i2 = 0 to %O {
   106  ///            affine.for %i3 = 0 to %P step 4 {
   107  ///              vector.transfer_write %f1, %A[%i0, %i1, %i2, %i3]
   108  ///                {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} :
   109  ///                 vector<4x4x4xf32>, memref<?x?x?x?xf32>
   110  ///      }}}}
   111  ///      return
   112  ///    }
   113  /// ```
   114  ///
   115  /// is instantiated by unroll-and-jam (just unroll in this case) into:
   116  ///
   117  /// ```mlir
   118  ///    mlfunc @materialize(%M : index, %N : index, %O : index, %P : index) {
   119  ///      %A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
   120  ///      %f1 = constant dense<vector<4x4xf32>, 1.000000e+00> : vector<4x4x4xf32>
   121  ///       affine.for %i0 = 0 to %arg0 step 4 {
   122  ///         affine.for %i1 = 0 to %arg1 step 4 {
   123  ///           affine.for %i2 = 0 to %arg2 {
   124  ///             affine.for %i3 = 0 to %arg3 step 4 {
   125  ///               vector.transfer_write f1, %0[%i0, %i1, %i2, %i3]
   126  ///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
   127  ///                 vector<4x4xf32>, memref<?x?x?x?xf32>
   128  ///               %i3p1 = affine.apply (d0) -> (d0 + 1)(%i3)
   129  ///               vector.transfer_write {{.*}}, %0[%i0, %i1, %i2, %i3p1]
   130  ///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
   131  ///                 vector<4x4xf32>, memref<?x?x?x?xf32>
   132  ///               %i3p2 = affine.apply (d0) -> (d0 + 2)(%i3)
   133  ///               vector.transfer_write {{.*}}, %0[%i0, %i1, %i2, %i3p2]
   134  ///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
   135  ///                 vector<4x4xf32>, memref<?x?x?x?xf32>
   136  ///               %i3p3 = affine.apply (d0) -> (d0 + 3)(%i3)
   137  ///               vector.transfer_write {{.*}}, %0[%i0, %i1, %i2, %i3p3]
   138  ///                 {permutation_map: (d0, d1, d2, d3) -> (d1, d0)} :
   139  ///                 vector<4x4xf32>, memref<?x?x?x?xf32>
   140  ///      }}}}
   141  ///      return
   142  ///    }
   143  /// ```
   144  
   145  using llvm::dbgs;
   146  using llvm::SetVector;
   147  
   148  using namespace mlir;
   149  using vector::VectorTransferReadOp;
   150  using vector::VectorTransferWriteOp;
   151  
   152  using functional::makePtrDynCaster;
   153  using functional::map;
   154  
   155  static llvm::cl::list<int>
   156      clVectorSize("vector-size",
   157                   llvm::cl::desc("Specify the HW vector size for vectorization"),
   158                   llvm::cl::ZeroOrMore);
   159  
   160  #define DEBUG_TYPE "materialize-vect"
   161  
   162  namespace {
   163  struct MaterializationState {
   164    /// In practice, the determination of the HW-specific vector type to use when
   165    /// lowering a super-vector type must be based on the elemental type. The
   166    /// elemental type must be retrieved from the super-vector type. In the future
   167    /// information about hardware vector type for a particular elemental type
   168    /// will be part of the contract between MLIR and the backend.
   169    ///
   170    /// For example, 8xf32 has the same size as 16xf16 but the targeted HW itself
   171    /// may exhibit the following property:
   172    /// 1. have a special unit for a 128xf16 datapath;
   173    /// 2. no F16 FPU support on the regular 8xf32/16xf16 vector datapath.
   174    ///
   175    /// For now, we just assume hwVectorSize has the proper information regardless
   176    /// of the type and we assert everything is f32.
   177    /// TODO(ntv): relax the assumptions on admissible element type once a
   178    /// contract exists.
   179    MaterializationState(SmallVector<int64_t, 8> sizes) : hwVectorSize(sizes) {}
   180  
   181    SmallVector<int64_t, 8> hwVectorSize;
   182    VectorType superVectorType;
   183    VectorType hwVectorType;
   184    SmallVector<unsigned, 8> hwVectorInstance;
   185    DenseMap<Value *, Value *> *substitutionsMap;
   186  };
   187  
   188  /// Base state for the vector materialization pass.
   189  /// Command line arguments are preempted by non-empty pass arguments.
   190  struct MaterializeVectorsPass : public FunctionPass<MaterializeVectorsPass> {
   191    MaterializeVectorsPass()
   192        : hwVectorSize(clVectorSize.begin(), clVectorSize.end()) {}
   193    MaterializeVectorsPass(ArrayRef<int64_t> hwVectorSize)
   194        : MaterializeVectorsPass() {
   195      if (!hwVectorSize.empty())
   196        this->hwVectorSize.assign(hwVectorSize.begin(), hwVectorSize.end());
   197    }
   198  
   199    SmallVector<int64_t, 8> hwVectorSize;
   200    void runOnFunction() override;
   201  };
   202  
   203  } // end anonymous namespace
   204  
   205  /// Given a shape with sizes greater than 0 along all dimensions,
   206  /// returns the distance, in number of elements, between a slice in a dimension
   207  /// and the next slice in the same dimension.
   208  ///   e.g. shape[3, 4, 5] -> strides[20, 5, 1]
   209  static SmallVector<unsigned, 8> makeStrides(ArrayRef<unsigned> shape) {
   210    SmallVector<unsigned, 8> tmp;
   211    tmp.reserve(shape.size());
   212    unsigned running = 1;
   213    for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
   214      assert(*rit > 0 && "size must be greater than 0 along all dimensions of "
   215                         "shape");
   216      tmp.push_back(running);
   217      running *= *rit;
   218    }
   219    return SmallVector<unsigned, 8>(tmp.rbegin(), tmp.rend());
   220  }
   221  
   222  /// Given a shape with sizes greater than 0 along all dimensions, returns the
   223  /// delinearized components of linearIndex along shape.
   224  static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
   225                                              ArrayRef<unsigned> shape) {
   226    SmallVector<unsigned, 8> res;
   227    res.reserve(shape.size());
   228    auto strides = makeStrides(shape);
   229    for (unsigned idx = 0; idx < strides.size(); ++idx) {
   230      assert(strides[idx] > 0);
   231      auto val = linearIndex / strides[idx];
   232      res.push_back(val);
   233      assert(val < shape[idx] && "delinearization is out of bounds");
   234      linearIndex %= strides[idx];
   235    }
   236    // Sanity check.
   237    assert(linearIndex == 0 && "linear index constructed from shape must "
   238                               "have 0 remainder after delinearization");
   239    return res;
   240  }
   241  
   242  static Operation *instantiate(OpBuilder b, Operation *opInst,
   243                                VectorType hwVectorType,
   244                                DenseMap<Value *, Value *> *substitutionsMap);
   245  
   246  /// Not all Values belong to a program slice scoped within the immediately
   247  /// enclosing loop.
   248  /// One simple example is constants defined outside the innermost loop scope.
   249  /// For such cases the substitutionsMap has no entry and we allow an additional
   250  /// insertion.
   251  /// For now, this is limited to ConstantOp because we do not vectorize loop
   252  /// indices and will need to be extended in the future.
   253  ///
   254  /// If substitution fails, returns nullptr.
   255  static Value *substitute(Value *v, VectorType hwVectorType,
   256                           DenseMap<Value *, Value *> *substitutionsMap) {
   257    auto it = substitutionsMap->find(v);
   258    if (it == substitutionsMap->end()) {
   259      auto *opInst = v->getDefiningOp();
   260      if (isa<ConstantOp>(opInst)) {
   261        OpBuilder b(opInst);
   262        auto *op = instantiate(b, opInst, hwVectorType, substitutionsMap);
   263        auto res = substitutionsMap->insert(std::make_pair(v, op->getResult(0)));
   264        assert(res.second && "Insertion failed");
   265        return res.first->second;
   266      }
   267      v->getDefiningOp()->emitError("Missing substitution");
   268      return nullptr;
   269    }
   270    return it->second;
   271  }
   272  
   273  /// Returns a list of single result AffineApplyOps that reindex the
   274  /// `memRefIndices` by the multi-dimensional `hwVectorInstance`. This is used by
   275  /// the function that materializes a vector.transfer operation to use hardware
   276  /// vector types instead of super-vector types.
   277  ///
   278  /// The general problem this function solves is as follows:
   279  /// Assume a vector.transfer operation at the super-vector granularity that has
   280  /// `l` enclosing loops (AffineForOp). Assume the vector transfer operation
   281  /// operates on a MemRef of rank `r`, a super-vector of rank `s` and a hardware
   282  /// vector of rank `h`. For the purpose of illustration assume l==4, r==3, s==2,
   283  /// h==1 and that the super-vector is vector<3x32xf32> and the hardware vector
   284  /// is vector<8xf32>. Assume the following MLIR snippet after
   285  /// super-vectorization has been applied:
   286  ///
   287  /// ```mlir
   288  /// affine.for %i0 = 0 to %M {
   289  ///   affine.for %i1 = 0 to %N step 3 {
   290  ///     affine.for %i2 = 0 to %O {
   291  ///       affine.for %i3 = 0 to %P step 32 {
   292  ///         %r = vector.transfer_read(%A, map0(%i..), map1(%i..), map2(%i..)) :
   293  ///              vector<3x32xf32>, memref<?x?x?xf32>
   294  ///         ...
   295  /// }}}}
   296  /// ```
   297  ///
   298  /// where map denotes an AffineMap operating on enclosing loops with properties
   299  /// compatible for vectorization (i.e. some contiguity left unspecified here).
   300  /// Note that the vectorized loops are %i1 and %i3.
   301  /// This function translates the vector.transfer_read operation to multiple
   302  /// instances of vector.transfer_read that operate on vector<8x32>.
   303  ///
   304  /// Without loss of generality, we assume hwVectorInstance is: {2, 1}.
   305  /// The only constraints on hwVectorInstance is they belong to:
   306  ///   [0, 2] x [0, 3], which is the span of ratio of super-vector shape to
   307  /// hardware vector shape in our example.
   308  ///
   309  /// This function instantiates the iteration <2, 1> of vector.transfer_read
   310  /// into the set of operations in pseudo-MLIR:
   311  ///
   312  /// ```mlir
   313  ///   #map2 = (d0, d1, d2, d3) -> (d0, d1 + 2, d2, d3 + 1 * 8)
   314  ///   #map3 = #map o #map2 // where o denotes composition
   315  ///   aff0 = affine.apply #map3.0(%i..)
   316  ///   aff1 = affine.apply #map3.1(%i..)
   317  ///   aff2 = affine.apply #map3.2(%i..)
   318  ///   %r = vector.transfer_read(%A, %aff0, %aff1, %aff2):
   319  //         vector<3x32xf32>, memref<?x?x?xf32>
   320  /// ```
   321  ///
   322  /// Practical considerations
   323  /// ========================
   324  /// For now, `map` is assumed to be the identity map and the indices are
   325  /// specified just as vector.transfer_read%A[%i0, %i1, %i2, %i3]. This will be
   326  /// extended in the future once we have a proper Op for vector transfers.
   327  /// Additionally, the example above is specified in pseudo-MLIR form; once we
   328  /// have proper support for generic maps we can generate the code and show
   329  /// actual MLIR.
   330  ///
   331  /// TODO(ntv): support a concrete AffineMap and compose with it.
   332  /// TODO(ntv): these implementation details should be captured in a
   333  /// vectorization trait at the op level directly.
   334  static SmallVector<mlir::Value *, 8>
   335  reindexAffineIndices(OpBuilder b, VectorType hwVectorType,
   336                       ArrayRef<unsigned> hwVectorInstance,
   337                       ArrayRef<Value *> memrefIndices) {
   338    auto vectorShape = hwVectorType.getShape();
   339    assert(hwVectorInstance.size() >= vectorShape.size());
   340  
   341    unsigned numIndices = memrefIndices.size();
   342    auto numMemRefIndices = numIndices - hwVectorInstance.size();
   343    auto numVectorIndices = hwVectorInstance.size() - vectorShape.size();
   344  
   345    SmallVector<AffineExpr, 8> affineExprs;
   346    // TODO(ntv): support a concrete map and composition.
   347    unsigned i = 0;
   348    // The first numMemRefIndices correspond to AffineForOp that have not been
   349    // vectorized, the transformation is the identity on those.
   350    for (i = 0; i < numMemRefIndices; ++i) {
   351      auto d_i = b.getAffineDimExpr(i);
   352      affineExprs.push_back(d_i);
   353    }
   354    // The next numVectorIndices correspond to super-vector dimensions that
   355    // do not have a hardware vector dimension counterpart. For those we only
   356    // need to increment the index by the corresponding hwVectorInstance.
   357    for (i = numMemRefIndices; i < numMemRefIndices + numVectorIndices; ++i) {
   358      auto d_i = b.getAffineDimExpr(i);
   359      auto offset = hwVectorInstance[i - numMemRefIndices];
   360      affineExprs.push_back(d_i + offset);
   361    }
   362    // The remaining indices correspond to super-vector dimensions that
   363    // have a hardware vector dimension counterpart. For those we to increment the
   364    // index by "hwVectorInstance" multiples of the corresponding hardware
   365    // vector size.
   366    for (; i < numIndices; ++i) {
   367      auto d_i = b.getAffineDimExpr(i);
   368      auto offset = hwVectorInstance[i - numMemRefIndices];
   369      auto stride = vectorShape[i - numMemRefIndices - numVectorIndices];
   370      affineExprs.push_back(d_i + offset * stride);
   371    }
   372  
   373    // Create a bunch of single result AffineApplyOp.
   374    SmallVector<mlir::Value *, 8> res;
   375    res.reserve(affineExprs.size());
   376    for (auto expr : affineExprs) {
   377      auto map = AffineMap::get(numIndices, 0, expr);
   378      res.push_back(makeComposedAffineApply(b, b.getInsertionPoint()->getLoc(),
   379                                            map, memrefIndices));
   380    }
   381    return res;
   382  }
   383  
   384  /// Returns attributes with the following substitutions applied:
   385  ///   - constant splat is replaced by constant splat of `hwVectorType`.
   386  /// TODO(ntv): add more substitutions on a per-need basis.
   387  static SmallVector<NamedAttribute, 1>
   388  materializeAttributes(Operation *opInst, VectorType hwVectorType) {
   389    SmallVector<NamedAttribute, 1> res;
   390    for (auto a : opInst->getAttrs()) {
   391      if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
   392        auto attr = SplatElementsAttr::get(hwVectorType, splat.getSplatValue());
   393        res.push_back(NamedAttribute(a.first, attr));
   394      } else {
   395        res.push_back(a);
   396      }
   397    }
   398    return res;
   399  }
   400  
   401  /// Creates an instantiated version of `opInst`.
   402  /// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
   403  /// affine reindexing. Just substitute their Value operands and be done. For
   404  /// this case the actual instance is irrelevant. Just use the values in
   405  /// substitutionsMap.
   406  ///
   407  /// If the underlying substitution fails, this fails too and returns nullptr.
   408  static Operation *instantiate(OpBuilder b, Operation *opInst,
   409                                VectorType hwVectorType,
   410                                DenseMap<Value *, Value *> *substitutionsMap) {
   411    assert(!isa<VectorTransferReadOp>(opInst) &&
   412           "Should call the function specialized for VectorTransferReadOp");
   413    assert(!isa<VectorTransferWriteOp>(opInst) &&
   414           "Should call the function specialized for VectorTransferWriteOp");
   415    if (opInst->getNumRegions() != 0)
   416      return nullptr;
   417  
   418    bool fail = false;
   419    auto operands = map(
   420        [hwVectorType, substitutionsMap, &fail](Value *v) -> Value * {
   421          auto *res =
   422              fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
   423          fail |= !res;
   424          return res;
   425        },
   426        opInst->getOperands());
   427    if (fail)
   428      return nullptr;
   429  
   430    auto attrs = materializeAttributes(opInst, hwVectorType);
   431  
   432    OperationState state(opInst->getLoc(), opInst->getName().getStringRef(),
   433                         operands, {hwVectorType}, attrs);
   434    return b.createOperation(state);
   435  }
   436  
   437  /// Computes the permutationMap required for a VectorTransferOp from the memref
   438  /// to the `hwVectorType`.
   439  /// This is achieved by returning the projection of the permutationMap along the
   440  /// dimensions of the super-vector type that remain in the hwVectorType.
   441  /// In particular, if a dimension is fully instantiated (i.e. unrolled) then it
   442  /// is projected out in the final result.
   443  template <typename VectorTransferOpTy>
   444  static AffineMap projectedPermutationMap(VectorTransferOpTy transfer,
   445                                           VectorType hwVectorType) {
   446    static_assert(
   447        std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value ||
   448            std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value,
   449        "Must be called on a VectorTransferOp");
   450    auto superVectorType = transfer.getVectorType();
   451    auto optionalRatio = shapeRatio(superVectorType, hwVectorType);
   452    assert(optionalRatio &&
   453           (optionalRatio->size() == superVectorType.getShape().size()) &&
   454           "Shape and ratio not of the same size");
   455    unsigned dim = 0;
   456    SmallVector<AffineExpr, 4> keep;
   457    MLIRContext *context = transfer.getContext();
   458    functional::zipApply(
   459        [&dim, &keep, context](int64_t shape, int64_t ratio) {
   460          assert(shape >= ratio && "shape dim must be greater than ratio dim");
   461          if (shape != ratio) {
   462            // HW vector is not full instantiated along this dim, keep it.
   463            keep.push_back(getAffineDimExpr(dim, context));
   464          }
   465          ++dim;
   466        },
   467        superVectorType.getShape(), *optionalRatio);
   468    auto permutationMap = transfer.getPermutationMap();
   469    LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: "));
   470    if (keep.empty()) {
   471      return permutationMap;
   472    }
   473    auto projectionMap = AffineMap::get(optionalRatio->size(), 0, keep);
   474    LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: "));
   475    return simplifyAffineMap(projectionMap.compose(permutationMap));
   476  }
   477  
   478  /// Creates an instantiated version of `read` for the instance of
   479  /// `hwVectorInstance` when lowering from a super-vector type to
   480  /// `hwVectorType`. `hwVectorInstance` represents one particular instance of
   481  /// `hwVectorType` int the covering of the super-vector type. For a more
   482  /// detailed description of the problem, see the description of
   483  /// reindexAffineIndices.
   484  static Operation *instantiate(OpBuilder b, VectorTransferReadOp read,
   485                                VectorType hwVectorType,
   486                                ArrayRef<unsigned> hwVectorInstance,
   487                                DenseMap<Value *, Value *> *substitutionsMap) {
   488    SmallVector<Value *, 8> indices =
   489        map(makePtrDynCaster<Value>(), read.getIndices());
   490    auto affineIndices =
   491        reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
   492    auto map = projectedPermutationMap(read, hwVectorType);
   493    if (!map) {
   494      return nullptr;
   495    }
   496    auto cloned = b.create<VectorTransferReadOp>(read.getLoc(), hwVectorType,
   497                                                 read.getMemRef(), affineIndices,
   498                                                 map, read.getPaddingValue());
   499    return cloned.getOperation();
   500  }
   501  
   502  /// Creates an instantiated version of `write` for the instance of
   503  /// `hwVectorInstance` when lowering from a super-vector type to
   504  /// `hwVectorType`. `hwVectorInstance` represents one particular instance of
   505  /// `hwVectorType` int the covering of th3e super-vector type. For a more
   506  /// detailed description of the problem, see the description of
   507  /// reindexAffineIndices.
   508  static Operation *instantiate(OpBuilder b, VectorTransferWriteOp write,
   509                                VectorType hwVectorType,
   510                                ArrayRef<unsigned> hwVectorInstance,
   511                                DenseMap<Value *, Value *> *substitutionsMap) {
   512    SmallVector<Value *, 8> indices =
   513        map(makePtrDynCaster<Value>(), write.getIndices());
   514    auto affineIndices =
   515        reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
   516    auto cloned = b.create<VectorTransferWriteOp>(
   517        write.getLoc(),
   518        substitute(write.getVector(), hwVectorType, substitutionsMap),
   519        write.getMemRef(), affineIndices,
   520        projectedPermutationMap(write, hwVectorType));
   521    return cloned.getOperation();
   522  }
   523  
   524  /// Returns `true` if op instance is properly cloned and inserted, false
   525  /// otherwise.
   526  /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of
   527  /// super-vector type to hw vector type.
   528  /// A cloned instance of `op` is formed as follows:
   529  ///   1. vector.transfer_read: the return `superVectorType` is replaced by
   530  ///      `hwVectorType`. Additionally, affine indices are reindexed with
   531  ///      `reindexAffineIndices` using `hwVectorInstance` and vector type
   532  ///      information;
   533  ///   2. vector.transfer_write: the `valueToStore` type is simply substituted.
   534  ///      Since we operate on a topologically sorted slice, a substitution must
   535  ///      have been registered for non-constant ops. Additionally, affine indices
   536  ///      are reindexed in the same way as for vector.transfer_read;
   537  ///   3. constant ops are splats of the super-vector type by construction.
   538  ///      They are cloned to a splat on the hw vector type with the same value;
   539  ///   4. remaining ops are cloned to version of the op that returns a hw vector
   540  ///      type, all operands are substituted according to `substitutions`. Thanks
   541  ///      to the topological order of a slice, the substitution is always
   542  ///      possible.
   543  ///
   544  /// Returns true on failure.
   545  static bool instantiateMaterialization(Operation *op,
   546                                         MaterializationState *state) {
   547    LLVM_DEBUG(dbgs() << "\ninstantiate: " << *op);
   548  
   549    // Create a builder here for unroll-and-jam effects.
   550    OpBuilder b(op);
   551    // AffineApplyOp are ignored: instantiating the proper vector op will take
   552    // care of AffineApplyOps by composing them properly.
   553    if (isa<AffineApplyOp>(op)) {
   554      return false;
   555    }
   556    if (op->getNumRegions() != 0)
   557      return op->emitError("NYI path Op with region"), true;
   558  
   559    if (auto write = dyn_cast<VectorTransferWriteOp>(op)) {
   560      auto *clone = instantiate(b, write, state->hwVectorType,
   561                                state->hwVectorInstance, state->substitutionsMap);
   562      return clone == nullptr;
   563    }
   564    if (auto read = dyn_cast<VectorTransferReadOp>(op)) {
   565      auto *clone = instantiate(b, read, state->hwVectorType,
   566                                state->hwVectorInstance, state->substitutionsMap);
   567      if (!clone) {
   568        return true;
   569      }
   570      state->substitutionsMap->insert(
   571          std::make_pair(read.getResult(), clone->getResult(0)));
   572      return false;
   573    }
   574    // The only op with 0 results reaching this point must, by construction, be
   575    // VectorTransferWriteOps and have been caught above. Ops with >= 2 results
   576    // are not yet supported. So just support 1 result.
   577    if (op->getNumResults() != 1) {
   578      return op->emitError("NYI: ops with != 1 results"), true;
   579    }
   580    if (op->getResult(0)->getType() != state->superVectorType) {
   581      return op->emitError("Op does not return a supervector."), true;
   582    }
   583    auto *clone =
   584        instantiate(b, op, state->hwVectorType, state->substitutionsMap);
   585    if (!clone) {
   586      return true;
   587    }
   588    state->substitutionsMap->insert(
   589        std::make_pair(op->getResult(0), clone->getResult(0)));
   590    return false;
   591  }
   592  
   593  /// Takes a slice and rewrites the operations in it so that occurrences
   594  /// of `superVectorType` are replaced by `hwVectorType`.
   595  ///
   596  /// Implementation
   597  /// ==============
   598  ///   1. computes the shape ratio of super-vector to HW vector shapes. This
   599  ///      gives for each op in the slice, how many instantiations are required
   600  ///      in each dimension;
   601  ///   2. performs the concrete materialization. Note that in a first
   602  ///      implementation we use full unrolling because it pragmatically removes
   603  ///      the need to explicitly materialize an AllocOp. Thanks to the properties
   604  ///      of super-vectors, this unrolling is always possible and simple:
   605  ///      vectorizing to a super-vector abstraction already achieved the
   606  ///      equivalent of loop strip-mining + loop sinking and encoded this in the
   607  ///      vector type.
   608  ///
   609  /// Returns true on failure.
   610  ///
   611  /// TODO(ntv): materialized allocs.
   612  /// TODO(ntv): full loops + materialized allocs.
   613  /// TODO(ntv): partial unrolling + materialized allocs.
   614  static bool emitSlice(MaterializationState *state,
   615                        SetVector<Operation *> *slice) {
   616    auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
   617    assert(ratio.hasValue() &&
   618           "ratio of super-vector to HW-vector shape is not integral");
   619    // The number of integer points in a hyperrectangular region is:
   620    // shape[0] * strides[0].
   621    auto numValueToUnroll = (*ratio)[0] * makeStrides(*ratio)[0];
   622    // Full unrolling to hardware vectors in a first approximation.
   623    for (unsigned idx = 0; idx < numValueToUnroll; ++idx) {
   624      // Fresh RAII instanceIndices and substitutionsMap.
   625      MaterializationState scopedState = *state;
   626      scopedState.hwVectorInstance = delinearize(idx, *ratio);
   627      DenseMap<Value *, Value *> substitutionMap;
   628      scopedState.substitutionsMap = &substitutionMap;
   629      // slice are topologically sorted, we can just clone them in order.
   630      for (auto *op : *slice) {
   631        auto fail = instantiateMaterialization(op, &scopedState);
   632        if (fail) {
   633          op->emitError("Unhandled super-vector materialization failure");
   634          return true;
   635        }
   636      }
   637    }
   638  
   639    LLVM_DEBUG(dbgs() << "\nFunction is now\n");
   640    LLVM_DEBUG((*slice)[0]->getParentOfType<FuncOp>().print(dbgs()));
   641  
   642    // slice are topologically sorted, we can just erase them in reverse
   643    // order. Reverse iterator does not just work simply with an operator*
   644    // dereference.
   645    for (int idx = slice->size() - 1; idx >= 0; --idx) {
   646      LLVM_DEBUG(dbgs() << "\nErase: ");
   647      LLVM_DEBUG((*slice)[idx]->print(dbgs()));
   648      (*slice)[idx]->erase();
   649    }
   650    return false;
   651  }
   652  
   653  /// Materializes super-vector types into concrete hw vector types as follows:
   654  ///   1. start from super-vector terminators (current vector.transfer_write
   655  ///      ops);
   656  ///   2. collect all the operations that can be reached by transitive use-defs
   657  ///      chains;
   658  ///   3. get the superVectorType for this particular terminator and the
   659  ///      corresponding hardware vector type (for now limited to F32)
   660  ///      TODO(ntv): be more general than F32.
   661  ///   4. emit the transitive useDef set to operate on the finer-grain vector
   662  ///      types.
   663  ///
   664  /// Notes
   665  /// =====
   666  /// The `slice` is sorted in topological order by construction.
   667  /// Additionally, this set is limited to operations in the same lexical scope
   668  /// because we currently disallow vectorization of defs that come from another
   669  /// scope.
   670  /// TODO(ntv): please document return value.
   671  static bool materialize(FuncOp f, const SetVector<Operation *> &terminators,
   672                          MaterializationState *state) {
   673    DenseSet<Operation *> seen;
   674    DominanceInfo domInfo(f);
   675    for (auto *term : terminators) {
   676      // Short-circuit test, a given terminator may have been reached by some
   677      // other previous transitive use-def chains.
   678      if (seen.count(term) > 0) {
   679        continue;
   680      }
   681  
   682      auto terminator = cast<VectorTransferWriteOp>(term);
   683      LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term);
   684  
   685      // Get the transitive use-defs starting from terminator, limited to the
   686      // current enclosing scope of the terminator. See the top of the function
   687      // Note for the justification of this restriction.
   688      // TODO(ntv): relax scoping constraints.
   689      auto *enclosingScope = term->getParentOp();
   690      auto keepIfInSameScope = [enclosingScope, &domInfo](Operation *op) {
   691        assert(op && "NULL op");
   692        if (!enclosingScope) {
   693          // by construction, everyone is always under the top scope (null scope).
   694          return true;
   695        }
   696        return domInfo.properlyDominates(enclosingScope, op);
   697      };
   698      SetVector<Operation *> slice =
   699          getSlice(term, keepIfInSameScope, keepIfInSameScope);
   700      assert(!slice.empty());
   701  
   702      // Sanity checks: transitive slice must be completely disjoint from
   703      // what we have seen so far.
   704      LLVM_DEBUG(dbgs() << "\nTransitive use-defs:");
   705      for (auto *ud : slice) {
   706        LLVM_DEBUG(dbgs() << "\nud:" << *ud);
   707        assert(seen.count(ud) == 0 &&
   708               "Transitive use-defs not disjoint from already seen");
   709        seen.insert(ud);
   710      }
   711  
   712      // Emit the current slice.
   713      // Set scoped super-vector and corresponding hw vector types.
   714      state->superVectorType = terminator.getVectorType();
   715      assert((state->superVectorType.getElementType() ==
   716              FloatType::getF32(term->getContext())) &&
   717             "Only f32 supported for now");
   718      state->hwVectorType = VectorType::get(
   719          state->hwVectorSize, state->superVectorType.getElementType());
   720      auto fail = emitSlice(state, &slice);
   721      if (fail) {
   722        return true;
   723      }
   724      LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
   725      LLVM_DEBUG(f.print(dbgs()));
   726    }
   727    return false;
   728  }
   729  
   730  void MaterializeVectorsPass::runOnFunction() {
   731    // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
   732    NestedPatternContext mlContext;
   733  
   734    // TODO(ntv): Check to see if this supports arbitrary top-level code.
   735    FuncOp f = getFunction();
   736    if (f.getBlocks().size() != 1)
   737      return;
   738  
   739    using matcher::Op;
   740    LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
   741    LLVM_DEBUG(f.print(dbgs()));
   742  
   743    MaterializationState state(hwVectorSize);
   744    // Get the hardware vector type.
   745    // TODO(ntv): get elemental type from super-vector type rather than force f32.
   746    auto subVectorType =
   747        VectorType::get(hwVectorSize, FloatType::getF32(&getContext()));
   748  
   749    // Capture terminators; i.e. vector.transfer_write ops involving a strict
   750    // super-vector of subVectorType.
   751    auto filter = [subVectorType](Operation &op) {
   752      if (!isa<VectorTransferWriteOp>(op)) {
   753        return false;
   754      }
   755      return matcher::operatesOnSuperVectorsOf(op, subVectorType);
   756    };
   757    auto pat = Op(filter);
   758    SmallVector<NestedMatch, 8> matches;
   759    pat.match(f, &matches);
   760    SetVector<Operation *> terminators;
   761    for (auto m : matches) {
   762      terminators.insert(m.getMatchedOperation());
   763    }
   764  
   765    if (materialize(f, terminators, &state))
   766      signalPassFailure();
   767  }
   768  
   769  std::unique_ptr<FunctionPassBase>
   770  mlir::createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize) {
   771    return std::make_unique<MaterializeVectorsPass>(vectorSize);
   772  }
   773  
   774  static PassRegistration<MaterializeVectorsPass>
   775      pass("affine-materialize-vectors",
   776           "Materializes super-vectors to vectors of the "
   777           "proper size for the hardware");
   778  
   779  #undef DEBUG_TYPE