github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (about)

     1  //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements a pass to generate NVVMIR operations for higher-level
    19  // GPU operations.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
    24  
    25  #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
    26  #include "mlir/Dialect/GPU/GPUDialect.h"
    27  #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
    28  #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
    29  #include "mlir/IR/Builders.h"
    30  #include "mlir/IR/StandardTypes.h"
    31  #include "mlir/Pass/Pass.h"
    32  #include "mlir/Pass/PassRegistry.h"
    33  #include "mlir/Transforms/DialectConversion.h"
    34  
    35  #include "llvm/ADT/StringSwitch.h"
    36  
    37  using namespace mlir;
    38  
    39  namespace {
    40  
    41  // Rewriting that replaces the types of a LaunchFunc operation with their
    42  // LLVM counterparts.
    43  struct GPULaunchFuncOpLowering : public LLVMOpLowering {
    44  public:
    45    explicit GPULaunchFuncOpLowering(LLVMTypeConverter &lowering_)
    46        : LLVMOpLowering(gpu::LaunchFuncOp::getOperationName(),
    47                         lowering_.getDialect()->getContext(), lowering_) {}
    48  
    49    // Convert the kernel arguments to an LLVM type, preserve the rest.
    50    PatternMatchResult
    51    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
    52                    ConversionPatternRewriter &rewriter) const override {
    53      rewriter.clone(*op)->setOperands(operands);
    54      return rewriter.replaceOp(op, llvm::None), matchSuccess();
    55    }
    56  };
    57  
    58  // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
    59  // that Op operates on.  Op is assumed to return an `std.index` value and
    60  // XOp, YOp and ZOp are assumed to return an `llvm.i32` value.  Depending on
    61  // `indexBitwidth`, sign-extend or truncate the resulting value to match the
    62  // bitwidth expected by the consumers of the value.
    63  template <typename Op, typename XOp, typename YOp, typename ZOp>
    64  struct GPUIndexIntrinsicOpLowering : public LLVMOpLowering {
    65  private:
    66    enum dimension { X = 0, Y = 1, Z = 2, invalid };
    67    unsigned indexBitwidth;
    68  
    69    static dimension dimensionToIndex(Op op) {
    70      return llvm::StringSwitch<dimension>(op.dimension())
    71          .Case("x", X)
    72          .Case("y", Y)
    73          .Case("z", Z)
    74          .Default(invalid);
    75    }
    76  
    77    static unsigned getIndexBitWidth(LLVMTypeConverter &lowering) {
    78      auto dialect = lowering.getDialect();
    79      return dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
    80    }
    81  
    82  public:
    83    explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_)
    84        : LLVMOpLowering(Op::getOperationName(),
    85                         lowering_.getDialect()->getContext(), lowering_),
    86          indexBitwidth(getIndexBitWidth(lowering_)) {}
    87  
    88    // Convert the kernel arguments to an LLVM type, preserve the rest.
    89    PatternMatchResult
    90    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
    91                    ConversionPatternRewriter &rewriter) const override {
    92      auto loc = op->getLoc();
    93      auto dialect = lowering.getDialect();
    94      Value *newOp;
    95      switch (dimensionToIndex(cast<Op>(op))) {
    96      case X:
    97        newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
    98        break;
    99      case Y:
   100        newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
   101        break;
   102      case Z:
   103        newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
   104        break;
   105      default:
   106        return matchFailure();
   107      }
   108  
   109      if (indexBitwidth > 32) {
   110        newOp = rewriter.create<LLVM::SExtOp>(
   111            loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
   112      } else if (indexBitwidth < 32) {
   113        newOp = rewriter.create<LLVM::TruncOp>(
   114            loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
   115      }
   116  
   117      rewriter.replaceOp(op, {newOp});
   118      return matchSuccess();
   119    }
   120  };
   121  
   122  // A pass that replaces all occurences of GPU operations with their
   123  // corresponding NVVM equivalent.
   124  //
   125  // This pass does not handle launching of kernels. Instead, it is meant to be
   126  // used on the body region of a launch or the body region of a kernel
   127  // function.
   128  class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
   129  public:
   130    void runOnModule() override {
   131      ModuleOp m = getModule();
   132  
   133      OwningRewritePatternList patterns;
   134      LLVMTypeConverter converter(m.getContext());
   135      populateGpuToNVVMConversionPatterns(converter, patterns);
   136  
   137      ConversionTarget target(getContext());
   138      target.addLegalDialect<LLVM::LLVMDialect>();
   139      target.addLegalDialect<NVVM::NVVMDialect>();
   140      target.addDynamicallyLegalOp<FuncOp>(
   141          [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
   142      if (failed(applyPartialConversion(m, target, patterns, &converter)))
   143        signalPassFailure();
   144    }
   145  };
   146  
   147  } // anonymous namespace
   148  
   149  /// Collect a set of patterns to convert from the GPU dialect to NVVM.
   150  void mlir::populateGpuToNVVMConversionPatterns(
   151      LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   152    patterns
   153        .insert<GPULaunchFuncOpLowering,
   154                GPUIndexIntrinsicOpLowering<gpu::ThreadId, NVVM::ThreadIdXOp,
   155                                            NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
   156                GPUIndexIntrinsicOpLowering<gpu::BlockDim, NVVM::BlockDimXOp,
   157                                            NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
   158                GPUIndexIntrinsicOpLowering<gpu::BlockId, NVVM::BlockIdXOp,
   159                                            NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
   160                GPUIndexIntrinsicOpLowering<gpu::GridDim, NVVM::GridDimXOp,
   161                                            NVVM::GridDimYOp, NVVM::GridDimZOp>>(
   162            converter);
   163  }
   164  
   165  std::unique_ptr<ModulePassBase> mlir::createLowerGpuOpsToNVVMOpsPass() {
   166    return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
   167  }
   168  
   169  static PassRegistration<LowerGpuOpsToNVVMOpsPass>
   170      pass("lower-gpu-ops-to-nvvm-ops",
   171           "Generate NVVM operations for gpu operations");