github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp (about) 1 //===- LowerToLoops.cpp - conversion from Linalg library ops to loops------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/AffineOps/AffineOps.h" 19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 20 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 21 #include "mlir/Dialect/Linalg/Passes.h" 22 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" 23 #include "mlir/Dialect/Linalg/Utils/Utils.h" 24 #include "mlir/Dialect/LoopOps/LoopOps.h" 25 #include "mlir/Dialect/StandardOps/Ops.h" 26 #include "mlir/EDSC/Helpers.h" 27 #include "mlir/IR/AffineExpr.h" 28 #include "mlir/IR/AffineMap.h" 29 #include "mlir/IR/BlockAndValueMapping.h" 30 #include "mlir/IR/OpImplementation.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Support/LLVM.h" 33 #include "mlir/Support/STLExtras.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "mlir/Transforms/FoldUtils.h" 36 37 using namespace mlir; 38 using namespace mlir::edsc; 39 using namespace mlir::edsc::intrinsics; 40 using namespace mlir::linalg; 41 using namespace mlir::linalg::intrinsics; 42 43 using IndexedLinalgValue = TemplatedIndexedValue<linalg_load, linalg_store>; 44 using edsc::op::operator+; 45 using edsc::op::operator==; 46 47 static SmallVector<ValueHandle, 8> 48 foldedAffineApplies(OpBuilder &b, Location loc, AffineMap map, 49 ArrayRef<Value *> vals, OperationFolder &folder) { 50 assert(map.getNumSymbols() == 0); 51 assert(map.getNumInputs() == vals.size()); 52 SmallVector<ValueHandle, 8> res; 53 res.reserve(map.getNumResults()); 54 auto dims = map.getNumDims(); 55 for (auto e : map.getResults()) { 56 auto exprMap = AffineMap::get(dims, 0, e); 57 SmallVector<Value *, 4> operands(vals.begin(), vals.end()); 58 canonicalizeMapAndOperands(&exprMap, &operands); 59 res.push_back(affine_apply(folder, exprMap, operands)); 60 } 61 return res; 62 } 63 64 static SmallVector<Value *, 4> permuteIvs(ArrayRef<Value *> ivs, 65 Optional<AffineMap> permutation, 66 OperationFolder &state) { 67 return permutation ? applyMapToValues(ScopedContext::getBuilder(), 68 ScopedContext::getLocation(), 69 permutation.getValue(), ivs, state) 70 : SmallVector<Value *, 4>(ivs.begin(), ivs.end()); 71 } 72 73 // Creates a number of ranges equal to the number of results in `map`. 74 // The returned ranges correspond to the loop ranges, in the proper order, for 75 // which new loops will be created. 76 static SmallVector<Value *, 4> emitLoopRanges(OpBuilder &b, Location loc, 77 AffineMap map, 78 ArrayRef<Value *> allViewSizes, 79 OperationFolder &folder) { 80 // Apply `map` to get view sizes in loop order. 81 auto sizes = applyMapToValues(b, loc, map, allViewSizes, folder); 82 // Create a new range with the applied tile sizes. 83 ScopedContext scope(b, loc); 84 SmallVector<Value *, 4> res; 85 for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { 86 res.push_back(range(constant_index(folder, 0), sizes[idx], 87 constant_index(folder, 1))); 88 } 89 return res; 90 } 91 92 template <typename LinalgOpType> class LinalgScopedEmitter {}; 93 94 template <> class LinalgScopedEmitter<CopyOp> { 95 public: 96 static void emitScalarImplementation(ArrayRef<Value *> allIvs, CopyOp copyOp, 97 OperationFolder &folder) { 98 auto nPar = copyOp.getNumParallelLoops(); 99 assert(nPar == allIvs.size()); 100 auto inputIvs = 101 permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation(), folder); 102 auto outputIvs = 103 permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation(), folder); 104 SmallVector<IndexHandle, 8> iivs(inputIvs.begin(), inputIvs.end()); 105 SmallVector<IndexHandle, 8> oivs(outputIvs.begin(), outputIvs.end()); 106 IndexedLinalgValue O(copyOp.getOutput(0)), I(copyOp.getInput(0)); 107 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 108 // an n-D loop nest; with or without permutations. 109 // clang-format off 110 nPar > 0 ? O(oivs) = I(iivs) : 111 O() = I(); 112 // clang-format on 113 } 114 }; 115 116 template <> class LinalgScopedEmitter<FillOp> { 117 public: 118 static void emitScalarImplementation(ArrayRef<Value *> allIvs, FillOp fillOp, 119 OperationFolder &folder) { 120 auto nPar = fillOp.getNumParallelLoops(); 121 assert(nPar == allIvs.size()); 122 auto ivs = 123 SmallVector<IndexHandle, 4>(allIvs.begin(), allIvs.begin() + nPar); 124 IndexedLinalgValue O(fillOp.getOutput(0)); 125 // Emit the proper scalar assignment, whether we are dealing with a 0-D or 126 // an n-D loop nest; with or without permutations. 127 nPar > 0 ? O(ivs) = ValueHandle(fillOp.getValue()) 128 : O() = ValueHandle(fillOp.getValue()); 129 } 130 }; 131 132 template <> class LinalgScopedEmitter<DotOp> { 133 public: 134 static void emitScalarImplementation(ArrayRef<Value *> allIvs, DotOp dotOp, 135 OperationFolder &folder) { 136 assert(allIvs.size() == 1); 137 IndexHandle r_i(allIvs[0]); 138 IndexedLinalgValue A(dotOp.getInput(0)), B(dotOp.getInput(1)), 139 C(dotOp.getOutput(0)); 140 // Emit scalar form. 141 C() = C() + A(r_i) * B(r_i); 142 } 143 }; 144 145 template <> class LinalgScopedEmitter<MatvecOp> { 146 public: 147 static void emitScalarImplementation(ArrayRef<Value *> allIvs, 148 MatvecOp matvecOp, 149 OperationFolder &folder) { 150 assert(allIvs.size() == 2); 151 IndexHandle i(allIvs[0]), r_j(allIvs[1]); 152 IndexedLinalgValue A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), 153 C(matvecOp.getOutput(0)); 154 // Emit scalar form. 155 C(i) = C(i) + A(i, r_j) * B(r_j); 156 } 157 }; 158 159 template <> class LinalgScopedEmitter<MatmulOp> { 160 public: 161 static void emitScalarImplementation(ArrayRef<Value *> allIvs, 162 MatmulOp matmulOp, 163 OperationFolder &folder) { 164 assert(allIvs.size() == 3); 165 IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); 166 IndexedLinalgValue A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), 167 C(matmulOp.getOutput(0)); 168 // Emit scalar form. 169 C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); 170 } 171 }; 172 173 template <> class LinalgScopedEmitter<ConvOp> { 174 public: 175 static void emitScalarImplementation(ArrayRef<Value *> allIvs, ConvOp convOp, 176 OperationFolder &folder) { 177 auto b = ScopedContext::getBuilder(); 178 auto loc = ScopedContext::getLocation(); 179 auto maps = loopToOperandRangesMaps(convOp); 180 SmallVector<ValueHandle, 8> fIdx( 181 foldedAffineApplies(b, loc, maps[0], allIvs, folder)); 182 SmallVector<ValueHandle, 8> imIdx( 183 foldedAffineApplies(b, loc, maps[1], allIvs, folder)); 184 SmallVector<ValueHandle, 8> oIdx( 185 foldedAffineApplies(b, loc, maps[2], allIvs, folder)); 186 IndexedLinalgValue F(convOp.filter()), I(convOp.input()), 187 O(convOp.output()); 188 // Emit scalar form. 189 O(oIdx) += F(fIdx) * I(imIdx); 190 } 191 }; 192 193 // Emits the MLIR for the scalar part of the generic op by: 194 // 1. Emitting linalg_load and linalg_store ops for each input and output 195 // view in order. This is achieved by applying the appropriate input or 196 // output map to the enclosing induction variables. 197 // 2. Emitting a call to `op.fun()` that takes as arguments the scalars 198 // from point 1. above. 199 // 3. Emitting linalg_store to store the results of 2. to the output 200 // views. 201 // 202 // An example output may resemble: 203 // 204 // ``` 205 // loop.for %i = %c0 to %0 step %c1 { 206 // loop.for %j = %c0 to %1 step %c1 { 207 // loop.for %k = %c0 to %4 step %c1 { 208 // %11 = linalg.load %arg0[%i, %j] : !linalg.view<?x?xf32> 209 // %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32> 210 // %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32> 211 // %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 212 // linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32> 213 // linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32> 214 // } 215 // } 216 // } 217 // ``` 218 template <> class LinalgScopedEmitter<GenericOp> { 219 public: 220 static void emitScalarImplementation(ArrayRef<Value *> allIvs, 221 GenericOp genericOp, 222 OperationFolder &folder) { 223 auto b = ScopedContext::getBuilder(); 224 auto loc = ScopedContext::getLocation(); 225 using edsc::intrinsics::detail::ValueHandleArray; 226 unsigned nInputs = genericOp.getNumInputs(); 227 unsigned nOutputs = genericOp.getNumOutputs(); 228 SmallVector<Value *, 4> indexedValues(nInputs + nOutputs); 229 230 // 1.a. Emit linalg_load from input views. 231 for (unsigned i = 0, e = nInputs; i < e; ++i) { 232 ValueHandleArray indexing(foldedAffineApplies( 233 b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); 234 indexedValues[i] = linalg_load(genericOp.getInput(i), indexing); 235 } 236 237 // 1.b. Emit linalg_load from output views. 238 for (unsigned i = 0, e = nOutputs; i < e; ++i) { 239 ValueHandleArray indexing(foldedAffineApplies( 240 b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); 241 indexedValues[nInputs + i] = 242 linalg_load(genericOp.getOutput(i), indexing); 243 } 244 245 auto funcOp = genericOp.getFunction(); 246 if (funcOp) { 247 // 2. Emit call. 248 Operation *callOp = call(funcOp, indexedValues); 249 assert(callOp->getNumResults() == genericOp.getNumOutputs()); 250 251 // 3. Emit linalg_store. 252 for (unsigned i = 0, e = nOutputs; i < e; ++i) { 253 ValueHandleArray indexing(foldedAffineApplies( 254 b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); 255 linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing); 256 } 257 } else { 258 // TODO(ntv): When a region inliner exists, use it. 259 // 2. Inline region, currently only works for a single basic block. 260 BlockAndValueMapping map; 261 auto &block = genericOp.region().front(); 262 for (auto it : llvm::zip(block.getArguments(), indexedValues)) 263 map.map(std::get<0>(it), std::get<1>(it)); 264 for (auto &op : block) { 265 // Skip terminator. 266 if (&op == &block.back()) 267 continue; 268 assert(op.getNumRegions() == 0); 269 auto *newOp = b.clone(op, map); 270 for (auto it : llvm::zip(op.getResults(), newOp->getResults())) 271 map.map(std::get<0>(it), std::get<1>(it)); 272 } 273 274 // 3. Emit linalg_store. 275 auto *yieldOp = cast<YieldOp>(block.back()).getOperation(); 276 assert(yieldOp->getNumOperands() == nOutputs); 277 for (unsigned i = 0, e = nOutputs; i < e; ++i) { 278 ValueHandleArray indexing(foldedAffineApplies( 279 b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); 280 linalg_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), 281 indexing); 282 } 283 } 284 } 285 }; 286 287 template <typename ConcreteOp> 288 class LinalgRewritePattern : public RewritePattern { 289 public: 290 explicit LinalgRewritePattern(MLIRContext *context) 291 : RewritePattern(ConcreteOp::getOperationName(), /*benefit=*/1, context) { 292 } 293 294 PatternMatchResult matchAndRewrite(Operation *op, 295 PatternRewriter &rewriter) const override { 296 OpBuilder b(op); 297 ScopedContext scope(b, op->getLoc()); 298 299 // The flattened loopToOperandRangesMaps is expected to be an invertible 300 // permutation map (which is asserted in the inverse calculation). 301 auto linalgOp = cast<ConcreteOp>(op); 302 auto invertedMap = 303 inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); 304 if (!invertedMap) { 305 LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation({}, linalgOp, 306 folder); 307 rewriter.replaceOp(op, {}); 308 return matchSuccess(); 309 } 310 311 auto nPar = linalgOp.getNumParallelLoops(); 312 auto nRed = linalgOp.getNumReductionLoops(); 313 auto nWin = linalgOp.getNumWindowLoops(); 314 SmallVector<IndexHandle, 4> allIvs(nPar + nRed + nWin); 315 SmallVector<ValueHandle *, 4> allPIvs = makeIndexHandlePointers(allIvs); 316 auto pivs = MutableArrayRef<ValueHandle *>(allPIvs).take_front(nPar); 317 auto rivs = MutableArrayRef<ValueHandle *>(allPIvs) 318 .take_front(nPar + nRed) 319 .take_back(nRed); 320 auto wivs = MutableArrayRef<ValueHandle *>(allPIvs).take_back(nWin); 321 322 auto loopRanges = 323 emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, 324 getViewSizes(linalgOp), folder); 325 assert(loopRanges.size() == pivs.size() + rivs.size() + wivs.size()); 326 327 // clang-format off 328 ArrayRef<Value *> ranges(loopRanges); 329 LoopNestRangeBuilder(pivs, ranges.take_front(nPar))([&] { 330 LoopNestRangeBuilder(rivs, ranges.drop_back(nWin).take_back(nRed))([&] { 331 LoopNestRangeBuilder(wivs, ranges.take_back(wivs.size()))( 332 [&linalgOp, &allIvs, this] { 333 auto allIvValues = extractValues(allIvs); 334 LinalgScopedEmitter<ConcreteOp>::emitScalarImplementation( 335 allIvValues, linalgOp, folder); 336 }); 337 }); 338 }); 339 // clang-format on 340 rewriter.replaceOp(op, {}); 341 return matchSuccess(); 342 } 343 344 mutable OperationFolder folder; 345 }; 346 347 // Helper classes for type list expansion. 348 template <typename... LinalgOps> class ConversionList; 349 350 template <> class ConversionList<> { 351 public: 352 static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {} 353 }; 354 355 template <typename ConcreteOp, typename... LinalgOps> 356 class ConversionList<ConcreteOp, LinalgOps...> { 357 public: 358 static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) { 359 patterns.insert<LinalgRewritePattern<ConcreteOp>>(ctx); 360 ConversionList<LinalgOps...>::build(patterns, ctx); 361 } 362 }; 363 364 /// Populate the given list with patterns that convert from Linalg to LLVM. 365 static void 366 populateLinalgToLoopRewritePatterns(OwningRewritePatternList &patterns, 367 MLIRContext *ctx) { 368 ConversionList< 369 #define GET_OP_LIST 370 #include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.cpp.inc" 371 >::build(patterns, ctx); 372 } 373 374 namespace { 375 struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> { 376 void runOnFunction(); 377 }; 378 } // namespace 379 380 void LowerLinalgToLoopsPass::runOnFunction() { 381 OwningRewritePatternList patterns; 382 populateLinalgToLoopRewritePatterns(patterns, &getContext()); 383 384 ConversionTarget target(getContext()); 385 target.addLegalDialect<AffineOpsDialect>(); 386 target.addLegalDialect<loop::LoopOpsDialect>(); 387 target.addLegalDialect<StandardOpsDialect>(); 388 if (failed(applyPartialConversion(getFunction(), target, patterns))) { 389 signalPassFailure(); 390 } 391 } 392 393 std::unique_ptr<FunctionPassBase> mlir::linalg::createLowerLinalgToLoopsPass() { 394 return std::make_unique<LowerLinalgToLoopsPass>(); 395 } 396 397 static PassRegistration<LowerLinalgToLoopsPass> 398 pass("linalg-lower-to-loops", 399 "Lower the operations from the linalg dialect into loops");