github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp (about) 1 //===- LowerToLLVMDialect.cpp - conversion from Linalg to LLVM dialect ----===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Conversion/VectorToLLVM/VectorToLLVM.h" 19 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/Dialect/VectorOps/VectorOps.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/MLIRContext.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/Operation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 #include "mlir/IR/Types.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Pass/PassManager.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "mlir/Transforms/Passes.h" 35 36 #include "llvm/IR/DerivedTypes.h" 37 #include "llvm/IR/Module.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/Support/Allocator.h" 40 #include "llvm/Support/ErrorHandling.h" 41 42 using namespace mlir; 43 44 template <typename T> 45 static LLVM::LLVMType getPtrToElementType(T containerType, 46 LLVMTypeConverter &lowering) { 47 return lowering.convertType(containerType.getElementType()) 48 .template cast<LLVM::LLVMType>() 49 .getPointerTo(); 50 } 51 52 // Create an array attribute containing integer attributes with values provided 53 // in `position`. 54 static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) { 55 SmallVector<Attribute, 4> attrs; 56 attrs.reserve(position.size()); 57 for (auto p : position) 58 attrs.push_back(builder.getI64IntegerAttr(p)); 59 return builder.getArrayAttr(attrs); 60 } 61 62 class ExtractElementOpConversion : public LLVMOpLowering { 63 public: 64 explicit ExtractElementOpConversion(MLIRContext *context, 65 LLVMTypeConverter &typeConverter) 66 : LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context, 67 typeConverter) {} 68 69 PatternMatchResult 70 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 71 ConversionPatternRewriter &rewriter) const override { 72 auto loc = op->getLoc(); 73 auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); 74 auto extractOp = cast<vector::ExtractElementOp>(op); 75 auto vectorType = extractOp.vector()->getType().cast<VectorType>(); 76 auto resultType = extractOp.getResult()->getType(); 77 auto llvmResultType = lowering.convertType(resultType); 78 79 auto positionArrayAttr = extractOp.position(); 80 // One-shot extraction of vector from array (only requires extractvalue). 81 if (resultType.isa<VectorType>()) { 82 Value *extracted = rewriter.create<LLVM::ExtractValueOp>( 83 loc, llvmResultType, adaptor.vector(), positionArrayAttr); 84 rewriter.replaceOp(op, extracted); 85 return matchSuccess(); 86 } 87 88 // Potential extraction of 1-D vector from struct. 89 auto *context = op->getContext(); 90 Value *extracted = adaptor.vector(); 91 auto positionAttrs = positionArrayAttr.getValue(); 92 auto i32Type = rewriter.getIntegerType(32); 93 if (positionAttrs.size() > 1) { 94 auto nDVectorType = vectorType; 95 auto oneDVectorType = VectorType::get(nDVectorType.getShape().take_back(), 96 nDVectorType.getElementType()); 97 auto nMinusOnePositionAttrs = 98 ArrayAttr::get(positionAttrs.drop_back(), context); 99 extracted = rewriter.create<LLVM::ExtractValueOp>( 100 loc, lowering.convertType(oneDVectorType), extracted, 101 nMinusOnePositionAttrs); 102 } 103 104 // Remaining extraction of element from 1-D LLVM vector 105 auto position = positionAttrs.back().cast<IntegerAttr>(); 106 auto constant = rewriter.create<LLVM::ConstantOp>( 107 loc, lowering.convertType(i32Type), position); 108 extracted = 109 rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant); 110 rewriter.replaceOp(op, extracted); 111 112 return matchSuccess(); 113 } 114 }; 115 116 class OuterProductOpConversion : public LLVMOpLowering { 117 public: 118 explicit OuterProductOpConversion(MLIRContext *context, 119 LLVMTypeConverter &typeConverter) 120 : LLVMOpLowering(vector::OuterProductOp::getOperationName(), context, 121 typeConverter) {} 122 123 PatternMatchResult 124 matchAndRewrite(Operation *op, ArrayRef<Value *> operands, 125 ConversionPatternRewriter &rewriter) const override { 126 auto loc = op->getLoc(); 127 auto adaptor = vector::OuterProductOpOperandAdaptor(operands); 128 auto *ctx = op->getContext(); 129 auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>(); 130 auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>(); 131 auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements(); 132 auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements(); 133 auto llvmArrayOfVectType = lowering.convertType( 134 cast<vector::OuterProductOp>(op).getResult()->getType()); 135 Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType); 136 Value *a = adaptor.lhs(), *b = adaptor.rhs(); 137 Value *acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front(); 138 SmallVector<Value *, 8> lhs, accs; 139 lhs.reserve(rankLHS); 140 accs.reserve(rankLHS); 141 for (unsigned d = 0, e = rankLHS; d < e; ++d) { 142 // shufflevector explicitly requires i32. 143 auto attr = rewriter.getI32IntegerAttr(d); 144 SmallVector<Attribute, 4> bcastAttr(rankRHS, attr); 145 auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx); 146 Value *aD = nullptr, *accD = nullptr; 147 // 1. Broadcast the element a[d] into vector aD. 148 aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr); 149 // 2. If acc is present, extract 1-d vector acc[d] into accD. 150 if (acc) 151 accD = rewriter.create<LLVM::ExtractValueOp>(loc, vRHS, acc, 152 positionAttr(rewriter, d)); 153 // 3. Compute aD outer b (plus accD, if relevant). 154 Value *aOuterbD = 155 accD ? rewriter.create<LLVM::fmuladd>(loc, vRHS, aD, b, accD) 156 .getResult() 157 : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult(); 158 // 4. Insert as value `d` in the descriptor. 159 desc = rewriter.create<LLVM::InsertValueOp>( 160 loc, llvmArrayOfVectType, desc, aOuterbD, positionAttr(rewriter, d)); 161 } 162 rewriter.replaceOp(op, desc); 163 return matchSuccess(); 164 } 165 }; 166 167 /// Populate the given list with patterns that convert from Vector to LLVM. 168 void mlir::populateVectorToLLVMConversionPatterns( 169 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 170 patterns.insert<ExtractElementOpConversion, OuterProductOpConversion>( 171 converter.getDialect()->getContext(), converter); 172 } 173 174 namespace { 175 struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> { 176 void runOnModule(); 177 }; 178 } // namespace 179 180 void LowerVectorToLLVMPass::runOnModule() { 181 // Convert to the LLVM IR dialect using the converter defined above. 182 OwningRewritePatternList patterns; 183 LLVMTypeConverter converter(&getContext()); 184 populateVectorToLLVMConversionPatterns(converter, patterns); 185 populateStdToLLVMConversionPatterns(converter, patterns); 186 187 ConversionTarget target(getContext()); 188 target.addLegalDialect<LLVM::LLVMDialect>(); 189 target.addDynamicallyLegalOp<FuncOp>( 190 [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); 191 if (failed( 192 applyPartialConversion(getModule(), target, patterns, &converter))) { 193 signalPassFailure(); 194 } 195 } 196 197 ModulePassBase *mlir::createLowerVectorToLLVMPass() { 198 return new LowerVectorToLLVMPass(); 199 } 200 201 static PassRegistration<LowerVectorToLLVMPass> 202 pass("vector-lower-to-llvm-dialect", 203 "Lower the operations from the vector dialect into the LLVM dialect");