github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp (about) 1 //===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Quantizer/Support/UniformSolvers.h" 19 20 #include "llvm/Support/raw_ostream.h" 21 22 #include <cmath> 23 24 using namespace mlir; 25 using namespace mlir::quantizer; 26 27 bool UniformParamsFromMinMaxSolver::compute() { 28 // Compute adjMin, adjMax, clamping to ensure that they straddle zero. 29 if (boundingMin > 0 && boundingMax >= boundingMin) { 30 // Lop-sided to the positive. 31 adjMin = 0; 32 adjMax = boundingMax; 33 } else if (boundingMax < 0 && boundingMin <= boundingMax) { 34 // Lop-sided to the negative. 35 adjMin = boundingMin; 36 adjMax = 0; 37 } else if (boundingMin <= 0 && boundingMax >= 0) { 38 adjMin = boundingMin; 39 adjMax = boundingMax; 40 } else { 41 // Illegal bounds. 42 return satisfied = false; 43 } 44 45 const double origMinAdj = adjMin; 46 const double origMaxAdj = adjMax; 47 const double numLevelsDouble = storageParams.numLevels; 48 49 struct fns { 50 static std::pair<double, double> 51 computeMinMax(double boundingMin, double numLevels, double delta) { 52 double adjMin = delta * std::floor(boundingMin / delta); 53 return std::make_pair(adjMin, adjMin + numLevels * delta); 54 } 55 static double overshoot(double boundingMin, double boundingMax, 56 double numLevels, double delta) { 57 auto adjMinMax = computeMinMax(boundingMin, numLevels, delta); 58 double maxOvershoot = adjMinMax.second - boundingMax; 59 double minOvershoot = boundingMin - adjMinMax.first; 60 // If undershooting on the min or max end, return that because it is 61 // to be unconditionally avoided. Otherwise return the end with the 62 // greateast magnitude of overshoot. 63 if (maxOvershoot < 0) 64 return maxOvershoot; 65 if (minOvershoot < 0) 66 return minOvershoot; 67 return std::max(maxOvershoot, minOvershoot); 68 } 69 }; 70 71 // Bisect to find a suitable delta, starting with bounds of deltaInit 72 // and deltaMax. 73 double deltaInit = (adjMax - adjMin) / numLevelsDouble; 74 double deltaMax = 75 ((numLevelsDouble * deltaInit) + 2 * deltaInit) / numLevelsDouble; 76 double deltaMid; 77 double prevDeltaMid = 0.0; 78 for (stepCount = 0; stepCount < 60; ++stepCount) { 79 deltaMid = (deltaInit + deltaMax) / 2.0; 80 auto fInit = 81 fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaInit); 82 auto fMid = 83 fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaMid); 84 if (fMid == 0 || (fMid > 0 && std::fabs(deltaMid - prevDeltaMid) < 1e-15)) { 85 // Solution found (or step size is infinitessimal and an overshoot). 86 // Empirically, this seems to terminate around 30-50 steps or so. 87 // This will find a zero point for exactly representable ranges and 88 // will terminate on a small step size for inexact, biasing towards 89 // overshooting. 90 delta = deltaMid; 91 break; 92 } 93 bool signMid = fMid > 0; 94 bool signInit = fInit > 0; 95 if (signMid == signInit) { 96 deltaInit = deltaMid; 97 } else { 98 deltaMax = deltaMid; 99 } 100 prevDeltaMid = deltaMid; 101 } 102 delta = deltaMid; 103 104 // Recalculate adjMin/adjMax based on new delta. 105 auto adjMinMax = fns::computeMinMax(origMinAdj, numLevelsDouble, delta); 106 adjMin = adjMinMax.first; 107 adjMax = adjMinMax.second; 108 109 satisfied = false; 110 zp = 0; 111 112 if (!std::isnan(delta) && !std::isnan(adjMin) && !std::isnan(adjMax)) { 113 satisfied = true; 114 // Finally, scale and zeroPoint. Since it casts to integer, only valid 115 // if the inputs are valid. 116 zp = std::round(storageParams.minValue - adjMin / delta); 117 } 118 119 return satisfied; 120 } 121 122 int64_t UniformParamsFromMinMaxSolver::quantize(double x) const { 123 int64_t xq = std::round(x / delta + zp); 124 return std::max<int64_t>(0, std::min<int64_t>(storageParams.numLevels, xq)); 125 } 126 127 double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const { 128 return (xq - zp) * delta; 129 } 130 131 namespace mlir { 132 namespace quantizer { 133 134 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, 135 const UniformStorageParams &p) { 136 os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}"; 137 return os; 138 } 139 140 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, 141 const UniformParamsFromMinMaxSolver &s) { 142 os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){"; 143 os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> "; 144 if (!s.isSatisfied()) { 145 os << "unsat}"; 146 return os; 147 } 148 149 os << "(" << s.getAdjMin() << ":" << s.getAdjMax() << ")"; 150 os << ", scale = " << s.getScale(); 151 os << ", zp = " << s.getZp(); 152 os << "}"; 153 154 return os; 155 } 156 157 } // end namespace quantizer 158 } // end namespace mlir