github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp (about) 1 //===- ConvertStandardToLLVM.cpp - Standard to LLVM dialect conversion-----===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a pass to convert MLIR standard and builtin dialects 19 // into the LLVM IR dialect. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 24 #include "mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h" 25 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 27 #include "mlir/Dialect/StandardOps/Ops.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/MLIRContext.h" 30 #include "mlir/IR/Module.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Support/Functional.h" 34 #include "mlir/Transforms/DialectConversion.h" 35 #include "mlir/Transforms/Passes.h" 36 #include "mlir/Transforms/Utils.h" 37 38 #include "llvm/IR/DerivedTypes.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/Type.h" 41 42 using namespace mlir; 43 44 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) 45 : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) { 46 assert(llvmDialect && "LLVM IR dialect is not registered"); 47 module = &llvmDialect->getLLVMModule(); 48 } 49 50 // Get the LLVM context. 51 llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { 52 return module->getContext(); 53 } 54 55 // Extract an LLVM IR type from the LLVM IR dialect type. 56 LLVM::LLVMType LLVMTypeConverter::unwrap(Type type) { 57 if (!type) 58 return nullptr; 59 auto *mlirContext = type.getContext(); 60 auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>(); 61 if (!wrappedLLVMType) 62 emitError(UnknownLoc::get(mlirContext), 63 "conversion resulted in a non-LLVM type"); 64 return wrappedLLVMType; 65 } 66 67 LLVM::LLVMType LLVMTypeConverter::getIndexType() { 68 return LLVM::LLVMType::getIntNTy( 69 llvmDialect, module->getDataLayout().getPointerSizeInBits()); 70 } 71 72 Type LLVMTypeConverter::convertIndexType(IndexType type) { 73 return getIndexType(); 74 } 75 76 Type LLVMTypeConverter::convertIntegerType(IntegerType type) { 77 return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); 78 } 79 80 Type LLVMTypeConverter::convertFloatType(FloatType type) { 81 switch (type.getKind()) { 82 case mlir::StandardTypes::F32: 83 return LLVM::LLVMType::getFloatTy(llvmDialect); 84 case mlir::StandardTypes::F64: 85 return LLVM::LLVMType::getDoubleTy(llvmDialect); 86 case mlir::StandardTypes::F16: 87 return LLVM::LLVMType::getHalfTy(llvmDialect); 88 case mlir::StandardTypes::BF16: { 89 auto *mlirContext = llvmDialect->getContext(); 90 return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), 91 Type(); 92 } 93 default: 94 llvm_unreachable("non-float type in convertFloatType"); 95 } 96 } 97 98 // Function types are converted to LLVM Function types by recursively converting 99 // argument and result types. If MLIR Function has zero results, the LLVM 100 // Function has one VoidType result. If MLIR Function has more than one result, 101 // they are into an LLVM StructType in their order of appearance. 102 Type LLVMTypeConverter::convertFunctionType(FunctionType type) { 103 // Convert argument types one by one and check for errors. 104 SmallVector<LLVM::LLVMType, 8> argTypes; 105 for (auto t : type.getInputs()) { 106 auto converted = convertType(t); 107 if (!converted) 108 return {}; 109 argTypes.push_back(unwrap(converted)); 110 } 111 112 // If function does not return anything, create the void result type, 113 // if it returns on element, convert it, otherwise pack the result types into 114 // a struct. 115 LLVM::LLVMType resultType = 116 type.getNumResults() == 0 117 ? LLVM::LLVMType::getVoidTy(llvmDialect) 118 : unwrap(packFunctionResults(type.getResults())); 119 if (!resultType) 120 return {}; 121 return LLVM::LLVMType::getFunctionTy(resultType, argTypes, /*isVarArg=*/false) 122 .getPointerTo(); 123 } 124 125 // Convert a MemRef to an LLVM type. If the memref is statically-shaped, then 126 // we return a pointer to the converted element type. Otherwise we return an 127 // LLVM stucture type, where the first element of the structure type is a 128 // pointer to the elemental type of the MemRef and the following N elements are 129 // values of the Index type, one for each of N dynamic dimensions of the MemRef. 130 Type LLVMTypeConverter::convertMemRefType(MemRefType type) { 131 LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); 132 if (!elementType) 133 return {}; 134 auto ptrType = elementType.getPointerTo(); 135 136 // Extra value for the memory space. 137 unsigned numDynamicSizes = type.getNumDynamicDims(); 138 // If memref is statically-shaped we return the underlying pointer type. 139 if (numDynamicSizes == 0) 140 return ptrType; 141 142 SmallVector<LLVM::LLVMType, 8> types(numDynamicSizes + 1, getIndexType()); 143 types.front() = ptrType; 144 145 return LLVM::LLVMType::getStructTy(llvmDialect, types); 146 } 147 148 // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when 149 // n > 1. 150 // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and 151 // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. 152 Type LLVMTypeConverter::convertVectorType(VectorType type) { 153 auto elementType = unwrap(convertType(type.getElementType())); 154 if (!elementType) 155 return {}; 156 auto vectorType = 157 LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); 158 auto shape = type.getShape(); 159 for (int i = shape.size() - 2; i >= 0; --i) 160 vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); 161 return vectorType; 162 } 163 164 // Dispatch based on the actual type. Return null type on error. 165 Type LLVMTypeConverter::convertStandardType(Type type) { 166 if (auto funcType = type.dyn_cast<FunctionType>()) 167 return convertFunctionType(funcType); 168 if (auto intType = type.dyn_cast<IntegerType>()) 169 return convertIntegerType(intType); 170 if (auto floatType = type.dyn_cast<FloatType>()) 171 return convertFloatType(floatType); 172 if (auto indexType = type.dyn_cast<IndexType>()) 173 return convertIndexType(indexType); 174 if (auto memRefType = type.dyn_cast<MemRefType>()) 175 return convertMemRefType(memRefType); 176 if (auto vectorType = type.dyn_cast<VectorType>()) 177 return convertVectorType(vectorType); 178 if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) 179 return llvmType; 180 181 return {}; 182 } 183 184 // Convert the element type of the memref `t` to to an LLVM type using 185 // `lowering`, get a pointer LLVM type pointing to the converted `t`, wrap it 186 // into the MLIR LLVM dialect type and return. 187 static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) { 188 auto elementType = t.getElementType(); 189 auto converted = lowering.convertType(elementType); 190 if (!converted) 191 return {}; 192 return converted.cast<LLVM::LLVMType>().getPointerTo(); 193 } 194 195 LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, 196 LLVMTypeConverter &lowering_, 197 PatternBenefit benefit) 198 : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} 199 200 namespace { 201 // Base class for Standard to LLVM IR op conversions. Matches the Op type 202 // provided as template argument. Carries a reference to the LLVM dialect in 203 // case it is necessary for rewriters. 204 template <typename SourceOp> 205 class LLVMLegalizationPattern : public LLVMOpLowering { 206 public: 207 // Construct a conversion pattern. 208 explicit LLVMLegalizationPattern(LLVM::LLVMDialect &dialect_, 209 LLVMTypeConverter &lowering_) 210 : LLVMOpLowering(SourceOp::getOperationName(), dialect_.getContext(), 211 lowering_), 212 dialect(dialect_) {} 213 214 // Get the LLVM IR dialect. 215 LLVM::LLVMDialect &getDialect() const { return dialect; } 216 // Get the LLVM context. 217 llvm::LLVMContext &getContext() const { return dialect.getLLVMContext(); } 218 // Get the LLVM module in which the types are constructed. 219 llvm::Module &getModule() const { return dialect.getLLVMModule(); } 220 221 // Get the MLIR type wrapping the LLVM integer type whose bit width is defined 222 // by the pointer size used in the LLVM module. 223 LLVM::LLVMType getIndexType() const { 224 return LLVM::LLVMType::getIntNTy( 225 &dialect, getModule().getDataLayout().getPointerSizeInBits()); 226 } 227 228 // Get the MLIR type wrapping the LLVM i8* type. 229 LLVM::LLVMType getVoidPtrType() const { 230 return LLVM::LLVMType::getInt8PtrTy(&dialect); 231 } 232 233 // Create an LLVM IR pseudo-operation defining the given index constant. 234 Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, 235 uint64_t value) const { 236 auto attr = builder.getIntegerAttr(builder.getIndexType(), value); 237 return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr); 238 } 239 240 // Get the array attribute named "position" containing the given list of 241 // integers as integer attribute elements. 242 static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder, 243 ArrayRef<int64_t> values) { 244 SmallVector<Attribute, 4> attrs; 245 attrs.reserve(values.size()); 246 for (int64_t pos : values) 247 attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos)); 248 return builder.getArrayAttr(attrs); 249 } 250 251 // Extract raw data pointer value from a value representing a memref. 252 static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder, 253 Location loc, 254 Value *convertedMemRefValue, 255 Type elementTypePtr, 256 bool hasStaticShape) { 257 Value *buffer; 258 if (hasStaticShape) 259 return convertedMemRefValue; 260 else 261 return builder.create<LLVM::ExtractValueOp>( 262 loc, elementTypePtr, convertedMemRefValue, 263 getIntegerArrayAttr(builder, 0)); 264 return buffer; 265 } 266 267 protected: 268 LLVM::LLVMDialect &dialect; 269 }; 270 271 struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { 272 using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern; 273 274 PatternMatchResult 275 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 276 ConversionPatternRewriter &rewriter) const override { 277 auto funcOp = cast<FuncOp>(op); 278 FunctionType type = funcOp.getType(); 279 280 // Convert the original function arguments. 281 TypeConverter::SignatureConversion result(type.getNumInputs()); 282 for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) 283 if (failed(lowering.convertSignatureArg(i, type.getInput(i), result))) 284 return matchFailure(); 285 286 // Pack the result types into a struct. 287 Type packedResult; 288 if (type.getNumResults() != 0) { 289 if (!(packedResult = lowering.packFunctionResults(type.getResults()))) 290 return matchFailure(); 291 } 292 293 // Create a new function with an updated signature. 294 auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); 295 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 296 newFuncOp.end()); 297 newFuncOp.setType(FunctionType::get( 298 result.getConvertedTypes(), 299 packedResult ? ArrayRef<Type>(packedResult) : llvm::None, 300 funcOp.getContext())); 301 302 // Tell the rewriter to convert the region signature. 303 rewriter.applySignatureConversion(&newFuncOp.getBody(), result); 304 rewriter.replaceOp(op, llvm::None); 305 return matchSuccess(); 306 } 307 }; 308 309 // Basic lowering implementation for one-to-one rewriting from Standard Ops to 310 // LLVM Dialect Ops. 311 template <typename SourceOp, typename TargetOp> 312 struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { 313 using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; 314 using Super = OneToOneLLVMOpLowering<SourceOp, TargetOp>; 315 316 // Convert the type of the result to an LLVM type, pass operands as is, 317 // preserve attributes. 318 PatternMatchResult 319 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 320 ConversionPatternRewriter &rewriter) const override { 321 unsigned numResults = op->getNumResults(); 322 323 Type packedType; 324 if (numResults != 0) { 325 packedType = this->lowering.packFunctionResults( 326 llvm::to_vector<4>(op->getResultTypes())); 327 assert(packedType && "type conversion failed, such operation should not " 328 "have been matched"); 329 } 330 331 auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands, 332 op->getAttrs()); 333 334 // If the operation produced 0 or 1 result, return them immediately. 335 if (numResults == 0) 336 return rewriter.replaceOp(op, llvm::None), this->matchSuccess(); 337 if (numResults == 1) 338 return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), 339 this->matchSuccess(); 340 341 // Otherwise, it had been converted to an operation producing a structure. 342 // Extract individual results from the structure and return them as list. 343 SmallVector<Value *, 4> results; 344 results.reserve(numResults); 345 for (unsigned i = 0; i < numResults; ++i) { 346 auto type = this->lowering.convertType(op->getResult(i)->getType()); 347 results.push_back(rewriter.create<LLVM::ExtractValueOp>( 348 op->getLoc(), type, newOp.getOperation()->getResult(0), 349 rewriter.getIndexArrayAttr(i))); 350 } 351 rewriter.replaceOp(op, results); 352 return this->matchSuccess(); 353 } 354 }; 355 356 // Express `linearIndex` in terms of coordinates of `basis`. 357 // Returns the empty vector when linearIndex is out of the range [0, P] where 358 // P is the product of all the basis coordinates. 359 // 360 // Prerequisites: 361 // Basis is an array of nonnegative integers (signed type inherited from 362 // vector shape type). 363 static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis, 364 unsigned linearIndex) { 365 SmallVector<int64_t, 4> res; 366 res.reserve(basis.size()); 367 for (unsigned basisElement : llvm::reverse(basis)) { 368 res.push_back(linearIndex % basisElement); 369 linearIndex = linearIndex / basisElement; 370 } 371 if (linearIndex > 0) 372 return {}; 373 std::reverse(res.begin(), res.end()); 374 return res; 375 } 376 377 // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect 378 // Ops for binary ops with one result. This supports higher-dimensional vector 379 // types. 380 template <typename SourceOp, typename TargetOp> 381 struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { 382 using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; 383 using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>; 384 385 // Convert the type of the result to an LLVM type, pass operands as is, 386 // preserve attributes. 387 PatternMatchResult 388 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 389 ConversionPatternRewriter &rewriter) const override { 390 static_assert( 391 std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value, 392 "expected binary op"); 393 static_assert( 394 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 395 "expected single result op"); 396 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, 397 SourceOp>::value, 398 "expected single result op"); 399 400 auto loc = op->getLoc(); 401 auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>(); 402 403 if (!llvmArrayTy.isArrayTy()) { 404 auto newOp = rewriter.create<TargetOp>( 405 op->getLoc(), operands[0]->getType(), operands, op->getAttrs()); 406 rewriter.replaceOp(op, newOp.getResult()); 407 return this->matchSuccess(); 408 } 409 410 // Unroll iterated array type until we hit a non-array type. 411 auto llvmTy = llvmArrayTy; 412 SmallVector<int64_t, 4> arraySizes; 413 while (llvmTy.isArrayTy()) { 414 arraySizes.push_back(llvmTy.getArrayNumElements()); 415 llvmTy = llvmTy.getArrayElementType(); 416 } 417 assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type"); 418 auto llvmVectorTy = llvmTy; 419 420 // Iteratively extract a position coordinates with basis `arraySize` from a 421 // `linearIndex` that is incremented at each step. This terminates when 422 // `linearIndex` exceeds the range specified by `arraySize`. 423 // This has the effect of fully unrolling the dimensions of the n-D array 424 // type, getting to the underlying vector element. 425 Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); 426 unsigned ub = 1; 427 for (auto s : arraySizes) 428 ub *= s; 429 for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { 430 auto coords = getCoordinates(arraySizes, linearIndex); 431 // Linear index is out of bounds, we are done. 432 if (coords.empty()) 433 break; 434 435 auto position = rewriter.getIndexArrayAttr(coords); 436 437 // For this unrolled `position` corresponding to the `linearIndex`^th 438 // element, extract operand vectors 439 Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>( 440 loc, llvmVectorTy, operands[0], position); 441 Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>( 442 loc, llvmVectorTy, operands[1], position); 443 Value *newVal = rewriter.create<TargetOp>( 444 loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS}, 445 op->getAttrs()); 446 desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, 447 newVal, position); 448 } 449 rewriter.replaceOp(op, desc); 450 return this->matchSuccess(); 451 } 452 }; 453 454 // Specific lowerings. 455 // FIXME: this should be tablegen'ed. 456 struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> { 457 using Super::Super; 458 }; 459 struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> { 460 using Super::Super; 461 }; 462 struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> { 463 using Super::Super; 464 }; 465 struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> { 466 using Super::Super; 467 }; 468 struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> { 469 using Super::Super; 470 }; 471 struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> { 472 using Super::Super; 473 }; 474 struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> { 475 using Super::Super; 476 }; 477 struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> { 478 using Super::Super; 479 }; 480 struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> { 481 using Super::Super; 482 }; 483 struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> { 484 using Super::Super; 485 }; 486 struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> { 487 using Super::Super; 488 }; 489 struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> { 490 using Super::Super; 491 }; 492 struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> { 493 using Super::Super; 494 }; 495 struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> { 496 using Super::Super; 497 }; 498 struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> { 499 using Super::Super; 500 }; 501 struct SelectOpLowering 502 : public OneToOneLLVMOpLowering<SelectOp, LLVM::SelectOp> { 503 using Super::Super; 504 }; 505 struct CallOpLowering : public OneToOneLLVMOpLowering<CallOp, LLVM::CallOp> { 506 using Super::Super; 507 }; 508 struct CallIndirectOpLowering 509 : public OneToOneLLVMOpLowering<CallIndirectOp, LLVM::CallOp> { 510 using Super::Super; 511 }; 512 struct ConstLLVMOpLowering 513 : public OneToOneLLVMOpLowering<ConstantOp, LLVM::ConstantOp> { 514 using Super::Super; 515 }; 516 517 // Check if the MemRefType `type` is supported by the lowering. We currently do 518 // not support memrefs with affine maps and non-default memory spaces. 519 static bool isSupportedMemRefType(MemRefType type) { 520 if (!type.getAffineMaps().empty()) 521 return false; 522 if (type.getMemorySpace() != 0) 523 return false; 524 return true; 525 } 526 527 // An `alloc` is converted into a definition of a memref descriptor value and 528 // a call to `malloc` to allocate the underlying data buffer. The memref 529 // descriptor is of the LLVM structure type where the first element is a pointer 530 // to the (typed) data buffer, and the remaining elements serve to store 531 // dynamic sizes of the memref using LLVM-converted `index` type. 532 struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { 533 using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern; 534 535 PatternMatchResult match(Operation *op) const override { 536 MemRefType type = cast<AllocOp>(op).getType(); 537 return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); 538 } 539 540 void rewrite(Operation *op, ArrayRef<Value *> operands, 541 ConversionPatternRewriter &rewriter) const override { 542 auto allocOp = cast<AllocOp>(op); 543 MemRefType type = allocOp.getType(); 544 545 // Get actual sizes of the memref as values: static sizes are constant 546 // values and dynamic sizes are passed to 'alloc' as operands. In case of 547 // zero-dimensional memref, assume a scalar (size 1). 548 SmallVector<Value *, 4> sizes; 549 auto numOperands = allocOp.getNumOperands(); 550 sizes.reserve(numOperands); 551 unsigned i = 0; 552 for (int64_t s : type.getShape()) 553 sizes.push_back(s == -1 ? operands[i++] 554 : createIndexConstant(rewriter, op->getLoc(), s)); 555 if (sizes.empty()) 556 sizes.push_back(createIndexConstant(rewriter, op->getLoc(), 1)); 557 558 // Compute the total number of memref elements. 559 Value *cumulativeSize = sizes.front(); 560 for (unsigned i = 1, e = sizes.size(); i < e; ++i) 561 cumulativeSize = rewriter.create<LLVM::MulOp>( 562 op->getLoc(), getIndexType(), 563 ArrayRef<Value *>{cumulativeSize, sizes[i]}); 564 565 // Compute the total amount of bytes to allocate. 566 auto elementType = type.getElementType(); 567 assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) && 568 "invalid memref element type"); 569 uint64_t elementSize = 0; 570 if (auto vectorType = elementType.dyn_cast<VectorType>()) 571 elementSize = vectorType.getNumElements() * 572 llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); 573 else 574 elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); 575 cumulativeSize = rewriter.create<LLVM::MulOp>( 576 op->getLoc(), getIndexType(), 577 ArrayRef<Value *>{ 578 cumulativeSize, 579 createIndexConstant(rewriter, op->getLoc(), elementSize)}); 580 581 // Insert the `malloc` declaration if it is not already present. 582 auto module = op->getParentOfType<ModuleOp>(); 583 FuncOp mallocFunc = module.lookupSymbol<FuncOp>("malloc"); 584 if (!mallocFunc) { 585 auto mallocType = 586 rewriter.getFunctionType(getIndexType(), getVoidPtrType()); 587 mallocFunc = 588 FuncOp::create(rewriter.getUnknownLoc(), "malloc", mallocType); 589 module.push_back(mallocFunc); 590 } 591 592 // Allocate the underlying buffer and store a pointer to it in the MemRef 593 // descriptor. 594 Value *allocated = 595 rewriter 596 .create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(), 597 rewriter.getSymbolRefAttr(mallocFunc), 598 cumulativeSize) 599 .getResult(0); 600 auto structElementType = lowering.convertType(elementType); 601 auto elementPtrType = 602 structElementType.cast<LLVM::LLVMType>().getPointerTo(); 603 allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType, 604 ArrayRef<Value *>(allocated)); 605 606 // Deal with static memrefs 607 if (numOperands == 0) 608 return rewriter.replaceOp(op, allocated); 609 610 // Create the MemRef descriptor. 611 auto structType = lowering.convertType(type); 612 Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>( 613 op->getLoc(), structType, ArrayRef<Value *>{}); 614 615 memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( 616 op->getLoc(), structType, memRefDescriptor, allocated, 617 rewriter.getIndexArrayAttr(0)); 618 619 // Store dynamically allocated sizes in the descriptor. Dynamic sizes are 620 // passed in as operands. 621 for (auto indexedSize : llvm::enumerate(operands)) { 622 memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( 623 op->getLoc(), structType, memRefDescriptor, indexedSize.value(), 624 rewriter.getIndexArrayAttr(1 + indexedSize.index())); 625 } 626 627 // Return the final value of the descriptor. 628 rewriter.replaceOp(op, memRefDescriptor); 629 } 630 }; 631 632 // A `dealloc` is converted into a call to `free` on the underlying data buffer. 633 // The memref descriptor being an SSA value, there is no need to clean it up 634 // in any way. 635 struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { 636 using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern; 637 638 PatternMatchResult 639 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 640 ConversionPatternRewriter &rewriter) const override { 641 assert(operands.size() == 1 && "dealloc takes one operand"); 642 OperandAdaptor<DeallocOp> transformed(operands); 643 644 // Insert the `free` declaration if it is not already present. 645 FuncOp freeFunc = 646 op->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>("free"); 647 if (!freeFunc) { 648 auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); 649 freeFunc = FuncOp::create(rewriter.getUnknownLoc(), "free", freeType); 650 op->getParentOfType<ModuleOp>().push_back(freeFunc); 651 } 652 653 auto type = transformed.memref()->getType().cast<LLVM::LLVMType>(); 654 auto hasStaticShape = type.isPointerTy(); 655 Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0); 656 Value *bufferPtr = 657 extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(), 658 elementPtrType, hasStaticShape); 659 Value *casted = rewriter.create<LLVM::BitcastOp>( 660 op->getLoc(), getVoidPtrType(), bufferPtr); 661 rewriter.replaceOpWithNewOp<LLVM::CallOp>( 662 op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted); 663 return matchSuccess(); 664 } 665 }; 666 667 struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { 668 using LLVMLegalizationPattern<MemRefCastOp>::LLVMLegalizationPattern; 669 670 PatternMatchResult match(Operation *op) const override { 671 auto memRefCastOp = cast<MemRefCastOp>(op); 672 MemRefType sourceType = 673 memRefCastOp.getOperand()->getType().cast<MemRefType>(); 674 MemRefType targetType = memRefCastOp.getType(); 675 return (isSupportedMemRefType(targetType) && 676 isSupportedMemRefType(sourceType)) 677 ? matchSuccess() 678 : matchFailure(); 679 } 680 681 void rewrite(Operation *op, ArrayRef<Value *> operands, 682 ConversionPatternRewriter &rewriter) const override { 683 auto memRefCastOp = cast<MemRefCastOp>(op); 684 OperandAdaptor<MemRefCastOp> transformed(operands); 685 auto targetType = memRefCastOp.getType(); 686 auto sourceType = memRefCastOp.getOperand()->getType().cast<MemRefType>(); 687 688 // Copy the data buffer pointer. 689 auto elementTypePtr = getMemRefElementPtrType(targetType, lowering); 690 Value *buffer = 691 extractMemRefElementPtr(rewriter, op->getLoc(), transformed.source(), 692 elementTypePtr, sourceType.hasStaticShape()); 693 // Account for static memrefs as target types 694 if (targetType.hasStaticShape()) 695 return rewriter.replaceOp(op, buffer); 696 697 // Create the new MemRef descriptor. 698 auto structType = lowering.convertType(targetType); 699 Value *newDescriptor = rewriter.create<LLVM::UndefOp>( 700 op->getLoc(), structType, ArrayRef<Value *>{}); 701 // Otherwise target type is dynamic memref, so create a proper descriptor. 702 newDescriptor = rewriter.create<LLVM::InsertValueOp>( 703 op->getLoc(), structType, newDescriptor, buffer, 704 rewriter.getIndexArrayAttr(0)); 705 706 // Fill in the dynamic sizes of the new descriptor. If the size was 707 // dynamic, copy it from the old descriptor. If the size was static, insert 708 // the constant. Note that the positions of dynamic sizes in the 709 // descriptors start from 1 (the buffer pointer is at position zero). 710 int64_t sourceDynamicDimIdx = 1; 711 int64_t targetDynamicDimIdx = 1; 712 for (int i = 0, e = sourceType.getRank(); i < e; ++i) { 713 // Ignore new static sizes (they will be known from the type). If the 714 // size was dynamic, update the index of dynamic types. 715 if (targetType.getShape()[i] != -1) { 716 if (sourceType.getShape()[i] == -1) 717 ++sourceDynamicDimIdx; 718 continue; 719 } 720 721 auto sourceSize = sourceType.getShape()[i]; 722 Value *size = 723 sourceSize == -1 724 ? rewriter.create<LLVM::ExtractValueOp>( 725 op->getLoc(), getIndexType(), 726 transformed.source(), // NB: dynamic memref 727 rewriter.getIndexArrayAttr(sourceDynamicDimIdx++)) 728 : createIndexConstant(rewriter, op->getLoc(), sourceSize); 729 newDescriptor = rewriter.create<LLVM::InsertValueOp>( 730 op->getLoc(), structType, newDescriptor, size, 731 rewriter.getIndexArrayAttr(targetDynamicDimIdx++)); 732 } 733 assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() && 734 "source dynamic dimensions were not processed"); 735 assert(targetDynamicDimIdx - 1 == targetType.getNumDynamicDims() && 736 "target dynamic dimensions were not set up"); 737 738 rewriter.replaceOp(op, newDescriptor); 739 } 740 }; 741 742 // A `dim` is converted to a constant for static sizes and to an access to the 743 // size stored in the memref descriptor for dynamic sizes. 744 struct DimOpLowering : public LLVMLegalizationPattern<DimOp> { 745 using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern; 746 747 PatternMatchResult match(Operation *op) const override { 748 auto dimOp = cast<DimOp>(op); 749 MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>(); 750 return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); 751 } 752 753 void rewrite(Operation *op, ArrayRef<Value *> operands, 754 ConversionPatternRewriter &rewriter) const override { 755 auto dimOp = cast<DimOp>(op); 756 OperandAdaptor<DimOp> transformed(operands); 757 MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>(); 758 759 auto shape = type.getShape(); 760 uint64_t index = dimOp.getIndex(); 761 // Extract dynamic size from the memref descriptor and define static size 762 // as a constant. 763 if (shape[index] == -1) { 764 // Find the position of the dynamic dimension in the list of dynamic sizes 765 // by counting the number of preceding dynamic dimensions. Start from 1 766 // because the buffer pointer is at position zero. 767 int64_t position = 1; 768 for (uint64_t i = 0; i < index; ++i) { 769 if (shape[i] == -1) 770 ++position; 771 } 772 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( 773 op, getIndexType(), transformed.memrefOrTensor(), 774 rewriter.getIndexArrayAttr(position)); 775 } else { 776 rewriter.replaceOp( 777 op, createIndexConstant(rewriter, op->getLoc(), shape[index])); 778 } 779 } 780 }; 781 782 // Common base for load and store operations on MemRefs. Restricts the match 783 // to supported MemRef types. Provides functionality to emit code accessing a 784 // specific element of the underlying data buffer. 785 template <typename Derived> 786 struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { 787 using LLVMLegalizationPattern<Derived>::LLVMLegalizationPattern; 788 using Base = LoadStoreOpLowering<Derived>; 789 790 PatternMatchResult match(Operation *op) const override { 791 MemRefType type = cast<Derived>(op).getMemRefType(); 792 return isSupportedMemRefType(type) ? this->matchSuccess() 793 : this->matchFailure(); 794 } 795 796 // Given subscript indices and array sizes in row-major order, 797 // i_n, i_{n-1}, ..., i_1 798 // s_n, s_{n-1}, ..., s_1 799 // obtain a value that corresponds to the linearized subscript 800 // \sum_k i_k * \prod_{j=1}^{k-1} s_j 801 // by accumulating the running linearized value. 802 // Note that `indices` and `allocSizes` are passed in the same order as they 803 // appear in load/store operations and memref type declarations. 804 Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, 805 ArrayRef<Value *> indices, 806 ArrayRef<Value *> allocSizes) const { 807 assert(indices.size() == allocSizes.size() && 808 "mismatching number of indices and allocation sizes"); 809 assert(!indices.empty() && "cannot linearize a 0-dimensional access"); 810 811 Value *linearized = indices.front(); 812 for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { 813 linearized = builder.create<LLVM::MulOp>( 814 loc, this->getIndexType(), 815 ArrayRef<Value *>{linearized, allocSizes[i]}); 816 linearized = builder.create<LLVM::AddOp>( 817 loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]}); 818 } 819 return linearized; 820 } 821 822 // Given the MemRef type, a descriptor and a list of indices, extract the data 823 // buffer pointer from the descriptor, convert multi-dimensional subscripts 824 // into a linearized index (using dynamic size data from the descriptor if 825 // necessary) and get the pointer to the buffer element identified by the 826 // indices. 827 Value *getElementPtr(Location loc, Type elementTypePtr, 828 ArrayRef<int64_t> shape, Value *memRefDescriptor, 829 ArrayRef<Value *> indices, 830 ConversionPatternRewriter &rewriter) const { 831 // Get the list of MemRef sizes. Static sizes are defined as constants. 832 // Dynamic sizes are extracted from the MemRef descriptor, where they start 833 // from the position 1 (the buffer is at position 0). 834 SmallVector<Value *, 4> sizes; 835 unsigned dynamicSizeIdx = 1; 836 for (int64_t s : shape) { 837 if (s == -1) { 838 Value *size = rewriter.create<LLVM::ExtractValueOp>( 839 loc, this->getIndexType(), memRefDescriptor, 840 rewriter.getIndexArrayAttr(dynamicSizeIdx++)); 841 sizes.push_back(size); 842 } else { 843 sizes.push_back(this->createIndexConstant(rewriter, loc, s)); 844 } 845 } 846 847 // The second and subsequent operands are access subscripts. Obtain the 848 // linearized address in the buffer. 849 Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes); 850 851 Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>( 852 loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0)); 853 return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, 854 ArrayRef<Value *>{dataPtr, subscript}, 855 ArrayRef<NamedAttribute>{}); 856 } 857 // This is a getElementPtr variant, where the value is a direct raw pointer. 858 // If a shape is empty, we are dealing with a zero-dimensional memref. Return 859 // the pointer unmodified in this case. Otherwise, linearize subscripts to 860 // obtain the offset with respect to the base pointer. Use this offset to 861 // compute and return the element pointer. 862 Value *getRawElementPtr(Location loc, Type elementTypePtr, 863 ArrayRef<int64_t> shape, Value *rawDataPtr, 864 ArrayRef<Value *> indices, 865 ConversionPatternRewriter &rewriter) const { 866 if (shape.empty()) 867 return rawDataPtr; 868 869 SmallVector<Value *, 4> sizes; 870 for (int64_t s : shape) { 871 sizes.push_back(this->createIndexConstant(rewriter, loc, s)); 872 } 873 874 Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes); 875 return rewriter.create<LLVM::GEPOp>( 876 loc, elementTypePtr, ArrayRef<Value *>{rawDataPtr, subscript}, 877 ArrayRef<NamedAttribute>{}); 878 } 879 880 Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr, 881 ArrayRef<Value *> indices, 882 ConversionPatternRewriter &rewriter, 883 llvm::Module &module) const { 884 auto ptrType = getMemRefElementPtrType(type, this->lowering); 885 auto shape = type.getShape(); 886 if (type.hasStaticShape()) { 887 // NB: If memref was statically-shaped, dataPtr is pointer to raw data. 888 return getRawElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter); 889 } 890 return getElementPtr(loc, ptrType, shape, dataPtr, indices, rewriter); 891 } 892 }; 893 894 // Load operation is lowered to obtaining a pointer to the indexed element 895 // and loading it. 896 struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> { 897 using Base::Base; 898 899 PatternMatchResult 900 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 901 ConversionPatternRewriter &rewriter) const override { 902 auto loadOp = cast<LoadOp>(op); 903 OperandAdaptor<LoadOp> transformed(operands); 904 auto type = loadOp.getMemRefType(); 905 906 Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), 907 transformed.indices(), rewriter, getModule()); 908 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr); 909 return matchSuccess(); 910 } 911 }; 912 913 // Store opreation is lowered to obtaining a pointer to the indexed element, 914 // and storing the given value to it. 915 struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> { 916 using Base::Base; 917 918 PatternMatchResult 919 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 920 ConversionPatternRewriter &rewriter) const override { 921 auto type = cast<StoreOp>(op).getMemRefType(); 922 OperandAdaptor<StoreOp> transformed(operands); 923 924 Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), 925 transformed.indices(), rewriter, getModule()); 926 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(), 927 dataPtr); 928 return matchSuccess(); 929 } 930 }; 931 932 // The lowering of index_cast becomes an integer conversion since index becomes 933 // an integer. If the bit width of the source and target integer types is the 934 // same, just erase the cast. If the target type is wider, sign-extend the 935 // value, otherwise truncate it. 936 struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> { 937 using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern; 938 939 PatternMatchResult 940 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 941 ConversionPatternRewriter &rewriter) const override { 942 IndexCastOpOperandAdaptor transformed(operands); 943 auto indexCastOp = cast<IndexCastOp>(op); 944 945 auto targetType = 946 this->lowering.convertType(indexCastOp.getResult()->getType()) 947 .cast<LLVM::LLVMType>(); 948 auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>(); 949 unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); 950 unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); 951 952 if (targetBits == sourceBits) 953 rewriter.replaceOp(op, transformed.in()); 954 else if (targetBits < sourceBits) 955 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, 956 transformed.in()); 957 else 958 rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, 959 transformed.in()); 960 return matchSuccess(); 961 } 962 }; 963 964 // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two 965 // enums share the numerical values so just cast. 966 template <typename LLVMPredType, typename StdPredType> 967 static LLVMPredType convertCmpPredicate(StdPredType pred) { 968 return static_cast<LLVMPredType>(pred); 969 } 970 971 struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> { 972 using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern; 973 974 PatternMatchResult 975 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 976 ConversionPatternRewriter &rewriter) const override { 977 auto cmpiOp = cast<CmpIOp>(op); 978 CmpIOpOperandAdaptor transformed(operands); 979 980 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 981 op, lowering.convertType(cmpiOp.getResult()->getType()), 982 rewriter.getI64IntegerAttr(static_cast<int64_t>( 983 convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))), 984 transformed.lhs(), transformed.rhs()); 985 986 return matchSuccess(); 987 } 988 }; 989 990 struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> { 991 using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern; 992 993 PatternMatchResult 994 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 995 ConversionPatternRewriter &rewriter) const override { 996 auto cmpfOp = cast<CmpFOp>(op); 997 CmpFOpOperandAdaptor transformed(operands); 998 999 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( 1000 op, lowering.convertType(cmpfOp.getResult()->getType()), 1001 rewriter.getI64IntegerAttr(static_cast<int64_t>( 1002 convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))), 1003 transformed.lhs(), transformed.rhs()); 1004 1005 return matchSuccess(); 1006 } 1007 }; 1008 1009 struct SIToFPLowering 1010 : public OneToOneLLVMOpLowering<SIToFPOp, LLVM::SIToFPOp> { 1011 using Super::Super; 1012 }; 1013 1014 // Base class for LLVM IR lowering terminator operations with successors. 1015 template <typename SourceOp, typename TargetOp> 1016 struct OneToOneLLVMTerminatorLowering 1017 : public LLVMLegalizationPattern<SourceOp> { 1018 using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern; 1019 using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>; 1020 1021 PatternMatchResult 1022 matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands, 1023 ArrayRef<Block *> destinations, 1024 ArrayRef<ArrayRef<Value *>> operands, 1025 ConversionPatternRewriter &rewriter) const override { 1026 rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations, 1027 operands, op->getAttrs()); 1028 return this->matchSuccess(); 1029 } 1030 }; 1031 1032 // Special lowering pattern for `ReturnOps`. Unlike all other operations, 1033 // `ReturnOp` interacts with the function signature and must have as many 1034 // operands as the function has return values. Because in LLVM IR, functions 1035 // can only return 0 or 1 value, we pack multiple values into a structure type. 1036 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if 1037 // necessary before returning it 1038 struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { 1039 using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern; 1040 1041 PatternMatchResult 1042 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 1043 ConversionPatternRewriter &rewriter) const override { 1044 unsigned numArguments = op->getNumOperands(); 1045 1046 // If ReturnOp has 0 or 1 operand, create it and return immediately. 1047 if (numArguments == 0) { 1048 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 1049 op, llvm::ArrayRef<Value *>(), llvm::ArrayRef<Block *>(), 1050 llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs()); 1051 return matchSuccess(); 1052 } 1053 if (numArguments == 1) { 1054 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 1055 op, llvm::ArrayRef<Value *>(operands.front()), 1056 llvm::ArrayRef<Block *>(), llvm::ArrayRef<llvm::ArrayRef<Value *>>(), 1057 op->getAttrs()); 1058 return matchSuccess(); 1059 } 1060 1061 // Otherwise, we need to pack the arguments into an LLVM struct type before 1062 // returning. 1063 auto packedType = 1064 lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); 1065 1066 Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType); 1067 for (unsigned i = 0; i < numArguments; ++i) { 1068 packed = rewriter.create<LLVM::InsertValueOp>( 1069 op->getLoc(), packedType, packed, operands[i], 1070 rewriter.getIndexArrayAttr(i)); 1071 } 1072 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( 1073 op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(), 1074 llvm::ArrayRef<llvm::ArrayRef<Value *>>(), op->getAttrs()); 1075 return matchSuccess(); 1076 } 1077 }; 1078 1079 // FIXME: this should be tablegen'ed as well. 1080 struct BranchOpLowering 1081 : public OneToOneLLVMTerminatorLowering<BranchOp, LLVM::BrOp> { 1082 using Super::Super; 1083 }; 1084 struct CondBranchOpLowering 1085 : public OneToOneLLVMTerminatorLowering<CondBranchOp, LLVM::CondBrOp> { 1086 using Super::Super; 1087 }; 1088 1089 } // namespace 1090 1091 static void ensureDistinctSuccessors(Block &bb) { 1092 auto *terminator = bb.getTerminator(); 1093 1094 // Find repeated successors with arguments. 1095 llvm::SmallDenseMap<Block *, llvm::SmallVector<int, 4>> successorPositions; 1096 for (int i = 0, e = terminator->getNumSuccessors(); i < e; ++i) { 1097 Block *successor = terminator->getSuccessor(i); 1098 // Blocks with no arguments are safe even if they appear multiple times 1099 // because they don't need PHI nodes. 1100 if (successor->getNumArguments() == 0) 1101 continue; 1102 successorPositions[successor].push_back(i); 1103 } 1104 1105 // If a successor appears for the second or more time in the terminator, 1106 // create a new dummy block that unconditionally branches to the original 1107 // destination, and retarget the terminator to branch to this new block. 1108 // There is no need to pass arguments to the dummy block because it will be 1109 // dominated by the original block and can therefore use any values defined in 1110 // the original block. 1111 for (const auto &successor : successorPositions) { 1112 const auto &positions = successor.second; 1113 // Start from the second occurrence of a block in the successor list. 1114 for (auto position = std::next(positions.begin()), end = positions.end(); 1115 position != end; ++position) { 1116 auto *dummyBlock = new Block(); 1117 bb.getParent()->push_back(dummyBlock); 1118 auto builder = OpBuilder(dummyBlock); 1119 SmallVector<Value *, 8> operands( 1120 terminator->getSuccessorOperands(*position)); 1121 builder.create<BranchOp>(terminator->getLoc(), successor.first, operands); 1122 terminator->setSuccessor(dummyBlock, *position); 1123 for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e; 1124 ++i) 1125 terminator->eraseSuccessorOperand(*position, i); 1126 } 1127 } 1128 } 1129 1130 void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { 1131 for (auto f : m.getOps<FuncOp>()) { 1132 for (auto &bb : f.getBlocks()) { 1133 ::ensureDistinctSuccessors(bb); 1134 } 1135 } 1136 } 1137 1138 /// Collect a set of patterns to convert from the Standard dialect to LLVM. 1139 void mlir::populateStdToLLVMConversionPatterns( 1140 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 1141 // FIXME: this should be tablegen'ed 1142 patterns.insert< 1143 AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering, 1144 BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering, 1145 CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, 1146 DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering, 1147 DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, 1148 MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, 1149 RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, 1150 SelectOpLowering, SIToFPLowering, StoreOpLowering, SubFOpLowering, 1151 SubIOpLowering, XOrOpLowering>(*converter.getDialect(), converter); 1152 } 1153 1154 // Convert types using the stored LLVM IR module. 1155 Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); } 1156 1157 // Create an LLVM IR structure type if there is more than one result. 1158 Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) { 1159 assert(!types.empty() && "expected non-empty list of type"); 1160 1161 if (types.size() == 1) 1162 return convertType(types.front()); 1163 1164 SmallVector<LLVM::LLVMType, 8> resultTypes; 1165 resultTypes.reserve(types.size()); 1166 for (auto t : types) { 1167 auto converted = convertType(t).dyn_cast<LLVM::LLVMType>(); 1168 if (!converted) 1169 return {}; 1170 resultTypes.push_back(converted); 1171 } 1172 1173 return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); 1174 } 1175 1176 /// Create an instance of LLVMTypeConverter in the given context. 1177 static std::unique_ptr<LLVMTypeConverter> 1178 makeStandardToLLVMTypeConverter(MLIRContext *context) { 1179 return std::make_unique<LLVMTypeConverter>(context); 1180 } 1181 1182 namespace { 1183 /// A pass converting MLIR operations into the LLVM IR dialect. 1184 struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> { 1185 // By default, the patterns are those converting Standard operations to the 1186 // LLVMIR dialect. 1187 explicit LLVMLoweringPass( 1188 LLVMPatternListFiller patternListFiller = 1189 populateStdToLLVMConversionPatterns, 1190 LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) 1191 : patternListFiller(patternListFiller), 1192 typeConverterMaker(converterBuilder) {} 1193 1194 // Run the dialect converter on the module. 1195 void runOnModule() override { 1196 if (!typeConverterMaker || !patternListFiller) 1197 return signalPassFailure(); 1198 1199 ModuleOp m = getModule(); 1200 LLVM::ensureDistinctSuccessors(m); 1201 std::unique_ptr<LLVMTypeConverter> typeConverter = 1202 typeConverterMaker(&getContext()); 1203 if (!typeConverter) 1204 return signalPassFailure(); 1205 1206 OwningRewritePatternList patterns; 1207 populateLoopToStdConversionPatterns(patterns, m.getContext()); 1208 patternListFiller(*typeConverter, patterns); 1209 1210 ConversionTarget target(getContext()); 1211 target.addLegalDialect<LLVM::LLVMDialect>(); 1212 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { 1213 return typeConverter->isSignatureLegal(op.getType()); 1214 }); 1215 if (failed(applyPartialConversion(m, target, patterns, &*typeConverter))) 1216 signalPassFailure(); 1217 } 1218 1219 // Callback for creating a list of patterns. It is called every time in 1220 // runOnModule since applyPartialConversion consumes the list. 1221 LLVMPatternListFiller patternListFiller; 1222 1223 // Callback for creating an instance of type converter. The converter 1224 // constructor needs an MLIRContext, which is not available until runOnModule. 1225 LLVMTypeConverterMaker typeConverterMaker; 1226 }; 1227 } // end namespace 1228 1229 std::unique_ptr<ModulePassBase> mlir::createConvertToLLVMIRPass() { 1230 return std::make_unique<LLVMLoweringPass>(); 1231 } 1232 1233 std::unique_ptr<ModulePassBase> 1234 mlir::createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, 1235 LLVMTypeConverterMaker typeConverterMaker) { 1236 return std::make_unique<LLVMLoweringPass>(patternListFiller, 1237 typeConverterMaker); 1238 } 1239 1240 static PassRegistration<LLVMLoweringPass> 1241 pass("lower-to-llvm", "Convert all functions to the LLVM IR dialect");