github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp (about)

     1  //===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/QuantOps/UniformSupport.h"
    19  #include "mlir/IR/StandardTypes.h"
    20  
    21  using namespace mlir;
    22  using namespace mlir::quant;
    23  
    24  static bool isQuantizablePrimitiveType(Type inputType) {
    25    return inputType.isa<FloatType>();
    26  }
    27  
    28  const ExpressedToUniformQuantizedConverter
    29  ExpressedToUniformQuantizedConverter::forInputType(Type inputType) {
    30    switch (inputType.getKind()) {
    31    default:
    32      if (isQuantizablePrimitiveType(inputType)) {
    33        // Supported primitive type (which just is the expressed type).
    34        return ExpressedToUniformQuantizedConverter{inputType, inputType};
    35      }
    36      // Unsupported.
    37      return ExpressedToUniformQuantizedConverter{inputType, nullptr};
    38    case StandardTypes::RankedTensor:
    39    case StandardTypes::UnrankedTensor:
    40    case StandardTypes::Vector: {
    41      Type elementType = inputType.cast<ShapedType>().getElementType();
    42      if (!isQuantizablePrimitiveType(elementType)) {
    43        // Unsupported.
    44        return ExpressedToUniformQuantizedConverter{inputType, nullptr};
    45      }
    46      return ExpressedToUniformQuantizedConverter{
    47          inputType, inputType.cast<ShapedType>().getElementType()};
    48    }
    49    }
    50  }
    51  
    52  Type ExpressedToUniformQuantizedConverter::convert(
    53      UniformQuantizedType elementalType) const {
    54    assert(expressedType && "convert() on unsupported conversion");
    55  
    56    switch (inputType.getKind()) {
    57    default:
    58      if (isQuantizablePrimitiveType(elementalType)) {
    59        // For primitives, just use the new elemental type.
    60        return elementalType;
    61      }
    62      // Unsupported.
    63      return nullptr;
    64    case StandardTypes::RankedTensor:
    65      return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
    66                                   elementalType);
    67    case StandardTypes::UnrankedTensor:
    68      return UnrankedTensorType::get(elementalType);
    69    case StandardTypes::Vector:
    70      return VectorType::get(inputType.cast<VectorType>().getShape(),
    71                             elementalType);
    72    }
    73  }