github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp (about)

     1  //===- LowerToLoops.cpp - conversion from Linalg library ops to loops------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/AffineOps/AffineOps.h"
    19  #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
    20  #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
    21  #include "mlir/Dialect/Linalg/Passes.h"
    22  #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
    23  #include "mlir/Dialect/Linalg/Utils/Utils.h"
    24  #include "mlir/Dialect/LoopOps/LoopOps.h"
    25  #include "mlir/Dialect/StandardOps/Ops.h"
    26  #include "mlir/EDSC/Helpers.h"
    27  #include "mlir/IR/AffineExpr.h"
    28  #include "mlir/IR/AffineMap.h"
    29  #include "mlir/IR/BlockAndValueMapping.h"
    30  #include "mlir/IR/OpImplementation.h"
    31  #include "mlir/Pass/Pass.h"
    32  #include "mlir/Support/LLVM.h"
    33  #include "mlir/Support/STLExtras.h"
    34  #include "mlir/Transforms/DialectConversion.h"
    35  #include "mlir/Transforms/FoldUtils.h"
    36  
    37  using namespace mlir;
    38  using namespace mlir::edsc;
    39  using namespace mlir::edsc::intrinsics;
    40  using namespace mlir::linalg;
    41  using namespace mlir::linalg::intrinsics;
    42  
    43  using IndexedLinalgValue = TemplatedIndexedValue<linalg_load, linalg_store>;
    44  using edsc::op::operator+;
    45  using edsc::op::operator==;
    46  
    47  static SmallVector<ValueHandle, 8>
    48  foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map,
    49                      ArrayRef<Value *> vals, OperationFolder &folder) {
    50    assert(map.getNumSymbols() == 0);
    51    assert(map.getNumInputs() == vals.size());
    52    SmallVector<ValueHandle, 8> res;
    53    res.reserve(map.getNumResults());
    54    auto dims = map.getNumDims();
    55    for (auto e : map.getResults()) {
    56      auto exprMap = AffineMap::get(dims, 0, e);
    57      SmallVector<Value *, 4> operands(vals.begin(), vals.end());
    58      canonicalizeMapAndOperands(&exprMap, &operands);
    59      res.push_back(affine_apply(folder, exprMap, operands));
    60    }
    61    return res;
    62  }
    63  
    64  static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs,
    65                                            Optional<AffineMap> permutation,
    66                                            OperationFolder &state) {
    67    return permutation ? applyMapToValues(ScopedContext::getBuilder(),
    68                                          ScopedContext::getLocation(),
    69                                          permutation.getValue(), ivs, state)
    70                       : SmallVector<Value *, 4>(ivs.begin(), ivs.end());
    71  }
    72  
    73  // Creates a number of ranges equal to the number of results in `map`.
    74  // The returned ranges correspond to the loop ranges, in the proper order, for
    75  // which new loops will be created.
    76  static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc,
    77                                                AffineMap map,
    78                                                ArrayRef<Value *> allViewSizes,
    79                                                OperationFolder &folder) {
    80    // Apply `map` to get view sizes in loop order.
    81    auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder);
    82    // Create a new range with the applied tile sizes.
    83    ScopedContext scope(b, loc);
    84    SmallVector<Value *, 4> res;
    85    for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) {
    86      res.push_back(range(constant_index(folder, 0), sizes[idx],
    87                          constant_index(folder, 1)));
    88    }
    89    return res;
    90  }
    91  
    92  template <typename LinalgOpType> class LinalgScopedEmitter {};
    93  
    94  template <> class LinalgScopedEmitter<CopyOp> {
    95  public:
    96    static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp,
    97                                         OperationFolder &folder) {
    98      auto nPar = copyOp.getNumParallelLoops();
    99      assert(nPar == allIvs.size());
   100      auto inputIvs =
   101          permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder);
   102      auto outputIvs =
   103          permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder);
   104      SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end());
   105      SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end());
   106      IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0));
   107      // Emit the proper scalar assignment, whether we are dealing with a 0-D or
   108      // an n-D loop nest; with or without permutations.
   109      // clang-format off
   110      nPar > 0 ? O(oivs) = I(iivs) :
   111                 O() = I();
   112      // clang-format on
   113    }
   114  };
   115  
   116  template <> class LinalgScopedEmitter<FillOp> {
   117  public:
   118    static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp,
   119                                         OperationFolder &folder) {
   120      auto nPar = fillOp.getNumParallelLoops();
   121      assert(nPar == allIvs.size());
   122      auto ivs =
   123          SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar);
   124      IndexedLinalgValue O(fillOp.getOutput(0));
   125      // Emit the proper scalar assignment, whether we are dealing with a 0-D or
   126      // an n-D loop nest; with or without permutations.
   127      nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue())
   128               : O() = ValueHandle(fillOp.getValue());
   129    }
   130  };
   131  
   132  template <> class LinalgScopedEmitter<DotOp> {
   133  public:
   134    static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp,
   135                                         OperationFolder &folder) {
   136      assert(allIvs.size() == 1);
   137      IndexHandle r_i(allIvs[0]);
   138      IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)),
   139          C(dotOp.getOutput(0));
   140      // Emit scalar form.
   141      C() = C() + A(r_i) * B(r_i);
   142    }
   143  };
   144  
   145  template <> class LinalgScopedEmitter<MatvecOp> {
   146  public:
   147    static void emitScalarImplementation(ArrayRef<Value *> allIvs,
   148                                         MatvecOp matvecOp,
   149                                         OperationFolder &folder) {
   150      assert(allIvs.size() == 2);
   151      IndexHandle i(allIvs[0]), r_j(allIvs[1]);
   152      IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)),
   153          C(matvecOp.getOutput(0));
   154      // Emit scalar form.
   155      C(i) = C(i) + A(i, r_j) * B(r_j);
   156    }
   157  };
   158  
   159  template <> class LinalgScopedEmitter<MatmulOp> {
   160  public:
   161    static void emitScalarImplementation(ArrayRef<Value *> allIvs,
   162                                         MatmulOp matmulOp,
   163                                         OperationFolder &folder) {
   164      assert(allIvs.size() == 3);
   165      IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]);
   166      IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)),
   167          C(matmulOp.getOutput(0));
   168      // Emit scalar form.
   169      C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j);
   170    }
   171  };
   172  
   173  template <> class LinalgScopedEmitter<ConvOp> {
   174  public:
   175    static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp,
   176                                         OperationFolder &folder) {
   177      auto b = ScopedContext::getBuilder();
   178      auto loc = ScopedContext::getLocation();
   179      auto maps = loopToOperandRangesMaps(convOp);
   180      SmallVector<ValueHandle, 8> fIdx(
   181          foldedAffineApplies(b, loc, maps[0], allIvs, folder));
   182      SmallVector<ValueHandle, 8> imIdx(
   183          foldedAffineApplies(b, loc, maps[1], allIvs, folder));
   184      SmallVector<ValueHandle, 8> oIdx(
   185          foldedAffineApplies(b, loc, maps[2], allIvs, folder));
   186      IndexedLinalgValue F(convOp.filter()), I(convOp.input()),
   187          O(convOp.output());
   188      // Emit scalar form.
   189      O(oIdx) += F(fIdx) * I(imIdx);
   190    }
   191  };
   192  
   193  // Emits the MLIR for the scalar part of the generic op by:
   194  //   1. Emitting linalg_load and linalg_store ops for each input and output
   195  //      view in order. This is achieved by applying the appropriate input or
   196  //      output map to the enclosing induction variables.
   197  //   2. Emitting a call to `op.fun()` that takes as arguments the scalars
   198  //      from point 1. above.
   199  //   3. Emitting linalg_store to store the results of 2. to the output
   200  //      views.
   201  //
   202  // An example output may resemble:
   203  //
   204  // ```
   205  //    loop.for %i = %c0 to %0 step %c1 {
   206  //      loop.for %j = %c0 to %1 step %c1 {
   207  //        loop.for %k = %c0 to %4 step %c1 {
   208  //          %11 = linalg.load %arg0[%i, %j] : !linalg.view<?x?xf32>
   209  //          %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
   210  //          %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
   211  //          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
   212  //          linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
   213  //          linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
   214  //       }
   215  //      }
   216  //    }
   217  // ```
   218  template <> class LinalgScopedEmitter<GenericOp> {
   219  public:
   220    static void emitScalarImplementation(ArrayRef<Value *> allIvs,
   221                                         GenericOp genericOp,
   222                                         OperationFolder &folder) {
   223      auto b = ScopedContext::getBuilder();
   224      auto loc = ScopedContext::getLocation();
   225      using edsc::intrinsics::detail::ValueHandleArray;
   226      unsigned nInputs = genericOp.getNumInputs();
   227      unsigned nOutputs = genericOp.getNumOutputs();
   228      SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
   229  
   230      // 1.a. Emit linalg_load from input views.
   231      for (unsigned i = 0, e = nInputs; i < e; ++i) {
   232        ValueHandleArray indexing(foldedAffineApplies(
   233            b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
   234        indexedValues[i] = linalg_load(genericOp.getInput(i), indexing);
   235      }
   236  
   237      // 1.b. Emit linalg_load from output views.
   238      for (unsigned i = 0, e = nOutputs; i < e; ++i) {
   239        ValueHandleArray indexing(foldedAffineApplies(
   240            b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
   241        indexedValues[nInputs + i] =
   242            linalg_load(genericOp.getOutput(i), indexing);
   243      }
   244  
   245      auto funcOp = genericOp.getFunction();
   246      if (funcOp) {
   247        // 2. Emit call.
   248        Operation *callOp = call(funcOp, indexedValues);
   249        assert(callOp->getNumResults() == genericOp.getNumOutputs());
   250  
   251        // 3. Emit linalg_store.
   252        for (unsigned i = 0, e = nOutputs; i < e; ++i) {
   253          ValueHandleArray indexing(foldedAffineApplies(
   254              b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
   255          linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
   256        }
   257      } else {
   258        // TODO(ntv): When a region inliner exists, use it.
   259        // 2. Inline region, currently only works for a single basic block.
   260        BlockAndValueMapping map;
   261        auto &block = genericOp.region().front();
   262        for (auto it : llvm::zip(block.getArguments(), indexedValues))
   263          map.map(std::get<0>(it), std::get<1>(it));
   264        for (auto &op : block) {
   265          // Skip terminator.
   266          if (&op == &block.back())
   267            continue;
   268          assert(op.getNumRegions() == 0);
   269          auto *newOp = b.clone(op, map);
   270          for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
   271            map.map(std::get<0>(it), std::get<1>(it));
   272        }
   273  
   274        // 3. Emit linalg_store.
   275        auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
   276        assert(yieldOp->getNumOperands() == nOutputs);
   277        for (unsigned i = 0, e = nOutputs; i < e; ++i) {
   278          ValueHandleArray indexing(foldedAffineApplies(
   279              b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
   280          linalg_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
   281                       indexing);
   282        }
   283      }
   284    }
   285  };
   286  
   287  template <typename ConcreteOp>
   288  class LinalgRewritePattern : public RewritePattern {
   289  public:
   290    explicit LinalgRewritePattern(MLIRContext *context)
   291        : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) {
   292    }
   293  
   294    PatternMatchResult matchAndRewrite(Operation *op,
   295                                       PatternRewriter &rewriter) const override {
   296      OpBuilder b(op);
   297      ScopedContext scope(b, op->getLoc());
   298  
   299      // The flattened loopToOperandRangesMaps is expected to be an invertible
   300      // permutation map (which is asserted in the inverse calculation).
   301      auto linalgOp = cast<ConcreteOp>(op);
   302      auto invertedMap =
   303          inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp)));
   304      if (!invertedMap) {
   305        LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation({}, linalgOp,
   306                                                                  folder);
   307        rewriter.replaceOp(op, {});
   308        return matchSuccess();
   309      }
   310  
   311      auto nPar = linalgOp.getNumParallelLoops();
   312      auto nRed = linalgOp.getNumReductionLoops();
   313      auto nWin = linalgOp.getNumWindowLoops();
   314      SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin);
   315      SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs);
   316      auto pivs = MutableArrayRef<ValueHandle *>(allPIvs).take_front(nPar);
   317      auto rivs = MutableArrayRef<ValueHandle *>(allPIvs)
   318                      .take_front(nPar + nRed)
   319                      .take_back(nRed);
   320      auto wivs = MutableArrayRef<ValueHandle *>(allPIvs).take_back(nWin);
   321  
   322      auto loopRanges =
   323          emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
   324                         getViewSizes(linalgOp), folder);
   325      assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size());
   326  
   327      // clang-format off
   328      ArrayRef<Value *> ranges(loopRanges);
   329      LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] {
   330        LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] {
   331          LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))(
   332            [&linalgOp, &allIvs, this] {
   333              auto allIvValues = extractValues(allIvs);
   334              LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation(
   335                  allIvValues, linalgOp, folder);
   336          });
   337        });
   338      });
   339      // clang-format on
   340      rewriter.replaceOp(op, {});
   341      return matchSuccess();
   342    }
   343  
   344    mutable OperationFolder folder;
   345  };
   346  
   347  // Helper classes for type list expansion.
   348  template <typename... LinalgOps> class ConversionList;
   349  
   350  template <> class ConversionList<> {
   351  public:
   352    static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
   353  };
   354  
   355  template <typename ConcreteOp, typename... LinalgOps>
   356  class ConversionList<ConcreteOp, LinalgOps...> {
   357  public:
   358    static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
   359      patterns.insert<LinalgRewritePattern<ConcreteOp>>(ctx);
   360      ConversionList<LinalgOps...>::build(patterns, ctx);
   361    }
   362  };
   363  
   364  /// Populate the given list with patterns that convert from Linalg to LLVM.
   365  static void
   366  populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns,
   367                                      MLIRContext *ctx) {
   368    ConversionList<
   369  #define GET_OP_LIST
   370  #include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc"
   371        >::build(patterns, ctx);
   372  }
   373  
   374  namespace {
   375  struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> {
   376    void runOnFunction();
   377  };
   378  } // namespace
   379  
   380  void LowerLinalgToLoopsPass::runOnFunction() {
   381    OwningRewritePatternList patterns;
   382    populateLinalgToLoopRewritePatterns(patterns, &getContext());
   383  
   384    ConversionTarget target(getContext());
   385    target.addLegalDialect<AffineOpsDialect>();
   386    target.addLegalDialect<loop::LoopOpsDialect>();
   387    target.addLegalDialect<StandardOpsDialect>();
   388    if (failed(applyPartialConversion(getFunction(), target, patterns))) {
   389      signalPassFailure();
   390    }
   391  }
   392  
   393  std::unique_ptr<FunctionPassBase> mlir::linalg::createLowerLinalgToLoopsPass() {
   394    return std::make_unique<LowerLinalgToLoopsPass>();
   395  }
   396  
   397  static PassRegistration<LowerLinalgToLoopsPass>
   398      pass("linalg-lower-to-loops",
   399           "Lower the operations from the linalg dialect into loops");