github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp (about)

     1  //===- ConvertControlFlowToCFG.cpp - ControlFlow to CFG conversion --------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements a pass to convert loop.for, loop.if and loop.terminator
    19  // ops into standard CFG ops.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
    24  #include "mlir/Dialect/LoopOps/LoopOps.h"
    25  #include "mlir/Dialect/StandardOps/Ops.h"
    26  #include "mlir/IR/Builders.h"
    27  #include "mlir/IR/MLIRContext.h"
    28  #include "mlir/IR/Module.h"
    29  #include "mlir/IR/PatternMatch.h"
    30  #include "mlir/Pass/Pass.h"
    31  #include "mlir/Support/Functional.h"
    32  #include "mlir/Transforms/DialectConversion.h"
    33  #include "mlir/Transforms/Passes.h"
    34  #include "mlir/Transforms/Utils.h"
    35  
    36  using namespace mlir;
    37  using namespace mlir::loop;
    38  
    39  namespace {
    40  
    41  struct ControlFlowToCFGPass : public FunctionPass<ControlFlowToCFGPass> {
    42    void runOnFunction() override;
    43  };
    44  
    45  // Create a CFG subgraph for the loop around its body blocks (if the body
    46  // contained other loops, they have been already lowered to a flow of blocks).
    47  // Maintain the invariants that a CFG subgraph created for any loop has a single
    48  // entry and a single exit, and that the entry/exit blocks are respectively
    49  // first/last blocks in the parent region.  The original loop operation is
    50  // replaced by the initialization operations that set up the initial value of
    51  // the loop induction variable (%iv) and computes the loop bounds that are loop-
    52  // invariant for affine loops.  The operations following the original loop.for
    53  // are split out into a separate continuation (exit) block. A condition block is
    54  // created before the continuation block. It checks the exit condition of the
    55  // loop and branches either to the continuation block, or to the first block of
    56  // the body. Induction variable modification is appended to the last block of
    57  // the body (which is the exit block from the body subgraph thanks to the
    58  // invariant we maintain) along with a branch that loops back to the condition
    59  // block.
    60  //
    61  //      +---------------------------------+
    62  //      |   <code before the ForOp>       |
    63  //      |   <compute initial %iv value>   |
    64  //      |   br cond(%iv)                  |
    65  //      +---------------------------------+
    66  //             |
    67  //  -------|   |
    68  //  |      v   v
    69  //  |   +--------------------------------+
    70  //  |   | cond(%iv):                     |
    71  //  |   |   <compare %iv to upper bound> |
    72  //  |   |   cond_br %r, body, end        |
    73  //  |   +--------------------------------+
    74  //  |          |               |
    75  //  |          |               -------------|
    76  //  |          v                            |
    77  //  |   +--------------------------------+  |
    78  //  |   | body-first:                    |  |
    79  //  |   |   <body contents>              |  |
    80  //  |   +--------------------------------+  |
    81  //  |                   |                   |
    82  //  |                  ...                  |
    83  //  |                   |                   |
    84  //  |   +--------------------------------+  |
    85  //  |   | body-last:                     |  |
    86  //  |   |   <body contents>              |  |
    87  //  |   |   %new_iv =<add step to %iv>   |  |
    88  //  |   |   br cond(%new_iv)             |  |
    89  //  |   +--------------------------------+  |
    90  //  |          |                            |
    91  //  |-----------        |--------------------
    92  //                      v
    93  //      +--------------------------------+
    94  //      | end:                           |
    95  //      |   <code after the ForOp> |
    96  //      +--------------------------------+
    97  //
    98  struct ForLowering : public OpRewritePattern<ForOp> {
    99    using OpRewritePattern<ForOp>::OpRewritePattern;
   100  
   101    PatternMatchResult matchAndRewrite(ForOp forOp,
   102                                       PatternRewriter &rewriter) const override;
   103  };
   104  
   105  // Create a CFG subgraph for the loop.if operation (including its "then" and
   106  // optional "else" operation blocks).  We maintain the invariants that the
   107  // subgraph has a single entry and a single exit point, and that the entry/exit
   108  // blocks are respectively the first/last block of the enclosing region. The
   109  // operations following the loop.if are split into a continuation (subgraph
   110  // exit) block. The condition is lowered to a chain of blocks that implement the
   111  // short-circuit scheme.  Condition blocks are created by splitting out an empty
   112  // block from the block that contains the loop.if operation.  They
   113  // conditionally branch to either the first block of the "then" region, or to
   114  // the first block of the "else" region.  If the latter is absent, they branch
   115  // to the continuation block instead.  The last blocks of "then" and "else"
   116  // regions (which are known to be exit blocks thanks to the invariant we
   117  // maintain).
   118  //
   119  //      +--------------------------------+
   120  //      | <code before the IfOp>         |
   121  //      | cond_br %cond, %then, %else    |
   122  //      +--------------------------------+
   123  //             |              |
   124  //             |              --------------|
   125  //             v                            |
   126  //      +--------------------------------+  |
   127  //      | then:                          |  |
   128  //      |   <then contents>              |  |
   129  //      |   br continue                  |  |
   130  //      +--------------------------------+  |
   131  //             |                            |
   132  //   |----------               |-------------
   133  //   |                         V
   134  //   |  +--------------------------------+
   135  //   |  | else:                          |
   136  //   |  |   <else contents>              |
   137  //   |  |   br continue                  |
   138  //   |  +--------------------------------+
   139  //   |         |
   140  //   ------|   |
   141  //         v   v
   142  //      +--------------------------------+
   143  //      | continue:                      |
   144  //      |   <code after the IfOp>  |
   145  //      +--------------------------------+
   146  //
   147  struct IfLowering : public OpRewritePattern<IfOp> {
   148    using OpRewritePattern<IfOp>::OpRewritePattern;
   149  
   150    PatternMatchResult matchAndRewrite(IfOp ifOp,
   151                                       PatternRewriter &rewriter) const override;
   152  };
   153  
   154  struct TerminatorLowering : public OpRewritePattern<TerminatorOp> {
   155    using OpRewritePattern<TerminatorOp>::OpRewritePattern;
   156  
   157    PatternMatchResult matchAndRewrite(TerminatorOp op,
   158                                       PatternRewriter &rewriter) const override {
   159      rewriter.replaceOp(op, {});
   160      return matchSuccess();
   161    }
   162  };
   163  } // namespace
   164  
   165  PatternMatchResult
   166  ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const {
   167    Location loc = forOp.getLoc();
   168  
   169    // Start by splitting the block containing the 'loop.for' into two parts.
   170    // The part before will get the init code, the part after will be the end
   171    // point.
   172    auto *initBlock = rewriter.getInsertionBlock();
   173    auto initPosition = rewriter.getInsertionPoint();
   174    auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
   175  
   176    // Use the first block of the loop body as the condition block since it is
   177    // the block that has the induction variable as its argument.  Split out
   178    // all operations from the first block into a new block.  Move all body
   179    // blocks from the loop body region to the region containing the loop.
   180    auto *conditionBlock = &forOp.region().front();
   181    auto *firstBodyBlock =
   182        rewriter.splitBlock(conditionBlock, conditionBlock->begin());
   183    auto *lastBodyBlock = &forOp.region().back();
   184    rewriter.inlineRegionBefore(forOp.region(), endBlock);
   185    auto *iv = conditionBlock->getArgument(0);
   186  
   187    // Append the induction variable stepping logic to the last body block and
   188    // branch back to the condition block.  Construct an expression f :
   189    // (x -> x+step) and apply this expression to the induction variable.
   190    rewriter.setInsertionPointToEnd(lastBodyBlock);
   191    auto *step = forOp.step();
   192    auto *stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
   193    if (!stepped)
   194      return matchFailure();
   195    rewriter.create<BranchOp>(loc, conditionBlock, stepped);
   196  
   197    // Compute loop bounds before branching to the condition.
   198    rewriter.setInsertionPointToEnd(initBlock);
   199    Value *lowerBound = forOp.lowerBound();
   200    Value *upperBound = forOp.upperBound();
   201    if (!lowerBound || !upperBound)
   202      return matchFailure();
   203    rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
   204  
   205    // With the body block done, we can fill in the condition block.
   206    rewriter.setInsertionPointToEnd(conditionBlock);
   207    auto comparison =
   208        rewriter.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
   209  
   210    rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
   211                                  ArrayRef<Value *>(), endBlock,
   212                                  ArrayRef<Value *>());
   213    // Ok, we're done!
   214    rewriter.replaceOp(forOp, {});
   215    return matchSuccess();
   216  }
   217  
   218  PatternMatchResult
   219  IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const {
   220    auto loc = ifOp.getLoc();
   221  
   222    // Start by splitting the block containing the 'loop.if' into two parts.
   223    // The part before will contain the condition, the part after will be the
   224    // continuation point.
   225    auto *condBlock = rewriter.getInsertionBlock();
   226    auto opPosition = rewriter.getInsertionPoint();
   227    auto *continueBlock = rewriter.splitBlock(condBlock, opPosition);
   228  
   229    // Move blocks from the "then" region to the region containing 'loop.if',
   230    // place it before the continuation block, and branch to it.
   231    auto &thenRegion = ifOp.thenRegion();
   232    auto *thenBlock = &thenRegion.front();
   233    rewriter.setInsertionPointToEnd(&thenRegion.back());
   234    rewriter.create<BranchOp>(loc, continueBlock);
   235    rewriter.inlineRegionBefore(thenRegion, continueBlock);
   236  
   237    // Move blocks from the "else" region (if present) to the region containing
   238    // 'loop.if', place it before the continuation block and branch to it.  It
   239    // will be placed after the "then" regions.
   240    auto *elseBlock = continueBlock;
   241    auto &elseRegion = ifOp.elseRegion();
   242    if (!elseRegion.empty()) {
   243      elseBlock = &elseRegion.front();
   244      rewriter.setInsertionPointToEnd(&elseRegion.back());
   245      rewriter.create<BranchOp>(loc, continueBlock);
   246      rewriter.inlineRegionBefore(elseRegion, continueBlock);
   247    }
   248  
   249    rewriter.setInsertionPointToEnd(condBlock);
   250    rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock,
   251                                  /*trueArgs=*/ArrayRef<Value *>(), elseBlock,
   252                                  /*falseArgs=*/ArrayRef<Value *>());
   253  
   254    // Ok, we're done!
   255    rewriter.replaceOp(ifOp, {});
   256    return matchSuccess();
   257  }
   258  
   259  void mlir::populateLoopToStdConversionPatterns(
   260      OwningRewritePatternList &patterns, MLIRContext *ctx) {
   261    patterns.insert<ForLowering, IfLowering, TerminatorLowering>(ctx);
   262  }
   263  
   264  void ControlFlowToCFGPass::runOnFunction() {
   265    OwningRewritePatternList patterns;
   266    populateLoopToStdConversionPatterns(patterns, &getContext());
   267    ConversionTarget target(getContext());
   268    target.addLegalDialect<StandardOpsDialect>();
   269    if (failed(applyPartialConversion(getFunction(), target, patterns)))
   270      signalPassFailure();
   271  }
   272  
   273  FunctionPassBase *mlir::createConvertToCFGPass() {
   274    return new ControlFlowToCFGPass();
   275  }
   276  
   277  static PassRegistration<ControlFlowToCFGPass>
   278      pass("lower-to-cfg", "Convert control flow operations to ");