github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (about) 1 //===- GPUToSPIRV.cpp - MLIR SPIR-V lowering passes -----------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a pass to convert a kernel function in the GPU Dialect 19 // into a spv.module operation 20 // 21 //===----------------------------------------------------------------------===// 22 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" 23 #include "mlir/Dialect/GPU/GPUDialect.h" 24 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 25 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 26 #include "mlir/Pass/Pass.h" 27 28 using namespace mlir; 29 30 namespace { 31 32 /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation 33 /// builin variables. 34 template <typename OpTy, spirv::BuiltIn builtin> 35 class LaunchConfigConversion : public SPIRVOpLowering<OpTy> { 36 public: 37 using SPIRVOpLowering<OpTy>::SPIRVOpLowering; 38 39 PatternMatchResult 40 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 41 ConversionPatternRewriter &rewriter) const override; 42 }; 43 44 /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the 45 /// attribute gpu.kernel) within a spv.module. 46 class KernelFnConversion final : public SPIRVOpLowering<FuncOp> { 47 public: 48 using SPIRVOpLowering<FuncOp>::SPIRVOpLowering; 49 50 PatternMatchResult 51 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 52 ConversionPatternRewriter &rewriter) const override; 53 }; 54 } // namespace 55 56 template <typename OpTy, spirv::BuiltIn builtin> 57 PatternMatchResult LaunchConfigConversion<OpTy, builtin>::matchAndRewrite( 58 Operation *op, ArrayRef<Value *> operands, 59 ConversionPatternRewriter &rewriter) const { 60 auto dimAttr = op->getAttrOfType<StringAttr>("dimension"); 61 if (!dimAttr) { 62 return this->matchFailure(); 63 } 64 int32_t index = 0; 65 if (dimAttr.getValue() == "x") { 66 index = 0; 67 } else if (dimAttr.getValue() == "y") { 68 index = 1; 69 } else if (dimAttr.getValue() == "z") { 70 index = 2; 71 } else { 72 return this->matchFailure(); 73 } 74 75 // SPIR-V invocation builtin variables are a vector of type <3xi32> 76 auto spirvBuiltin = this->loadFromBuiltinVariable(op, builtin, rewriter); 77 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 78 op, rewriter.getIntegerType(32), spirvBuiltin, 79 rewriter.getI32ArrayAttr({index})); 80 return this->matchSuccess(); 81 } 82 83 PatternMatchResult 84 KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 85 ConversionPatternRewriter &rewriter) const { 86 auto funcOp = cast<FuncOp>(op); 87 FuncOp newFuncOp; 88 if (!gpu::GPUDialect::isKernel(funcOp)) { 89 return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter, 90 newFuncOp)) 91 ? matchSuccess() 92 : matchFailure(); 93 } 94 95 if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter, 96 newFuncOp))) { 97 return matchFailure(); 98 } 99 newFuncOp.getOperation()->removeAttr(Identifier::get( 100 gpu::GPUDialect::getKernelFuncAttrName(), op->getContext())); 101 return matchSuccess(); 102 } 103 104 namespace { 105 /// Pass to lower GPU Dialect to SPIR-V. The pass only converts those functions 106 /// that have the "gpu.kernel" attribute, i.e. those functions that are 107 /// referenced in gpu::LaunchKernelOp operations. For each such function 108 /// 109 /// 1) Create a spirv::ModuleOp, and clone the function into spirv::ModuleOp 110 /// (the original function is still needed by the gpu::LaunchKernelOp, so cannot 111 /// replace it). 112 /// 113 /// 2) Lower the body of the spirv::ModuleOp. 114 class GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> { 115 void runOnModule() override; 116 }; 117 } // namespace 118 119 void GPUToSPIRVPass::runOnModule() { 120 auto context = &getContext(); 121 auto module = getModule(); 122 123 SmallVector<Operation *, 4> spirvModules; 124 for (auto funcOp : module.getOps<FuncOp>()) { 125 if (gpu::GPUDialect::isKernel(funcOp)) { 126 OpBuilder builder(module.getBodyRegion()); 127 // Create a new spirv::ModuleOp for this function, and clone the 128 // function into it. 129 // TODO : Generalize this to account for different extensions, 130 // capabilities, extended_instruction_sets, other addressing models 131 // and memory models. 132 auto spvModule = builder.create<spirv::ModuleOp>( 133 funcOp.getLoc(), 134 builder.getI32IntegerAttr( 135 static_cast<int32_t>(spirv::AddressingModel::Logical)), 136 builder.getI32IntegerAttr( 137 static_cast<int32_t>(spirv::MemoryModel::VulkanKHR))); 138 OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0)); 139 moduleBuilder.clone(*funcOp.getOperation()); 140 spirvModules.push_back(spvModule); 141 } 142 } 143 144 /// Dialect conversion to lower the functions with the spirv::ModuleOps. 145 SPIRVBasicTypeConverter basicTypeConverter(context); 146 SPIRVTypeConverter typeConverter(&basicTypeConverter); 147 OwningRewritePatternList patterns; 148 patterns.insert< 149 KernelFnConversion, 150 LaunchConfigConversion<gpu::BlockDim, spirv::BuiltIn::WorkgroupSize>, 151 LaunchConfigConversion<gpu::BlockId, spirv::BuiltIn::WorkgroupId>, 152 LaunchConfigConversion<gpu::GridDim, spirv::BuiltIn::NumWorkgroups>, 153 LaunchConfigConversion<gpu::ThreadId, spirv::BuiltIn::LocalInvocationId>>( 154 context, typeConverter); 155 populateStandardToSPIRVPatterns(context, patterns); 156 157 ConversionTarget target(*context); 158 target.addLegalDialect<spirv::SPIRVDialect>(); 159 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) { 160 return basicTypeConverter.isSignatureLegal(Op.getType()); 161 }); 162 163 if (failed(applyFullConversion(spirvModules, target, patterns, 164 &typeConverter))) { 165 return signalPassFailure(); 166 } 167 } 168 169 ModulePassBase *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); } 170 171 static PassRegistration<GPUToSPIRVPass> 172 pass("convert-gpu-to-spirv", "Convert GPU dialect to SPIR-V dialect");