github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp (about) 1 //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" 19 #include "mlir/Dialect/QuantOps/QuantTypes.h" 20 21 using namespace mlir; 22 using namespace mlir::quant; 23 24 UniformQuantizedType 25 mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, 26 double rmax, bool narrowRange, 27 Type expressedType, bool isSigned) { 28 MLIRContext *ctx = expressedType.getContext(); 29 Type storageType; 30 unsigned flags; 31 int64_t qmin; 32 int64_t qmax; 33 34 // Hard-coded type mapping from TFLite. 35 if (numBits <= 8) { 36 storageType = IntegerType::get(8, ctx); 37 if (isSigned) { 38 flags = QuantizationFlags::Signed; 39 qmin = -128; 40 qmax = 127; 41 } else { 42 flags = 0; 43 qmin = 0; 44 qmax = 255; 45 } 46 } else if (numBits <= 16) { 47 storageType = IntegerType::get(16, ctx); 48 if (isSigned) { 49 flags = QuantizationFlags::Signed; 50 qmin = -32768; 51 qmax = 32767; 52 } else { 53 flags = 0; 54 qmin = 0; 55 qmax = 65535; 56 } 57 } else { 58 emitError(loc, "unsupported FakeQuant number of bits: ") << numBits; 59 return nullptr; 60 } 61 62 // Handle narrowRange. 63 if (narrowRange) { 64 qmin += 1; 65 } 66 67 // Range must straddle zero. 68 if (rmin > 0.0 || rmax < 0.0) { 69 return (emitError(loc, "FakeQuant range must straddle zero: [") 70 << rmin << "," << rmax << "]", 71 nullptr); 72 } 73 74 // Special case where min/max is close enough. The tensor contents are all 75 // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero 76 // points and dequantized to 0.0. 77 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { 78 return UniformQuantizedType::getChecked(flags, storageType, expressedType, 79 1.0, qmin, qmin, qmax, loc); 80 } 81 82 // Determine the scale. 83 const double qminDouble = qmin; 84 const double qmaxDouble = qmax; 85 const double scale = (rmax - rmin) / (qmaxDouble - qminDouble); 86 87 // Zero point computation. 88 // In float, solve the affine equation for any known pair 89 // (real value, corresponding quantized value), of which, two such pairs 90 // are known: (rmin, qmin), (rmax, qmax). 91 // The arithmetic error on the zero point computed from either pair will be 92 // roughly machine_epsilon * (sum of absolute values of terms). 93 // Use the variant that adds the smaller error. 94 const double zeroPointFromMin = qminDouble - rmin / scale; 95 const double zeroPointFromMinError = 96 std::abs(qminDouble) + std::abs(rmin / scale); 97 const double zeroPointFromMax = qmaxDouble - rmax / scale; 98 const double zeroPointFromMaxError = 99 std::abs(qmaxDouble) + std::abs(rmax / scale); 100 101 const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) 102 ? zeroPointFromMin 103 : zeroPointFromMax; 104 105 // Now nudge the zero point to be an integer. 106 int64_t nudgedZeroPoint = 0; 107 if (zeroPointDouble < qminDouble) { 108 nudgedZeroPoint = qmin; 109 } else if (zeroPointDouble > qmaxDouble) { 110 nudgedZeroPoint = qmax; 111 } else { 112 nudgedZeroPoint = round(zeroPointDouble); 113 } 114 115 // By construction, the nudged zero point should always be in range. 116 assert(nudgedZeroPoint >= qmin); 117 assert(nudgedZeroPoint <= qmax); 118 119 return UniformQuantizedType::getChecked(flags, storageType, expressedType, 120 scale, nudgedZeroPoint, qmin, qmax, 121 loc); 122 }