github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp (about) 1 //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" 19 #include "mlir/Dialect/QuantOps/Passes.h" 20 #include "mlir/Dialect/QuantOps/QuantOps.h" 21 #include "mlir/Dialect/QuantOps/UniformSupport.h" 22 #include "mlir/IR/Attributes.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/Pass/Pass.h" 26 27 using namespace mlir; 28 using namespace mlir::quant; 29 30 namespace { 31 32 class ConvertSimulatedQuantPass 33 : public FunctionPass<ConvertSimulatedQuantPass> { 34 public: 35 void runOnFunction() override; 36 }; 37 38 } // end anonymous namespace 39 40 /// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair. 41 class ConstFakeQuantRewrite : public RewritePattern { 42 public: 43 bool *hadFailure; 44 45 ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure) 46 : RewritePattern(ConstFakeQuant::getOperationName(), 1, context), 47 hadFailure(hadFailure) {} 48 49 PatternMatchResult matchAndRewrite(Operation *op, 50 PatternRewriter &rewriter) const override { 51 // TODO: If this pattern comes up more frequently, consider adding core 52 // support for failable rewrites. 53 if (failableRewrite(op, rewriter)) { 54 *hadFailure = true; 55 return matchFailure(); 56 } 57 58 return matchSuccess(); 59 } 60 61 bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { 62 auto fqOp = cast<ConstFakeQuant>(op); 63 64 auto converter = 65 ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType()); 66 if (!converter) { 67 return (op->emitError("unsupported quantized type conversion"), true); 68 } 69 70 UniformQuantizedType uniformElementType = fakeQuantAttrsToType( 71 fqOp.getLoc(), fqOp.num_bits().getSExtValue(), 72 fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), 73 fqOp.narrow_range(), converter.expressedType, fqOp.is_signed()); 74 75 if (!uniformElementType) { 76 // Note that the fakeQuantAttrsToType will have emitted the error. 77 return true; 78 } 79 80 Type quantizedType = converter.convert(uniformElementType); 81 assert(quantizedType && 82 "Converter accepted a type that it did not convert"); 83 84 // TODO: Map to a qbarrier with an attribute like [Forced] to signal that 85 // this is a forced/hard-coded constraint. 86 auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType, 87 fqOp.inputs()); 88 rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType, 89 qbarrier.getResult()); 90 91 return false; 92 } 93 }; 94 95 void ConvertSimulatedQuantPass::runOnFunction() { 96 bool hadFailure = false; 97 OwningRewritePatternList patterns; 98 auto func = getFunction(); 99 auto *context = &getContext(); 100 patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure); 101 applyPatternsGreedily(func, patterns); 102 if (hadFailure) 103 signalPassFailure(); 104 } 105 106 std::unique_ptr<FunctionPassBase> 107 mlir::quant::createConvertSimulatedQuantPass() { 108 return std::make_unique<ConvertSimulatedQuantPass>(); 109 } 110 111 static PassRegistration<ConvertSimulatedQuantPass> 112 pass("quant-convert-simulated-quantization", 113 "Converts training-time simulated quantization ops to corresponding " 114 "quantize/dequantize casts.");