github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Support/UniformSolvers.cpp (about)

     1  //===- UniformSolvers.cpp - Uniform type solver algorithms ----------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Quantizer/Support/UniformSolvers.h"
    19  
    20  #include "llvm/Support/raw_ostream.h"
    21  
    22  #include <cmath>
    23  
    24  using namespace mlir;
    25  using namespace mlir::quantizer;
    26  
    27  bool UniformParamsFromMinMaxSolver::compute() {
    28    // Compute adjMin, adjMax, clamping to ensure that they straddle zero.
    29    if (boundingMin > 0 && boundingMax >= boundingMin) {
    30      // Lop-sided to the positive.
    31      adjMin = 0;
    32      adjMax = boundingMax;
    33    } else if (boundingMax < 0 && boundingMin <= boundingMax) {
    34      // Lop-sided to the negative.
    35      adjMin = boundingMin;
    36      adjMax = 0;
    37    } else if (boundingMin <= 0 && boundingMax >= 0) {
    38      adjMin = boundingMin;
    39      adjMax = boundingMax;
    40    } else {
    41      // Illegal bounds.
    42      return satisfied = false;
    43    }
    44  
    45    const double origMinAdj = adjMin;
    46    const double origMaxAdj = adjMax;
    47    const double numLevelsDouble = storageParams.numLevels;
    48  
    49    struct fns {
    50      static std::pair<double, double>
    51      computeMinMax(double boundingMin, double numLevels, double delta) {
    52        double adjMin = delta * std::floor(boundingMin / delta);
    53        return std::make_pair(adjMin, adjMin + numLevels * delta);
    54      }
    55      static double overshoot(double boundingMin, double boundingMax,
    56                              double numLevels, double delta) {
    57        auto adjMinMax = computeMinMax(boundingMin, numLevels, delta);
    58        double maxOvershoot = adjMinMax.second - boundingMax;
    59        double minOvershoot = boundingMin - adjMinMax.first;
    60        // If undershooting on the min or max end, return that because it is
    61        // to be unconditionally avoided. Otherwise return the end with the
    62        // greateast magnitude of overshoot.
    63        if (maxOvershoot < 0)
    64          return maxOvershoot;
    65        if (minOvershoot < 0)
    66          return minOvershoot;
    67        return std::max(maxOvershoot, minOvershoot);
    68      }
    69    };
    70  
    71    // Bisect to find a suitable delta, starting with bounds of deltaInit
    72    // and deltaMax.
    73    double deltaInit = (adjMax - adjMin) / numLevelsDouble;
    74    double deltaMax =
    75        ((numLevelsDouble * deltaInit) + 2 * deltaInit) / numLevelsDouble;
    76    double deltaMid;
    77    double prevDeltaMid = 0.0;
    78    for (stepCount = 0; stepCount < 60; ++stepCount) {
    79      deltaMid = (deltaInit + deltaMax) / 2.0;
    80      auto fInit =
    81          fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaInit);
    82      auto fMid =
    83          fns::overshoot(origMinAdj, origMaxAdj, numLevelsDouble, deltaMid);
    84      if (fMid == 0 || (fMid > 0 && std::fabs(deltaMid - prevDeltaMid) < 1e-15)) {
    85        // Solution found (or step size is infinitessimal and an overshoot).
    86        // Empirically, this seems to terminate around 30-50 steps or so.
    87        // This will find a zero point for exactly representable ranges and
    88        // will terminate on a small step size for inexact, biasing towards
    89        // overshooting.
    90        delta = deltaMid;
    91        break;
    92      }
    93      bool signMid = fMid > 0;
    94      bool signInit = fInit > 0;
    95      if (signMid == signInit) {
    96        deltaInit = deltaMid;
    97      } else {
    98        deltaMax = deltaMid;
    99      }
   100      prevDeltaMid = deltaMid;
   101    }
   102    delta = deltaMid;
   103  
   104    // Recalculate adjMin/adjMax based on new delta.
   105    auto adjMinMax = fns::computeMinMax(origMinAdj, numLevelsDouble, delta);
   106    adjMin = adjMinMax.first;
   107    adjMax = adjMinMax.second;
   108  
   109    satisfied = false;
   110    zp = 0;
   111  
   112    if (!std::isnan(delta) && !std::isnan(adjMin) && !std::isnan(adjMax)) {
   113      satisfied = true;
   114      // Finally, scale and zeroPoint. Since it casts to integer, only valid
   115      // if the inputs are valid.
   116      zp = std::round(storageParams.minValue - adjMin / delta);
   117    }
   118  
   119    return satisfied;
   120  }
   121  
   122  int64_t UniformParamsFromMinMaxSolver::quantize(double x) const {
   123    int64_t xq = std::round(x / delta + zp);
   124    return std::max<int64_t>(0, std::min<int64_t>(storageParams.numLevels, xq));
   125  }
   126  
   127  double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const {
   128    return (xq - zp) * delta;
   129  }
   130  
   131  namespace mlir {
   132  namespace quantizer {
   133  
   134  llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
   135                                const UniformStorageParams &p) {
   136    os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}";
   137    return os;
   138  }
   139  
   140  llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
   141                                const UniformParamsFromMinMaxSolver &s) {
   142    os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){";
   143    os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> ";
   144    if (!s.isSatisfied()) {
   145      os << "unsat}";
   146      return os;
   147    }
   148  
   149    os << "(" << s.getAdjMin() << ":" << s.getAdjMax() << ")";
   150    os << ", scale = " << s.getScale();
   151    os << ", zp = " << s.getZp();
   152    os << "}";
   153  
   154    return os;
   155  }
   156  
   157  } // end namespace quantizer
   158  } // end namespace mlir