github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp (about) 1 //===- FxpMathConfig.cpp - Reference fixed point config -------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines a TargetConfiguration for reference fixed-point math 19 // quantization scheme based on the FxpMathOps (plus a small category of 20 // extension ops that can be added from other dialects). 21 // 22 //===----------------------------------------------------------------------===// 23 24 #include "mlir/Quantizer/Configurations/FxpMathConfig.h" 25 26 #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" 27 #include "mlir/Dialect/QuantOps/QuantOps.h" 28 #include "mlir/Dialect/QuantOps/QuantTypes.h" 29 #include "mlir/Dialect/StandardOps/Ops.h" 30 #include "mlir/IR/Matchers.h" 31 #include "mlir/IR/StandardTypes.h" 32 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" 33 #include "mlir/Quantizer/Support/Metadata.h" 34 #include "mlir/Quantizer/Support/Statistics.h" 35 #include "mlir/Quantizer/Support/UniformConstraints.h" 36 37 using namespace mlir; 38 using namespace mlir::quantizer; 39 using namespace mlir::fxpmath; 40 using namespace mlir::quant; 41 using namespace std::placeholders; 42 43 namespace { 44 45 struct FxpMathTargetConfigImpl : public FxpMathTargetConfig { 46 FxpMathTargetConfigImpl(SolverContext &context) 47 : FxpMathTargetConfig(context) { 48 Builder b(&context.getMlirContext()); 49 IntegerType i8Type = b.getIntegerType(8); 50 IntegerType i16Type = b.getIntegerType(16); 51 IntegerType i32Type = b.getIntegerType(32); 52 53 q8 = addCandidateType( 54 AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr, 55 std::numeric_limits<int8_t>::min(), 56 std::numeric_limits<int8_t>::max()), 57 CandidateQuantizedType::Scheme::UniformPerLayer); 58 q16 = addCandidateType( 59 AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr, 60 std::numeric_limits<int16_t>::min(), 61 std::numeric_limits<int16_t>::max()), 62 CandidateQuantizedType::Scheme::UniformPerLayer); 63 q32ExplicitFixedPoint = addCandidateType( 64 AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr, 65 std::numeric_limits<int32_t>::min(), 66 std::numeric_limits<int32_t>::max()), 67 CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale); 68 69 // Op handlers. 70 addOpHandler<ConstantOp>( 71 std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2)); 72 addOpHandler<ReturnOp>( 73 std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2)); 74 addOpHandler<quant::StatisticsOp>( 75 std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2)); 76 77 // FxpMathOps. 78 addOpHandler<RealAddEwOp>( 79 std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2)); 80 addOpHandler<RealMulEwOp>( 81 std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2)); 82 addOpHandler<RealMatMulOp>( 83 std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2)); 84 addOpHandler<RealMatMulBiasOp>( 85 std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2)); 86 87 // Require stats ops. 88 addRequireStatsOp<RealAddEwOp>(); 89 addRequireStatsOp<RealSubEwOp>(); 90 addRequireStatsOp<RealDivEwOp>(); 91 addRequireStatsOp<RealMulEwOp>(); 92 addRequireStatsOp<RealMatMulOp>(); 93 addRequireStatsOp<RealMatMulBiasOp>(); 94 } 95 96 bool isHandledType(Type t) const final { 97 if (t.isa<FloatType>()) 98 return true; 99 return (t.isa<VectorType>() || t.isa<TensorType>()) && 100 t.cast<ShapedType>().getElementType().isa<FloatType>(); 101 } 102 103 void finalizeAnchors(CAGSlice &cag) const override { 104 cag.enumerateImpliedConnections( 105 [&](CAGAnchorNode *from, CAGAnchorNode *to) { 106 UniformConstraintsBuilder(cag).coupleAnchors(from, to); 107 }); 108 } 109 110 void addValueIdentityOpByName(StringRef opName) override { 111 addOpHandlerByName( 112 opName, 113 std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2)); 114 } 115 116 void handleValueIdentity(Operation *op, CAGSlice &cag) const { 117 assert(op->getNumResults() == 1); 118 if (!isHandledType(op->getResult(0)->getType())) 119 return; 120 121 auto resultNode = cag.getResultAnchor(op, 0); 122 resultNode->setTypeTransformRule( 123 CAGAnchorNode::TypeTransformRule::DirectStorage); 124 125 for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) { 126 if (!isHandledType(op->getOperand(opIdx)->getType())) 127 continue; 128 auto operandNode = cag.getOperandAnchor(op, opIdx); 129 operandNode->setTypeTransformRule( 130 CAGAnchorNode::TypeTransformRule::DirectStorage); 131 UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode); 132 } 133 } 134 135 void handleConstant(Operation *op, CAGSlice &cag) const { 136 if (!isHandledType(op->getResult(0)->getType())) 137 return; 138 139 auto resultNode = cag.getResultAnchor(op, 0); 140 resultNode->setTypeTransformRule( 141 CAGAnchorNode::TypeTransformRule::ExpressedOnly); 142 Attribute valueAttr; 143 if (!matchPattern(op, m_Constant(&valueAttr))) { 144 return; 145 } 146 147 AttributeTensorStatistics stats(valueAttr); 148 TensorAxisStatistics layerStats; 149 if (!stats.get(layerStats)) { 150 op->emitOpError("could not compute statistics"); 151 return; 152 } 153 154 UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats); 155 } 156 157 void handleTerminal(Operation *op, CAGSlice &cag) const { 158 if (!isHandledType(op->getOperand(0)->getType())) 159 return; 160 auto operandNode = cag.getOperandAnchor(op, 0); 161 operandNode->setTypeTransformRule( 162 CAGAnchorNode::TypeTransformRule::ExpressedOnly); 163 } 164 165 void handleStats(Operation *op, CAGSlice &cag) const { 166 if (!isHandledType(op->getResult(0)->getType())) 167 return; 168 169 auto argNode = cag.getOperandAnchor(op, 0); 170 auto resultNode = cag.getResultAnchor(op, 0); 171 UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode); 172 173 TensorAxisStatistics layerStats; 174 auto statsOp = cast<quant::StatisticsOp>(op); 175 auto layerStatsAttr = statsOp.layerStats(); 176 layerStats.minValue = 177 layerStatsAttr.getValue<FloatAttr>(0).getValueAsDouble(); 178 layerStats.maxValue = 179 layerStatsAttr.getValue<FloatAttr>(1).getValueAsDouble(); 180 UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats); 181 } 182 183 void handleAdd(Operation *op, CAGSlice &cag) const { 184 if (!isHandledType(op->getResult(0)->getType())) 185 return; 186 187 auto lhs = cag.getOperandAnchor(op, 0); 188 auto rhs = cag.getOperandAnchor(op, 1); 189 auto resultNode = cag.getResultAnchor(op, 0); 190 // Add supports 8/16 bit math. 191 llvm::SmallBitVector disableMask = 192 getCandidateTypeDisabledExceptMask({q8, q16}); 193 lhs->getUniformMetadata().disabledCandidateTypes = disableMask; 194 rhs->getUniformMetadata().disabledCandidateTypes = disableMask; 195 resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; 196 // NOTE: We couple the add such that the scale/zeroPoint match between 197 // both args and the result. This is overly constrained in that it is 198 // possible to write efficient add kernels with a bit more freedom (i.e. 199 // zeroPoints can vary, scales can differ by a power of two, etc). 200 // However, fully coupled yields the simples solutions on the fast path. 201 // Further efficiency can be had by constraining the zeroPoint to 0, but 202 // there isn't a constraint for this yet (and there are tradeoffs). 203 UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode); 204 UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode); 205 addRealMathOptionalConstraints(op, resultNode, cag); 206 } 207 208 void handleMul(Operation *op, CAGSlice &cag) const { 209 if (!isHandledType(op->getResult(0)->getType())) 210 return; 211 212 auto lhs = cag.getOperandAnchor(op, 0); 213 auto rhs = cag.getOperandAnchor(op, 1); 214 auto resultNode = cag.getResultAnchor(op, 0); 215 // Mul supports 8/16 bit math. 216 llvm::SmallBitVector disableMask = 217 getCandidateTypeDisabledExceptMask({q8, q16}); 218 lhs->getUniformMetadata().disabledCandidateTypes = disableMask; 219 rhs->getUniformMetadata().disabledCandidateTypes = disableMask; 220 resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; 221 addRealMathOptionalConstraints(op, resultNode, cag); 222 } 223 224 void handleMatMul(Operation *op, CAGSlice &cag) const { 225 if (!isHandledType(op->getResult(0)->getType())) 226 return; 227 228 auto lhs = cag.getOperandAnchor(op, 0); 229 auto rhs = cag.getOperandAnchor(op, 1); 230 auto resultNode = cag.getResultAnchor(op, 0); 231 // Mul supports 8/16 bit math. 232 llvm::SmallBitVector disableMask = 233 getCandidateTypeDisabledExceptMask({q8, q16}); 234 lhs->getUniformMetadata().disabledCandidateTypes = disableMask; 235 rhs->getUniformMetadata().disabledCandidateTypes = disableMask; 236 resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; 237 addRealMathOptionalConstraints(op, resultNode, cag); 238 } 239 240 void handleMatMulBias(Operation *op, CAGSlice &cag) const { 241 if (!isHandledType(op->getResult(0)->getType())) 242 return; 243 244 auto lhs = cag.getOperandAnchor(op, 0); 245 auto rhs = cag.getOperandAnchor(op, 1); 246 auto bias = cag.getOperandAnchor(op, 2); 247 bias->getUniformMetadata().disabledCandidateTypes = 248 getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint}); 249 250 auto resultNode = cag.getResultAnchor(op, 0); 251 UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias); 252 253 // Mul supports 8/16 bit math. 254 llvm::SmallBitVector disableMask = 255 getCandidateTypeDisabledExceptMask({q8, q16}); 256 lhs->getUniformMetadata().disabledCandidateTypes = disableMask; 257 rhs->getUniformMetadata().disabledCandidateTypes = disableMask; 258 resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; 259 addRealMathOptionalConstraints(op, resultNode, cag); 260 } 261 262 void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor, 263 CAGSlice &cag) const { 264 // TODO: It would be nice if these all extended some base trait instead 265 // of requiring name lookup. 266 auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min"); 267 auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max"); 268 269 if (clampMinAttr || clampMaxAttr) { 270 auto nan = APFloat::getQNaN(APFloat::IEEEdouble()); 271 auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan; 272 auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan; 273 UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax); 274 } 275 } 276 277 unsigned q8; 278 unsigned q16; 279 unsigned q32ExplicitFixedPoint; 280 }; 281 282 } // anonymous namespace 283 284 std::unique_ptr<FxpMathTargetConfig> 285 FxpMathTargetConfig::create(SolverContext &context) { 286 return std::make_unique<FxpMathTargetConfigImpl>(context); 287 }