github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (about)

     1  //===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements a pass to convert a kernel function in the GPU Dialect
    19  // into a spv.module operation
    20  //
    21  //===----------------------------------------------------------------------===//
    22  #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
    23  #include "mlir/Dialect/GPU/GPUDialect.h"
    24  #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
    25  #include "mlir/Dialect/SPIRV/SPIRVOps.h"
    26  #include "mlir/Pass/Pass.h"
    27  
    28  using namespace mlir;
    29  
    30  namespace {
    31  
    32  /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
    33  /// builin variables.
    34  template <typename OpTy, spirv::BuiltIn builtin>
    35  class LaunchConfigConversion : public SPIRVOpLowering<OpTy> {
    36  public:
    37    using SPIRVOpLowering<OpTy>::SPIRVOpLowering;
    38  
    39    PatternMatchResult
    40    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
    41                    ConversionPatternRewriter &rewriter) const override;
    42  };
    43  
    44  /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
    45  /// attribute gpu.kernel) within a spv.module.
    46  class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
    47  public:
    48    using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
    49  
    50    PatternMatchResult
    51    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
    52                    ConversionPatternRewriter &rewriter) const override;
    53  };
    54  } // namespace
    55  
    56  template <typename OpTy, spirv::BuiltIn builtin>
    57  PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite(
    58      Operation *op, ArrayRef<Value *> operands,
    59      ConversionPatternRewriter &rewriter) const {
    60    auto dimAttr = op->getAttrOfType<StringAttr>("dimension");
    61    if (!dimAttr) {
    62      return this->matchFailure();
    63    }
    64    int32_t index = 0;
    65    if (dimAttr.getValue() == "x") {
    66      index = 0;
    67    } else if (dimAttr.getValue() == "y") {
    68      index = 1;
    69    } else if (dimAttr.getValue() == "z") {
    70      index = 2;
    71    } else {
    72      return this->matchFailure();
    73    }
    74  
    75    // SPIR-V invocation builtin variables are a vector of type <3xi32>
    76    auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter);
    77    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
    78        op, rewriter.getIntegerType(32), spirvBuiltin,
    79        rewriter.getI32ArrayAttr({index}));
    80    return this->matchSuccess();
    81  }
    82  
    83  PatternMatchResult
    84  KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
    85                                      ConversionPatternRewriter &rewriter) const {
    86    auto funcOp = cast<FuncOp>(op);
    87    FuncOp newFuncOp;
    88    if (!gpu::GPUDialect::isKernel(funcOp)) {
    89      return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
    90                                     newFuncOp))
    91                 ? matchSuccess()
    92                 : matchFailure();
    93    }
    94  
    95    if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
    96                                    newFuncOp))) {
    97      return matchFailure();
    98    }
    99    newFuncOp.getOperation()->removeAttr(Identifier::get(
   100        gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
   101    return matchSuccess();
   102  }
   103  
   104  namespace {
   105  /// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions
   106  /// that have the "gpu.kernel" attribute, i.e. those functions that are
   107  /// referenced in gpu::LaunchKernelOp operations. For each such function
   108  ///
   109  /// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp
   110  /// (the original function is still needed by the gpu::LaunchKernelOp, so cannot
   111  /// replace it).
   112  ///
   113  /// 2) Lower the body of the spirv::ModuleOp.
   114  class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
   115    void runOnModule() override;
   116  };
   117  } // namespace
   118  
   119  void GPUToSPIRVPass::runOnModule() {
   120    auto context = &getContext();
   121    auto module = getModule();
   122  
   123    SmallVector<Operation *, 4> spirvModules;
   124    for (auto funcOp : module.getOps<FuncOp>()) {
   125      if (gpu::GPUDialect::isKernel(funcOp)) {
   126        OpBuilder builder(module.getBodyRegion());
   127        // Create a new spirv::ModuleOp for this function, and clone the
   128        // function into it.
   129        // TODO : Generalize this to account for different extensions,
   130        // capabilities, extended_instruction_sets, other addressing models
   131        // and memory models.
   132        auto spvModule = builder.create<spirv::ModuleOp>(
   133            funcOp.getLoc(),
   134            builder.getI32IntegerAttr(
   135                static_cast<int32_t>(spirv::AddressingModel::Logical)),
   136            builder.getI32IntegerAttr(
   137                static_cast<int32_t>(spirv::MemoryModel::VulkanKHR)));
   138        OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
   139        moduleBuilder.clone(*funcOp.getOperation());
   140        spirvModules.push_back(spvModule);
   141      }
   142    }
   143  
   144    /// Dialect conversion to lower the functions with the spirv::ModuleOps.
   145    SPIRVBasicTypeConverter basicTypeConverter(context);
   146    SPIRVTypeConverter typeConverter(&basicTypeConverter);
   147    OwningRewritePatternList patterns;
   148    patterns.insert<
   149        KernelFnConversion,
   150        LaunchConfigConversion<gpu::BlockDim, spirv::BuiltIn::WorkgroupSize>,
   151        LaunchConfigConversion<gpu::BlockId, spirv::BuiltIn::WorkgroupId>,
   152        LaunchConfigConversion<gpu::GridDim, spirv::BuiltIn::NumWorkgroups>,
   153        LaunchConfigConversion<gpu::ThreadId, spirv::BuiltIn::LocalInvocationId>>(
   154        context, typeConverter);
   155    populateStandardToSPIRVPatterns(context, patterns);
   156  
   157    ConversionTarget target(*context);
   158    target.addLegalDialect<spirv::SPIRVDialect>();
   159    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) {
   160      return basicTypeConverter.isSignatureLegal(Op.getType());
   161    });
   162  
   163    if (failed(applyFullConversion(spirvModules, target, patterns,
   164                                   &typeConverter))) {
   165      return signalPassFailure();
   166    }
   167  }
   168  
   169  ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
   170  
   171  static PassRegistration<GPUToSPIRVPass>
   172      pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");