github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/UniformKernelUtils.h (about) 1 //===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ 19 #define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ 20 21 #include "mlir/Dialect/QuantOps/QuantOps.h" 22 #include "mlir/Dialect/QuantOps/QuantTypes.h" 23 #include "mlir/Dialect/QuantOps/UniformSupport.h" 24 #include "mlir/IR/Operation.h" 25 26 #include <cmath> 27 28 namespace mlir { 29 namespace fxpmath { 30 namespace detail { 31 32 inline quant::UniformQuantizedType getUniformElementType(Type t) { 33 return quant::QuantizedType::getQuantizedElementType(t) 34 .dyn_cast_or_null<quant::UniformQuantizedType>(); 35 } 36 37 inline bool hasStorageBitWidth(quant::QuantizedType t, 38 llvm::ArrayRef<unsigned> checkWidths) { 39 unsigned w = t.getStorageType().getIntOrFloatBitWidth(); 40 for (unsigned checkWidth : checkWidths) { 41 if (w == checkWidth) 42 return true; 43 } 44 return false; 45 } 46 47 /// Computes the log2(x), rounded to an integral value. Returns whether 'x' can 48 /// be considered an exact integral value. 49 template <typename F> bool integralLog2(F x, int &log2Result) { 50 const F xLog2 = std::log(x) * (1.0 / std::log(2.0)); 51 const F xLog2Rounded = std::round(xLog2); 52 const F xLog2Frac = xLog2 - xLog2Rounded; 53 log2Result = static_cast<int>(xLog2Rounded); 54 // Allow small comparison slop below the level that would make a difference 55 // for 2^16 levels. 56 return std::abs(xLog2Frac) < 1e-6; 57 } 58 59 /// Helper class for operating on binary operations where all operands 60 /// and the result are a UniformQuantizedType. 61 struct UniformBinaryOpInfo { 62 UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs, 63 Optional<APFloat> clampMin, Optional<APFloat> clampMax) 64 : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), 65 lhsType(getUniformElementType(lhs->getType())), 66 rhsType(getUniformElementType(rhs->getType())), 67 resultType(getUniformElementType(*op->result_type_begin())), 68 lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())), 69 rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())), 70 resultStorageType( 71 quant::QuantizedType::castToStorageType(*op->result_type_begin())) { 72 } 73 74 /// Returns whether this info is valid (all types defined, etc). 75 bool isValid() const { 76 return lhsType && rhsType && resultType && lhsStorageType && 77 rhsStorageType && resultStorageType; 78 } 79 80 /// Gets the final quantized result type of the result. 81 Type getQuantizedResultType() const { return *op->result_type_begin(); } 82 83 /// Returns whether the storage type of all operands is identical. 84 bool isSameStorageType() const { 85 return lhsType.getStorageType() == rhsType.getStorageType() && 86 lhsType.getStorageType() == resultType.getStorageType(); 87 } 88 89 /// Returns whether all operands and result are considered fixedpoint power 90 /// of two, setting the lhs, rhs, and result log2 scale references. 91 bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale, 92 int &resultLog2Scale) const { 93 if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() || 94 !resultType.isFixedPoint()) { 95 return false; 96 } 97 98 if (!integralLog2(lhsType.getScale(), lhsLog2Scale) || 99 !integralLog2(rhsType.getScale(), rhsLog2Scale) || 100 !integralLog2(resultType.getScale(), resultLog2Scale)) { 101 return false; 102 } 103 104 return true; 105 } 106 107 /// Gets the result integer clamp range given the result quantized type 108 // and any explicit clamp provided as attributes. 109 std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const { 110 int64_t typeMin = resultType.getStorageTypeMin(); 111 int64_t typeMax = resultType.getStorageTypeMax(); 112 113 if (clampMin || clampMax) { 114 quant::UniformQuantizedValueConverter conv(resultType); 115 if (clampMin) { 116 typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin)); 117 } 118 if (clampMax) { 119 typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax)); 120 } 121 } 122 123 // The quantized, integral ops expect clamps as 32bit ints. 124 return { 125 IntegerAttr::get(ty, typeMin), 126 IntegerAttr::get(ty, typeMax), 127 }; 128 } 129 130 Operation *op; 131 Value *lhs; 132 Value *rhs; 133 Optional<APFloat> clampMin; 134 Optional<APFloat> clampMax; 135 136 // Element UniformQuantizedType for operands/result. 137 quant::UniformQuantizedType lhsType; 138 quant::UniformQuantizedType rhsType; 139 quant::UniformQuantizedType resultType; 140 141 // Full storage-based types. 142 Type lhsStorageType; 143 Type rhsStorageType; 144 Type resultStorageType; 145 }; 146 147 /// Derives a quantized multiplier and shift from a real valued multiplier 148 /// less than 1. 149 struct QuantizedMultiplierSmallerThanOneExp { 150 QuantizedMultiplierSmallerThanOneExp(double realMultiplier) { 151 assert(realMultiplier < 1.0); 152 assert(realMultiplier > 0.0); 153 154 const double q = std::frexp(realMultiplier, &exponent); 155 auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31))); 156 assert(qFixed <= (1ll << 31)); 157 if (qFixed == (1ll << 31)) { 158 qFixed /= 2; 159 ++exponent; 160 } 161 assert(qFixed <= std::numeric_limits<int32_t>::max()); 162 multiplier = static_cast<int32_t>(qFixed); 163 } 164 165 int32_t multiplier; 166 int exponent; 167 }; 168 169 /// Casts an integer or floating point based shaped type to a new element type. 170 inline Type castElementType(Type t, Type newElementType) { 171 if (auto st = t.dyn_cast<ShapedType>()) { 172 switch (st.getKind()) { 173 case StandardTypes::Kind::Vector: 174 return VectorType::get(st.getShape(), newElementType); 175 case StandardTypes::Kind::RankedTensor: 176 return RankedTensorType::get(st.getShape(), newElementType); 177 case StandardTypes::Kind::UnrankedTensor: 178 return UnrankedTensorType::get(newElementType); 179 case StandardTypes::Kind::MemRef: 180 return MemRefType::get(st.getShape(), newElementType, 181 st.cast<MemRefType>().getAffineMaps()); 182 } 183 } 184 assert(t.isIntOrFloat()); 185 return newElementType; 186 } 187 188 /// Creates an IntegerAttr with a type that matches the shape of 't' (which can 189 /// be a scalar primitive or a shaped type). 190 inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { 191 if (auto st = t.dyn_cast<ShapedType>()) { 192 assert(st.getElementType().isa<IntegerType>()); 193 return DenseElementsAttr::get(st, 194 IntegerAttr::get(st.getElementType(), value)); 195 } 196 197 auto integerType = t.cast<IntegerType>(); 198 assert(t.isa<IntegerType>() && "integer broadcast must be of integer type"); 199 return IntegerAttr::get(integerType, value); 200 } 201 202 /// Given an APFloat, converts it to the float semantics that matches the 203 /// given FloatType, silently ignoring inexact conversions. 204 inline APFloat convertFloatToType(FloatType ft, APFloat value) { 205 bool losesInfo; 206 auto status = value.convert(ft.getFloatSemantics(), 207 APFloat::rmNearestTiesToEven, &losesInfo); 208 (void)status; // unused in opt mode 209 assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 && 210 "could not convert to float const"); 211 return value; 212 } 213 214 /// Creates a FloatAttr with a type that matches the shape of 't' (which can be 215 /// a scalar primitive or a shaped type). 216 inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) { 217 if (auto st = t.dyn_cast<ShapedType>()) { 218 FloatType floatElementType = st.getElementType().dyn_cast<FloatType>(); 219 assert(floatElementType && 220 "float broadcast element type must be float like"); 221 APFloat apValue = convertFloatToType(floatElementType, value); 222 return DenseElementsAttr::get(st, 223 FloatAttr::get(st.getElementType(), apValue)); 224 } else { 225 auto floatType = t.dyn_cast<FloatType>(); 226 assert(floatType && "float broadcast must be of float type"); 227 APFloat apValue = convertFloatToType(floatType, value); 228 return FloatAttr::get(floatType, apValue); 229 } 230 } 231 232 } // namespace detail 233 } // namespace fxpmath 234 } // namespace mlir 235 236 #endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_