github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp (about)

     1  //===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements a pass to convert MLIR standard and builtin dialects
    19  // into the SPIR-V dialect.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
    23  #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
    24  #include "mlir/Dialect/SPIRV/SPIRVOps.h"
    25  #include "mlir/Dialect/StandardOps/Ops.h"
    26  
    27  using namespace mlir;
    28  
    29  //===----------------------------------------------------------------------===//
    30  // Type Conversion
    31  //===----------------------------------------------------------------------===//
    32  
    33  SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
    34      : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
    35  
    36  Type SPIRVBasicTypeConverter::convertType(Type t) {
    37    // Check if the type is SPIR-V supported. If so return the type.
    38    if (spirvDialect->isValidSPIRVType(t)) {
    39      return t;
    40    }
    41  
    42    if (auto indexType = t.dyn_cast<IndexType>()) {
    43      // Return I32 for index types.
    44      return IntegerType::get(32, t.getContext());
    45    }
    46  
    47    if (auto memRefType = t.dyn_cast<MemRefType>()) {
    48      if (memRefType.hasStaticShape()) {
    49        // Convert MemrefType to a multi-dimensional spv.array if size is known.
    50        auto elementType = memRefType.getElementType();
    51        for (auto size : reverse(memRefType.getShape())) {
    52          elementType = spirv::ArrayType::get(elementType, size);
    53        }
    54        // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
    55        // to support other Storage Classes.
    56        return spirv::PointerType::get(elementType,
    57                                       spirv::StorageClass::StorageBuffer);
    58      }
    59    }
    60    return Type();
    61  }
    62  
    63  //===----------------------------------------------------------------------===//
    64  // Entry Function signature Conversion
    65  //===----------------------------------------------------------------------===//
    66  
    67  LogicalResult
    68  SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
    69                                          SignatureConversion &result) {
    70    // Try to convert the given input type.
    71    auto convertedType = basicTypeConverter->convertType(type);
    72    // TODO(ravishankarm) : Vulkan spec requires these to be a
    73    // spirv::StructType. This is not a SPIR-V requirement, so just making this a
    74    // pointer type for now.
    75    if (!convertedType)
    76      return failure();
    77    // For arguments to entry functions, convert the type into a pointer type if
    78    // it is already not one, unless the original type was an index type.
    79    // TODO(ravishankarm): For arguments that are of index type, keep the
    80    // arguments as the scalar converted type, i.e. i32. These are still not
    81    // handled effectively. These are potentially best handled as specialization
    82    // constants.
    83    if (!convertedType.isa<spirv::PointerType>() && !type.isa<IndexType>()) {
    84      // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
    85      // to support other Storage classes.
    86      convertedType = spirv::PointerType::get(convertedType,
    87                                              spirv::StorageClass::StorageBuffer);
    88    }
    89  
    90    // Add the new inputs.
    91    result.addInputs(inputNo, convertedType);
    92    return success();
    93  }
    94  
    95  static LogicalResult lowerFunctionImpl(
    96      FuncOp funcOp, ArrayRef<Value *> operands,
    97      ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
    98      TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
    99    auto fnType = funcOp.getType();
   100  
   101    if (fnType.getNumResults()) {
   102      return funcOp.emitError("SPIR-V dialect only supports functions with no "
   103                              "return values right now");
   104    }
   105  
   106    for (auto &argType : enumerate(fnType.getInputs())) {
   107      // Get the type of the argument
   108      if (failed(typeConverter->convertSignatureArg(
   109              argType.index(), argType.value(), signatureConverter))) {
   110        return funcOp.emitError("unable to convert argument type ")
   111               << argType.value() << " to SPIR-V type";
   112      }
   113    }
   114  
   115    // Create a new function with an updated signature.
   116    newFuncOp = rewriter.cloneWithoutRegions(funcOp);
   117    rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
   118                                newFuncOp.end());
   119    newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
   120                                        llvm::None, funcOp.getContext()));
   121  
   122    // Tell the rewriter to convert the region signature.
   123    rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
   124    rewriter.replaceOp(funcOp.getOperation(), llvm::None);
   125    return success();
   126  }
   127  
   128  namespace mlir {
   129  LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
   130                              SPIRVTypeConverter *typeConverter,
   131                              ConversionPatternRewriter &rewriter,
   132                              FuncOp &newFuncOp) {
   133    auto fnType = funcOp.getType();
   134    TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
   135    return lowerFunctionImpl(funcOp, operands, rewriter,
   136                             typeConverter->getBasicTypeConverter(),
   137                             signatureConverter, newFuncOp);
   138  }
   139  
   140  LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
   141                                     SPIRVTypeConverter *typeConverter,
   142                                     ConversionPatternRewriter &rewriter,
   143                                     FuncOp &newFuncOp) {
   144    auto fnType = funcOp.getType();
   145    TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
   146    if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
   147                                 signatureConverter, newFuncOp))) {
   148      return failure();
   149    }
   150    // Create spv.globalVariable ops for each of the arguments. These need to be
   151    // bound by the runtime. For now use descriptor_set 0, and arg number as the
   152    // binding number.
   153    auto module = funcOp.getParentOfType<spirv::ModuleOp>();
   154    if (!module) {
   155      return funcOp.emitError("expected op to be within a spv.module");
   156    }
   157    auto ip = rewriter.saveInsertionPoint();
   158    rewriter.setInsertionPointToStart(&module.getBlock());
   159    SmallVector<Attribute, 4> interface;
   160    for (auto &convertedArgType :
   161         llvm::enumerate(signatureConverter.getConvertedTypes())) {
   162      // TODO(ravishankarm) : The arguments to the converted function are either
   163      // spirv::PointerType or i32 type, the latter due to conversion of index
   164      // type to i32. Eventually entry function should be of signature
   165      // void(void). Arguments converted to spirv::PointerType, will be made
   166      // variables and those converted to i32 will be made specialization
   167      // constants. Latter is not implemented.
   168      if (!convertedArgType.value().isa<spirv::PointerType>()) {
   169        continue;
   170      }
   171      std::string varName = funcOp.getName().str() + "_arg_" +
   172                            std::to_string(convertedArgType.index());
   173      auto variableOp = rewriter.create<spirv::GlobalVariableOp>(
   174          funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()),
   175          rewriter.getStringAttr(varName), nullptr);
   176      variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0));
   177      variableOp.setAttr("binding",
   178                         rewriter.getI32IntegerAttr(convertedArgType.index()));
   179      interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name()));
   180    }
   181    // Create an entry point instruction for this function.
   182    // TODO(ravishankarm) : Add execution mode for the entry function
   183    rewriter.setInsertionPoint(&(module.getBlock().back()));
   184    rewriter.create<spirv::EntryPointOp>(
   185        funcOp.getLoc(),
   186        rewriter.getI32IntegerAttr(
   187            static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
   188        rewriter.getSymbolRefAttr(newFuncOp.getName()),
   189        rewriter.getArrayAttr(interface));
   190    rewriter.restoreInsertionPoint(ip);
   191    return success();
   192  }
   193  } // namespace mlir
   194  
   195  //===----------------------------------------------------------------------===//
   196  // Operation conversion
   197  //===----------------------------------------------------------------------===//
   198  
   199  namespace {
   200  
   201  /// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
   202  /// for this. If the integer operation is on variables of IndexType, the type of
   203  /// the return value of the replacement operation differs from that of the
   204  /// replaced operation. This is not handled in tablegen-based pattern
   205  /// specification.
   206  template <typename StdOp, typename SPIRVOp>
   207  class IntegerOpConversion final : public ConversionPattern {
   208  public:
   209    IntegerOpConversion(MLIRContext *context)
   210        : ConversionPattern(StdOp::getOperationName(), 1, context) {}
   211  
   212    PatternMatchResult
   213    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   214                    ConversionPatternRewriter &rewriter) const override {
   215      rewriter.template replaceOpWithNewOp<SPIRVOp>(
   216          op, operands[0]->getType(), operands, ArrayRef<NamedAttribute>());
   217      return this->matchSuccess();
   218    }
   219  };
   220  
   221  /// Convert load -> spv.LoadOp. The operands of the replaced operation are of
   222  /// IndexType while that of the replacement operation are of type i32. This is
   223  /// not suppored in tablegen based pattern specification.
   224  // TODO(ravishankarm) : These could potentially be templated on the operation
   225  // being converted, since the same logic should work for linalg.load.
   226  class LoadOpConversion final : public ConversionPattern {
   227  public:
   228    LoadOpConversion(MLIRContext *context)
   229        : ConversionPattern(LoadOp::getOperationName(), 1, context) {}
   230  
   231    PatternMatchResult
   232    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   233                    ConversionPatternRewriter &rewriter) const override {
   234      LoadOpOperandAdaptor loadOperands(operands);
   235      auto basePtr = loadOperands.memref();
   236      auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
   237      if (!ptrType) {
   238        return matchFailure();
   239      }
   240      auto loadPtr = rewriter.create<spirv::AccessChainOp>(
   241          op->getLoc(), basePtr, loadOperands.indices());
   242      auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>();
   243      rewriter.replaceOpWithNewOp<spirv::LoadOp>(
   244          op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr,
   245          /*alignment =*/nullptr);
   246      return matchSuccess();
   247    }
   248  };
   249  
   250  /// Convert return -> spv.Return.
   251  class ReturnToSPIRVConversion : public ConversionPattern {
   252  public:
   253    ReturnToSPIRVConversion(MLIRContext *context)
   254        : ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
   255    virtual PatternMatchResult
   256    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   257                    ConversionPatternRewriter &rewriter) const override {
   258      if (op->getNumOperands()) {
   259        return matchFailure();
   260      }
   261      rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
   262      return matchSuccess();
   263    }
   264  };
   265  
   266  /// Convert store -> spv.StoreOp. The operands of the replaced operation are of
   267  /// IndexType while that of the replacement operation are of type i32. This is
   268  /// not suppored in tablegen based pattern specification.
   269  // TODO(ravishankarm) : These could potentially be templated on the operation
   270  // being converted, since the same logic should work for linalg.store.
   271  class StoreOpConversion final : public ConversionPattern {
   272  public:
   273    StoreOpConversion(MLIRContext *context)
   274        : ConversionPattern(StoreOp::getOperationName(), 1, context) {}
   275  
   276    PatternMatchResult
   277    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   278                    ConversionPatternRewriter &rewriter) const override {
   279      StoreOpOperandAdaptor storeOperands(operands);
   280      auto value = storeOperands.value();
   281      auto basePtr = storeOperands.memref();
   282      auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
   283      if (!ptrType) {
   284        return matchFailure();
   285      }
   286      auto storePtr = rewriter.create<spirv::AccessChainOp>(
   287          op->getLoc(), basePtr, storeOperands.indices());
   288      rewriter.replaceOpWithNewOp<spirv::StoreOp>(op, storePtr, value,
   289                                                  /*memory_access =*/nullptr,
   290                                                  /*alignment =*/nullptr);
   291      return matchSuccess();
   292    }
   293  };
   294  
   295  } // namespace
   296  
   297  namespace {
   298  /// Import the Standard Ops to SPIR-V Patterns.
   299  #include "StandardToSPIRV.cpp.inc"
   300  } // namespace
   301  
   302  namespace mlir {
   303  void populateStandardToSPIRVPatterns(MLIRContext *context,
   304                                       OwningRewritePatternList &patterns) {
   305    populateWithGenerated(context, &patterns);
   306    // Add the return op conversion.
   307    patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>,
   308                    IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
   309                    ReturnToSPIRVConversion, StoreOpConversion>(context);
   310  }
   311  } // namespace mlir