github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp (about) 1 //===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" 19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 23 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 24 #include "mlir/Dialect/Linalg/Passes.h" 25 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h" 26 #include "mlir/EDSC/Builders.h" 27 #include "mlir/EDSC/Intrinsics.h" 28 #include "mlir/IR/AffineExpr.h" 29 #include "mlir/IR/AffineMap.h" 30 #include "mlir/IR/Attributes.h" 31 #include "mlir/IR/Builders.h" 32 #include "mlir/IR/MLIRContext.h" 33 #include "mlir/IR/Module.h" 34 #include "mlir/IR/Operation.h" 35 #include "mlir/IR/PatternMatch.h" 36 #include "mlir/IR/StandardTypes.h" 37 #include "mlir/IR/Types.h" 38 #include "mlir/Pass/Pass.h" 39 #include "mlir/Pass/PassManager.h" 40 #include "mlir/Support/LogicalResult.h" 41 #include "mlir/Transforms/DialectConversion.h" 42 #include "mlir/Transforms/LowerAffine.h" 43 #include "mlir/Transforms/Passes.h" 44 45 #include "llvm/ADT/SetVector.h" 46 #include "llvm/IR/DerivedTypes.h" 47 #include "llvm/IR/Module.h" 48 #include "llvm/IR/Type.h" 49 #include "llvm/Support/Allocator.h" 50 #include "llvm/Support/ErrorHandling.h" 51 52 using namespace mlir; 53 using namespace mlir::edsc; 54 using namespace mlir::edsc::intrinsics; 55 using namespace mlir::LLVM; 56 using namespace mlir::linalg; 57 using namespace mlir::linalg::intrinsics; 58 59 using add = ValueBuilder<mlir::LLVM::AddOp>; 60 using addi = ValueBuilder<mlir::AddIOp>; 61 using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>; 62 using cmpi = ValueBuilder<mlir::CmpIOp>; 63 using constant = ValueBuilder<mlir::LLVM::ConstantOp>; 64 using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>; 65 using gep = ValueBuilder<mlir::LLVM::GEPOp>; 66 using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>; 67 using llvm_call = OperationBuilder<mlir::LLVM::CallOp>; 68 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>; 69 using llvm_load = ValueBuilder<LLVM::LoadOp>; 70 using llvm_store = OperationBuilder<LLVM::StoreOp>; 71 using llvm_select = ValueBuilder<LLVM::SelectOp>; 72 using mul = ValueBuilder<mlir::LLVM::MulOp>; 73 using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>; 74 using sub = ValueBuilder<mlir::LLVM::SubOp>; 75 using undef = ValueBuilder<mlir::LLVM::UndefOp>; 76 using urem = ValueBuilder<mlir::LLVM::URemOp>; 77 using llvm_alloca = ValueBuilder<LLVM::AllocaOp>; 78 using llvm_return = OperationBuilder<LLVM::ReturnOp>; 79 80 template <typename T> 81 static LLVMType getPtrToElementType(T containerType, 82 LLVMTypeConverter &lowering) { 83 return lowering.convertType(containerType.getElementType()) 84 .template cast<LLVMType>() 85 .getPointerTo(); 86 } 87 88 // Convert the given type to the LLVM IR Dialect type. The following 89 // conversions are supported: 90 // - an Index type is converted into an LLVM integer type with pointer 91 // bitwidth (analogous to intptr_t in C); 92 // - an Integer type is converted into an LLVM integer type of the same width; 93 // - an F32 type is converted into an LLVM float type 94 // - a Buffer, Range or View is converted into an LLVM structure type 95 // containing the respective dynamic values. 96 static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { 97 auto *context = t.getContext(); 98 auto int64Ty = lowering.convertType(IntegerType::get(64, context)) 99 .cast<LLVM::LLVMType>(); 100 101 // A buffer descriptor contains the pointer to a flat region of storage and 102 // the size of the region. 103 // 104 // template <typename Elem, size_t Rank> 105 // struct { 106 // void *baseAlloc; 107 // Elem *ptr; 108 // int64_t size; 109 // }; 110 if (auto bufferType = t.dyn_cast<BufferType>()) { 111 auto voidPtrTy = LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); 112 auto ptrTy = getPtrToElementType(bufferType, lowering); 113 return LLVMType::getStructTy(voidPtrTy, ptrTy, int64Ty); 114 } 115 116 // Range descriptor contains the range bounds and the step as 64-bit integers. 117 // 118 // struct { 119 // int64_t min; 120 // int64_t max; 121 // int64_t step; 122 // }; 123 if (t.isa<RangeType>()) 124 return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); 125 126 // A linalg.view type converts to a *pointer to* a view descriptor. The view 127 // descriptor contains the pointer to the data buffer, followed by a 64-bit 128 // integer containing the distance between the beginning of the buffer and the 129 // first element to be accessed through the view, followed by two arrays, each 130 // containing as many 64-bit integers as the rank of the View. The first array 131 // represents the size, in number of original elements, of the view along the 132 // given dimension. When taking the view, the size is the difference between 133 // the upper and the lower bound of the range. The second array represents the 134 // "stride" (in tensor abstraction sense), i.e. the number of consecutive 135 // elements of the underlying buffer that separate two consecutive elements 136 // addressable through the view along the given dimension. When taking the 137 // view, the strides are constructed as products of the original sizes along 138 // the trailing dimensions, multiplied by the view step. For example, a view 139 // of a MxN memref with ranges {0:M:1}, {0:N:1}, i.e. the view of a complete 140 // memref, will have strides N and 1. A view with ranges {0:M:2}, {0:N:3} 141 // will have strides 2*N and 3. 142 // 143 // template <typename Elem, size_t Rank> 144 // struct { 145 // Elem *ptr; 146 // int64_t offset; 147 // int64_t sizes[Rank]; 148 // int64_t strides[Rank]; 149 // } *; 150 if (auto viewType = t.dyn_cast<ViewType>()) { 151 auto ptrTy = getPtrToElementType(viewType, lowering); 152 auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank()); 153 return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy) 154 .getPointerTo(); 155 } 156 157 return Type(); 158 } 159 160 static constexpr int kBasePtrPosInBuffer = 0; 161 static constexpr int kPtrPosInBuffer = 1; 162 static constexpr int kSizePosInBuffer = 2; 163 static constexpr int kPtrPosInView = 0; 164 static constexpr int kOffsetPosInView = 1; 165 static constexpr int kSizePosInView = 2; 166 static constexpr int kStridePosInView = 3; 167 168 // Create an array attribute containing integer attributes with values provided 169 // in `position`. 170 static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) { 171 SmallVector<Attribute, 4> attrs; 172 attrs.reserve(position.size()); 173 for (auto p : position) 174 attrs.push_back(builder.getI64IntegerAttr(p)); 175 return builder.getArrayAttr(attrs); 176 } 177 178 namespace { 179 /// Factor out the common information for all view conversions: 180 /// 1. common types in (standard and LLVM dialects) 181 /// 2. `pos` method 182 /// 3. op of the FuncOp alloca'ed value and descriptor. 183 class BaseViewConversionHelper { 184 public: 185 BaseViewConversionHelper(Operation *op, ViewType viewType, 186 ConversionPatternRewriter &rewriter, 187 LLVMTypeConverter &lowering) 188 : indexType(rewriter.getIndexType()), viewType(viewType), 189 elementTy(getPtrToElementType(viewType, lowering)), 190 int64Ty( 191 lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()), 192 viewDescriptorPtrTy( 193 convertLinalgType(viewType, lowering).cast<LLVMType>()), 194 rewriter(rewriter) { 195 196 OpBuilder::InsertionGuard insertGuard(rewriter); 197 rewriter.setInsertionPointToStart( 198 &op->getParentOfType<FuncOp>().getBlocks().front()); 199 200 edsc::ScopedContext context(rewriter, op->getLoc()); 201 one = constant(int64Ty, IntegerAttr::get(indexType, 1)); 202 // Alloca with proper alignment. 203 allocatedDesc = llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); 204 // Load the alloca'ed descriptor. 205 desc = llvm_load(allocatedDesc); 206 } 207 208 ArrayAttr pos(ArrayRef<int> values) const { 209 return positionAttr(rewriter, values); 210 }; 211 212 IndexType indexType; 213 ViewType viewType; 214 LLVMType elementTy, int64Ty, viewDescriptorPtrTy; 215 ConversionPatternRewriter &rewriter; 216 Value *one, *allocatedDesc, *desc; 217 }; 218 } // namespace 219 220 // BufferAllocOp creates a new `!linalg.buffer` value. 221 class BufferAllocOpConversion : public LLVMOpLowering { 222 public: 223 explicit BufferAllocOpConversion(MLIRContext *context, 224 LLVMTypeConverter &lowering_) 225 : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {} 226 227 PatternMatchResult 228 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 229 ConversionPatternRewriter &rewriter) const override { 230 auto indexType = IndexType::get(op->getContext()); 231 auto voidPtrTy = 232 LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); 233 auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); 234 // Insert the `malloc` declaration if it is not already present. 235 auto module = op->getParentOfType<ModuleOp>(); 236 FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc"); 237 if (!mallocFunc) { 238 auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); 239 mallocFunc = 240 FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); 241 module.push_back(mallocFunc); 242 } 243 244 // Get MLIR types for injecting element pointer. 245 auto allocOp = cast<BufferAllocOp>(op); 246 auto elementType = allocOp.getElementType(); 247 uint64_t elementSize = 0; 248 if (auto vectorType = elementType.dyn_cast<VectorType>()) 249 elementSize = vectorType.getNumElements() * 250 llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); 251 else 252 elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); 253 auto bufferType = allocOp.getBufferType(); 254 auto elementPtrType = getPtrToElementType(bufferType, lowering); 255 auto bufferDescriptorTy = convertLinalgType(bufferType, lowering); 256 257 // Emit IR for creating a new buffer descriptor with an underlying malloc. 258 edsc::ScopedContext context(rewriter, op->getLoc()); 259 auto constantSize = bufferType.getBufferSize(); 260 Value *size = 261 constantSize 262 ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize)) 263 .getValue() 264 : operands[0]; 265 Value *allocSize = 266 mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize))); 267 Value *one = nullptr, *align = nullptr; 268 if (allocOp.alignment().hasValue()) { 269 one = constant(int64Ty, IntegerAttr::get(indexType, 1)); 270 align = 271 constant(int64Ty, rewriter.getIntegerAttr( 272 rewriter.getIndexType(), 273 allocOp.alignment().getValue().getSExtValue())); 274 allocSize = sub(add(allocSize, align), one); 275 } 276 277 Value *allocated = 278 llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize) 279 .getOperation() 280 ->getResult(0); 281 Value *data = allocated; 282 if (allocOp.alignment().hasValue()) { 283 // offset = (align - (ptr % align))% align 284 Value *offset = 285 urem(sub(align, urem(ptrtoint(int64Ty, allocated), align)), align); 286 data = gep(voidPtrTy, allocated, offset); 287 } 288 data = bitcast(elementPtrType, data); 289 Value *desc = undef(bufferDescriptorTy); 290 desc = insertvalue(bufferDescriptorTy, desc, allocated, 291 positionAttr(rewriter, kBasePtrPosInBuffer)); 292 desc = insertvalue(bufferDescriptorTy, desc, data, 293 positionAttr(rewriter, kPtrPosInBuffer)); 294 desc = insertvalue(bufferDescriptorTy, desc, size, 295 positionAttr(rewriter, kSizePosInBuffer)); 296 rewriter.replaceOp(op, desc); 297 return matchSuccess(); 298 } 299 }; 300 301 // BufferDeallocOp creates no value. 302 class BufferDeallocOpConversion : public LLVMOpLowering { 303 public: 304 explicit BufferDeallocOpConversion(MLIRContext *context, 305 LLVMTypeConverter &lowering_) 306 : LLVMOpLowering(BufferDeallocOp::getOperationName(), context, 307 lowering_) {} 308 309 PatternMatchResult 310 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 311 ConversionPatternRewriter &rewriter) const override { 312 auto voidPtrTy = 313 LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); 314 // Insert the `free` declaration if it is not already present. 315 auto module = op->getParentOfType<ModuleOp>(); 316 FuncOp freeFunc = module.lookupSymbol<FuncOp>("free"); 317 if (!freeFunc) { 318 auto freeType = rewriter.getFunctionType(voidPtrTy, {}); 319 freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); 320 module.push_back(freeFunc); 321 } 322 323 // Emit MLIR for buffer_dealloc. 324 BufferDeallocOpOperandAdaptor adaptor(operands); 325 edsc::ScopedContext context(rewriter, op->getLoc()); 326 Value *base = extractvalue(voidPtrTy, adaptor.buffer(), 327 positionAttr(rewriter, kBasePtrPosInBuffer)); 328 llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), base); 329 rewriter.replaceOp(op, llvm::None); 330 return matchSuccess(); 331 } 332 }; 333 334 // BufferSizeOp creates a new `index` value. 335 class BufferSizeOpConversion : public LLVMOpLowering { 336 public: 337 BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 338 : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {} 339 340 PatternMatchResult 341 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 342 ConversionPatternRewriter &rewriter) const override { 343 auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); 344 edsc::ScopedContext context(rewriter, op->getLoc()); 345 BufferSizeOpOperandAdaptor adaptor(operands); 346 rewriter.replaceOp( 347 op, {extractvalue(int64Ty, adaptor.buffer(), 348 positionAttr(rewriter, kSizePosInBuffer))}); 349 return matchSuccess(); 350 } 351 }; 352 353 // DimOp creates a new `index` value. 354 class DimOpConversion : public LLVMOpLowering { 355 public: 356 explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 357 : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {} 358 359 PatternMatchResult 360 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 361 ConversionPatternRewriter &rewriter) const override { 362 auto dimOp = cast<linalg::DimOp>(op); 363 auto indexTy = lowering.convertType(rewriter.getIndexType()); 364 edsc::ScopedContext context(rewriter, op->getLoc()); 365 auto pos = positionAttr( 366 rewriter, {kSizePosInView, static_cast<int>(dimOp.getIndex())}); 367 linalg::DimOpOperandAdaptor adaptor(operands); 368 Value *viewDescriptor = llvm_load(adaptor.view()); 369 rewriter.replaceOp(op, {extractvalue(indexTy, viewDescriptor, pos)}); 370 return matchSuccess(); 371 } 372 }; 373 374 namespace { 375 // Common functionality for Linalg LoadOp and StoreOp conversion to the 376 // LLVM IR Dialect. 377 template <typename Op> class LoadStoreOpConversion : public LLVMOpLowering { 378 public: 379 explicit LoadStoreOpConversion(MLIRContext *context, 380 LLVMTypeConverter &lowering_) 381 : LLVMOpLowering(Op::getOperationName(), context, lowering_) {} 382 using Base = LoadStoreOpConversion<Op>; 383 384 // Compute the pointer to an element of the buffer underlying the view given 385 // current view indices. Use the base offset and strides stored in the view 386 // descriptor to emit IR iteratively computing the actual offset, followed by 387 // a getelementptr. This must be called under an edsc::ScopedContext. 388 Value *obtainDataPtr(Operation *op, Value *viewDescriptorPtr, 389 ArrayRef<Value *> indices, 390 ConversionPatternRewriter &rewriter) const { 391 auto loadOp = cast<Op>(op); 392 auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering); 393 auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); 394 auto pos = [&rewriter](ArrayRef<int> values) { 395 return positionAttr(rewriter, values); 396 }; 397 398 // Linearize subscripts as: 399 // base_offset + SUM_i index_i * stride_i. 400 Value *viewDescriptor = llvm_load(viewDescriptorPtr); 401 Value *base = extractvalue(elementTy, viewDescriptor, pos(kPtrPosInView)); 402 Value *offset = 403 extractvalue(int64Ty, viewDescriptor, pos(kOffsetPosInView)); 404 for (int i = 0, e = loadOp.getRank(); i < e; ++i) { 405 Value *stride = 406 extractvalue(int64Ty, viewDescriptor, pos({kStridePosInView, i})); 407 Value *additionalOffset = mul(indices[i], stride); 408 offset = add(offset, additionalOffset); 409 } 410 return gep(elementTy, base, offset); 411 } 412 }; 413 } // namespace 414 415 // A load is converted into the actual address computation, getelementptr and 416 // an LLVM IR load. 417 class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> { 418 using Base::Base; 419 PatternMatchResult 420 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 421 ConversionPatternRewriter &rewriter) const override { 422 edsc::ScopedContext edscContext(rewriter, op->getLoc()); 423 auto elementTy = lowering.convertType(*op->result_type_begin()); 424 linalg::LoadOpOperandAdaptor adaptor(operands); 425 auto ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter); 426 rewriter.replaceOp(op, {llvm_load(elementTy, ptr)}); 427 return matchSuccess(); 428 } 429 }; 430 431 // RangeOp creates a new range descriptor. 432 class RangeOpConversion : public LLVMOpLowering { 433 public: 434 explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 435 : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} 436 437 PatternMatchResult 438 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 439 ConversionPatternRewriter &rewriter) const override { 440 auto rangeOp = cast<RangeOp>(op); 441 auto rangeDescriptorTy = 442 convertLinalgType(rangeOp.getResult()->getType(), lowering); 443 444 edsc::ScopedContext context(rewriter, op->getLoc()); 445 446 // Fill in an aggregate value of the descriptor. 447 RangeOpOperandAdaptor adaptor(operands); 448 Value *desc = undef(rangeDescriptorTy); 449 desc = insertvalue(desc, adaptor.min(), positionAttr(rewriter, 0)); 450 desc = insertvalue(desc, adaptor.max(), positionAttr(rewriter, 1)); 451 desc = insertvalue(desc, adaptor.step(), positionAttr(rewriter, 2)); 452 rewriter.replaceOp(op, desc); 453 return matchSuccess(); 454 } 455 }; 456 457 /// Conversion pattern that transforms a linalg.slice op into: 458 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 459 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 460 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 461 /// and stride corresponding to the region of memory within the bounds of 462 /// the parent view. 463 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 464 /// The linalg.slice op is replaced by the alloca'ed pointer. 465 class SliceOpConversion : public LLVMOpLowering { 466 public: 467 explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 468 : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} 469 470 PatternMatchResult 471 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 472 ConversionPatternRewriter &rewriter) const override { 473 SliceOpOperandAdaptor adaptor(operands); 474 auto sliceOp = cast<SliceOp>(op); 475 auto viewDescriptorPtrTy = 476 convertLinalgType(sliceOp.getViewType(), lowering); 477 auto viewType = sliceOp.getBaseViewType(); 478 auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); 479 480 // Helper function to create an integer array attribute out of a list of 481 // values. 482 auto pos = [&rewriter](ArrayRef<int> values) { 483 return positionAttr(rewriter, values); 484 }; 485 486 edsc::ScopedContext context(rewriter, op->getLoc()); 487 // Declare the view descriptor and insert data ptr *at the entry block of 488 // the function*, which is the preferred location for LLVM's analyses. 489 auto ip = rewriter.getInsertionPoint(); 490 auto ib = rewriter.getInsertionBlock(); 491 rewriter.setInsertionPointToStart( 492 &op->getParentOfType<FuncOp>().getBlocks().front()); 493 Value *zero = 494 constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); 495 Value *one = 496 constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); 497 // Alloca with proper alignment. 498 Value *allocatedDesc = 499 llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); 500 Value *desc = llvm_load(allocatedDesc); 501 rewriter.setInsertionPoint(ib, ip); 502 503 Value *baseDesc = llvm_load(adaptor.view()); 504 505 auto ptrPos = pos(kPtrPosInView); 506 auto elementTy = getPtrToElementType(sliceOp.getViewType(), lowering); 507 desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); 508 509 // TODO(ntv): extract sizes and emit asserts. 510 SmallVector<Value *, 4> strides(viewType.getRank()); 511 for (int i = 0, e = viewType.getRank(); i < e; ++i) { 512 strides[i] = extractvalue(int64Ty, baseDesc, pos({kStridePosInView, i})); 513 } 514 515 // Compute and insert base offset. 516 Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView)); 517 for (int i = 0, e = viewType.getRank(); i < e; ++i) { 518 Value *indexing = adaptor.indexings()[i]; 519 Value *min = indexing; 520 if (sliceOp.indexing(i)->getType().isa<RangeType>()) 521 min = extractvalue(int64Ty, indexing, pos(0)); 522 baseOffset = add(baseOffset, mul(min, strides[i])); 523 } 524 desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); 525 526 // Compute and insert view sizes (max - min along the range) and strides. 527 // Skip the non-range operands as they will be projected away from the view. 528 int numNewDims = 0; 529 for (auto en : llvm::enumerate(sliceOp.indexings())) { 530 Value *indexing = en.value(); 531 if (indexing->getType().isa<RangeType>()) { 532 int rank = en.index(); 533 Value *rangeDescriptor = adaptor.indexings()[rank]; 534 Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); 535 Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); 536 Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); 537 Value *baseSize = 538 extractvalue(int64Ty, baseDesc, pos({kSizePosInView, rank})); 539 // Bound upper by base view upper bound. 540 max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, 541 baseSize); 542 Value *size = sub(max, min); 543 // Bound lower by zero. 544 size = 545 llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); 546 Value *stride = mul(strides[rank], step); 547 desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims})); 548 desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims})); 549 ++numNewDims; 550 } 551 } 552 553 // Store back in alloca'ed region. 554 llvm_store(desc, allocatedDesc); 555 rewriter.replaceOp(op, allocatedDesc); 556 return matchSuccess(); 557 } 558 }; 559 560 // A store is converted into the actual address computation, getelementptr and 561 // an LLVM IR store. 562 class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> { 563 using Base::Base; 564 PatternMatchResult 565 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 566 ConversionPatternRewriter &rewriter) const override { 567 edsc::ScopedContext edscContext(rewriter, op->getLoc()); 568 linalg::StoreOpOperandAdaptor adaptor(operands); 569 Value *ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter); 570 llvm_store(adaptor.value(), ptr); 571 rewriter.replaceOp(op, llvm::None); 572 return matchSuccess(); 573 } 574 }; 575 576 /// Conversion pattern that transforms a linalg.transpose op into: 577 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 578 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 579 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 580 /// and stride. Size and stride are permutations of the original values. 581 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 582 /// The linalg.transpose op is replaced by the alloca'ed pointer. 583 class TransposeOpConversion : public LLVMOpLowering { 584 public: 585 explicit TransposeOpConversion(MLIRContext *context, 586 LLVMTypeConverter &lowering_) 587 : LLVMOpLowering(TransposeOp::getOperationName(), context, lowering_) {} 588 589 PatternMatchResult 590 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 591 ConversionPatternRewriter &rewriter) const override { 592 // Initialize the common boilerplate and alloca at the top of the FuncOp. 593 TransposeOpOperandAdaptor adaptor(operands); 594 auto tranposeOp = cast<TransposeOp>(op); 595 BaseViewConversionHelper helper(op, tranposeOp.getViewType(), rewriter, 596 lowering); 597 IndexType indexType = helper.indexType; 598 ViewType viewType = helper.viewType; 599 LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty, 600 viewDescriptorPtrTy = helper.viewDescriptorPtrTy; 601 Value *allocatedDesc = helper.allocatedDesc, *desc = helper.desc; 602 603 edsc::ScopedContext context(rewriter, op->getLoc()); 604 // Load the descriptor of the view constructed by the helper. 605 Value *baseDesc = llvm_load(adaptor.view()); 606 607 // Copy the base pointer from the old descriptor to the new one. 608 ArrayAttr ptrPos = helper.pos(kPtrPosInView); 609 desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); 610 611 // Copy the offset pointer from the old descriptor to the new one. 612 ArrayAttr offPos = helper.pos(kOffsetPosInView); 613 desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos); 614 615 if (tranposeOp.permutation().isIdentity()) { 616 // No permutation, just store back in alloca'ed region. 617 llvm_store(desc, allocatedDesc); 618 return rewriter.replaceOp(op, allocatedDesc), matchSuccess(); 619 } 620 621 // Iterate over the dimensions and apply size/stride permutation. 622 for (auto en : llvm::enumerate(tranposeOp.permutation().getResults())) { 623 int sourcePos = en.index(); 624 int targetPos = en.value().cast<AffineDimExpr>().getPosition(); 625 Value *size = extractvalue(int64Ty, baseDesc, 626 helper.pos({kSizePosInView, sourcePos})); 627 desc = insertvalue(desc, size, helper.pos({kSizePosInView, targetPos})); 628 Value *stride = extractvalue(int64Ty, baseDesc, 629 helper.pos({kStridePosInView, sourcePos})); 630 desc = 631 insertvalue(desc, stride, helper.pos({kStridePosInView, targetPos})); 632 } 633 634 // Store back in alloca'ed region. 635 llvm_store(desc, allocatedDesc); 636 rewriter.replaceOp(op, allocatedDesc); 637 return matchSuccess(); 638 } 639 }; 640 641 /// Conversion pattern that transforms a linalg.view op into: 642 /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. 643 /// 2. A load of the ViewDescriptor from the pointer allocated in 1. 644 /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size 645 /// and stride. 646 /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. 647 /// The linalg.view op is replaced by the alloca'ed pointer. 648 class ViewOpConversion : public LLVMOpLowering { 649 public: 650 explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) 651 : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {} 652 653 PatternMatchResult 654 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 655 ConversionPatternRewriter &rewriter) const override { 656 auto viewOp = cast<ViewOp>(op); 657 ViewOpOperandAdaptor adaptor(operands); 658 auto viewDescriptorPtrTy = 659 convertLinalgType(viewOp.getViewType(), lowering); 660 auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering); 661 auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); 662 663 auto pos = [&rewriter](ArrayRef<int> values) { 664 return positionAttr(rewriter, values); 665 }; 666 667 Value *bufferDescriptor = adaptor.buffer(); 668 auto bufferTy = getPtrToElementType( 669 viewOp.buffer()->getType().cast<BufferType>(), lowering); 670 671 // Declare the descriptor of the view. 672 edsc::ScopedContext context(rewriter, op->getLoc()); 673 auto ip = rewriter.getInsertionPoint(); 674 auto ib = rewriter.getInsertionBlock(); 675 rewriter.setInsertionPointToStart( 676 &op->getParentOfType<FuncOp>().getBlocks().front()); 677 Value *one = 678 constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); 679 // Alloca for proper alignment. 680 Value *allocatedDesc = 681 llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8); 682 Value *desc = llvm_load(allocatedDesc); 683 rewriter.setInsertionPoint(ib, ip); 684 685 // Copy the buffer pointer from the old descriptor to the new one. 686 Value *bufferAsViewElementType = 687 bitcast(elementTy, 688 extractvalue(bufferTy, bufferDescriptor, pos(kPtrPosInBuffer))); 689 desc = insertvalue(desc, bufferAsViewElementType, pos(kPtrPosInView)); 690 691 // Zero base offset. 692 auto indexTy = rewriter.getIndexType(); 693 Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0)); 694 desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView)); 695 696 // Compute and insert view sizes (max - min along the range). 697 int numRanges = llvm::size(viewOp.ranges()); 698 Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1)); 699 for (int i = numRanges - 1; i >= 0; --i) { 700 // Update stride. 701 Value *rangeDescriptor = operands[1 + i]; 702 Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); 703 Value *stride = mul(runningStride, step); 704 desc = insertvalue(desc, stride, pos({kStridePosInView, i})); 705 // Update size. 706 Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); 707 Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); 708 Value *size = sub(max, min); 709 desc = insertvalue(desc, size, pos({kSizePosInView, i})); 710 // Update stride for the next dimension. 711 if (i > 0) 712 runningStride = mul(runningStride, max); 713 } 714 715 // Store back in alloca'ed region. 716 llvm_store(desc, allocatedDesc); 717 rewriter.replaceOp(op, allocatedDesc); 718 return matchSuccess(); 719 } 720 }; 721 722 // Get function definition for the LinalgOp. If it doesn't exist, insert a 723 // definition. 724 template <typename LinalgOp> 725 static FuncOp 726 getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, 727 ConversionPatternRewriter &rewriter) { 728 auto linalgOp = cast<LinalgOp>(op); 729 auto fnName = linalgOp.getLibraryCallName(); 730 if (fnName.empty()) { 731 op->emitWarning("No library call defined for: ") << *op; 732 return FuncOp(); 733 } 734 auto module = op->getParentOfType<ModuleOp>(); 735 if (auto f = module.lookupSymbol<FuncOp>(fnName)) { 736 return f; 737 } 738 739 // Get the Function type consistent with LLVM Lowering. 740 SmallVector<Type, 4> inputTypes; 741 for (auto operand : op->getOperands()) 742 inputTypes.push_back(lowering.convertType(operand->getType())); 743 assert(op->getNumResults() == 0 && 744 "Library call for linalg operation can be generated only for ops that " 745 "have void return types"); 746 auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); 747 auto libFn = FuncOp::create(op->getLoc(), fnName, libFnType); 748 module.push_back(libFn); 749 // Return after creating the function definition. The body will be created 750 // later. 751 return libFn; 752 } 753 754 namespace { 755 // The conversion class from Linalg to LLVMIR. 756 class LinalgTypeConverter : public LLVMTypeConverter { 757 using LLVMTypeConverter::LLVMTypeConverter; 758 759 public: 760 Type convertType(Type t) override { 761 if (auto result = LLVMTypeConverter::convertType(t)) 762 return result; 763 return convertLinalgType(t, *this); 764 } 765 }; 766 } // end anonymous namespace 767 768 // LinalgOpConversion<LinalgOp> creates a new call to the 769 // `LinalgOp::getLibraryCallName()` function. 770 // The implementation of the function can be either in the same module or in an 771 // externally linked library. 772 template <typename LinalgOp> class LinalgOpConversion : public LLVMOpLowering { 773 public: 774 explicit LinalgOpConversion(MLIRContext *context, 775 LinalgTypeConverter &lowering_) 776 : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {} 777 778 PatternMatchResult 779 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 780 ConversionPatternRewriter &rewriter) const override { 781 auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter); 782 if (!f) 783 return matchFailure(); 784 785 auto fAttr = rewriter.getSymbolRefAttr(f); 786 auto named = rewriter.getNamedAttr("callee", fAttr); 787 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands, 788 ArrayRef<NamedAttribute>{named}); 789 return matchSuccess(); 790 } 791 }; 792 793 /// Conversion pattern specialization for CopyOp. This kicks in when both input 794 /// and output permutations are left unspecified or are the identity. 795 template <> class LinalgOpConversion<CopyOp> : public LLVMOpLowering { 796 public: 797 explicit LinalgOpConversion(MLIRContext *context, 798 LinalgTypeConverter &lowering_) 799 : LLVMOpLowering(CopyOp::getOperationName(), context, lowering_) {} 800 801 PatternMatchResult 802 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 803 ConversionPatternRewriter &rewriter) const override { 804 auto copyOp = cast<CopyOp>(op); 805 auto inputPerm = copyOp.inputPermutation(); 806 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 807 return matchFailure(); 808 auto outputPerm = copyOp.outputPermutation(); 809 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 810 return matchFailure(); 811 812 auto f = getLLVMLibraryCallDeclaration<CopyOp>(op, lowering, rewriter); 813 if (!f) 814 return matchFailure(); 815 816 auto fAttr = rewriter.getSymbolRefAttr(f); 817 auto named = rewriter.getNamedAttr("callee", fAttr); 818 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operands, 819 ArrayRef<NamedAttribute>{named}); 820 return matchSuccess(); 821 } 822 }; 823 824 /// A non-conversion rewrite pattern kicks in to convert CopyOp with 825 /// permutations into a sequence of TransposeOp and permutation-free CopyOp. 826 /// This interplays together with TransposeOpConversion and 827 /// LinalgConversion<CopyOp> to create a path to the LLVM dialect. 828 class CopyTransposeConversion : public OpRewritePattern<CopyOp> { 829 public: 830 using OpRewritePattern<CopyOp>::OpRewritePattern; 831 832 PatternMatchResult matchAndRewrite(CopyOp op, 833 PatternRewriter &rewriter) const override { 834 Value *in = op.input(), *out = op.output(); 835 836 // If either inputPerm or outputPerm are non-identities, insert transposes. 837 auto inputPerm = op.inputPermutation(); 838 if (inputPerm.hasValue() && !inputPerm->isIdentity()) 839 in = rewriter.create<linalg::TransposeOp>(op.getLoc(), in, 840 AffineMapAttr::get(*inputPerm)); 841 auto outputPerm = op.outputPermutation(); 842 if (outputPerm.hasValue() && !outputPerm->isIdentity()) 843 out = rewriter.create<linalg::TransposeOp>( 844 op.getLoc(), out, AffineMapAttr::get(*outputPerm)); 845 846 // If nothing was transposed, fail and let the conversion kick in. 847 if (in == op.input() && out == op.output()) 848 return matchFailure(); 849 850 rewriter.replaceOpWithNewOp<CopyOp>(op, in, out); 851 return matchSuccess(); 852 } 853 }; 854 855 /// Populate the given list with patterns that convert from Linalg to LLVM. 856 static void 857 populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, 858 OwningRewritePatternList &patterns, 859 MLIRContext *ctx) { 860 patterns.insert<CopyTransposeConversion>(ctx); 861 patterns.insert<BufferAllocOpConversion, BufferDeallocOpConversion, 862 BufferSizeOpConversion, DimOpConversion, 863 LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>, 864 LinalgOpConversion<FillOp>, LinalgOpConversion<MatmulOp>, 865 LoadOpConversion, RangeOpConversion, SliceOpConversion, 866 StoreOpConversion, TransposeOpConversion, ViewOpConversion>( 867 ctx, converter); 868 } 869 870 namespace { 871 struct LowerLinalgToLLVMPass : public ModulePass<LowerLinalgToLLVMPass> { 872 void runOnModule(); 873 }; 874 } // namespace 875 876 // This is currently written as a standalone function because the lowering to 877 // affine will look different than lowering to LLVM and it is still unclear how 878 // everything will be eventually structured. 879 static void lowerLinalgSubViewOps(FuncOp &f) { 880 f.walk([&](SubViewOp op) { 881 OpBuilder b(op); 882 ScopedContext scope(b, op.getLoc()); 883 auto *view = op.getView(); 884 SmallVector<Value *, 8> ranges; 885 for (auto sliceRange : op.getRanges()) 886 ranges.push_back(range(sliceRange.min, sliceRange.max, sliceRange.step)); 887 op.replaceAllUsesWith(slice(view, ranges)); 888 op.erase(); 889 }); 890 } 891 892 void LowerLinalgToLLVMPass::runOnModule() { 893 auto module = getModule(); 894 895 for (auto f : module.getOps<FuncOp>()) 896 lowerLinalgSubViewOps(f); 897 898 // Convert to the LLVM IR dialect using the converter defined above. 899 OwningRewritePatternList patterns; 900 LinalgTypeConverter converter(&getContext()); 901 populateAffineToStdConversionPatterns(patterns, &getContext()); 902 populateLoopToStdConversionPatterns(patterns, &getContext()); 903 populateStdToLLVMConversionPatterns(converter, patterns); 904 populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); 905 906 ConversionTarget target(getContext()); 907 target.addLegalDialect<LLVM::LLVMDialect>(); 908 target.addDynamicallyLegalOp<FuncOp>( 909 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 910 if (failed(applyPartialConversion(module, target, patterns, &converter))) { 911 signalPassFailure(); 912 } 913 } 914 915 std::unique_ptr<ModulePassBase> mlir::linalg::createLowerLinalgToLLVMPass() { 916 return std::make_unique<LowerLinalgToLLVMPass>(); 917 } 918 919 static PassRegistration<LowerLinalgToLLVMPass> 920 pass("linalg-lower-to-llvm-dialect", 921 "Lower the operations from the linalg dialect into the LLVM dialect");