github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h (about)

     1  //===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
    19  #define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
    20  
    21  #include "mlir/Dialect/QuantOps/QuantOps.h"
    22  #include "mlir/Dialect/QuantOps/QuantTypes.h"
    23  #include "mlir/Dialect/QuantOps/UniformSupport.h"
    24  #include "mlir/IR/Operation.h"
    25  
    26  #include <cmath>
    27  
    28  namespace mlir {
    29  namespace fxpmath {
    30  namespace detail {
    31  
    32  inline quant::UniformQuantizedType getUniformElementType(Type t) {
    33    return quant::QuantizedType::getQuantizedElementType(t)
    34        .dyn_cast_or_null<quant::UniformQuantizedType>();
    35  }
    36  
    37  inline bool hasStorageBitWidth(quant::QuantizedType t,
    38                                 llvm::ArrayRef<unsigned> checkWidths) {
    39    unsigned w = t.getStorageType().getIntOrFloatBitWidth();
    40    for (unsigned checkWidth : checkWidths) {
    41      if (w == checkWidth)
    42        return true;
    43    }
    44    return false;
    45  }
    46  
    47  /// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
    48  /// be considered an exact integral value.
    49  template <typename F> bool integralLog2(F x, int &log2Result) {
    50    const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
    51    const F xLog2Rounded = std::round(xLog2);
    52    const F xLog2Frac = xLog2 - xLog2Rounded;
    53    log2Result = static_cast<int>(xLog2Rounded);
    54    // Allow small comparison slop below the level that would make a difference
    55    // for 2^16 levels.
    56    return std::abs(xLog2Frac) < 1e-6;
    57  }
    58  
    59  /// Helper class for operating on binary operations where all operands
    60  /// and the result are a UniformQuantizedType.
    61  struct UniformBinaryOpInfo {
    62    UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
    63                        Optional<APFloat> clampMin, Optional<APFloat> clampMax)
    64        : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
    65          lhsType(getUniformElementType(lhs->getType())),
    66          rhsType(getUniformElementType(rhs->getType())),
    67          resultType(getUniformElementType(*op->result_type_begin())),
    68          lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
    69          rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
    70          resultStorageType(
    71              quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
    72    }
    73  
    74    /// Returns whether this info is valid (all types defined, etc).
    75    bool isValid() const {
    76      return lhsType && rhsType && resultType && lhsStorageType &&
    77             rhsStorageType && resultStorageType;
    78    }
    79  
    80    /// Gets the final quantized result type of the result.
    81    Type getQuantizedResultType() const { return *op->result_type_begin(); }
    82  
    83    /// Returns whether the storage type of all operands is identical.
    84    bool isSameStorageType() const {
    85      return lhsType.getStorageType() == rhsType.getStorageType() &&
    86             lhsType.getStorageType() == resultType.getStorageType();
    87    }
    88  
    89    /// Returns whether all operands and result are considered fixedpoint power
    90    /// of two, setting the lhs, rhs, and result log2 scale references.
    91    bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
    92                         int &resultLog2Scale) const {
    93      if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
    94          !resultType.isFixedPoint()) {
    95        return false;
    96      }
    97  
    98      if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
    99          !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
   100          !integralLog2(resultType.getScale(), resultLog2Scale)) {
   101        return false;
   102      }
   103  
   104      return true;
   105    }
   106  
   107    /// Gets the result integer clamp range given the result quantized type
   108    // and any explicit clamp provided as attributes.
   109    std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
   110      int64_t typeMin = resultType.getStorageTypeMin();
   111      int64_t typeMax = resultType.getStorageTypeMax();
   112  
   113      if (clampMin || clampMax) {
   114        quant::UniformQuantizedValueConverter conv(resultType);
   115        if (clampMin) {
   116          typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
   117        }
   118        if (clampMax) {
   119          typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
   120        }
   121      }
   122  
   123      // The quantized, integral ops expect clamps as 32bit ints.
   124      return {
   125          IntegerAttr::get(ty, typeMin),
   126          IntegerAttr::get(ty, typeMax),
   127      };
   128    }
   129  
   130    Operation *op;
   131    Value *lhs;
   132    Value *rhs;
   133    Optional<APFloat> clampMin;
   134    Optional<APFloat> clampMax;
   135  
   136    // Element UniformQuantizedType for operands/result.
   137    quant::UniformQuantizedType lhsType;
   138    quant::UniformQuantizedType rhsType;
   139    quant::UniformQuantizedType resultType;
   140  
   141    // Full storage-based types.
   142    Type lhsStorageType;
   143    Type rhsStorageType;
   144    Type resultStorageType;
   145  };
   146  
   147  /// Derives a quantized multiplier and shift from a real valued multiplier
   148  /// less than 1.
   149  struct QuantizedMultiplierSmallerThanOneExp {
   150    QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
   151      assert(realMultiplier < 1.0);
   152      assert(realMultiplier > 0.0);
   153  
   154      const double q = std::frexp(realMultiplier, &exponent);
   155      auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
   156      assert(qFixed <= (1ll << 31));
   157      if (qFixed == (1ll << 31)) {
   158        qFixed /= 2;
   159        ++exponent;
   160      }
   161      assert(qFixed <= std::numeric_limits<int32_t>::max());
   162      multiplier = static_cast<int32_t>(qFixed);
   163    }
   164  
   165    int32_t multiplier;
   166    int exponent;
   167  };
   168  
   169  /// Casts an integer or floating point based shaped type to a new element type.
   170  inline Type castElementType(Type t, Type newElementType) {
   171    if (auto st = t.dyn_cast<ShapedType>()) {
   172      switch (st.getKind()) {
   173      case StandardTypes::Kind::Vector:
   174        return VectorType::get(st.getShape(), newElementType);
   175      case StandardTypes::Kind::RankedTensor:
   176        return RankedTensorType::get(st.getShape(), newElementType);
   177      case StandardTypes::Kind::UnrankedTensor:
   178        return UnrankedTensorType::get(newElementType);
   179      case StandardTypes::Kind::MemRef:
   180        return MemRefType::get(st.getShape(), newElementType,
   181                               st.cast<MemRefType>().getAffineMaps());
   182      }
   183    }
   184    assert(t.isIntOrFloat());
   185    return newElementType;
   186  }
   187  
   188  /// Creates an IntegerAttr with a type that matches the shape of 't' (which can
   189  /// be a scalar primitive or a shaped type).
   190  inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
   191    if (auto st = t.dyn_cast<ShapedType>()) {
   192      assert(st.getElementType().isa<IntegerType>());
   193      return DenseElementsAttr::get(st,
   194                                    IntegerAttr::get(st.getElementType(), value));
   195    }
   196  
   197    auto integerType = t.cast<IntegerType>();
   198    assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
   199    return IntegerAttr::get(integerType, value);
   200  }
   201  
   202  /// Given an APFloat, converts it to the float semantics that matches the
   203  /// given FloatType, silently ignoring inexact conversions.
   204  inline APFloat convertFloatToType(FloatType ft, APFloat value) {
   205    bool losesInfo;
   206    auto status = value.convert(ft.getFloatSemantics(),
   207                                APFloat::rmNearestTiesToEven, &losesInfo);
   208    (void)status; // unused in opt mode
   209    assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
   210           "could not convert to float const");
   211    return value;
   212  }
   213  
   214  /// Creates a FloatAttr with a type that matches the shape of 't' (which can be
   215  /// a scalar primitive or a shaped type).
   216  inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
   217    if (auto st = t.dyn_cast<ShapedType>()) {
   218      FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
   219      assert(floatElementType &&
   220             "float broadcast element type must be float like");
   221      APFloat apValue = convertFloatToType(floatElementType, value);
   222      return DenseElementsAttr::get(st,
   223                                    FloatAttr::get(st.getElementType(), apValue));
   224    } else {
   225      auto floatType = t.dyn_cast<FloatType>();
   226      assert(floatType && "float broadcast must be of float type");
   227      APFloat apValue = convertFloatToType(floatType, value);
   228      return FloatAttr::get(floatType, apValue);
   229    }
   230  }
   231  
   232  } // namespace detail
   233  } // namespace fxpmath
   234  } // namespace mlir
   235  
   236  #endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_