github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp (about)

     1  //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
    19  #include "mlir/Dialect/QuantOps/QuantTypes.h"
    20  
    21  using namespace mlir;
    22  using namespace mlir::quant;
    23  
    24  UniformQuantizedType
    25  mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
    26                                    double rmax, bool narrowRange,
    27                                    Type expressedType, bool isSigned) {
    28    MLIRContext *ctx = expressedType.getContext();
    29    Type storageType;
    30    unsigned flags;
    31    int64_t qmin;
    32    int64_t qmax;
    33  
    34    // Hard-coded type mapping from TFLite.
    35    if (numBits <= 8) {
    36      storageType = IntegerType::get(8, ctx);
    37      if (isSigned) {
    38        flags = QuantizationFlags::Signed;
    39        qmin = -128;
    40        qmax = 127;
    41      } else {
    42        flags = 0;
    43        qmin = 0;
    44        qmax = 255;
    45      }
    46    } else if (numBits <= 16) {
    47      storageType = IntegerType::get(16, ctx);
    48      if (isSigned) {
    49        flags = QuantizationFlags::Signed;
    50        qmin = -32768;
    51        qmax = 32767;
    52      } else {
    53        flags = 0;
    54        qmin = 0;
    55        qmax = 65535;
    56      }
    57    } else {
    58      emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
    59      return nullptr;
    60    }
    61  
    62    // Handle narrowRange.
    63    if (narrowRange) {
    64      qmin += 1;
    65    }
    66  
    67    // Range must straddle zero.
    68    if (rmin > 0.0 || rmax < 0.0) {
    69      return (emitError(loc, "FakeQuant range must straddle zero: [")
    70                  << rmin << "," << rmax << "]",
    71              nullptr);
    72    }
    73  
    74    // Special case where min/max is close enough. The tensor contents are all
    75    // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
    76    // points and dequantized to 0.0.
    77    if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
    78      return UniformQuantizedType::getChecked(flags, storageType, expressedType,
    79                                              1.0, qmin, qmin, qmax, loc);
    80    }
    81  
    82    // Determine the scale.
    83    const double qminDouble = qmin;
    84    const double qmaxDouble = qmax;
    85    const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
    86  
    87    // Zero point computation.
    88    // In float, solve the affine equation for any known pair
    89    // (real value, corresponding quantized value), of which, two such pairs
    90    // are known: (rmin, qmin), (rmax, qmax).
    91    // The arithmetic error on the zero point computed from either pair will be
    92    // roughly machine_epsilon * (sum of absolute values of terms).
    93    // Use the variant that adds the smaller error.
    94    const double zeroPointFromMin = qminDouble - rmin / scale;
    95    const double zeroPointFromMinError =
    96        std::abs(qminDouble) + std::abs(rmin / scale);
    97    const double zeroPointFromMax = qmaxDouble - rmax / scale;
    98    const double zeroPointFromMaxError =
    99        std::abs(qmaxDouble) + std::abs(rmax / scale);
   100  
   101    const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
   102                                       ? zeroPointFromMin
   103                                       : zeroPointFromMax;
   104  
   105    // Now nudge the zero point to be an integer.
   106    int64_t nudgedZeroPoint = 0;
   107    if (zeroPointDouble < qminDouble) {
   108      nudgedZeroPoint = qmin;
   109    } else if (zeroPointDouble > qmaxDouble) {
   110      nudgedZeroPoint = qmax;
   111    } else {
   112      nudgedZeroPoint = round(zeroPointDouble);
   113    }
   114  
   115    // By construction, the nudged zero point should always be in range.
   116    assert(nudgedZeroPoint >= qmin);
   117    assert(nudgedZeroPoint <= qmax);
   118  
   119    return UniformQuantizedType::getChecked(flags, storageType, expressedType,
   120                                            scale, nudgedZeroPoint, qmin, qmax,
   121                                            loc);
   122  }