github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/LowerAffine.cpp (about) 1 //===- LowerAffine.cpp - Lower affine constructs to primitives ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file lowers affine constructs (If and For statements, AffineApply 19 // operations) within a function into their standard If and For equivalent ops. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Transforms/LowerAffine.h" 24 #include "mlir/Dialect/AffineOps/AffineOps.h" 25 #include "mlir/Dialect/LoopOps/LoopOps.h" 26 #include "mlir/Dialect/StandardOps/Ops.h" 27 #include "mlir/IR/AffineExprVisitor.h" 28 #include "mlir/IR/BlockAndValueMapping.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/MLIRContext.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Support/Functional.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "mlir/Transforms/Passes.h" 36 37 using namespace mlir; 38 39 namespace { 40 // Visit affine expressions recursively and build the sequence of operations 41 // that correspond to it. Visitation functions return an Value of the 42 // expression subtree they visited or `nullptr` on error. 43 class AffineApplyExpander 44 : public AffineExprVisitor<AffineApplyExpander, Value *> { 45 public: 46 // This internal class expects arguments to be non-null, checks must be 47 // performed at the call site. 48 AffineApplyExpander(OpBuilder &builder, ArrayRef<Value *> dimValues, 49 ArrayRef<Value *> symbolValues, Location loc) 50 : builder(builder), dimValues(dimValues), symbolValues(symbolValues), 51 loc(loc) {} 52 53 template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) { 54 auto lhs = visit(expr.getLHS()); 55 auto rhs = visit(expr.getRHS()); 56 if (!lhs || !rhs) 57 return nullptr; 58 auto op = builder.create<OpTy>(loc, lhs, rhs); 59 return op.getResult(); 60 } 61 62 Value *visitAddExpr(AffineBinaryOpExpr expr) { 63 return buildBinaryExpr<AddIOp>(expr); 64 } 65 66 Value *visitMulExpr(AffineBinaryOpExpr expr) { 67 return buildBinaryExpr<MulIOp>(expr); 68 } 69 70 // Euclidean modulo operation: negative RHS is not allowed. 71 // Remainder of the euclidean integer division is always non-negative. 72 // 73 // Implemented as 74 // 75 // a mod b = 76 // let remainder = srem a, b; 77 // negative = a < 0 in 78 // select negative, remainder + b, remainder. 79 Value *visitModExpr(AffineBinaryOpExpr expr) { 80 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 81 if (!rhsConst) { 82 emitError( 83 loc, 84 "semi-affine expressions (modulo by non-const) are not supported"); 85 return nullptr; 86 } 87 if (rhsConst.getValue() <= 0) { 88 emitError(loc, "modulo by non-positive value is not supported"); 89 return nullptr; 90 } 91 92 auto lhs = visit(expr.getLHS()); 93 auto rhs = visit(expr.getRHS()); 94 assert(lhs && rhs && "unexpected affine expr lowering failure"); 95 96 Value *remainder = builder.create<RemISOp>(loc, lhs, rhs); 97 Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0); 98 Value *isRemainderNegative = 99 builder.create<CmpIOp>(loc, CmpIPredicate::SLT, remainder, zeroCst); 100 Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs); 101 Value *result = builder.create<SelectOp>(loc, isRemainderNegative, 102 correctedRemainder, remainder); 103 return result; 104 } 105 106 // Floor division operation (rounds towards negative infinity). 107 // 108 // For positive divisors, it can be implemented without branching and with a 109 // single division operation as 110 // 111 // a floordiv b = 112 // let negative = a < 0 in 113 // let absolute = negative ? -a - 1 : a in 114 // let quotient = absolute / b in 115 // negative ? -quotient - 1 : quotient 116 Value *visitFloorDivExpr(AffineBinaryOpExpr expr) { 117 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 118 if (!rhsConst) { 119 emitError( 120 loc, 121 "semi-affine expressions (division by non-const) are not supported"); 122 return nullptr; 123 } 124 if (rhsConst.getValue() <= 0) { 125 emitError(loc, "division by non-positive value is not supported"); 126 return nullptr; 127 } 128 129 auto lhs = visit(expr.getLHS()); 130 auto rhs = visit(expr.getRHS()); 131 assert(lhs && rhs && "unexpected affine expr lowering failure"); 132 133 Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0); 134 Value *noneCst = builder.create<ConstantIndexOp>(loc, -1); 135 Value *negative = 136 builder.create<CmpIOp>(loc, CmpIPredicate::SLT, lhs, zeroCst); 137 Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs); 138 Value *dividend = 139 builder.create<SelectOp>(loc, negative, negatedDecremented, lhs); 140 Value *quotient = builder.create<DivISOp>(loc, dividend, rhs); 141 Value *correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient); 142 Value *result = 143 builder.create<SelectOp>(loc, negative, correctedQuotient, quotient); 144 return result; 145 } 146 147 // Ceiling division operation (rounds towards positive infinity). 148 // 149 // For positive divisors, it can be implemented without branching and with a 150 // single division operation as 151 // 152 // a ceildiv b = 153 // let negative = a <= 0 in 154 // let absolute = negative ? -a : a - 1 in 155 // let quotient = absolute / b in 156 // negative ? -quotient : quotient + 1 157 Value *visitCeilDivExpr(AffineBinaryOpExpr expr) { 158 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>(); 159 if (!rhsConst) { 160 emitError(loc) << "semi-affine expressions (division by non-const) are " 161 "not supported"; 162 return nullptr; 163 } 164 if (rhsConst.getValue() <= 0) { 165 emitError(loc, "division by non-positive value is not supported"); 166 return nullptr; 167 } 168 auto lhs = visit(expr.getLHS()); 169 auto rhs = visit(expr.getRHS()); 170 assert(lhs && rhs && "unexpected affine expr lowering failure"); 171 172 Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0); 173 Value *oneCst = builder.create<ConstantIndexOp>(loc, 1); 174 Value *nonPositive = 175 builder.create<CmpIOp>(loc, CmpIPredicate::SLE, lhs, zeroCst); 176 Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs); 177 Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst); 178 Value *dividend = 179 builder.create<SelectOp>(loc, nonPositive, negated, decremented); 180 Value *quotient = builder.create<DivISOp>(loc, dividend, rhs); 181 Value *negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient); 182 Value *incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst); 183 Value *result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient, 184 incrementedQuotient); 185 return result; 186 } 187 188 Value *visitConstantExpr(AffineConstantExpr expr) { 189 auto valueAttr = 190 builder.getIntegerAttr(builder.getIndexType(), expr.getValue()); 191 auto op = 192 builder.create<ConstantOp>(loc, builder.getIndexType(), valueAttr); 193 return op.getResult(); 194 } 195 196 Value *visitDimExpr(AffineDimExpr expr) { 197 assert(expr.getPosition() < dimValues.size() && 198 "affine dim position out of range"); 199 return dimValues[expr.getPosition()]; 200 } 201 202 Value *visitSymbolExpr(AffineSymbolExpr expr) { 203 assert(expr.getPosition() < symbolValues.size() && 204 "symbol dim position out of range"); 205 return symbolValues[expr.getPosition()]; 206 } 207 208 private: 209 OpBuilder &builder; 210 ArrayRef<Value *> dimValues; 211 ArrayRef<Value *> symbolValues; 212 213 Location loc; 214 }; 215 } // namespace 216 217 // Create a sequence of operations that implement the `expr` applied to the 218 // given dimension and symbol values. 219 mlir::Value *mlir::expandAffineExpr(OpBuilder &builder, Location loc, 220 AffineExpr expr, 221 ArrayRef<Value *> dimValues, 222 ArrayRef<Value *> symbolValues) { 223 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); 224 } 225 226 // Create a sequence of operations that implement the `affineMap` applied to 227 // the given `operands` (as it it were an AffineApplyOp). 228 Optional<SmallVector<Value *, 8>> static expandAffineMap( 229 OpBuilder &builder, Location loc, AffineMap affineMap, 230 ArrayRef<Value *> operands) { 231 auto numDims = affineMap.getNumDims(); 232 auto expanded = functional::map( 233 [numDims, &builder, loc, operands](AffineExpr expr) { 234 return expandAffineExpr(builder, loc, expr, 235 operands.take_front(numDims), 236 operands.drop_front(numDims)); 237 }, 238 affineMap.getResults()); 239 if (llvm::all_of(expanded, [](Value *v) { return v; })) 240 return expanded; 241 return None; 242 } 243 244 // Given a range of values, emit the code that reduces them with "min" or "max" 245 // depending on the provided comparison predicate. The predicate defines which 246 // comparison to perform, "lt" for "min", "gt" for "max" and is used for the 247 // `cmpi` operation followed by the `select` operation: 248 // 249 // %cond = cmpi "predicate" %v0, %v1 250 // %result = select %cond, %v0, %v1 251 // 252 // Multiple values are scanned in a linear sequence. This creates a data 253 // dependences that wouldn't exist in a tree reduction, but is easier to 254 // recognize as a reduction by the subsequent passes. 255 static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate, 256 ArrayRef<Value *> values, 257 OpBuilder &builder) { 258 assert(!llvm::empty(values) && "empty min/max chain"); 259 260 auto valueIt = values.begin(); 261 Value *value = *valueIt++; 262 for (; valueIt != values.end(); ++valueIt) { 263 auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt); 264 value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt); 265 } 266 267 return value; 268 } 269 270 // Emit instructions that correspond to the affine map in the lower bound 271 // applied to the respective operands, and compute the maximum value across 272 // the results. 273 Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) { 274 SmallVector<Value *, 8> boundOperands(op.getLowerBoundOperands()); 275 auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(), 276 boundOperands); 277 if (!lbValues) 278 return nullptr; 279 return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues, 280 builder); 281 } 282 283 // Emit instructions that correspond to the affine map in the upper bound 284 // applied to the respective operands, and compute the minimum value across 285 // the results. 286 Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { 287 SmallVector<Value *, 8> boundOperands(op.getUpperBoundOperands()); 288 auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(), 289 boundOperands); 290 if (!ubValues) 291 return nullptr; 292 return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues, 293 builder); 294 } 295 296 namespace { 297 // Affine terminators are removed. 298 class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> { 299 public: 300 using OpRewritePattern<AffineTerminatorOp>::OpRewritePattern; 301 302 PatternMatchResult matchAndRewrite(AffineTerminatorOp op, 303 PatternRewriter &rewriter) const override { 304 rewriter.replaceOpWithNewOp<loop::TerminatorOp>(op); 305 return matchSuccess(); 306 } 307 }; 308 309 class AffineForLowering : public OpRewritePattern<AffineForOp> { 310 public: 311 using OpRewritePattern<AffineForOp>::OpRewritePattern; 312 313 PatternMatchResult matchAndRewrite(AffineForOp op, 314 PatternRewriter &rewriter) const override { 315 Location loc = op.getLoc(); 316 Value *lowerBound = lowerAffineLowerBound(op, rewriter); 317 Value *upperBound = lowerAffineUpperBound(op, rewriter); 318 Value *step = rewriter.create<ConstantIndexOp>(loc, op.getStep()); 319 auto f = rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step); 320 f.region().getBlocks().clear(); 321 rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); 322 rewriter.replaceOp(op, {}); 323 return matchSuccess(); 324 } 325 }; 326 327 class AffineIfLowering : public OpRewritePattern<AffineIfOp> { 328 public: 329 using OpRewritePattern<AffineIfOp>::OpRewritePattern; 330 331 PatternMatchResult matchAndRewrite(AffineIfOp op, 332 PatternRewriter &rewriter) const override { 333 auto loc = op.getLoc(); 334 335 // Now we just have to handle the condition logic. 336 auto integerSet = op.getIntegerSet(); 337 Value *zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0); 338 SmallVector<Value *, 8> operands(op.getOperation()->getOperands()); 339 auto operandsRef = llvm::makeArrayRef(operands); 340 341 // Calculate cond as a conjunction without short-circuiting. 342 Value *cond = nullptr; 343 for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) { 344 AffineExpr constraintExpr = integerSet.getConstraint(i); 345 bool isEquality = integerSet.isEq(i); 346 347 // Build and apply an affine expression 348 auto numDims = integerSet.getNumDims(); 349 Value *affResult = expandAffineExpr(rewriter, loc, constraintExpr, 350 operandsRef.take_front(numDims), 351 operandsRef.drop_front(numDims)); 352 if (!affResult) 353 return matchFailure(); 354 auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE; 355 Value *cmpVal = 356 rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant); 357 cond = 358 cond ? rewriter.create<AndOp>(loc, cond, cmpVal).getResult() : cmpVal; 359 } 360 cond = cond ? cond 361 : rewriter.create<ConstantIntOp>(loc, /*value=*/1, /*width=*/1); 362 363 bool hasElseRegion = !op.elseRegion().empty(); 364 auto ifOp = rewriter.create<loop::IfOp>(loc, cond, hasElseRegion); 365 rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back()); 366 ifOp.thenRegion().back().erase(); 367 if (hasElseRegion) { 368 rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back()); 369 ifOp.elseRegion().back().erase(); 370 } 371 372 // Ok, we're done! 373 rewriter.replaceOp(op, {}); 374 return matchSuccess(); 375 } 376 }; 377 378 // Convert an "affine.apply" operation into a sequence of arithmetic 379 // operations using the StandardOps dialect. 380 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> { 381 public: 382 using OpRewritePattern<AffineApplyOp>::OpRewritePattern; 383 384 virtual PatternMatchResult 385 matchAndRewrite(AffineApplyOp op, PatternRewriter &rewriter) const override { 386 auto maybeExpandedMap = 387 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), 388 llvm::to_vector<8>(op.getOperands())); 389 if (!maybeExpandedMap) 390 return matchFailure(); 391 rewriter.replaceOp(op, *maybeExpandedMap); 392 return matchSuccess(); 393 } 394 }; 395 396 // Apply the affine map from an 'affine.load' operation to its operands, and 397 // feed the results to a newly created 'std.load' operation (which replaces the 398 // original 'affine.load'). 399 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> { 400 public: 401 using OpRewritePattern<AffineLoadOp>::OpRewritePattern; 402 403 virtual PatternMatchResult 404 matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override { 405 // Expand affine map from 'affineLoadOp'. 406 SmallVector<Value *, 8> indices(op.getIndices()); 407 auto maybeExpandedMap = 408 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 409 if (!maybeExpandedMap) 410 return matchFailure(); 411 412 // Build std.load memref[expandedMap.results]. 413 rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *maybeExpandedMap); 414 return matchSuccess(); 415 } 416 }; 417 418 // Apply the affine map from an 'affine.store' operation to its operands, and 419 // feed the results to a newly created 'std.store' operation (which replaces the 420 // original 'affine.store'). 421 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> { 422 public: 423 using OpRewritePattern<AffineStoreOp>::OpRewritePattern; 424 425 virtual PatternMatchResult 426 matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override { 427 // Expand affine map from 'affineStoreOp'. 428 SmallVector<Value *, 8> indices(op.getIndices()); 429 auto maybeExpandedMap = 430 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 431 if (!maybeExpandedMap) 432 return matchFailure(); 433 434 // Build std.store valutToStore, memref[expandedMap.results]. 435 rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(), 436 op.getMemRef(), *maybeExpandedMap); 437 return matchSuccess(); 438 } 439 }; 440 441 // Apply the affine maps from an 'affine.dma_start' operation to each of their 442 // respective map operands, and feed the results to a newly created 443 // 'std.dma_start' operation (which replaces the original 'affine.dma_start'). 444 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> { 445 public: 446 using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern; 447 448 virtual PatternMatchResult 449 matchAndRewrite(AffineDmaStartOp op, 450 PatternRewriter &rewriter) const override { 451 SmallVector<Value *, 8> operands(op.getOperands()); 452 auto operandsRef = llvm::makeArrayRef(operands); 453 454 // Expand affine map for DMA source memref. 455 auto maybeExpandedSrcMap = expandAffineMap( 456 rewriter, op.getLoc(), op.getSrcMap(), 457 operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1)); 458 if (!maybeExpandedSrcMap) 459 return matchFailure(); 460 // Expand affine map for DMA destination memref. 461 auto maybeExpandedDstMap = expandAffineMap( 462 rewriter, op.getLoc(), op.getDstMap(), 463 operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1)); 464 if (!maybeExpandedDstMap) 465 return matchFailure(); 466 // Expand affine map for DMA tag memref. 467 auto maybeExpandedTagMap = expandAffineMap( 468 rewriter, op.getLoc(), op.getTagMap(), 469 operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1)); 470 if (!maybeExpandedTagMap) 471 return matchFailure(); 472 473 // Build std.dma_start operation with affine map results. 474 rewriter.replaceOpWithNewOp<DmaStartOp>( 475 op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(), 476 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(), 477 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride()); 478 return matchSuccess(); 479 } 480 }; 481 482 // Apply the affine map from an 'affine.dma_wait' operation tag memref, 483 // and feed the results to a newly created 'std.dma_wait' operation (which 484 // replaces the original 'affine.dma_wait'). 485 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> { 486 public: 487 using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern; 488 489 virtual PatternMatchResult 490 matchAndRewrite(AffineDmaWaitOp op, 491 PatternRewriter &rewriter) const override { 492 // Expand affine map for DMA tag memref. 493 SmallVector<Value *, 8> indices(op.getTagIndices()); 494 auto maybeExpandedTagMap = 495 expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); 496 if (!maybeExpandedTagMap) 497 return matchFailure(); 498 499 // Build std.dma_wait operation with affine map results. 500 rewriter.replaceOpWithNewOp<DmaWaitOp>( 501 op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements()); 502 return matchSuccess(); 503 } 504 }; 505 506 } // end namespace 507 508 void mlir::populateAffineToStdConversionPatterns( 509 OwningRewritePatternList &patterns, MLIRContext *ctx) { 510 patterns 511 .insert<AffineApplyLowering, AffineDmaStartLowering, 512 AffineDmaWaitLowering, AffineLoadLowering, AffineStoreLowering, 513 AffineForLowering, AffineIfLowering, AffineTerminatorLowering>( 514 ctx); 515 } 516 517 namespace { 518 class LowerAffinePass : public FunctionPass<LowerAffinePass> { 519 void runOnFunction() override { 520 OwningRewritePatternList patterns; 521 populateAffineToStdConversionPatterns(patterns, &getContext()); 522 ConversionTarget target(getContext()); 523 target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>(); 524 if (failed(applyPartialConversion(getFunction(), target, patterns))) 525 signalPassFailure(); 526 } 527 }; 528 } // namespace 529 530 /// Lowers If and For operations within a function into their lower level CFG 531 /// equivalent blocks. 532 std::unique_ptr<FunctionPassBase> mlir::createLowerAffinePass() { 533 return std::make_unique<LowerAffinePass>(); 534 } 535 536 static PassRegistration<LowerAffinePass> 537 pass("lower-affine", 538 "Lower If, For, AffineApply operations to primitive equivalents");