github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/LoopOps/LoopOps.cpp (about)

     1  //===- Ops.cpp - Loop MLIR Operations -------------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/LoopOps/LoopOps.h"
    19  #include "mlir/Dialect/StandardOps/Ops.h"
    20  #include "mlir/IR/AffineExpr.h"
    21  #include "mlir/IR/AffineMap.h"
    22  #include "mlir/IR/Builders.h"
    23  #include "mlir/IR/Function.h"
    24  #include "mlir/IR/Matchers.h"
    25  #include "mlir/IR/Module.h"
    26  #include "mlir/IR/OpImplementation.h"
    27  #include "mlir/IR/PatternMatch.h"
    28  #include "mlir/IR/StandardTypes.h"
    29  #include "mlir/IR/Value.h"
    30  #include "mlir/Support/MathExtras.h"
    31  #include "mlir/Support/STLExtras.h"
    32  
    33  using namespace mlir;
    34  using namespace mlir::loop;
    35  
    36  //===----------------------------------------------------------------------===//
    37  // LoopOpsDialect
    38  //===----------------------------------------------------------------------===//
    39  
    40  LoopOpsDialect::LoopOpsDialect(MLIRContext *context)
    41      : Dialect(getDialectNamespace(), context) {
    42    addOperations<
    43  #define GET_OP_LIST
    44  #include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"
    45        >();
    46  }
    47  
    48  //===----------------------------------------------------------------------===//
    49  // ForOp
    50  //===----------------------------------------------------------------------===//
    51  
    52  void ForOp::build(Builder *builder, OperationState *result, Value *lb,
    53                    Value *ub, Value *step) {
    54    result->addOperands({lb, ub, step});
    55    Region *bodyRegion = result->addRegion();
    56    ForOp::ensureTerminator(*bodyRegion, *builder, result->location);
    57    bodyRegion->front().addArgument(builder->getIndexType());
    58  }
    59  
    60  LogicalResult verify(ForOp op) {
    61    if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step()->getDefiningOp()))
    62      if (cst.getValue() <= 0)
    63        return op.emitOpError("constant step operand must be nonnegative");
    64  
    65    // Check that the body defines as single block argument for the induction
    66    // variable.
    67    auto *body = op.getBody();
    68    if (body->getNumArguments() != 1 ||
    69        !body->getArgument(0)->getType().isIndex())
    70      return op.emitOpError("expected body to have a single index argument for "
    71                            "the induction variable");
    72    return success();
    73  }
    74  
    75  static void print(OpAsmPrinter *p, ForOp op) {
    76    *p << op.getOperationName() << " " << *op.getInductionVar() << " = "
    77       << *op.lowerBound() << " to " << *op.upperBound() << " step "
    78       << *op.step();
    79    p->printRegion(op.region(),
    80                   /*printEntryBlockArgs=*/false,
    81                   /*printBlockTerminators=*/false);
    82    p->printOptionalAttrDict(op.getAttrs());
    83  }
    84  
    85  static ParseResult parseForOp(OpAsmParser *parser, OperationState *result) {
    86    auto &builder = parser->getBuilder();
    87    OpAsmParser::OperandType inductionVariable, lb, ub, step;
    88    // Parse the induction variable followed by '='.
    89    if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
    90      return failure();
    91  
    92    // Parse loop bounds.
    93    Type indexType = builder.getIndexType();
    94    if (parser->parseOperand(lb) ||
    95        parser->resolveOperand(lb, indexType, result->operands) ||
    96        parser->parseKeyword("to") || parser->parseOperand(ub) ||
    97        parser->resolveOperand(ub, indexType, result->operands) ||
    98        parser->parseKeyword("step") || parser->parseOperand(step) ||
    99        parser->resolveOperand(step, indexType, result->operands))
   100      return failure();
   101  
   102    // Parse the body region.
   103    Region *body = result->addRegion();
   104    if (parser->parseRegion(*body, inductionVariable, indexType))
   105      return failure();
   106  
   107    ForOp::ensureTerminator(*body, builder, result->location);
   108  
   109    // Parse the optional attribute list.
   110    if (parser->parseOptionalAttributeDict(result->attributes))
   111      return failure();
   112  
   113    return success();
   114  }
   115  
   116  ForOp mlir::loop::getForInductionVarOwner(Value *val) {
   117    auto *ivArg = dyn_cast<BlockArgument>(val);
   118    if (!ivArg)
   119      return ForOp();
   120    assert(ivArg->getOwner() && "unlinked block argument");
   121    auto *containingInst = ivArg->getOwner()->getParentOp();
   122    return dyn_cast_or_null<ForOp>(containingInst);
   123  }
   124  
   125  //===----------------------------------------------------------------------===//
   126  // IfOp
   127  //===----------------------------------------------------------------------===//
   128  
   129  void IfOp::build(Builder *builder, OperationState *result, Value *cond,
   130                   bool withElseRegion) {
   131    result->addOperands(cond);
   132    Region *thenRegion = result->addRegion();
   133    Region *elseRegion = result->addRegion();
   134    IfOp::ensureTerminator(*thenRegion, *builder, result->location);
   135    if (withElseRegion)
   136      IfOp::ensureTerminator(*elseRegion, *builder, result->location);
   137  }
   138  
   139  static LogicalResult verify(IfOp op) {
   140    // Verify that the entry of each child region does not have arguments.
   141    for (auto &region : op.getOperation()->getRegions()) {
   142      if (region.empty())
   143        continue;
   144  
   145      for (auto &b : region)
   146        if (b.getNumArguments() != 0)
   147          return op.emitOpError(
   148              "requires that child entry blocks have no arguments");
   149    }
   150    return success();
   151  }
   152  
   153  static ParseResult parseIfOp(OpAsmParser *parser, OperationState *result) {
   154    // Create the regions for 'then'.
   155    result->regions.reserve(2);
   156    Region *thenRegion = result->addRegion();
   157    Region *elseRegion = result->addRegion();
   158  
   159    auto &builder = parser->getBuilder();
   160    OpAsmParser::OperandType cond;
   161    Type i1Type = builder.getIntegerType(1);
   162    if (parser->parseOperand(cond) ||
   163        parser->resolveOperand(cond, i1Type, result->operands))
   164      return failure();
   165  
   166    // Parse the 'then' region.
   167    if (parser->parseRegion(*thenRegion, {}, {}))
   168      return failure();
   169    IfOp::ensureTerminator(*thenRegion, parser->getBuilder(), result->location);
   170  
   171    // If we find an 'else' keyword then parse the 'else' region.
   172    if (!parser->parseOptionalKeyword("else")) {
   173      if (parser->parseRegion(*elseRegion, {}, {}))
   174        return failure();
   175      IfOp::ensureTerminator(*elseRegion, parser->getBuilder(), result->location);
   176    }
   177  
   178    // Parse the optional attribute list.
   179    if (parser->parseOptionalAttributeDict(result->attributes))
   180      return failure();
   181  
   182    return success();
   183  }
   184  
   185  static void print(OpAsmPrinter *p, IfOp op) {
   186    *p << IfOp::getOperationName() << " " << *op.condition();
   187    p->printRegion(op.thenRegion(),
   188                   /*printEntryBlockArgs=*/false,
   189                   /*printBlockTerminators=*/false);
   190  
   191    // Print the 'else' regions if it exists and has a block.
   192    auto &elseRegion = op.elseRegion();
   193    if (!elseRegion.empty()) {
   194      *p << " else";
   195      p->printRegion(elseRegion,
   196                     /*printEntryBlockArgs=*/false,
   197                     /*printBlockTerminators=*/false);
   198    }
   199  
   200    p->printOptionalAttrDict(op.getAttrs());
   201  }
   202  
   203  //===----------------------------------------------------------------------===//
   204  // TableGen'd op method definitions
   205  //===----------------------------------------------------------------------===//
   206  
   207  #define GET_OP_CLASSES
   208  #include "mlir/Dialect/LoopOps/LoopOps.cpp.inc"