github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp (about)

     1  //===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h"
    19  #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
    20  #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
    21  #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
    22  #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
    23  #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
    24  #include "mlir/Dialect/Linalg/Passes.h"
    25  #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
    26  #include "mlir/EDSC/Builders.h"
    27  #include "mlir/EDSC/Intrinsics.h"
    28  #include "mlir/IR/AffineExpr.h"
    29  #include "mlir/IR/AffineMap.h"
    30  #include "mlir/IR/Attributes.h"
    31  #include "mlir/IR/Builders.h"
    32  #include "mlir/IR/MLIRContext.h"
    33  #include "mlir/IR/Module.h"
    34  #include "mlir/IR/Operation.h"
    35  #include "mlir/IR/PatternMatch.h"
    36  #include "mlir/IR/StandardTypes.h"
    37  #include "mlir/IR/Types.h"
    38  #include "mlir/Pass/Pass.h"
    39  #include "mlir/Pass/PassManager.h"
    40  #include "mlir/Support/LogicalResult.h"
    41  #include "mlir/Transforms/DialectConversion.h"
    42  #include "mlir/Transforms/LowerAffine.h"
    43  #include "mlir/Transforms/Passes.h"
    44  
    45  #include "llvm/ADT/SetVector.h"
    46  #include "llvm/IR/DerivedTypes.h"
    47  #include "llvm/IR/Module.h"
    48  #include "llvm/IR/Type.h"
    49  #include "llvm/Support/Allocator.h"
    50  #include "llvm/Support/ErrorHandling.h"
    51  
    52  using namespace mlir;
    53  using namespace mlir::edsc;
    54  using namespace mlir::edsc::intrinsics;
    55  using namespace mlir::LLVM;
    56  using namespace mlir::linalg;
    57  using namespace mlir::linalg::intrinsics;
    58  
    59  using add = ValueBuilder<mlir::LLVM::AddOp>;
    60  using addi = ValueBuilder<mlir::AddIOp>;
    61  using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
    62  using cmpi = ValueBuilder<mlir::CmpIOp>;
    63  using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
    64  using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
    65  using gep = ValueBuilder<mlir::LLVM::GEPOp>;
    66  using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
    67  using llvm_call = OperationBuilder<mlir::LLVM::CallOp>;
    68  using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
    69  using llvm_load = ValueBuilder<LLVM::LoadOp>;
    70  using llvm_store = OperationBuilder<LLVM::StoreOp>;
    71  using llvm_select = ValueBuilder<LLVM::SelectOp>;
    72  using mul = ValueBuilder<mlir::LLVM::MulOp>;
    73  using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>;
    74  using sub = ValueBuilder<mlir::LLVM::SubOp>;
    75  using undef = ValueBuilder<mlir::LLVM::UndefOp>;
    76  using urem = ValueBuilder<mlir::LLVM::URemOp>;
    77  using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
    78  using llvm_return = OperationBuilder<LLVM::ReturnOp>;
    79  
    80  template <typename T>
    81  static LLVMType getPtrToElementType(T containerType,
    82                                      LLVMTypeConverter &lowering) {
    83    return lowering.convertType(containerType.getElementType())
    84        .template cast<LLVMType>()
    85        .getPointerTo();
    86  }
    87  
    88  // Convert the given type to the LLVM IR Dialect type.  The following
    89  // conversions are supported:
    90  //   - an Index type is converted into an LLVM integer type with pointer
    91  //     bitwidth (analogous to intptr_t in C);
    92  //   - an Integer type is converted into an LLVM integer type of the same width;
    93  //   - an F32 type is converted into an LLVM float type
    94  //   - a Buffer, Range or View is converted into an LLVM structure type
    95  //     containing the respective dynamic values.
    96  static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
    97    auto *context = t.getContext();
    98    auto int64Ty = lowering.convertType(IntegerType::get(64, context))
    99                       .cast<LLVM::LLVMType>();
   100  
   101    // A buffer descriptor contains the pointer to a flat region of storage and
   102    // the size of the region.
   103    //
   104    // template <typename Elem, size_t Rank>
   105    // struct {
   106    //   void *baseAlloc;
   107    //   Elem *ptr;
   108    //   int64_t size;
   109    // };
   110    if (auto bufferType = t.dyn_cast<BufferType>()) {
   111      auto voidPtrTy = LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
   112      auto ptrTy = getPtrToElementType(bufferType, lowering);
   113      return LLVMType::getStructTy(voidPtrTy, ptrTy, int64Ty);
   114    }
   115  
   116    // Range descriptor contains the range bounds and the step as 64-bit integers.
   117    //
   118    // struct {
   119    //   int64_t min;
   120    //   int64_t max;
   121    //   int64_t step;
   122    // };
   123    if (t.isa<RangeType>())
   124      return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
   125  
   126    // A linalg.view type converts to a *pointer to* a view descriptor. The view
   127    // descriptor contains the pointer to the data buffer, followed by a 64-bit
   128    // integer containing the distance between the beginning of the buffer and the
   129    // first element to be accessed through the view, followed by two arrays, each
   130    // containing as many 64-bit integers as the rank of the View. The first array
   131    // represents the size, in number of original elements, of the view along the
   132    // given dimension.  When taking the view, the size is the difference between
   133    // the upper and the lower bound of the range. The second array represents the
   134    // "stride" (in tensor abstraction sense), i.e. the number of consecutive
   135    // elements of the underlying buffer that separate two consecutive elements
   136    // addressable through the view along the given dimension.  When taking the
   137    // view, the strides are constructed as products of the original sizes along
   138    // the trailing dimensions, multiplied by the view step.  For example, a view
   139    // of a MxN memref with ranges {0:M:1}, {0:N:1}, i.e. the view of a complete
   140    // memref, will have strides N and 1.  A view with ranges {0:M:2}, {0:N:3}
   141    // will have strides 2*N and 3.
   142    //
   143    // template <typename Elem, size_t Rank>
   144    // struct {
   145    //   Elem *ptr;
   146    //   int64_t offset;
   147    //   int64_t sizes[Rank];
   148    //   int64_t strides[Rank];
   149    // } *;
   150    if (auto viewType = t.dyn_cast<ViewType>()) {
   151      auto ptrTy = getPtrToElementType(viewType, lowering);
   152      auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
   153      return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy)
   154          .getPointerTo();
   155    }
   156  
   157    return Type();
   158  }
   159  
   160  static constexpr int kBasePtrPosInBuffer = 0;
   161  static constexpr int kPtrPosInBuffer = 1;
   162  static constexpr int kSizePosInBuffer = 2;
   163  static constexpr int kPtrPosInView = 0;
   164  static constexpr int kOffsetPosInView = 1;
   165  static constexpr int kSizePosInView = 2;
   166  static constexpr int kStridePosInView = 3;
   167  
   168  // Create an array attribute containing integer attributes with values provided
   169  // in `position`.
   170  static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
   171    SmallVector<Attribute, 4> attrs;
   172    attrs.reserve(position.size());
   173    for (auto p : position)
   174      attrs.push_back(builder.getI64IntegerAttr(p));
   175    return builder.getArrayAttr(attrs);
   176  }
   177  
   178  namespace {
   179  /// Factor out the common information for all view conversions:
   180  ///   1. common types in (standard and LLVM dialects)
   181  ///   2. `pos` method
   182  ///   3. op of the FuncOp alloca'ed value and descriptor.
   183  class BaseViewConversionHelper {
   184  public:
   185    BaseViewConversionHelper(Operation *op, ViewType viewType,
   186                             ConversionPatternRewriter &rewriter,
   187                             LLVMTypeConverter &lowering)
   188        : indexType(rewriter.getIndexType()), viewType(viewType),
   189          elementTy(getPtrToElementType(viewType, lowering)),
   190          int64Ty(
   191              lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()),
   192          viewDescriptorPtrTy(
   193              convertLinalgType(viewType, lowering).cast<LLVMType>()),
   194          rewriter(rewriter) {
   195  
   196      OpBuilder::InsertionGuard insertGuard(rewriter);
   197      rewriter.setInsertionPointToStart(
   198          &op->getParentOfType<FuncOp>().getBlocks().front());
   199  
   200      edsc::ScopedContext context(rewriter, op->getLoc());
   201      one = constant(int64Ty, IntegerAttr::get(indexType, 1));
   202      // Alloca with proper alignment.
   203      allocatedDesc = llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
   204      // Load the alloca'ed descriptor.
   205      desc = llvm_load(allocatedDesc);
   206    }
   207  
   208    ArrayAttr pos(ArrayRef<int> values) const {
   209      return positionAttr(rewriter, values);
   210    };
   211  
   212    IndexType indexType;
   213    ViewType viewType;
   214    LLVMType elementTy, int64Ty, viewDescriptorPtrTy;
   215    ConversionPatternRewriter &rewriter;
   216    Value *one, *allocatedDesc, *desc;
   217  };
   218  } // namespace
   219  
   220  // BufferAllocOp creates a new `!linalg.buffer` value.
   221  class BufferAllocOpConversion : public LLVMOpLowering {
   222  public:
   223    explicit BufferAllocOpConversion(MLIRContext *context,
   224                                     LLVMTypeConverter &lowering_)
   225        : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {}
   226  
   227    PatternMatchResult
   228    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   229                    ConversionPatternRewriter &rewriter) const override {
   230      auto indexType = IndexType::get(op->getContext());
   231      auto voidPtrTy =
   232          LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
   233      auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
   234      // Insert the `malloc` declaration if it is not already present.
   235      auto module = op->getParentOfType<ModuleOp>();
   236      FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc");
   237      if (!mallocFunc) {
   238        auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
   239        mallocFunc =
   240            FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType);
   241        module.push_back(mallocFunc);
   242      }
   243  
   244      // Get MLIR types for injecting element pointer.
   245      auto allocOp = cast<BufferAllocOp>(op);
   246      auto elementType = allocOp.getElementType();
   247      uint64_t elementSize = 0;
   248      if (auto vectorType = elementType.dyn_cast<VectorType>())
   249        elementSize = vectorType.getNumElements() *
   250                      llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
   251      else
   252        elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
   253      auto bufferType = allocOp.getBufferType();
   254      auto elementPtrType = getPtrToElementType(bufferType, lowering);
   255      auto bufferDescriptorTy = convertLinalgType(bufferType, lowering);
   256  
   257      // Emit IR for creating a new buffer descriptor with an underlying malloc.
   258      edsc::ScopedContext context(rewriter, op->getLoc());
   259      auto constantSize = bufferType.getBufferSize();
   260      Value *size =
   261          constantSize
   262              ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize))
   263                    .getValue()
   264              : operands[0];
   265      Value *allocSize =
   266          mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
   267      Value *one = nullptr, *align = nullptr;
   268      if (allocOp.alignment().hasValue()) {
   269        one = constant(int64Ty, IntegerAttr::get(indexType, 1));
   270        align =
   271            constant(int64Ty, rewriter.getIntegerAttr(
   272                                  rewriter.getIndexType(),
   273                                  allocOp.alignment().getValue().getSExtValue()));
   274        allocSize = sub(add(allocSize, align), one);
   275      }
   276  
   277      Value *allocated =
   278          llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
   279              .getOperation()
   280              ->getResult(0);
   281      Value *data = allocated;
   282      if (allocOp.alignment().hasValue()) {
   283        // offset = (align - (ptr % align))% align
   284        Value *offset =
   285            urem(sub(align, urem(ptrtoint(int64Ty, allocated), align)), align);
   286        data = gep(voidPtrTy, allocated, offset);
   287      }
   288      data = bitcast(elementPtrType, data);
   289      Value *desc = undef(bufferDescriptorTy);
   290      desc = insertvalue(bufferDescriptorTy, desc, allocated,
   291                         positionAttr(rewriter, kBasePtrPosInBuffer));
   292      desc = insertvalue(bufferDescriptorTy, desc, data,
   293                         positionAttr(rewriter, kPtrPosInBuffer));
   294      desc = insertvalue(bufferDescriptorTy, desc, size,
   295                         positionAttr(rewriter, kSizePosInBuffer));
   296      rewriter.replaceOp(op, desc);
   297      return matchSuccess();
   298    }
   299  };
   300  
   301  // BufferDeallocOp creates no value.
   302  class BufferDeallocOpConversion : public LLVMOpLowering {
   303  public:
   304    explicit BufferDeallocOpConversion(MLIRContext *context,
   305                                       LLVMTypeConverter &lowering_)
   306        : LLVMOpLowering(BufferDeallocOp::getOperationName(), context,
   307                         lowering_) {}
   308  
   309    PatternMatchResult
   310    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   311                    ConversionPatternRewriter &rewriter) const override {
   312      auto voidPtrTy =
   313          LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
   314      // Insert the `free` declaration if it is not already present.
   315      auto module = op->getParentOfType<ModuleOp>();
   316      FuncOp freeFunc = module.lookupSymbol<FuncOp>("free");
   317      if (!freeFunc) {
   318        auto freeType = rewriter.getFunctionType(voidPtrTy, {});
   319        freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType);
   320        module.push_back(freeFunc);
   321      }
   322  
   323      // Emit MLIR for buffer_dealloc.
   324      BufferDeallocOpOperandAdaptor adaptor(operands);
   325      edsc::ScopedContext context(rewriter, op->getLoc());
   326      Value *base = extractvalue(voidPtrTy, adaptor.buffer(),
   327                                 positionAttr(rewriter, kBasePtrPosInBuffer));
   328      llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), base);
   329      rewriter.replaceOp(op, llvm::None);
   330      return matchSuccess();
   331    }
   332  };
   333  
   334  // BufferSizeOp creates a new `index` value.
   335  class BufferSizeOpConversion : public LLVMOpLowering {
   336  public:
   337    BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
   338        : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {}
   339  
   340    PatternMatchResult
   341    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   342                    ConversionPatternRewriter &rewriter) const override {
   343      auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
   344      edsc::ScopedContext context(rewriter, op->getLoc());
   345      BufferSizeOpOperandAdaptor adaptor(operands);
   346      rewriter.replaceOp(
   347          op, {extractvalue(int64Ty, adaptor.buffer(),
   348                            positionAttr(rewriter, kSizePosInBuffer))});
   349      return matchSuccess();
   350    }
   351  };
   352  
   353  // DimOp creates a new `index` value.
   354  class DimOpConversion : public LLVMOpLowering {
   355  public:
   356    explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
   357        : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {}
   358  
   359    PatternMatchResult
   360    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   361                    ConversionPatternRewriter &rewriter) const override {
   362      auto dimOp = cast<linalg::DimOp>(op);
   363      auto indexTy = lowering.convertType(rewriter.getIndexType());
   364      edsc::ScopedContext context(rewriter, op->getLoc());
   365      auto pos = positionAttr(
   366          rewriter, {kSizePosInView, static_cast<int>(dimOp.getIndex())});
   367      linalg::DimOpOperandAdaptor adaptor(operands);
   368      Value *viewDescriptor = llvm_load(adaptor.view());
   369      rewriter.replaceOp(op, {extractvalue(indexTy, viewDescriptor, pos)});
   370      return matchSuccess();
   371    }
   372  };
   373  
   374  namespace {
   375  // Common functionality for Linalg LoadOp and StoreOp conversion to the
   376  // LLVM IR Dialect.
   377  template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering {
   378  public:
   379    explicit LoadStoreOpConversion(MLIRContext *context,
   380                                   LLVMTypeConverter &lowering_)
   381        : LLVMOpLowering(Op::getOperationName(), context, lowering_) {}
   382    using Base = LoadStoreOpConversion<Op>;
   383  
   384    // Compute the pointer to an element of the buffer underlying the view given
   385    // current view indices.  Use the base offset and strides stored in the view
   386    // descriptor to emit IR iteratively computing the actual offset, followed by
   387    // a getelementptr. This must be called under an edsc::ScopedContext.
   388    Value *obtainDataPtr(Operation *op, Value *viewDescriptorPtr,
   389                         ArrayRef<Value *> indices,
   390                         ConversionPatternRewriter &rewriter) const {
   391      auto loadOp = cast<Op>(op);
   392      auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering);
   393      auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
   394      auto pos = [&rewriter](ArrayRef<int> values) {
   395        return positionAttr(rewriter, values);
   396      };
   397  
   398      // Linearize subscripts as:
   399      //   base_offset + SUM_i index_i * stride_i.
   400      Value *viewDescriptor = llvm_load(viewDescriptorPtr);
   401      Value *base = extractvalue(elementTy, viewDescriptor, pos(kPtrPosInView));
   402      Value *offset =
   403          extractvalue(int64Ty, viewDescriptor, pos(kOffsetPosInView));
   404      for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
   405        Value *stride =
   406            extractvalue(int64Ty, viewDescriptor, pos({kStridePosInView, i}));
   407        Value *additionalOffset = mul(indices[i], stride);
   408        offset = add(offset, additionalOffset);
   409      }
   410      return gep(elementTy, base, offset);
   411    }
   412  };
   413  } // namespace
   414  
   415  // A load is converted into the actual address computation, getelementptr and
   416  // an LLVM IR load.
   417  class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
   418    using Base::Base;
   419    PatternMatchResult
   420    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   421                    ConversionPatternRewriter &rewriter) const override {
   422      edsc::ScopedContext edscContext(rewriter, op->getLoc());
   423      auto elementTy = lowering.convertType(*op->result_type_begin());
   424      linalg::LoadOpOperandAdaptor adaptor(operands);
   425      auto ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter);
   426      rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
   427      return matchSuccess();
   428    }
   429  };
   430  
   431  // RangeOp creates a new range descriptor.
   432  class RangeOpConversion : public LLVMOpLowering {
   433  public:
   434    explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
   435        : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {}
   436  
   437    PatternMatchResult
   438    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   439                    ConversionPatternRewriter &rewriter) const override {
   440      auto rangeOp = cast<RangeOp>(op);
   441      auto rangeDescriptorTy =
   442          convertLinalgType(rangeOp.getResult()->getType(), lowering);
   443  
   444      edsc::ScopedContext context(rewriter, op->getLoc());
   445  
   446      // Fill in an aggregate value of the descriptor.
   447      RangeOpOperandAdaptor adaptor(operands);
   448      Value *desc = undef(rangeDescriptorTy);
   449      desc = insertvalue(desc, adaptor.min(), positionAttr(rewriter, 0));
   450      desc = insertvalue(desc, adaptor.max(), positionAttr(rewriter, 1));
   451      desc = insertvalue(desc, adaptor.step(), positionAttr(rewriter, 2));
   452      rewriter.replaceOp(op, desc);
   453      return matchSuccess();
   454    }
   455  };
   456  
   457  /// Conversion pattern that transforms a linalg.slice op into:
   458  ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
   459  ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
   460  ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
   461  ///      and stride corresponding to the region of memory within the bounds of
   462  ///      the parent view.
   463  ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
   464  /// The linalg.slice op is replaced by the alloca'ed pointer.
   465  class SliceOpConversion : public LLVMOpLowering {
   466  public:
   467    explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
   468        : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {}
   469  
   470    PatternMatchResult
   471    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   472                    ConversionPatternRewriter &rewriter) const override {
   473      SliceOpOperandAdaptor adaptor(operands);
   474      auto sliceOp = cast<SliceOp>(op);
   475      auto viewDescriptorPtrTy =
   476          convertLinalgType(sliceOp.getViewType(), lowering);
   477      auto viewType = sliceOp.getBaseViewType();
   478      auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
   479  
   480      // Helper function to create an integer array attribute out of a list of
   481      // values.
   482      auto pos = [&rewriter](ArrayRef<int> values) {
   483        return positionAttr(rewriter, values);
   484      };
   485  
   486      edsc::ScopedContext context(rewriter, op->getLoc());
   487      // Declare the view descriptor and insert data ptr *at the entry block of
   488      // the function*, which is the preferred location for LLVM's analyses.
   489      auto ip = rewriter.getInsertionPoint();
   490      auto ib = rewriter.getInsertionBlock();
   491      rewriter.setInsertionPointToStart(
   492          &op->getParentOfType<FuncOp>().getBlocks().front());
   493      Value *zero =
   494          constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
   495      Value *one =
   496          constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
   497      // Alloca with proper alignment.
   498      Value *allocatedDesc =
   499          llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
   500      Value *desc = llvm_load(allocatedDesc);
   501      rewriter.setInsertionPoint(ib, ip);
   502  
   503      Value *baseDesc = llvm_load(adaptor.view());
   504  
   505      auto ptrPos = pos(kPtrPosInView);
   506      auto elementTy = getPtrToElementType(sliceOp.getViewType(), lowering);
   507      desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
   508  
   509      // TODO(ntv): extract sizes and emit asserts.
   510      SmallVector<Value *, 4> strides(viewType.getRank());
   511      for (int i = 0, e = viewType.getRank(); i < e; ++i) {
   512        strides[i] = extractvalue(int64Ty, baseDesc, pos({kStridePosInView, i}));
   513      }
   514  
   515      // Compute and insert base offset.
   516      Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
   517      for (int i = 0, e = viewType.getRank(); i < e; ++i) {
   518        Value *indexing = adaptor.indexings()[i];
   519        Value *min = indexing;
   520        if (sliceOp.indexing(i)->getType().isa<RangeType>())
   521          min = extractvalue(int64Ty, indexing, pos(0));
   522        baseOffset = add(baseOffset, mul(min, strides[i]));
   523      }
   524      desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
   525  
   526      // Compute and insert view sizes (max - min along the range) and strides.
   527      // Skip the non-range operands as they will be projected away from the view.
   528      int numNewDims = 0;
   529      for (auto en : llvm::enumerate(sliceOp.indexings())) {
   530        Value *indexing = en.value();
   531        if (indexing->getType().isa<RangeType>()) {
   532          int rank = en.index();
   533          Value *rangeDescriptor = adaptor.indexings()[rank];
   534          Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
   535          Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
   536          Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
   537          Value *baseSize =
   538              extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank}));
   539          // Bound upper by base view upper bound.
   540          max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max,
   541                            baseSize);
   542          Value *size = sub(max, min);
   543          // Bound lower by zero.
   544          size =
   545              llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size);
   546          Value *stride = mul(strides[rank], step);
   547          desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
   548          desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
   549          ++numNewDims;
   550        }
   551      }
   552  
   553      // Store back in alloca'ed region.
   554      llvm_store(desc, allocatedDesc);
   555      rewriter.replaceOp(op, allocatedDesc);
   556      return matchSuccess();
   557    }
   558  };
   559  
   560  // A store is converted into the actual address computation, getelementptr and
   561  // an LLVM IR store.
   562  class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
   563    using Base::Base;
   564    PatternMatchResult
   565    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   566                    ConversionPatternRewriter &rewriter) const override {
   567      edsc::ScopedContext edscContext(rewriter, op->getLoc());
   568      linalg::StoreOpOperandAdaptor adaptor(operands);
   569      Value *ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter);
   570      llvm_store(adaptor.value(), ptr);
   571      rewriter.replaceOp(op, llvm::None);
   572      return matchSuccess();
   573    }
   574  };
   575  
   576  /// Conversion pattern that transforms a linalg.transpose op into:
   577  ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
   578  ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
   579  ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
   580  ///      and stride. Size and stride are permutations of the original values.
   581  ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
   582  /// The linalg.transpose op is replaced by the alloca'ed pointer.
   583  class TransposeOpConversion : public LLVMOpLowering {
   584  public:
   585    explicit TransposeOpConversion(MLIRContext *context,
   586                                   LLVMTypeConverter &lowering_)
   587        : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {}
   588  
   589    PatternMatchResult
   590    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   591                    ConversionPatternRewriter &rewriter) const override {
   592      // Initialize the common boilerplate and alloca at the top of the FuncOp.
   593      TransposeOpOperandAdaptor adaptor(operands);
   594      auto tranposeOp = cast<TransposeOp>(op);
   595      BaseViewConversionHelper helper(op, tranposeOp.getViewType(), rewriter,
   596                                      lowering);
   597      IndexType indexType = helper.indexType;
   598      ViewType viewType = helper.viewType;
   599      LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty,
   600               viewDescriptorPtrTy = helper.viewDescriptorPtrTy;
   601      Value *allocatedDesc = helper.allocatedDesc, *desc = helper.desc;
   602  
   603      edsc::ScopedContext context(rewriter, op->getLoc());
   604      // Load the descriptor of the view constructed by the helper.
   605      Value *baseDesc = llvm_load(adaptor.view());
   606  
   607      // Copy the base pointer from the old descriptor to the new one.
   608      ArrayAttr ptrPos = helper.pos(kPtrPosInView);
   609      desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
   610  
   611      // Copy the offset pointer from the old descriptor to the new one.
   612      ArrayAttr offPos = helper.pos(kOffsetPosInView);
   613      desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos);
   614  
   615      if (tranposeOp.permutation().isIdentity()) {
   616        // No permutation, just store back in alloca'ed region.
   617        llvm_store(desc, allocatedDesc);
   618        return rewriter.replaceOp(op, allocatedDesc), matchSuccess();
   619      }
   620  
   621      // Iterate over the dimensions and apply size/stride permutation.
   622      for (auto en : llvm::enumerate(tranposeOp.permutation().getResults())) {
   623        int sourcePos = en.index();
   624        int targetPos = en.value().cast<AffineDimExpr>().getPosition();
   625        Value *size = extractvalue(int64Ty, baseDesc,
   626                                   helper.pos({kSizePosInView, sourcePos}));
   627        desc = insertvalue(desc, size, helper.pos({kSizePosInView, targetPos}));
   628        Value *stride = extractvalue(int64Ty, baseDesc,
   629                                     helper.pos({kStridePosInView, sourcePos}));
   630        desc =
   631            insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos}));
   632      }
   633  
   634      // Store back in alloca'ed region.
   635      llvm_store(desc, allocatedDesc);
   636      rewriter.replaceOp(op, allocatedDesc);
   637      return matchSuccess();
   638    }
   639  };
   640  
   641  /// Conversion pattern that transforms a linalg.view op into:
   642  ///   1. A function entry `alloca` operation to allocate a ViewDescriptor.
   643  ///   2. A load of the ViewDescriptor from the pointer allocated in 1.
   644  ///   3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
   645  ///      and stride.
   646  ///   4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
   647  /// The linalg.view op is replaced by the alloca'ed pointer.
   648  class ViewOpConversion : public LLVMOpLowering {
   649  public:
   650    explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
   651        : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
   652  
   653    PatternMatchResult
   654    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   655                    ConversionPatternRewriter &rewriter) const override {
   656      auto viewOp = cast<ViewOp>(op);
   657      ViewOpOperandAdaptor adaptor(operands);
   658      auto viewDescriptorPtrTy =
   659          convertLinalgType(viewOp.getViewType(), lowering);
   660      auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
   661      auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
   662  
   663      auto pos = [&rewriter](ArrayRef<int> values) {
   664        return positionAttr(rewriter, values);
   665      };
   666  
   667      Value *bufferDescriptor = adaptor.buffer();
   668      auto bufferTy = getPtrToElementType(
   669          viewOp.buffer()->getType().cast<BufferType>(), lowering);
   670  
   671      // Declare the descriptor of the view.
   672      edsc::ScopedContext context(rewriter, op->getLoc());
   673      auto ip = rewriter.getInsertionPoint();
   674      auto ib = rewriter.getInsertionBlock();
   675      rewriter.setInsertionPointToStart(
   676          &op->getParentOfType<FuncOp>().getBlocks().front());
   677      Value *one =
   678          constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
   679      // Alloca for proper alignment.
   680      Value *allocatedDesc =
   681          llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
   682      Value *desc = llvm_load(allocatedDesc);
   683      rewriter.setInsertionPoint(ib, ip);
   684  
   685      // Copy the buffer pointer from the old descriptor to the new one.
   686      Value *bufferAsViewElementType =
   687          bitcast(elementTy,
   688                  extractvalue(bufferTy, bufferDescriptor, pos(kPtrPosInBuffer)));
   689      desc = insertvalue(desc, bufferAsViewElementType, pos(kPtrPosInView));
   690  
   691      // Zero base offset.
   692      auto indexTy = rewriter.getIndexType();
   693      Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
   694      desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
   695  
   696      // Compute and insert view sizes (max - min along the range).
   697      int numRanges = llvm::size(viewOp.ranges());
   698      Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1));
   699      for (int i = numRanges - 1; i >= 0; --i) {
   700        // Update stride.
   701        Value *rangeDescriptor = operands[1 + i];
   702        Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
   703        Value *stride = mul(runningStride, step);
   704        desc = insertvalue(desc, stride, pos({kStridePosInView, i}));
   705        // Update size.
   706        Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
   707        Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
   708        Value *size = sub(max, min);
   709        desc = insertvalue(desc, size, pos({kSizePosInView, i}));
   710        // Update stride for the next dimension.
   711        if (i > 0)
   712          runningStride = mul(runningStride, max);
   713      }
   714  
   715      // Store back in alloca'ed region.
   716      llvm_store(desc, allocatedDesc);
   717      rewriter.replaceOp(op, allocatedDesc);
   718      return matchSuccess();
   719    }
   720  };
   721  
   722  // Get function definition for the LinalgOp. If it doesn't exist, insert a
   723  // definition.
   724  template <typename LinalgOp>
   725  static FuncOp
   726  getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
   727                                ConversionPatternRewriter &rewriter) {
   728    auto linalgOp = cast<LinalgOp>(op);
   729    auto fnName = linalgOp.getLibraryCallName();
   730    if (fnName.empty()) {
   731      op->emitWarning("No library call defined for: ") << *op;
   732      return FuncOp();
   733    }
   734    auto module = op->getParentOfType<ModuleOp>();
   735    if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
   736      return f;
   737    }
   738  
   739    // Get the Function type consistent with LLVM Lowering.
   740    SmallVector<Type, 4> inputTypes;
   741    for (auto operand : op->getOperands())
   742      inputTypes.push_back(lowering.convertType(operand->getType()));
   743    assert(op->getNumResults() == 0 &&
   744           "Library call for linalg operation can be generated only for ops that "
   745           "have void return types");
   746    auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
   747    auto libFn = FuncOp::create(op->getLoc(), fnName, libFnType);
   748    module.push_back(libFn);
   749    // Return after creating the function definition. The body will be created
   750    // later.
   751    return libFn;
   752  }
   753  
   754  namespace {
   755  // The conversion class from Linalg to LLVMIR.
   756  class LinalgTypeConverter : public LLVMTypeConverter {
   757    using LLVMTypeConverter::LLVMTypeConverter;
   758  
   759  public:
   760    Type convertType(Type t) override {
   761      if (auto result = LLVMTypeConverter::convertType(t))
   762        return result;
   763      return convertLinalgType(t, *this);
   764    }
   765  };
   766  } // end anonymous namespace
   767  
   768  // LinalgOpConversion<LinalgOp> creates a new call to the
   769  // `LinalgOp::getLibraryCallName()` function.
   770  // The implementation of the function can be either in the same module or in an
   771  // externally linked library.
   772  template <typename LinalgOp> class LinalgOpConversion : public LLVMOpLowering {
   773  public:
   774    explicit LinalgOpConversion(MLIRContext *context,
   775                                LinalgTypeConverter &lowering_)
   776        : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {}
   777  
   778    PatternMatchResult
   779    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   780                    ConversionPatternRewriter &rewriter) const override {
   781      auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
   782      if (!f)
   783        return matchFailure();
   784  
   785      auto fAttr = rewriter.getSymbolRefAttr(f);
   786      auto named = rewriter.getNamedAttr("callee", fAttr);
   787      rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
   788                                                ArrayRef<NamedAttribute>{named});
   789      return matchSuccess();
   790    }
   791  };
   792  
   793  /// Conversion pattern specialization for CopyOp. This kicks in when both input
   794  /// and output permutations are left unspecified or are the identity.
   795  template <> class LinalgOpConversion<CopyOp> : public LLVMOpLowering {
   796  public:
   797    explicit LinalgOpConversion(MLIRContext *context,
   798                                LinalgTypeConverter &lowering_)
   799        : LLVMOpLowering(CopyOp::getOperationName(), context, lowering_) {}
   800  
   801    PatternMatchResult
   802    matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   803                    ConversionPatternRewriter &rewriter) const override {
   804      auto copyOp = cast<CopyOp>(op);
   805      auto inputPerm = copyOp.inputPermutation();
   806      if (inputPerm.hasValue() && !inputPerm->isIdentity())
   807        return matchFailure();
   808      auto outputPerm = copyOp.outputPermutation();
   809      if (outputPerm.hasValue() && !outputPerm->isIdentity())
   810        return matchFailure();
   811  
   812      auto f = getLLVMLibraryCallDeclaration<CopyOp>(op, lowering, rewriter);
   813      if (!f)
   814        return matchFailure();
   815  
   816      auto fAttr = rewriter.getSymbolRefAttr(f);
   817      auto named = rewriter.getNamedAttr("callee", fAttr);
   818      rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands,
   819                                                ArrayRef<NamedAttribute>{named});
   820      return matchSuccess();
   821    }
   822  };
   823  
   824  /// A non-conversion rewrite pattern kicks in to convert CopyOp with
   825  /// permutations into a sequence of TransposeOp and permutation-free CopyOp.
   826  /// This interplays together with TransposeOpConversion and
   827  /// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
   828  class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
   829  public:
   830    using OpRewritePattern<CopyOp>::OpRewritePattern;
   831  
   832    PatternMatchResult matchAndRewrite(CopyOp op,
   833                                       PatternRewriter &rewriter) const override {
   834      Value *in = op.input(), *out = op.output();
   835  
   836      // If either inputPerm or outputPerm are non-identities, insert transposes.
   837      auto inputPerm = op.inputPermutation();
   838      if (inputPerm.hasValue() && !inputPerm->isIdentity())
   839        in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in,
   840                                                  AffineMapAttr::get(*inputPerm));
   841      auto outputPerm = op.outputPermutation();
   842      if (outputPerm.hasValue() && !outputPerm->isIdentity())
   843        out = rewriter.create<linalg::TransposeOp>(
   844            op.getLoc(), out, AffineMapAttr::get(*outputPerm));
   845  
   846      // If nothing was transposed, fail and let the conversion kick in.
   847      if (in == op.input() && out == op.output())
   848        return matchFailure();
   849  
   850      rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
   851      return matchSuccess();
   852    }
   853  };
   854  
   855  /// Populate the given list with patterns that convert from Linalg to LLVM.
   856  static void
   857  populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter,
   858                                         OwningRewritePatternList &patterns,
   859                                         MLIRContext *ctx) {
   860    patterns.insert<CopyTransposeConversion>(ctx);
   861    patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion,
   862                    BufferSizeOpConversion, DimOpConversion,
   863                    LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
   864                    LinalgOpConversion<FillOp>, LinalgOpConversion<MatmulOp>,
   865                    LoadOpConversion, RangeOpConversion, SliceOpConversion,
   866                    StoreOpConversion, TransposeOpConversion, ViewOpConversion>(
   867        ctx, converter);
   868  }
   869  
   870  namespace {
   871  struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> {
   872    void runOnModule();
   873  };
   874  } // namespace
   875  
   876  // This is currently written as a standalone function because the lowering to
   877  // affine will look different than lowering to LLVM and it is still unclear how
   878  // everything will be eventually structured.
   879  static void lowerLinalgSubViewOps(FuncOp &f) {
   880    f.walk([&](SubViewOp op) {
   881      OpBuilder b(op);
   882      ScopedContext scope(b, op.getLoc());
   883      auto *view = op.getView();
   884      SmallVector<Value *, 8> ranges;
   885      for (auto sliceRange : op.getRanges())
   886        ranges.push_back(range(sliceRange.min, sliceRange.max, sliceRange.step));
   887      op.replaceAllUsesWith(slice(view, ranges));
   888      op.erase();
   889    });
   890  }
   891  
   892  void LowerLinalgToLLVMPass::runOnModule() {
   893    auto module = getModule();
   894  
   895    for (auto f : module.getOps<FuncOp>())
   896      lowerLinalgSubViewOps(f);
   897  
   898    // Convert to the LLVM IR dialect using the converter defined above.
   899    OwningRewritePatternList patterns;
   900    LinalgTypeConverter converter(&getContext());
   901    populateAffineToStdConversionPatterns(patterns, &getContext());
   902    populateLoopToStdConversionPatterns(patterns, &getContext());
   903    populateStdToLLVMConversionPatterns(converter, patterns);
   904    populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
   905  
   906    ConversionTarget target(getContext());
   907    target.addLegalDialect<LLVM::LLVMDialect>();
   908    target.addDynamicallyLegalOp<FuncOp>(
   909        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
   910    if (failed(applyPartialConversion(module, target, patterns, &converter))) {
   911      signalPassFailure();
   912    }
   913  }
   914  
   915  std::unique_ptr<ModulePassBase> mlir::linalg::createLowerLinalgToLLVMPass() {
   916    return std::make_unique<LowerLinalgToLLVMPass>();
   917  }
   918  
   919  static PassRegistration<LowerLinalgToLLVMPass>
   920      pass("linalg-lower-to-llvm-dialect",
   921           "Lower the operations from the linalg dialect into the LLVM dialect");