github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp (about)

     1  //===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/QuantOps/QuantizeUtils.h"
    19  #include "mlir/Dialect/QuantOps/UniformSupport.h"
    20  #include "mlir/IR/Attributes.h"
    21  #include "mlir/IR/StandardTypes.h"
    22  
    23  namespace mlir {
    24  namespace quant {
    25  /// Converts a possible primitive, real expressed value attribute to a
    26  /// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
    27  /// quantizedElementType is the QuantizedType that describes the expressed
    28  /// origValue.
    29  /// Returns a converter Attribute or nullptr if conversion is not possible.
    30  static Attribute convertPrimitiveValueAttr(
    31      Attribute origRealValue, QuantizedType quantizedElementType,
    32      const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
    33    if (origRealValue.isa<FloatAttr>()) {
    34      FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
    35      outConvertedType = quantizedElementType.getStorageType();
    36      return IntegerAttr::get(quantizedElementType.getStorageType(),
    37                              converter.quantizeFloatToInt(floatAttr.getValue()));
    38    }
    39  
    40    return nullptr;
    41  }
    42  
    43  /// Converts a real expressed DenseFPElementsAttr to a corresponding
    44  /// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
    45  /// storage values assuming the given quantizedElementType and converter.
    46  static DenseElementsAttr
    47  convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
    48                             QuantizedType quantizedElementType,
    49                             const UniformQuantizedValueConverter &converter) {
    50    // Convert to corresponding quantized value attributes.
    51    SmallVector<APInt, 8> quantValues;
    52    if (realFPElementsAttr.isSplat()) {
    53      quantValues.push_back(
    54          converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
    55    } else {
    56      quantValues.reserve(realFPElementsAttr.getNumElements());
    57      for (APFloat realVal : realFPElementsAttr) {
    58        quantValues.push_back(converter.quantizeFloatToInt(realVal));
    59      }
    60    }
    61  
    62    // Cast from an expressed-type-based type to storage-type-based type,
    63    // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
    64    ShapedType newDenseType =
    65        quantizedElementType
    66            .castExpressedToStorageType(realFPElementsAttr.getType())
    67            .dyn_cast_or_null<ShapedType>();
    68    if (!newDenseType) {
    69      return nullptr;
    70    }
    71    return DenseIntElementsAttr::get(newDenseType, quantValues);
    72  }
    73  
    74  /// Converts a real expressed SplatElementsAttr to a corresponding
    75  /// SplatElementsAttr containing quantized storage values assuming the given
    76  /// quantizedElementType and converter.
    77  static SparseElementsAttr
    78  convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
    79                            QuantizedType quantizedElementType,
    80                            const UniformQuantizedValueConverter &converter) {
    81    DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
    82    if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
    83      return nullptr;
    84    }
    85    DenseElementsAttr quantDenseAttr =
    86        convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
    87                                   quantizedElementType, converter);
    88    if (!quantDenseAttr) {
    89      return nullptr;
    90    }
    91  
    92    // Cast from an expressed-type-based type to storage-type-based type,
    93    // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
    94    ShapedType newSparseType =
    95        quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
    96            .dyn_cast_or_null<ShapedType>();
    97    if (!newSparseType) {
    98      return nullptr;
    99    }
   100    return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
   101                                   quantDenseAttr);
   102  }
   103  
   104  /// Converts a real expressed Attribute to a corresponding Attribute containing
   105  /// quantized storage values assuming the given uniform quantizedElementType and
   106  /// converter.
   107  Attribute quantizeAttrUniform(Attribute realValue,
   108                                UniformQuantizedType quantizedElementType,
   109                                const UniformQuantizedValueConverter &converter,
   110                                Type &outConvertedType) {
   111    // Fork to handle different variants of constants supported.
   112    if (realValue.isa<DenseFPElementsAttr>()) {
   113      // Dense tensor or vector constant.
   114      auto converted = convertDenseFPElementsAttr(
   115          realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
   116      outConvertedType = converted.getType();
   117      return converted;
   118    } else if (realValue.isa<SparseElementsAttr>()) {
   119      // Sparse tensor or vector constant.
   120      auto converted = convertSparseElementsAttr(
   121          realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
   122      outConvertedType = converted.getType();
   123      return converted;
   124    } else {
   125      // Nothing else matched: try to convert a primitive.
   126      return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
   127                                       outConvertedType);
   128    }
   129  }
   130  
   131  /// Convert an attribute from a type based on
   132  /// quantizedElementType.getExpressedType() to one based on
   133  /// quantizedElementType.getStorageType().
   134  /// Returns nullptr if the conversion is not supported.
   135  /// On success, stores the converted type in outConvertedType.
   136  Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
   137                         Type &outConvertedType) {
   138    // Hard-coded to just support UniformQuantizedType. This will need to
   139    // be generalized when there is more than one.
   140    auto uniformQuantizedType =
   141        quantizedElementType.dyn_cast<UniformQuantizedType>();
   142    if (!uniformQuantizedType) {
   143      return nullptr;
   144    }
   145    UniformQuantizedValueConverter converter(uniformQuantizedType);
   146    return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
   147                               outConvertedType);
   148  }
   149  
   150  } // namespace quant
   151  } // namespace mlir