github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp (about)

     1  //===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements a pass to convert MLIR standard and builtin dialects
    19  // into the LLVM IR dialect.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
    24  #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
    25  #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
    26  #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
    27  #include "mlir/Dialect/StandardOps/Ops.h"
    28  #include "mlir/IR/Builders.h"
    29  #include "mlir/IR/MLIRContext.h"
    30  #include "mlir/IR/Module.h"
    31  #include "mlir/IR/PatternMatch.h"
    32  #include "mlir/Pass/Pass.h"
    33  #include "mlir/Support/Functional.h"
    34  #include "mlir/Transforms/DialectConversion.h"
    35  #include "mlir/Transforms/Passes.h"
    36  #include "mlir/Transforms/Utils.h"
    37  
    38  #include "llvm/IR/DerivedTypes.h"
    39  #include "llvm/IR/IRBuilder.h"
    40  #include "llvm/IR/Type.h"
    41  
    42  using namespace mlir;
    43  
    44  LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
    45      : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
    46    assert(llvmDialect && "LLVM IR dialect is not registered");
    47    module = &llvmDialect->getLLVMModule();
    48  }
    49  
    50  // Get the LLVM context.
    51  llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
    52    return module->getContext();
    53  }
    54  
    55  // Extract an LLVM IR type from the LLVM IR dialect type.
    56  LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) {
    57    if (!type)
    58      return nullptr;
    59    auto *mlirContext = type.getContext();
    60    auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
    61    if (!wrappedLLVMType)
    62      emitError(UnknownLoc::get(mlirContext),
    63                "conversion resulted in a non-LLVM type");
    64    return wrappedLLVMType;
    65  }
    66  
    67  LLVM::LLVMType LLVMTypeConverter::getIndexType() {
    68    return LLVM::LLVMType::getIntNTy(
    69        llvmDialect, module->getDataLayout().getPointerSizeInBits());
    70  }
    71  
    72  Type LLVMTypeConverter::convertIndexType(IndexType type) {
    73    return getIndexType();
    74  }
    75  
    76  Type LLVMTypeConverter::convertIntegerType(IntegerType type) {
    77    return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth());
    78  }
    79  
    80  Type LLVMTypeConverter::convertFloatType(FloatType type) {
    81    switch (type.getKind()) {
    82    case mlir::StandardTypes::F32:
    83      return LLVM::LLVMType::getFloatTy(llvmDialect);
    84    case mlir::StandardTypes::F64:
    85      return LLVM::LLVMType::getDoubleTy(llvmDialect);
    86    case mlir::StandardTypes::F16:
    87      return LLVM::LLVMType::getHalfTy(llvmDialect);
    88    case mlir::StandardTypes::BF16: {
    89      auto *mlirContext = llvmDialect->getContext();
    90      return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"),
    91             Type();
    92    }
    93    default:
    94      llvm_unreachable("non-float type in convertFloatType");
    95    }
    96  }
    97  
    98  // Function types are converted to LLVM Function types by recursively converting
    99  // argument and result types.  If MLIR Function has zero results, the LLVM
   100  // Function has one VoidType result.  If MLIR Function has more than one result,
   101  // they are into an LLVM StructType in their order of appearance.
   102  Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
   103    // Convert argument types one by one and check for errors.
   104    SmallVector<LLVM::LLVMType, 8> argTypes;
   105    for (auto t : type.getInputs()) {
   106      auto converted = convertType(t);
   107      if (!converted)
   108        return {};
   109      argTypes.push_back(unwrap(converted));
   110    }
   111  
   112    // If function does not return anything, create the void result type,
   113    // if it returns on element, convert it, otherwise pack the result types into
   114    // a struct.
   115    LLVM::LLVMType resultType =
   116        type.getNumResults() == 0
   117            ? LLVM::LLVMType::getVoidTy(llvmDialect)
   118            : unwrap(packFunctionResults(type.getResults()));
   119    if (!resultType)
   120      return {};
   121    return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false)
   122        .getPointerTo();
   123  }
   124  
   125  // Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
   126  // we return a pointer to the converted element type. Otherwise we return an
   127  // LLVM stucture type, where the first element of the structure type is a
   128  // pointer to the elemental type of the MemRef and the following N elements are
   129  // values of the Index type, one for each of N dynamic dimensions of the MemRef.
   130  Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
   131    LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
   132    if (!elementType)
   133      return {};
   134    auto ptrType = elementType.getPointerTo();
   135  
   136    // Extra value for the memory space.
   137    unsigned numDynamicSizes = type.getNumDynamicDims();
   138    // If memref is statically-shaped we return the underlying pointer type.
   139    if (numDynamicSizes == 0)
   140      return ptrType;
   141  
   142    SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType());
   143    types.front() = ptrType;
   144  
   145    return LLVM::LLVMType::getStructTy(llvmDialect, types);
   146  }
   147  
   148  // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when
   149  // n > 1.
   150  // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and
   151  // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`.
   152  Type LLVMTypeConverter::convertVectorType(VectorType type) {
   153    auto elementType = unwrap(convertType(type.getElementType()));
   154    if (!elementType)
   155      return {};
   156    auto vectorType =
   157        LLVM::LLVMType::getVectorTy(elementType, type.getShape().back());
   158    auto shape = type.getShape();
   159    for (int i = shape.size() - 2; i >= 0; --i)
   160      vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]);
   161    return vectorType;
   162  }
   163  
   164  // Dispatch based on the actual type.  Return null type on error.
   165  Type LLVMTypeConverter::convertStandardType(Type type) {
   166    if (auto funcType = type.dyn_cast<FunctionType>())
   167      return convertFunctionType(funcType);
   168    if (auto intType = type.dyn_cast<IntegerType>())
   169      return convertIntegerType(intType);
   170    if (auto floatType = type.dyn_cast<FloatType>())
   171      return convertFloatType(floatType);
   172    if (auto indexType = type.dyn_cast<IndexType>())
   173      return convertIndexType(indexType);
   174    if (auto memRefType = type.dyn_cast<MemRefType>())
   175      return convertMemRefType(memRefType);
   176    if (auto vectorType = type.dyn_cast<VectorType>())
   177      return convertVectorType(vectorType);
   178    if (auto llvmType = type.dyn_cast<LLVM::LLVMType>())
   179      return llvmType;
   180  
   181    return {};
   182  }
   183  
   184  // Convert the element type of the memref `t` to to an LLVM type using
   185  // `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it
   186  // into the MLIR LLVM dialect type and return.
   187  static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
   188    auto elementType = t.getElementType();
   189    auto converted = lowering.convertType(elementType);
   190    if (!converted)
   191      return {};
   192    return converted.cast<LLVM::LLVMType>().getPointerTo();
   193  }
   194  
   195  LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
   196                                 LLVMTypeConverter &lowering_,
   197                                 PatternBenefit benefit)
   198      : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {}
   199  
   200  namespace {
   201  // Base class for Standard to LLVM IR op conversions.  Matches the Op type
   202  // provided as template argument.  Carries a reference to the LLVM dialect in
   203  // case it is necessary for rewriters.
   204  template <typename SourceOp>
   205  class LLVMLegalizationPattern : public LLVMOpLowering {
   206  public:
   207    // Construct a conversion pattern.
   208    explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_,
   209                                     LLVMTypeConverter &lowering_)
   210        : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(),
   211                         lowering_),
   212          dialect(dialect_) {}
   213  
   214    // Get the LLVM IR dialect.
   215    LLVM::LLVMDialect &getDialect() const { return dialect; }
   216    // Get the LLVM context.
   217    llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); }
   218    // Get the LLVM module in which the types are constructed.
   219    llvm::Module &getModule() const { return dialect.getLLVMModule(); }
   220  
   221    // Get the MLIR type wrapping the LLVM integer type whose bit width is defined
   222    // by the pointer size used in the LLVM module.
   223    LLVM::LLVMType getIndexType() const {
   224      return LLVM::LLVMType::getIntNTy(
   225          &dialect, getModule().getDataLayout().getPointerSizeInBits());
   226    }
   227  
   228    // Get the MLIR type wrapping the LLVM i8* type.
   229    LLVM::LLVMType getVoidPtrType() const {
   230      return LLVM::LLVMType::getInt8PtrTy(&dialect);
   231    }
   232  
   233    // Create an LLVM IR pseudo-operation defining the given index constant.
   234    Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
   235                               uint64_t value) const {
   236      auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
   237      return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
   238    }
   239  
   240    // Get the array attribute named "position" containing the given list of
   241    // integers as integer attribute elements.
   242    static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
   243                                         ArrayRef<int64_t> values) {
   244      SmallVector<Attribute, 4> attrs;
   245      attrs.reserve(values.size());
   246      for (int64_t pos : values)
   247        attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
   248      return builder.getArrayAttr(attrs);
   249    }
   250  
   251    // Extract raw data pointer value from a value representing a memref.
   252    static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
   253                                          Location loc,
   254                                          Value *convertedMemRefValue,
   255                                          Type elementTypePtr,
   256                                          bool hasStaticShape) {
   257      Value *buffer;
   258      if (hasStaticShape)
   259        return convertedMemRefValue;
   260      else
   261        return builder.create<LLVM::ExtractValueOp>(
   262            loc, elementTypePtr, convertedMemRefValue,
   263            getIntegerArrayAttr(builder, 0));
   264      return buffer;
   265    }
   266  
   267  protected:
   268    LLVM::LLVMDialect &dialect;
   269  };
   270  
   271  struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
   272    using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
   273  
   274    PatternMatchResult
   275    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   276                    ConversionPatternRewriter &rewriter) const override {
   277      auto funcOp = cast<FuncOp>(op);
   278      FunctionType type = funcOp.getType();
   279  
   280      // Convert the original function arguments.
   281      TypeConverter::SignatureConversion result(type.getNumInputs());
   282      for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
   283        if (failed(lowering.convertSignatureArg(i, type.getInput(i), result)))
   284          return matchFailure();
   285  
   286      // Pack the result types into a struct.
   287      Type packedResult;
   288      if (type.getNumResults() != 0) {
   289        if (!(packedResult = lowering.packFunctionResults(type.getResults())))
   290          return matchFailure();
   291      }
   292  
   293      // Create a new function with an updated signature.
   294      auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
   295      rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
   296                                  newFuncOp.end());
   297      newFuncOp.setType(FunctionType::get(
   298          result.getConvertedTypes(),
   299          packedResult ? ArrayRef<Type>(packedResult) : llvm::None,
   300          funcOp.getContext()));
   301  
   302      // Tell the rewriter to convert the region signature.
   303      rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
   304      rewriter.replaceOp(op, llvm::None);
   305      return matchSuccess();
   306    }
   307  };
   308  
   309  // Basic lowering implementation for one-to-one rewriting from Standard Ops to
   310  // LLVM Dialect Ops.
   311  template <typename SourceOp, typename TargetOp>
   312  struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
   313    using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
   314    using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>;
   315  
   316    // Convert the type of the result to an LLVM type, pass operands as is,
   317    // preserve attributes.
   318    PatternMatchResult
   319    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   320                    ConversionPatternRewriter &rewriter) const override {
   321      unsigned numResults = op->getNumResults();
   322  
   323      Type packedType;
   324      if (numResults != 0) {
   325        packedType = this->lowering.packFunctionResults(
   326            llvm::to_vector<4>(op->getResultTypes()));
   327        assert(packedType && "type conversion failed, such operation should not "
   328                             "have been matched");
   329      }
   330  
   331      auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
   332                                             op->getAttrs());
   333  
   334      // If the operation produced 0 or 1 result, return them immediately.
   335      if (numResults == 0)
   336        return rewriter.replaceOp(op, llvm::None), this->matchSuccess();
   337      if (numResults == 1)
   338        return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)),
   339               this->matchSuccess();
   340  
   341      // Otherwise, it had been converted to an operation producing a structure.
   342      // Extract individual results from the structure and return them as list.
   343      SmallVector<Value *, 4> results;
   344      results.reserve(numResults);
   345      for (unsigned i = 0; i < numResults; ++i) {
   346        auto type = this->lowering.convertType(op->getResult(i)->getType());
   347        results.push_back(rewriter.create<LLVM::ExtractValueOp>(
   348            op->getLoc(), type, newOp.getOperation()->getResult(0),
   349            rewriter.getIndexArrayAttr(i)));
   350      }
   351      rewriter.replaceOp(op, results);
   352      return this->matchSuccess();
   353    }
   354  };
   355  
   356  // Express `linearIndex` in terms of coordinates of `basis`.
   357  // Returns the empty vector when linearIndex is out of the range [0, P] where
   358  // P is the product of all the basis coordinates.
   359  //
   360  // Prerequisites:
   361  //   Basis is an array of nonnegative integers (signed type inherited from
   362  //   vector shape type).
   363  static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
   364                                                unsigned linearIndex) {
   365    SmallVector<int64_t, 4> res;
   366    res.reserve(basis.size());
   367    for (unsigned basisElement : llvm::reverse(basis)) {
   368      res.push_back(linearIndex % basisElement);
   369      linearIndex = linearIndex / basisElement;
   370    }
   371    if (linearIndex > 0)
   372      return {};
   373    std::reverse(res.begin(), res.end());
   374    return res;
   375  }
   376  
   377  // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
   378  // Ops for binary ops with one result. This supports higher-dimensional vector
   379  // types.
   380  template <typename SourceOp, typename TargetOp>
   381  struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
   382    using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
   383    using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
   384  
   385    // Convert the type of the result to an LLVM type, pass operands as is,
   386    // preserve attributes.
   387    PatternMatchResult
   388    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   389                    ConversionPatternRewriter &rewriter) const override {
   390      static_assert(
   391          std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
   392          "expected binary op");
   393      static_assert(
   394          std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
   395          "expected single result op");
   396      static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
   397                                    SourceOp>::value,
   398                    "expected single result op");
   399  
   400      auto loc = op->getLoc();
   401      auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>();
   402  
   403      if (!llvmArrayTy.isArrayTy()) {
   404        auto newOp = rewriter.create<TargetOp>(
   405            op->getLoc(), operands[0]->getType(), operands, op->getAttrs());
   406        rewriter.replaceOp(op, newOp.getResult());
   407        return this->matchSuccess();
   408      }
   409  
   410      // Unroll iterated array type until we hit a non-array type.
   411      auto llvmTy = llvmArrayTy;
   412      SmallVector<int64_t, 4> arraySizes;
   413      while (llvmTy.isArrayTy()) {
   414        arraySizes.push_back(llvmTy.getArrayNumElements());
   415        llvmTy = llvmTy.getArrayElementType();
   416      }
   417      assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
   418      auto llvmVectorTy = llvmTy;
   419  
   420      // Iteratively extract a position coordinates with basis `arraySize` from a
   421      // `linearIndex` that is incremented at each step. This terminates when
   422      // `linearIndex` exceeds the range specified by `arraySize`.
   423      // This has the effect of fully unrolling the dimensions of the n-D array
   424      // type, getting to the underlying vector element.
   425      Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
   426      unsigned ub = 1;
   427      for (auto s : arraySizes)
   428        ub *= s;
   429      for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
   430        auto coords = getCoordinates(arraySizes, linearIndex);
   431        // Linear index is out of bounds, we are done.
   432        if (coords.empty())
   433          break;
   434  
   435        auto position = rewriter.getIndexArrayAttr(coords);
   436  
   437        // For this unrolled `position` corresponding to the `linearIndex`^th
   438        // element, extract operand vectors
   439        Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
   440            loc, llvmVectorTy, operands[0], position);
   441        Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
   442            loc, llvmVectorTy, operands[1], position);
   443        Value *newVal = rewriter.create<TargetOp>(
   444            loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
   445            op->getAttrs());
   446        desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
   447                                                    newVal, position);
   448      }
   449      rewriter.replaceOp(op, desc);
   450      return this->matchSuccess();
   451    }
   452  };
   453  
   454  // Specific lowerings.
   455  // FIXME: this should be tablegen'ed.
   456  struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
   457    using Super::Super;
   458  };
   459  struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
   460    using Super::Super;
   461  };
   462  struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
   463    using Super::Super;
   464  };
   465  struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> {
   466    using Super::Super;
   467  };
   468  struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
   469    using Super::Super;
   470  };
   471  struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> {
   472    using Super::Super;
   473  };
   474  struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> {
   475    using Super::Super;
   476  };
   477  struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
   478    using Super::Super;
   479  };
   480  struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
   481    using Super::Super;
   482  };
   483  struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
   484    using Super::Super;
   485  };
   486  struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
   487    using Super::Super;
   488  };
   489  struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
   490    using Super::Super;
   491  };
   492  struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
   493    using Super::Super;
   494  };
   495  struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
   496    using Super::Super;
   497  };
   498  struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
   499    using Super::Super;
   500  };
   501  struct SelectOpLowering
   502      : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> {
   503    using Super::Super;
   504  };
   505  struct CallOpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::CallOp> {
   506    using Super::Super;
   507  };
   508  struct CallIndirectOpLowering
   509      : public OneToOneLLVMOpLowering<CallIndirectOp, LLVM::CallOp> {
   510    using Super::Super;
   511  };
   512  struct ConstLLVMOpLowering
   513      : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> {
   514    using Super::Super;
   515  };
   516  
   517  // Check if the MemRefType `type` is supported by the lowering. We currently do
   518  // not support memrefs with affine maps and non-default memory spaces.
   519  static bool isSupportedMemRefType(MemRefType type) {
   520    if (!type.getAffineMaps().empty())
   521      return false;
   522    if (type.getMemorySpace() != 0)
   523      return false;
   524    return true;
   525  }
   526  
   527  // An `alloc` is converted into a definition of a memref descriptor value and
   528  // a call to `malloc` to allocate the underlying data buffer.  The memref
   529  // descriptor is of the LLVM structure type where the first element is a pointer
   530  // to the (typed) data buffer, and the remaining elements serve to store
   531  // dynamic sizes of the memref using LLVM-converted `index` type.
   532  struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
   533    using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
   534  
   535    PatternMatchResult match(Operation *op) const override {
   536      MemRefType type = cast<AllocOp>(op).getType();
   537      return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
   538    }
   539  
   540    void rewrite(Operation *op, ArrayRef<Value *> operands,
   541                 ConversionPatternRewriter &rewriter) const override {
   542      auto allocOp = cast<AllocOp>(op);
   543      MemRefType type = allocOp.getType();
   544  
   545      // Get actual sizes of the memref as values: static sizes are constant
   546      // values and dynamic sizes are passed to 'alloc' as operands.  In case of
   547      // zero-dimensional memref, assume a scalar (size 1).
   548      SmallVector<Value *, 4> sizes;
   549      auto numOperands = allocOp.getNumOperands();
   550      sizes.reserve(numOperands);
   551      unsigned i = 0;
   552      for (int64_t s : type.getShape())
   553        sizes.push_back(s == -1 ? operands[i++]
   554                                : createIndexConstant(rewriter, op->getLoc(), s));
   555      if (sizes.empty())
   556        sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1));
   557  
   558      // Compute the total number of memref elements.
   559      Value *cumulativeSize = sizes.front();
   560      for (unsigned i = 1, e = sizes.size(); i < e; ++i)
   561        cumulativeSize = rewriter.create<LLVM::MulOp>(
   562            op->getLoc(), getIndexType(),
   563            ArrayRef<Value *>{cumulativeSize, sizes[i]});
   564  
   565      // Compute the total amount of bytes to allocate.
   566      auto elementType = type.getElementType();
   567      assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
   568             "invalid memref element type");
   569      uint64_t elementSize = 0;
   570      if (auto vectorType = elementType.dyn_cast<VectorType>())
   571        elementSize = vectorType.getNumElements() *
   572                      llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
   573      else
   574        elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
   575      cumulativeSize = rewriter.create<LLVM::MulOp>(
   576          op->getLoc(), getIndexType(),
   577          ArrayRef<Value *>{
   578              cumulativeSize,
   579              createIndexConstant(rewriter, op->getLoc(), elementSize)});
   580  
   581      // Insert the `malloc` declaration if it is not already present.
   582      auto module = op->getParentOfType<ModuleOp>();
   583      FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
   584      if (!mallocFunc) {
   585        auto mallocType =
   586            rewriter.getFunctionType(getIndexType(), getVoidPtrType());
   587        mallocFunc =
   588            FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
   589        module.push_back(mallocFunc);
   590      }
   591  
   592      // Allocate the underlying buffer and store a pointer to it in the MemRef
   593      // descriptor.
   594      Value *allocated =
   595          rewriter
   596              .create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
   597                                    rewriter.getSymbolRefAttr(mallocFunc),
   598                                    cumulativeSize)
   599              .getResult(0);
   600      auto structElementType = lowering.convertType(elementType);
   601      auto elementPtrType =
   602          structElementType.cast<LLVM::LLVMType>().getPointerTo();
   603      allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
   604                                                   ArrayRef<Value *>(allocated));
   605  
   606      // Deal with static memrefs
   607      if (numOperands == 0)
   608        return rewriter.replaceOp(op, allocated);
   609  
   610      // Create the MemRef descriptor.
   611      auto structType = lowering.convertType(type);
   612      Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
   613          op->getLoc(), structType, ArrayRef<Value *>{});
   614  
   615      memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
   616          op->getLoc(), structType, memRefDescriptor, allocated,
   617          rewriter.getIndexArrayAttr(0));
   618  
   619      // Store dynamically allocated sizes in the descriptor.  Dynamic sizes are
   620      // passed in as operands.
   621      for (auto indexedSize : llvm::enumerate(operands)) {
   622        memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
   623            op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
   624            rewriter.getIndexArrayAttr(1 + indexedSize.index()));
   625      }
   626  
   627      // Return the final value of the descriptor.
   628      rewriter.replaceOp(op, memRefDescriptor);
   629    }
   630  };
   631  
   632  // A `dealloc` is converted into a call to `free` on the underlying data buffer.
   633  // The memref descriptor being an SSA value, there is no need to clean it up
   634  // in any way.
   635  struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
   636    using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
   637  
   638    PatternMatchResult
   639    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   640                    ConversionPatternRewriter &rewriter) const override {
   641      assert(operands.size() == 1 && "dealloc takes one operand");
   642      OperandAdaptor<DeallocOp> transformed(operands);
   643  
   644      // Insert the `free` declaration if it is not already present.
   645      FuncOp freeFunc =
   646          op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free");
   647      if (!freeFunc) {
   648        auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
   649        freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
   650        op->getParentOfType<ModuleOp>().push_back(freeFunc);
   651      }
   652  
   653      auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
   654      auto hasStaticShape = type.isPointerTy();
   655      Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
   656      Value *bufferPtr =
   657          extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
   658                                  elementPtrType, hasStaticShape);
   659      Value *casted = rewriter.create<LLVM::BitcastOp>(
   660          op->getLoc(), getVoidPtrType(), bufferPtr);
   661      rewriter.replaceOpWithNewOp<LLVM::CallOp>(
   662          op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
   663      return matchSuccess();
   664    }
   665  };
   666  
   667  struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
   668    using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern;
   669  
   670    PatternMatchResult match(Operation *op) const override {
   671      auto memRefCastOp = cast<MemRefCastOp>(op);
   672      MemRefType sourceType =
   673          memRefCastOp.getOperand()->getType().cast<MemRefType>();
   674      MemRefType targetType = memRefCastOp.getType();
   675      return (isSupportedMemRefType(targetType) &&
   676              isSupportedMemRefType(sourceType))
   677                 ? matchSuccess()
   678                 : matchFailure();
   679    }
   680  
   681    void rewrite(Operation *op, ArrayRef<Value *> operands,
   682                 ConversionPatternRewriter &rewriter) const override {
   683      auto memRefCastOp = cast<MemRefCastOp>(op);
   684      OperandAdaptor<MemRefCastOp> transformed(operands);
   685      auto targetType = memRefCastOp.getType();
   686      auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>();
   687  
   688      // Copy the data buffer pointer.
   689      auto elementTypePtr = getMemRefElementPtrType(targetType, lowering);
   690      Value *buffer =
   691          extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(),
   692                                  elementTypePtr, sourceType.hasStaticShape());
   693      // Account for static memrefs as target types
   694      if (targetType.hasStaticShape())
   695        return rewriter.replaceOp(op, buffer);
   696  
   697      // Create the new MemRef descriptor.
   698      auto structType = lowering.convertType(targetType);
   699      Value *newDescriptor = rewriter.create<LLVM::UndefOp>(
   700          op->getLoc(), structType, ArrayRef<Value *>{});
   701      // Otherwise target type is dynamic memref, so create a proper descriptor.
   702      newDescriptor = rewriter.create<LLVM::InsertValueOp>(
   703          op->getLoc(), structType, newDescriptor, buffer,
   704          rewriter.getIndexArrayAttr(0));
   705  
   706      // Fill in the dynamic sizes of the new descriptor.  If the size was
   707      // dynamic, copy it from the old descriptor.  If the size was static, insert
   708      // the constant.  Note that the positions of dynamic sizes in the
   709      // descriptors start from 1 (the buffer pointer is at position zero).
   710      int64_t sourceDynamicDimIdx = 1;
   711      int64_t targetDynamicDimIdx = 1;
   712      for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
   713        // Ignore new static sizes (they will be known from the type).  If the
   714        // size was dynamic, update the index of dynamic types.
   715        if (targetType.getShape()[i] != -1) {
   716          if (sourceType.getShape()[i] == -1)
   717            ++sourceDynamicDimIdx;
   718          continue;
   719        }
   720  
   721        auto sourceSize = sourceType.getShape()[i];
   722        Value *size =
   723            sourceSize == -1
   724                ? rewriter.create<LLVM::ExtractValueOp>(
   725                      op->getLoc(), getIndexType(),
   726                      transformed.source(), // NB: dynamic memref
   727                      rewriter.getIndexArrayAttr(sourceDynamicDimIdx++))
   728                : createIndexConstant(rewriter, op->getLoc(), sourceSize);
   729        newDescriptor = rewriter.create<LLVM::InsertValueOp>(
   730            op->getLoc(), structType, newDescriptor, size,
   731            rewriter.getIndexArrayAttr(targetDynamicDimIdx++));
   732      }
   733      assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
   734             "source dynamic dimensions were not processed");
   735      assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() &&
   736             "target dynamic dimensions were not set up");
   737  
   738      rewriter.replaceOp(op, newDescriptor);
   739    }
   740  };
   741  
   742  // A `dim` is converted to a constant for static sizes and to an access to the
   743  // size stored in the memref descriptor for dynamic sizes.
   744  struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
   745    using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern;
   746  
   747    PatternMatchResult match(Operation *op) const override {
   748      auto dimOp = cast<DimOp>(op);
   749      MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
   750      return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
   751    }
   752  
   753    void rewrite(Operation *op, ArrayRef<Value *> operands,
   754                 ConversionPatternRewriter &rewriter) const override {
   755      auto dimOp = cast<DimOp>(op);
   756      OperandAdaptor<DimOp> transformed(operands);
   757      MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
   758  
   759      auto shape = type.getShape();
   760      uint64_t index = dimOp.getIndex();
   761      // Extract dynamic size from the memref descriptor and define static size
   762      // as a constant.
   763      if (shape[index] == -1) {
   764        // Find the position of the dynamic dimension in the list of dynamic sizes
   765        // by counting the number of preceding dynamic dimensions.  Start from 1
   766        // because the buffer pointer is at position zero.
   767        int64_t position = 1;
   768        for (uint64_t i = 0; i < index; ++i) {
   769          if (shape[i] == -1)
   770            ++position;
   771        }
   772        rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
   773            op, getIndexType(), transformed.memrefOrTensor(),
   774            rewriter.getIndexArrayAttr(position));
   775      } else {
   776        rewriter.replaceOp(
   777            op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
   778      }
   779    }
   780  };
   781  
   782  // Common base for load and store operations on MemRefs.  Restricts the match
   783  // to supported MemRef types.  Provides functionality to emit code accessing a
   784  // specific element of the underlying data buffer.
   785  template <typename Derived>
   786  struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
   787    using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern;
   788    using Base = LoadStoreOpLowering<Derived>;
   789  
   790    PatternMatchResult match(Operation *op) const override {
   791      MemRefType type = cast<Derived>(op).getMemRefType();
   792      return isSupportedMemRefType(type) ? this->matchSuccess()
   793                                         : this->matchFailure();
   794    }
   795  
   796    // Given subscript indices and array sizes in row-major order,
   797    //   i_n, i_{n-1}, ..., i_1
   798    //   s_n, s_{n-1}, ..., s_1
   799    // obtain a value that corresponds to the linearized subscript
   800    //   \sum_k i_k * \prod_{j=1}^{k-1} s_j
   801    // by accumulating the running linearized value.
   802    // Note that `indices` and `allocSizes` are passed in the same order as they
   803    // appear in load/store operations and memref type declarations.
   804    Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
   805                               ArrayRef<Value *> indices,
   806                               ArrayRef<Value *> allocSizes) const {
   807      assert(indices.size() == allocSizes.size() &&
   808             "mismatching number of indices and allocation sizes");
   809      assert(!indices.empty() && "cannot linearize a 0-dimensional access");
   810  
   811      Value *linearized = indices.front();
   812      for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
   813        linearized = builder.create<LLVM::MulOp>(
   814            loc, this->getIndexType(),
   815            ArrayRef<Value *>{linearized, allocSizes[i]});
   816        linearized = builder.create<LLVM::AddOp>(
   817            loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]});
   818      }
   819      return linearized;
   820    }
   821  
   822    // Given the MemRef type, a descriptor and a list of indices, extract the data
   823    // buffer pointer from the descriptor, convert multi-dimensional subscripts
   824    // into a linearized index (using dynamic size data from the descriptor if
   825    // necessary) and get the pointer to the buffer element identified by the
   826    // indices.
   827    Value *getElementPtr(Location loc, Type elementTypePtr,
   828                         ArrayRef<int64_t> shape, Value *memRefDescriptor,
   829                         ArrayRef<Value *> indices,
   830                         ConversionPatternRewriter &rewriter) const {
   831      // Get the list of MemRef sizes.  Static sizes are defined as constants.
   832      // Dynamic sizes are extracted from the MemRef descriptor, where they start
   833      // from the position 1 (the buffer is at position 0).
   834      SmallVector<Value *, 4> sizes;
   835      unsigned dynamicSizeIdx = 1;
   836      for (int64_t s : shape) {
   837        if (s == -1) {
   838          Value *size = rewriter.create<LLVM::ExtractValueOp>(
   839              loc, this->getIndexType(), memRefDescriptor,
   840              rewriter.getIndexArrayAttr(dynamicSizeIdx++));
   841          sizes.push_back(size);
   842        } else {
   843          sizes.push_back(this->createIndexConstant(rewriter, loc, s));
   844        }
   845      }
   846  
   847      // The second and subsequent operands are access subscripts.  Obtain the
   848      // linearized address in the buffer.
   849      Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
   850  
   851      Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
   852          loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0));
   853      return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
   854                                          ArrayRef<Value *>{dataPtr, subscript},
   855                                          ArrayRef<NamedAttribute>{});
   856    }
   857    // This is a getElementPtr variant, where the value is a direct raw pointer.
   858    // If a shape is empty, we are dealing with a zero-dimensional memref. Return
   859    // the pointer unmodified in this case.  Otherwise, linearize subscripts to
   860    // obtain the offset with respect to the base pointer.  Use this offset to
   861    // compute and return the element pointer.
   862    Value *getRawElementPtr(Location loc, Type elementTypePtr,
   863                            ArrayRef<int64_t> shape, Value *rawDataPtr,
   864                            ArrayRef<Value *> indices,
   865                            ConversionPatternRewriter &rewriter) const {
   866      if (shape.empty())
   867        return rawDataPtr;
   868  
   869      SmallVector<Value *, 4> sizes;
   870      for (int64_t s : shape) {
   871        sizes.push_back(this->createIndexConstant(rewriter, loc, s));
   872      }
   873  
   874      Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
   875      return rewriter.create<LLVM::GEPOp>(
   876          loc, elementTypePtr, ArrayRef<Value *>{rawDataPtr, subscript},
   877          ArrayRef<NamedAttribute>{});
   878    }
   879  
   880    Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr,
   881                      ArrayRef<Value *> indices,
   882                      ConversionPatternRewriter &rewriter,
   883                      llvm::Module &module) const {
   884      auto ptrType = getMemRefElementPtrType(type, this->lowering);
   885      auto shape = type.getShape();
   886      if (type.hasStaticShape()) {
   887        // NB: If memref was statically-shaped, dataPtr is pointer to raw data.
   888        return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
   889      }
   890      return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter);
   891    }
   892  };
   893  
   894  // Load operation is lowered to obtaining a pointer to the indexed element
   895  // and loading it.
   896  struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
   897    using Base::Base;
   898  
   899    PatternMatchResult
   900    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   901                    ConversionPatternRewriter &rewriter) const override {
   902      auto loadOp = cast<LoadOp>(op);
   903      OperandAdaptor<LoadOp> transformed(operands);
   904      auto type = loadOp.getMemRefType();
   905  
   906      Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
   907                                  transformed.indices(), rewriter, getModule());
   908      rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
   909      return matchSuccess();
   910    }
   911  };
   912  
   913  // Store opreation is lowered to obtaining a pointer to the indexed element,
   914  // and storing the given value to it.
   915  struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
   916    using Base::Base;
   917  
   918    PatternMatchResult
   919    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   920                    ConversionPatternRewriter &rewriter) const override {
   921      auto type = cast<StoreOp>(op).getMemRefType();
   922      OperandAdaptor<StoreOp> transformed(operands);
   923  
   924      Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
   925                                  transformed.indices(), rewriter, getModule());
   926      rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
   927                                                 dataPtr);
   928      return matchSuccess();
   929    }
   930  };
   931  
   932  // The lowering of index_cast becomes an integer conversion since index becomes
   933  // an integer.  If the bit width of the source and target integer types is the
   934  // same, just erase the cast.  If the target type is wider, sign-extend the
   935  // value, otherwise truncate it.
   936  struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
   937    using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern;
   938  
   939    PatternMatchResult
   940    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   941                    ConversionPatternRewriter &rewriter) const override {
   942      IndexCastOpOperandAdaptor transformed(operands);
   943      auto indexCastOp = cast<IndexCastOp>(op);
   944  
   945      auto targetType =
   946          this->lowering.convertType(indexCastOp.getResult()->getType())
   947              .cast<LLVM::LLVMType>();
   948      auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>();
   949      unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
   950      unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
   951  
   952      if (targetBits == sourceBits)
   953        rewriter.replaceOp(op, transformed.in());
   954      else if (targetBits < sourceBits)
   955        rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
   956                                                   transformed.in());
   957      else
   958        rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType,
   959                                                  transformed.in());
   960      return matchSuccess();
   961    }
   962  };
   963  
   964  // Convert std.cmp predicate into the LLVM dialect CmpPredicate.  The two
   965  // enums share the numerical values so just cast.
   966  template <typename LLVMPredType, typename StdPredType>
   967  static LLVMPredType convertCmpPredicate(StdPredType pred) {
   968    return static_cast<LLVMPredType>(pred);
   969  }
   970  
   971  struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
   972    using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern;
   973  
   974    PatternMatchResult
   975    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   976                    ConversionPatternRewriter &rewriter) const override {
   977      auto cmpiOp = cast<CmpIOp>(op);
   978      CmpIOpOperandAdaptor transformed(operands);
   979  
   980      rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
   981          op, lowering.convertType(cmpiOp.getResult()->getType()),
   982          rewriter.getI64IntegerAttr(static_cast<int64_t>(
   983              convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
   984          transformed.lhs(), transformed.rhs());
   985  
   986      return matchSuccess();
   987    }
   988  };
   989  
   990  struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
   991    using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern;
   992  
   993    PatternMatchResult
   994    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   995                    ConversionPatternRewriter &rewriter) const override {
   996      auto cmpfOp = cast<CmpFOp>(op);
   997      CmpFOpOperandAdaptor transformed(operands);
   998  
   999      rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
  1000          op, lowering.convertType(cmpfOp.getResult()->getType()),
  1001          rewriter.getI64IntegerAttr(static_cast<int64_t>(
  1002              convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
  1003          transformed.lhs(), transformed.rhs());
  1004  
  1005      return matchSuccess();
  1006    }
  1007  };
  1008  
  1009  struct SIToFPLowering
  1010      : public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> {
  1011    using Super::Super;
  1012  };
  1013  
  1014  // Base class for LLVM IR lowering terminator operations with successors.
  1015  template <typename SourceOp, typename TargetOp>
  1016  struct OneToOneLLVMTerminatorLowering
  1017      : public LLVMLegalizationPattern<SourceOp> {
  1018    using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
  1019    using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
  1020  
  1021    PatternMatchResult
  1022    matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands,
  1023                    ArrayRef<Block *> destinations,
  1024                    ArrayRef<ArrayRef<Value *>> operands,
  1025                    ConversionPatternRewriter &rewriter) const override {
  1026      rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
  1027                                            operands, op->getAttrs());
  1028      return this->matchSuccess();
  1029    }
  1030  };
  1031  
  1032  // Special lowering pattern for `ReturnOps`.  Unlike all other operations,
  1033  // `ReturnOp` interacts with the function signature and must have as many
  1034  // operands as the function has return values.  Because in LLVM IR, functions
  1035  // can only return 0 or 1 value, we pack multiple values into a structure type.
  1036  // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
  1037  // necessary before returning it
  1038  struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
  1039    using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern;
  1040  
  1041    PatternMatchResult
  1042    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
  1043                    ConversionPatternRewriter &rewriter) const override {
  1044      unsigned numArguments = op->getNumOperands();
  1045  
  1046      // If ReturnOp has 0 or 1 operand, create it and return immediately.
  1047      if (numArguments == 0) {
  1048        rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
  1049            op, llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(),
  1050            llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
  1051        return matchSuccess();
  1052      }
  1053      if (numArguments == 1) {
  1054        rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
  1055            op, llvm::ArrayRef<Value *>(operands.front()),
  1056            llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(),
  1057            op->getAttrs());
  1058        return matchSuccess();
  1059      }
  1060  
  1061      // Otherwise, we need to pack the arguments into an LLVM struct type before
  1062      // returning.
  1063      auto packedType =
  1064          lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes()));
  1065  
  1066      Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
  1067      for (unsigned i = 0; i < numArguments; ++i) {
  1068        packed = rewriter.create<LLVM::InsertValueOp>(
  1069            op->getLoc(), packedType, packed, operands[i],
  1070            rewriter.getIndexArrayAttr(i));
  1071      }
  1072      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
  1073          op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
  1074          llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs());
  1075      return matchSuccess();
  1076    }
  1077  };
  1078  
  1079  // FIXME: this should be tablegen'ed as well.
  1080  struct BranchOpLowering
  1081      : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> {
  1082    using Super::Super;
  1083  };
  1084  struct CondBranchOpLowering
  1085      : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> {
  1086    using Super::Super;
  1087  };
  1088  
  1089  } // namespace
  1090  
  1091  static void ensureDistinctSuccessors(Block &bb) {
  1092    auto *terminator = bb.getTerminator();
  1093  
  1094    // Find repeated successors with arguments.
  1095    llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions;
  1096    for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) {
  1097      Block *successor = terminator->getSuccessor(i);
  1098      // Blocks with no arguments are safe even if they appear multiple times
  1099      // because they don't need PHI nodes.
  1100      if (successor->getNumArguments() == 0)
  1101        continue;
  1102      successorPositions[successor].push_back(i);
  1103    }
  1104  
  1105    // If a successor appears for the second or more time in the terminator,
  1106    // create a new dummy block that unconditionally branches to the original
  1107    // destination, and retarget the terminator to branch to this new block.
  1108    // There is no need to pass arguments to the dummy block because it will be
  1109    // dominated by the original block and can therefore use any values defined in
  1110    // the original block.
  1111    for (const auto &successor : successorPositions) {
  1112      const auto &positions = successor.second;
  1113      // Start from the second occurrence of a block in the successor list.
  1114      for (auto position = std::next(positions.begin()), end = positions.end();
  1115           position != end; ++position) {
  1116        auto *dummyBlock = new Block();
  1117        bb.getParent()->push_back(dummyBlock);
  1118        auto builder = OpBuilder(dummyBlock);
  1119        SmallVector<Value *, 8> operands(
  1120            terminator->getSuccessorOperands(*position));
  1121        builder.create<BranchOp>(terminator->getLoc(), successor.first, operands);
  1122        terminator->setSuccessor(dummyBlock, *position);
  1123        for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e;
  1124             ++i)
  1125          terminator->eraseSuccessorOperand(*position, i);
  1126      }
  1127    }
  1128  }
  1129  
  1130  void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
  1131    for (auto f : m.getOps<FuncOp>()) {
  1132      for (auto &bb : f.getBlocks()) {
  1133        ::ensureDistinctSuccessors(bb);
  1134      }
  1135    }
  1136  }
  1137  
  1138  /// Collect a set of patterns to convert from the Standard dialect to LLVM.
  1139  void mlir::populateStdToLLVMConversionPatterns(
  1140      LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
  1141    // FIXME: this should be tablegen'ed
  1142    patterns.insert<
  1143        AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
  1144        BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
  1145        CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
  1146        DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
  1147        DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
  1148        MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
  1149        RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
  1150        SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering,
  1151        SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter);
  1152  }
  1153  
  1154  // Convert types using the stored LLVM IR module.
  1155  Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); }
  1156  
  1157  // Create an LLVM IR structure type if there is more than one result.
  1158  Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
  1159    assert(!types.empty() && "expected non-empty list of type");
  1160  
  1161    if (types.size() == 1)
  1162      return convertType(types.front());
  1163  
  1164    SmallVector<LLVM::LLVMType, 8> resultTypes;
  1165    resultTypes.reserve(types.size());
  1166    for (auto t : types) {
  1167      auto converted = convertType(t).dyn_cast<LLVM::LLVMType>();
  1168      if (!converted)
  1169        return {};
  1170      resultTypes.push_back(converted);
  1171    }
  1172  
  1173    return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes);
  1174  }
  1175  
  1176  /// Create an instance of LLVMTypeConverter in the given context.
  1177  static std::unique_ptr<LLVMTypeConverter>
  1178  makeStandardToLLVMTypeConverter(MLIRContext *context) {
  1179    return std::make_unique<LLVMTypeConverter>(context);
  1180  }
  1181  
  1182  namespace {
  1183  /// A pass converting MLIR operations into the LLVM IR dialect.
  1184  struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
  1185    // By default, the patterns are those converting Standard operations to the
  1186    // LLVMIR dialect.
  1187    explicit LLVMLoweringPass(
  1188        LLVMPatternListFiller patternListFiller =
  1189            populateStdToLLVMConversionPatterns,
  1190        LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
  1191        : patternListFiller(patternListFiller),
  1192          typeConverterMaker(converterBuilder) {}
  1193  
  1194    // Run the dialect converter on the module.
  1195    void runOnModule() override {
  1196      if (!typeConverterMaker || !patternListFiller)
  1197        return signalPassFailure();
  1198  
  1199      ModuleOp m = getModule();
  1200      LLVM::ensureDistinctSuccessors(m);
  1201      std::unique_ptr<LLVMTypeConverter> typeConverter =
  1202          typeConverterMaker(&getContext());
  1203      if (!typeConverter)
  1204        return signalPassFailure();
  1205  
  1206      OwningRewritePatternList patterns;
  1207      populateLoopToStdConversionPatterns(patterns, m.getContext());
  1208      patternListFiller(*typeConverter, patterns);
  1209  
  1210      ConversionTarget target(getContext());
  1211      target.addLegalDialect<LLVM::LLVMDialect>();
  1212      target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
  1213        return typeConverter->isSignatureLegal(op.getType());
  1214      });
  1215      if (failed(applyPartialConversion(m, target, patterns, &*typeConverter)))
  1216        signalPassFailure();
  1217    }
  1218  
  1219    // Callback for creating a list of patterns.  It is called every time in
  1220    // runOnModule since applyPartialConversion consumes the list.
  1221    LLVMPatternListFiller patternListFiller;
  1222  
  1223    // Callback for creating an instance of type converter.  The converter
  1224    // constructor needs an MLIRContext, which is not available until runOnModule.
  1225    LLVMTypeConverterMaker typeConverterMaker;
  1226  };
  1227  } // end namespace
  1228  
  1229  std::unique_ptr<ModulePassBase> mlir::createConvertToLLVMIRPass() {
  1230    return std::make_unique<LLVMLoweringPass>();
  1231  }
  1232  
  1233  std::unique_ptr<ModulePassBase>
  1234  mlir::createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller,
  1235                                  LLVMTypeConverterMaker typeConverterMaker) {
  1236    return std::make_unique<LLVMLoweringPass>(patternListFiller,
  1237                                              typeConverterMaker);
  1238  }
  1239  
  1240  static PassRegistration<LLVMLoweringPass>
  1241      pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");