github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp (about) 1 //===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a pass to convert MLIR standard and builtin dialects 19 // into the SPIR-V dialect. 20 // 21 //===----------------------------------------------------------------------===// 22 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" 23 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 24 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 25 #include "mlir/Dialect/StandardOps/Ops.h" 26 27 using namespace mlir; 28 29 //===----------------------------------------------------------------------===// 30 // Type Conversion 31 //===----------------------------------------------------------------------===// 32 33 SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context) 34 : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {} 35 36 Type SPIRVBasicTypeConverter::convertType(Type t) { 37 // Check if the type is SPIR-V supported. If so return the type. 38 if (spirvDialect->isValidSPIRVType(t)) { 39 return t; 40 } 41 42 if (auto indexType = t.dyn_cast<IndexType>()) { 43 // Return I32 for index types. 44 return IntegerType::get(32, t.getContext()); 45 } 46 47 if (auto memRefType = t.dyn_cast<MemRefType>()) { 48 if (memRefType.hasStaticShape()) { 49 // Convert MemrefType to a multi-dimensional spv.array if size is known. 50 auto elementType = memRefType.getElementType(); 51 for (auto size : reverse(memRefType.getShape())) { 52 elementType = spirv::ArrayType::get(elementType, size); 53 } 54 // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need 55 // to support other Storage Classes. 56 return spirv::PointerType::get(elementType, 57 spirv::StorageClass::StorageBuffer); 58 } 59 } 60 return Type(); 61 } 62 63 //===----------------------------------------------------------------------===// 64 // Entry Function signature Conversion 65 //===----------------------------------------------------------------------===// 66 67 LogicalResult 68 SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type, 69 SignatureConversion &result) { 70 // Try to convert the given input type. 71 auto convertedType = basicTypeConverter->convertType(type); 72 // TODO(ravishankarm) : Vulkan spec requires these to be a 73 // spirv::StructType. This is not a SPIR-V requirement, so just making this a 74 // pointer type for now. 75 if (!convertedType) 76 return failure(); 77 // For arguments to entry functions, convert the type into a pointer type if 78 // it is already not one, unless the original type was an index type. 79 // TODO(ravishankarm): For arguments that are of index type, keep the 80 // arguments as the scalar converted type, i.e. i32. These are still not 81 // handled effectively. These are potentially best handled as specialization 82 // constants. 83 if (!convertedType.isa<spirv::PointerType>() && !type.isa<IndexType>()) { 84 // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need 85 // to support other Storage classes. 86 convertedType = spirv::PointerType::get(convertedType, 87 spirv::StorageClass::StorageBuffer); 88 } 89 90 // Add the new inputs. 91 result.addInputs(inputNo, convertedType); 92 return success(); 93 } 94 95 static LogicalResult lowerFunctionImpl( 96 FuncOp funcOp, ArrayRef<Value *> operands, 97 ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, 98 TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) { 99 auto fnType = funcOp.getType(); 100 101 if (fnType.getNumResults()) { 102 return funcOp.emitError("SPIR-V dialect only supports functions with no " 103 "return values right now"); 104 } 105 106 for (auto &argType : enumerate(fnType.getInputs())) { 107 // Get the type of the argument 108 if (failed(typeConverter->convertSignatureArg( 109 argType.index(), argType.value(), signatureConverter))) { 110 return funcOp.emitError("unable to convert argument type ") 111 << argType.value() << " to SPIR-V type"; 112 } 113 } 114 115 // Create a new function with an updated signature. 116 newFuncOp = rewriter.cloneWithoutRegions(funcOp); 117 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 118 newFuncOp.end()); 119 newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(), 120 llvm::None, funcOp.getContext())); 121 122 // Tell the rewriter to convert the region signature. 123 rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); 124 rewriter.replaceOp(funcOp.getOperation(), llvm::None); 125 return success(); 126 } 127 128 namespace mlir { 129 LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands, 130 SPIRVTypeConverter *typeConverter, 131 ConversionPatternRewriter &rewriter, 132 FuncOp &newFuncOp) { 133 auto fnType = funcOp.getType(); 134 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 135 return lowerFunctionImpl(funcOp, operands, rewriter, 136 typeConverter->getBasicTypeConverter(), 137 signatureConverter, newFuncOp); 138 } 139 140 LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands, 141 SPIRVTypeConverter *typeConverter, 142 ConversionPatternRewriter &rewriter, 143 FuncOp &newFuncOp) { 144 auto fnType = funcOp.getType(); 145 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); 146 if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter, 147 signatureConverter, newFuncOp))) { 148 return failure(); 149 } 150 // Create spv.globalVariable ops for each of the arguments. These need to be 151 // bound by the runtime. For now use descriptor_set 0, and arg number as the 152 // binding number. 153 auto module = funcOp.getParentOfType<spirv::ModuleOp>(); 154 if (!module) { 155 return funcOp.emitError("expected op to be within a spv.module"); 156 } 157 auto ip = rewriter.saveInsertionPoint(); 158 rewriter.setInsertionPointToStart(&module.getBlock()); 159 SmallVector<Attribute, 4> interface; 160 for (auto &convertedArgType : 161 llvm::enumerate(signatureConverter.getConvertedTypes())) { 162 // TODO(ravishankarm) : The arguments to the converted function are either 163 // spirv::PointerType or i32 type, the latter due to conversion of index 164 // type to i32. Eventually entry function should be of signature 165 // void(void). Arguments converted to spirv::PointerType, will be made 166 // variables and those converted to i32 will be made specialization 167 // constants. Latter is not implemented. 168 if (!convertedArgType.value().isa<spirv::PointerType>()) { 169 continue; 170 } 171 std::string varName = funcOp.getName().str() + "_arg_" + 172 std::to_string(convertedArgType.index()); 173 auto variableOp = rewriter.create<spirv::GlobalVariableOp>( 174 funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()), 175 rewriter.getStringAttr(varName), nullptr); 176 variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0)); 177 variableOp.setAttr("binding", 178 rewriter.getI32IntegerAttr(convertedArgType.index())); 179 interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name())); 180 } 181 // Create an entry point instruction for this function. 182 // TODO(ravishankarm) : Add execution mode for the entry function 183 rewriter.setInsertionPoint(&(module.getBlock().back())); 184 rewriter.create<spirv::EntryPointOp>( 185 funcOp.getLoc(), 186 rewriter.getI32IntegerAttr( 187 static_cast<int32_t>(spirv::ExecutionModel::GLCompute)), 188 rewriter.getSymbolRefAttr(newFuncOp.getName()), 189 rewriter.getArrayAttr(interface)); 190 rewriter.restoreInsertionPoint(ip); 191 return success(); 192 } 193 } // namespace mlir 194 195 //===----------------------------------------------------------------------===// 196 // Operation conversion 197 //===----------------------------------------------------------------------===// 198 199 namespace { 200 201 /// Convert integer binary operations to SPIR-V operations. Cannot use tablegen 202 /// for this. If the integer operation is on variables of IndexType, the type of 203 /// the return value of the replacement operation differs from that of the 204 /// replaced operation. This is not handled in tablegen-based pattern 205 /// specification. 206 template <typename StdOp, typename SPIRVOp> 207 class IntegerOpConversion final : public ConversionPattern { 208 public: 209 IntegerOpConversion(MLIRContext *context) 210 : ConversionPattern(StdOp::getOperationName(), 1, context) {} 211 212 PatternMatchResult 213 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 214 ConversionPatternRewriter &rewriter) const override { 215 rewriter.template replaceOpWithNewOp<SPIRVOp>( 216 op, operands[0]->getType(), operands, ArrayRef<NamedAttribute>()); 217 return this->matchSuccess(); 218 } 219 }; 220 221 /// Convert load -> spv.LoadOp. The operands of the replaced operation are of 222 /// IndexType while that of the replacement operation are of type i32. This is 223 /// not suppored in tablegen based pattern specification. 224 // TODO(ravishankarm) : These could potentially be templated on the operation 225 // being converted, since the same logic should work for linalg.load. 226 class LoadOpConversion final : public ConversionPattern { 227 public: 228 LoadOpConversion(MLIRContext *context) 229 : ConversionPattern(LoadOp::getOperationName(), 1, context) {} 230 231 PatternMatchResult 232 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 233 ConversionPatternRewriter &rewriter) const override { 234 LoadOpOperandAdaptor loadOperands(operands); 235 auto basePtr = loadOperands.memref(); 236 auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>(); 237 if (!ptrType) { 238 return matchFailure(); 239 } 240 auto loadPtr = rewriter.create<spirv::AccessChainOp>( 241 op->getLoc(), basePtr, loadOperands.indices()); 242 auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>(); 243 rewriter.replaceOpWithNewOp<spirv::LoadOp>( 244 op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr, 245 /*alignment =*/nullptr); 246 return matchSuccess(); 247 } 248 }; 249 250 /// Convert return -> spv.Return. 251 class ReturnToSPIRVConversion : public ConversionPattern { 252 public: 253 ReturnToSPIRVConversion(MLIRContext *context) 254 : ConversionPattern(ReturnOp::getOperationName(), 1, context) {} 255 virtual PatternMatchResult 256 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 257 ConversionPatternRewriter &rewriter) const override { 258 if (op->getNumOperands()) { 259 return matchFailure(); 260 } 261 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op); 262 return matchSuccess(); 263 } 264 }; 265 266 /// Convert store -> spv.StoreOp. The operands of the replaced operation are of 267 /// IndexType while that of the replacement operation are of type i32. This is 268 /// not suppored in tablegen based pattern specification. 269 // TODO(ravishankarm) : These could potentially be templated on the operation 270 // being converted, since the same logic should work for linalg.store. 271 class StoreOpConversion final : public ConversionPattern { 272 public: 273 StoreOpConversion(MLIRContext *context) 274 : ConversionPattern(StoreOp::getOperationName(), 1, context) {} 275 276 PatternMatchResult 277 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 278 ConversionPatternRewriter &rewriter) const override { 279 StoreOpOperandAdaptor storeOperands(operands); 280 auto value = storeOperands.value(); 281 auto basePtr = storeOperands.memref(); 282 auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>(); 283 if (!ptrType) { 284 return matchFailure(); 285 } 286 auto storePtr = rewriter.create<spirv::AccessChainOp>( 287 op->getLoc(), basePtr, storeOperands.indices()); 288 rewriter.replaceOpWithNewOp<spirv::StoreOp>(op, storePtr, value, 289 /*memory_access =*/nullptr, 290 /*alignment =*/nullptr); 291 return matchSuccess(); 292 } 293 }; 294 295 } // namespace 296 297 namespace { 298 /// Import the Standard Ops to SPIR-V Patterns. 299 #include "StandardToSPIRV.cpp.inc" 300 } // namespace 301 302 namespace mlir { 303 void populateStandardToSPIRVPatterns(MLIRContext *context, 304 OwningRewritePatternList &patterns) { 305 populateWithGenerated(context, &patterns); 306 // Add the return op conversion. 307 patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>, 308 IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion, 309 ReturnToSPIRVConversion, StoreOpConversion>(context); 310 } 311 } // namespace mlir