github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (about)

     1  //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements the GPU kernel-related dialect and its operations.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Dialect/GPU/GPUDialect.h"
    23  #include "mlir/Dialect/StandardOps/Ops.h"
    24  #include "mlir/IR/Builders.h"
    25  #include "mlir/IR/Function.h"
    26  #include "mlir/IR/Module.h"
    27  #include "mlir/IR/OpImplementation.h"
    28  #include "mlir/IR/PatternMatch.h"
    29  #include "mlir/IR/StandardTypes.h"
    30  
    31  using namespace mlir;
    32  using namespace mlir::gpu;
    33  
    34  StringRef GPUDialect::getDialectName() { return "gpu"; }
    35  
    36  bool GPUDialect::isKernel(FuncOp function) {
    37    UnitAttr isKernelAttr =
    38        function.getAttrOfType<UnitAttr>(getKernelFuncAttrName());
    39    return static_cast<bool>(isKernelAttr);
    40  }
    41  
    42  GPUDialect::GPUDialect(MLIRContext *context)
    43      : Dialect(getDialectName(), context) {
    44    addOperations<LaunchOp, LaunchFuncOp,
    45  #define GET_OP_LIST
    46  #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
    47                  >();
    48  }
    49  
    50  template <typename T> static LogicalResult verifyIndexOp(T op) {
    51    auto dimension = op.dimension();
    52    if (dimension != "x" && dimension != "y" && dimension != "z")
    53      return op.emitError("dimension \"") << dimension << "\" is invalid";
    54    return success();
    55  }
    56  
    57  #define GET_OP_CLASSES
    58  #include "mlir/Dialect/GPU/GPUOps.cpp.inc"
    59  
    60  //===----------------------------------------------------------------------===//
    61  // LaunchOp
    62  //===----------------------------------------------------------------------===//
    63  
    64  static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) {
    65    SmallVector<Type, 4> types;
    66    types.reserve(values.size());
    67    for (Value *v : values)
    68      types.push_back(v->getType());
    69    return types;
    70  }
    71  
    72  void LaunchOp::build(Builder *builder, OperationState *result, Value *gridSizeX,
    73                       Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
    74                       Value *blockSizeY, Value *blockSizeZ,
    75                       ArrayRef<Value *> operands) {
    76    // Add grid and block sizes as op operands, followed by the data operands.
    77    result->addOperands(
    78        {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
    79    result->addOperands(operands);
    80  
    81    // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
    82    // where the first kNumConfigRegionAttributes arguments have `index` type and
    83    // the rest have the same types as the data operands.
    84    Region *kernelRegion = result->addRegion();
    85    Block *body = new Block();
    86    body->addArguments(
    87        std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
    88    body->addArguments(getValueTypes(operands));
    89    kernelRegion->push_back(body);
    90  }
    91  
    92  Region &LaunchOp::getBody() { return getOperation()->getRegion(0); }
    93  
    94  KernelDim3 LaunchOp::getBlockIds() {
    95    assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
    96    auto args = getBody().getBlocks().front().getArguments();
    97    return KernelDim3{args[0], args[1], args[2]};
    98  }
    99  
   100  KernelDim3 LaunchOp::getThreadIds() {
   101    assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
   102    auto args = getBody().getBlocks().front().getArguments();
   103    return KernelDim3{args[3], args[4], args[5]};
   104  }
   105  
   106  KernelDim3 LaunchOp::getGridSize() {
   107    assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
   108    auto args = getBody().getBlocks().front().getArguments();
   109    return KernelDim3{args[6], args[7], args[8]};
   110  }
   111  
   112  KernelDim3 LaunchOp::getBlockSize() {
   113    assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty.");
   114    auto args = getBody().getBlocks().front().getArguments();
   115    return KernelDim3{args[9], args[10], args[11]};
   116  }
   117  
   118  LaunchOp::operand_range LaunchOp::getKernelOperandValues() {
   119    return llvm::drop_begin(getOperands(), kNumConfigOperands);
   120  }
   121  
   122  LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() {
   123    return llvm::drop_begin(getOperandTypes(), kNumConfigOperands);
   124  }
   125  
   126  KernelDim3 LaunchOp::getGridSizeOperandValues() {
   127    return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
   128  }
   129  
   130  KernelDim3 LaunchOp::getBlockSizeOperandValues() {
   131    return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
   132  }
   133  
   134  llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() {
   135    auto args = getBody().getBlocks().front().getArguments();
   136    return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes);
   137  }
   138  
   139  LogicalResult LaunchOp::verify() {
   140    // Kernel launch takes kNumConfigOperands leading operands for grid/block
   141    // sizes and transforms them into kNumConfigRegionAttributes region arguments
   142    // for block/thread identifiers and grid/block sizes.
   143    if (!getBody().empty()) {
   144      Block &entryBlock = getBody().front();
   145      if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands())
   146        return emitError("unexpected number of region arguments");
   147    }
   148  
   149    // Block terminators without successors are expected to exit the kernel region
   150    // and must be `gpu.launch`.
   151    for (Block &block : getBody()) {
   152      if (block.empty())
   153        continue;
   154      if (block.back().getNumSuccessors() != 0)
   155        continue;
   156      if (!isa<gpu::Return>(&block.back())) {
   157        return block.back()
   158                   .emitError("expected 'gpu.terminator' or a terminator with "
   159                              "successors")
   160                   .attachNote(getLoc())
   161               << "in '" << getOperationName() << "' body region";
   162      }
   163    }
   164  
   165    return success();
   166  }
   167  
   168  // Pretty-print the kernel grid/block size assignment as
   169  //   (%iter-x, %iter-y, %iter-z) in
   170  //   (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
   171  // where %size-* and %iter-* will correspond to the body region arguments.
   172  static void printSizeAssignment(OpAsmPrinter *p, KernelDim3 size,
   173                                  ArrayRef<Value *> operands, KernelDim3 ids) {
   174    *p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
   175    *p << *size.x << " = " << *operands[0] << ", ";
   176    *p << *size.y << " = " << *operands[1] << ", ";
   177    *p << *size.z << " = " << *operands[2] << ')';
   178  }
   179  
   180  void LaunchOp::print(OpAsmPrinter *p) {
   181    SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end());
   182    ArrayRef<Value *> operands(operandContainer);
   183  
   184    // Print the launch configuration.
   185    *p << getOperationName() << ' ' << getBlocksKeyword();
   186    printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds());
   187    *p << ' ' << getThreadsKeyword();
   188    printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds());
   189  
   190    // From now on, the first kNumConfigOperands operands corresponding to grid
   191    // and block sizes are irrelevant, so we can drop them.
   192    operands = operands.drop_front(kNumConfigOperands);
   193  
   194    // Print the data argument remapping.
   195    if (!getBody().empty() && !operands.empty()) {
   196      *p << ' ' << getArgsKeyword() << '(';
   197      for (unsigned i = 0, e = operands.size(); i < e; ++i) {
   198        if (i != 0)
   199          *p << ", ";
   200        *p << *getBody().front().getArgument(kNumConfigRegionAttributes + i)
   201           << " = " << *operands[i];
   202      }
   203      *p << ") ";
   204    }
   205  
   206    // Print the types of data arguments.
   207    if (!operands.empty()) {
   208      *p << ": ";
   209      for (unsigned i = 0, e = operands.size(); i < e; ++i) {
   210        if (i != 0)
   211          *p << ", ";
   212        *p << operands[i]->getType();
   213      }
   214    }
   215  
   216    p->printRegion(getBody(), /*printEntryBlockArgs=*/false);
   217    p->printOptionalAttrDict(getAttrs());
   218  }
   219  
   220  // Parse the size assignment blocks for blocks and threads.  These have the form
   221  //   (%region_arg, %region_arg, %region_arg) in
   222  //   (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
   223  // where %region_arg are percent-identifiers for the region arguments to be
   224  // introduced futher (SSA defs), and %operand are percent-identifiers for the
   225  // SSA value uses.
   226  static ParseResult
   227  parseSizeAssignment(OpAsmParser *parser,
   228                      MutableArrayRef<OpAsmParser::OperandType> sizes,
   229                      MutableArrayRef<OpAsmParser::OperandType> regionSizes,
   230                      MutableArrayRef<OpAsmParser::OperandType> indices) {
   231    assert(indices.size() == 3 && "space for three indices expected");
   232    SmallVector<OpAsmParser::OperandType, 3> args;
   233    if (parser->parseRegionArgumentList(args, /*requiredOperandCount=*/3,
   234                                        OpAsmParser::Delimiter::Paren) ||
   235        parser->parseKeyword("in") || parser->parseLParen())
   236      return failure();
   237    std::move(args.begin(), args.end(), indices.begin());
   238  
   239    for (int i = 0; i < 3; ++i) {
   240      if (i != 0 && parser->parseComma())
   241        return failure();
   242      if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() ||
   243          parser->parseOperand(sizes[i]))
   244        return failure();
   245    }
   246  
   247    return parser->parseRParen();
   248  }
   249  
   250  // Parses a Launch operation.
   251  // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
   252  //                           `threads` `(` ssa-id-list `)` `in` ssa-reassignment
   253  //                             (`args` ssa-reassignment `:` type-list)?
   254  //                             region attr-dict?
   255  // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
   256  ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
   257    // Sizes of the grid and block.
   258    SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes(
   259        kNumConfigOperands);
   260    MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes);
   261  
   262    // Actual (data) operands passed to the kernel.
   263    SmallVector<OpAsmParser::OperandType, 4> dataOperands;
   264  
   265    // Region arguments to be created.
   266    SmallVector<OpAsmParser::OperandType, 16> regionArgs(
   267        kNumConfigRegionAttributes);
   268    MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs);
   269  
   270    // Parse the size assignment segments: the first segment assigns grid siezs
   271    // and defines values for block identifiers; the second segment assigns block
   272    // sies and defines values for thread identifiers.  In the region argument
   273    // list, identifiers preceed sizes, and block-related values preceed
   274    // thread-related values.
   275    if (parser->parseKeyword(getBlocksKeyword().data()) ||
   276        parseSizeAssignment(parser, sizesRef.take_front(3),
   277                            regionArgsRef.slice(6, 3),
   278                            regionArgsRef.slice(0, 3)) ||
   279        parser->parseKeyword(getThreadsKeyword().data()) ||
   280        parseSizeAssignment(parser, sizesRef.drop_front(3),
   281                            regionArgsRef.slice(9, 3),
   282                            regionArgsRef.slice(3, 3)) ||
   283        parser->resolveOperands(sizes, parser->getBuilder().getIndexType(),
   284                                result->operands))
   285      return failure();
   286  
   287    // If kernel argument renaming segment is present, parse it.  When present,
   288    // the segment should have at least one element.  If this segment is present,
   289    // so is the trailing type list.  Parse it as well and use the parsed types
   290    // to resolve the operands passed to the kernel arguments.
   291    SmallVector<Type, 4> dataTypes;
   292    if (!parser->parseOptionalKeyword(getArgsKeyword().data())) {
   293      llvm::SMLoc argsLoc = parser->getCurrentLocation();
   294  
   295      regionArgs.push_back({});
   296      dataOperands.push_back({});
   297      if (parser->parseLParen() ||
   298          parser->parseRegionArgument(regionArgs.back()) ||
   299          parser->parseEqual() || parser->parseOperand(dataOperands.back()))
   300        return failure();
   301  
   302      while (!parser->parseOptionalComma()) {
   303        regionArgs.push_back({});
   304        dataOperands.push_back({});
   305        if (parser->parseRegionArgument(regionArgs.back()) ||
   306            parser->parseEqual() || parser->parseOperand(dataOperands.back()))
   307          return failure();
   308      }
   309  
   310      if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) ||
   311          parser->resolveOperands(dataOperands, dataTypes, argsLoc,
   312                                  result->operands))
   313        return failure();
   314    }
   315  
   316    // Introduce the body region and parse it.  The region has
   317    // kNumConfigRegionAttributes leading arguments that correspond to
   318    // block/thread identifiers and grid/block sizes, all of the `index` type.
   319    // Follow the actual kernel arguments.
   320    Type index = parser->getBuilder().getIndexType();
   321    dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index);
   322    Region *body = result->addRegion();
   323    return failure(parser->parseRegion(*body, regionArgs, dataTypes) ||
   324                   parser->parseOptionalAttributeDict(result->attributes));
   325  }
   326  
   327  void LaunchOp::eraseKernelArgument(unsigned index) {
   328    Block &entryBlock = getBody().front();
   329    assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes &&
   330           "kernel argument index overflow");
   331    entryBlock.eraseArgument(kNumConfigRegionAttributes + index);
   332    getOperation()->eraseOperand(kNumConfigOperands + index);
   333  }
   334  
   335  namespace {
   336  // Clone any known constants passed as operands to the kernel into its body.
   337  class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
   338    using OpRewritePattern<LaunchOp>::OpRewritePattern;
   339  
   340    PatternMatchResult matchAndRewrite(LaunchOp launchOp,
   341                                       PatternRewriter &rewriter) const override {
   342      auto oringInsertionPoint = rewriter.saveInsertionPoint();
   343      rewriter.setInsertionPointToStart(&launchOp.getBody().front());
   344  
   345      // Traverse operands passed to kernel and check if some of them are known
   346      // constants.  If so, clone the constant operation inside the kernel region
   347      // and use it instead of passing the value from the parent region.  Perform
   348      // the traversal in the inverse order to simplify index arithmetics when
   349      // dropping arguments.
   350      SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(),
   351                                       launchOp.getKernelOperandValues().end());
   352      SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(),
   353                                         launchOp.getKernelArguments().end());
   354      bool found = false;
   355      for (unsigned i = operands.size(); i > 0; --i) {
   356        unsigned index = i - 1;
   357        Value *operand = operands[index];
   358        if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
   359          continue;
   360        }
   361  
   362        found = true;
   363        Value *internalConstant =
   364            rewriter.clone(*operand->getDefiningOp())->getResult(0);
   365        Value *kernelArg = kernelArgs[index];
   366        kernelArg->replaceAllUsesWith(internalConstant);
   367        launchOp.eraseKernelArgument(index);
   368      }
   369      rewriter.restoreInsertionPoint(oringInsertionPoint);
   370  
   371      if (!found)
   372        return matchFailure();
   373  
   374      rewriter.updatedRootInPlace(launchOp);
   375      return matchSuccess();
   376    }
   377  };
   378  } // end namespace
   379  
   380  void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   381                                             MLIRContext *context) {
   382    results.insert<PropagateConstantBounds>(context);
   383  }
   384  
   385  //===----------------------------------------------------------------------===//
   386  // LaunchFuncOp
   387  //===----------------------------------------------------------------------===//
   388  
   389  void LaunchFuncOp::build(Builder *builder, OperationState *result,
   390                           FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
   391                           Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
   392                           Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
   393    // Add grid and block sizes as op operands, followed by the data operands.
   394    result->addOperands(
   395        {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
   396    result->addOperands(kernelOperands);
   397    result->addAttribute(getKernelAttrName(),
   398                         builder->getSymbolRefAttr(kernelFunc));
   399  }
   400  
   401  void LaunchFuncOp::build(Builder *builder, OperationState *result,
   402                           FuncOp kernelFunc, KernelDim3 gridSize,
   403                           KernelDim3 blockSize,
   404                           ArrayRef<Value *> kernelOperands) {
   405    build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
   406          blockSize.x, blockSize.y, blockSize.z, kernelOperands);
   407  }
   408  
   409  StringRef LaunchFuncOp::kernel() {
   410    return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue();
   411  }
   412  
   413  unsigned LaunchFuncOp::getNumKernelOperands() {
   414    return getNumOperands() - kNumConfigOperands;
   415  }
   416  
   417  Value *LaunchFuncOp::getKernelOperand(unsigned i) {
   418    return getOperation()->getOperand(i + kNumConfigOperands);
   419  }
   420  
   421  KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
   422    return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
   423  }
   424  
   425  KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
   426    return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
   427  }
   428  
   429  LogicalResult LaunchFuncOp::verify() {
   430    auto kernelAttr = this->getAttr(getKernelAttrName());
   431    if (!kernelAttr) {
   432      return emitOpError("attribute 'kernel' must be specified");
   433    } else if (!kernelAttr.isa<SymbolRefAttr>()) {
   434      return emitOpError("attribute 'kernel' must be a function");
   435    }
   436  
   437    auto module = getParentOfType<ModuleOp>();
   438    FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel());
   439    if (!kernelFunc)
   440      return emitError() << "kernel function '" << kernelAttr << "' is undefined";
   441  
   442    if (!kernelFunc.getAttrOfType<mlir::UnitAttr>(
   443            GPUDialect::getKernelFuncAttrName())) {
   444      return emitError("kernel function is missing the '")
   445             << GPUDialect::getKernelFuncAttrName() << "' attribute";
   446    }
   447    unsigned numKernelFuncArgs = kernelFunc.getNumArguments();
   448    if (getNumKernelOperands() != numKernelFuncArgs) {
   449      return emitOpError("got ")
   450             << getNumKernelOperands() << " kernel operands but expected "
   451             << numKernelFuncArgs;
   452    }
   453    auto functionType = kernelFunc.getType();
   454    for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
   455      if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
   456        return emitOpError("type of function argument ")
   457               << i << " does not match";
   458      }
   459    }
   460    return success();
   461  }