github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp (about) 1 //===- ConvertConst.cpp - Quantizes constant ops --------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/QuantOps/Passes.h" 19 #include "mlir/Dialect/QuantOps/QuantOps.h" 20 #include "mlir/Dialect/QuantOps/QuantizeUtils.h" 21 #include "mlir/Dialect/QuantOps/UniformSupport.h" 22 #include "mlir/Dialect/StandardOps/Ops.h" 23 #include "mlir/IR/Attributes.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/StandardTypes.h" 27 #include "mlir/Pass/Pass.h" 28 29 using namespace mlir; 30 using namespace mlir::quant; 31 32 namespace { 33 34 class ConvertConstPass : public FunctionPass<ConvertConstPass> { 35 public: 36 void runOnFunction() override; 37 }; 38 39 struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> { 40 using OpRewritePattern<QuantizeCastOp>::OpRewritePattern; 41 42 PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier, 43 PatternRewriter &rewriter) const override; 44 }; 45 46 } // end anonymous namespace 47 48 /// Matches a [constant] -> [qbarrier] where the qbarrier results type is 49 /// quantized and the operand type is quantizable. 50 51 PatternMatchResult 52 QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, 53 PatternRewriter &rewriter) const { 54 Attribute value; 55 56 // Is the operand a constant? 57 if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { 58 return matchFailure(); 59 } 60 61 // Does the qbarrier convert to a quantized type. This will not be true 62 // if a quantized type has not yet been chosen or if the cast to an equivalent 63 // storage type is not supported. 64 Type qbarrierResultType = qbarrier.getResult()->getType(); 65 QuantizedType quantizedElementType = 66 QuantizedType::getQuantizedElementType(qbarrierResultType); 67 if (!quantizedElementType) { 68 return matchFailure(); 69 } 70 if (!QuantizedType::castToStorageType(qbarrierResultType)) { 71 return matchFailure(); 72 } 73 74 // Is the operand type compatible with the expressed type of the quantized 75 // type? This will not be true if the qbarrier is superfluous (converts 76 // from and to a quantized type). 77 if (!quantizedElementType.isCompatibleExpressedType( 78 qbarrier.arg()->getType())) { 79 return matchFailure(); 80 } 81 82 // Is the constant value a type expressed in a way that we support? 83 if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() && 84 !value.isa<SparseElementsAttr>()) { 85 return matchFailure(); 86 } 87 88 Type newConstValueType; 89 auto newConstValue = 90 quantizeAttr(value, quantizedElementType, newConstValueType); 91 if (!newConstValue) { 92 return matchFailure(); 93 } 94 95 // When creating the new const op, use a fused location that combines the 96 // original const and the qbarrier that led to the quantization. 97 auto fusedLoc = FusedLoc::get( 98 {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()}, 99 rewriter.getContext()); 100 auto newConstOp = 101 rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue); 102 rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier, 103 qbarrier.getType(), newConstOp); 104 return matchSuccess(); 105 } 106 107 void ConvertConstPass::runOnFunction() { 108 OwningRewritePatternList patterns; 109 auto func = getFunction(); 110 auto *context = &getContext(); 111 patterns.insert<QuantizedConstRewrite>(context); 112 applyPatternsGreedily(func, patterns); 113 } 114 115 std::unique_ptr<FunctionPassBase> mlir::quant::createConvertConstPass() { 116 return std::make_unique<ConvertConstPass>(); 117 } 118 119 static PassRegistration<ConvertConstPass> 120 pass("quant-convert-const", 121 "Converts constants followed by qbarrier to actual quantized values");