github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp (about)

     1  //===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file defines a TargetConfiguration for reference fixed-point math
    19  // quantization scheme based on the FxpMathOps (plus a small category of
    20  // extension ops that can be added from other dialects).
    21  //
    22  //===----------------------------------------------------------------------===//
    23  
    24  #include "mlir/Quantizer/Configurations/FxpMathConfig.h"
    25  
    26  #include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
    27  #include "mlir/Dialect/QuantOps/QuantOps.h"
    28  #include "mlir/Dialect/QuantOps/QuantTypes.h"
    29  #include "mlir/Dialect/StandardOps/Ops.h"
    30  #include "mlir/IR/Matchers.h"
    31  #include "mlir/IR/StandardTypes.h"
    32  #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
    33  #include "mlir/Quantizer/Support/Metadata.h"
    34  #include "mlir/Quantizer/Support/Statistics.h"
    35  #include "mlir/Quantizer/Support/UniformConstraints.h"
    36  
    37  using namespace mlir;
    38  using namespace mlir::quantizer;
    39  using namespace mlir::fxpmath;
    40  using namespace mlir::quant;
    41  using namespace std::placeholders;
    42  
    43  namespace {
    44  
    45  struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
    46    FxpMathTargetConfigImpl(SolverContext &context)
    47        : FxpMathTargetConfig(context) {
    48      Builder b(&context.getMlirContext());
    49      IntegerType i8Type = b.getIntegerType(8);
    50      IntegerType i16Type = b.getIntegerType(16);
    51      IntegerType i32Type = b.getIntegerType(32);
    52  
    53      q8 = addCandidateType(
    54          AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
    55                                std::numeric_limits<int8_t>::min(),
    56                                std::numeric_limits<int8_t>::max()),
    57          CandidateQuantizedType::Scheme::UniformPerLayer);
    58      q16 = addCandidateType(
    59          AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
    60                                std::numeric_limits<int16_t>::min(),
    61                                std::numeric_limits<int16_t>::max()),
    62          CandidateQuantizedType::Scheme::UniformPerLayer);
    63      q32ExplicitFixedPoint = addCandidateType(
    64          AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
    65                                std::numeric_limits<int32_t>::min(),
    66                                std::numeric_limits<int32_t>::max()),
    67          CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
    68  
    69      // Op handlers.
    70      addOpHandler<ConstantOp>(
    71          std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
    72      addOpHandler<ReturnOp>(
    73          std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
    74      addOpHandler<quant::StatisticsOp>(
    75          std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
    76  
    77      // FxpMathOps.
    78      addOpHandler<RealAddEwOp>(
    79          std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
    80      addOpHandler<RealMulEwOp>(
    81          std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
    82      addOpHandler<RealMatMulOp>(
    83          std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
    84      addOpHandler<RealMatMulBiasOp>(
    85          std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
    86  
    87      // Require stats ops.
    88      addRequireStatsOp<RealAddEwOp>();
    89      addRequireStatsOp<RealSubEwOp>();
    90      addRequireStatsOp<RealDivEwOp>();
    91      addRequireStatsOp<RealMulEwOp>();
    92      addRequireStatsOp<RealMatMulOp>();
    93      addRequireStatsOp<RealMatMulBiasOp>();
    94    }
    95  
    96    bool isHandledType(Type t) const final {
    97      if (t.isa<FloatType>())
    98        return true;
    99      return (t.isa<VectorType>() || t.isa<TensorType>()) &&
   100             t.cast<ShapedType>().getElementType().isa<FloatType>();
   101    }
   102  
   103    void finalizeAnchors(CAGSlice &cag) const override {
   104      cag.enumerateImpliedConnections(
   105          [&](CAGAnchorNode *from, CAGAnchorNode *to) {
   106            UniformConstraintsBuilder(cag).coupleAnchors(from, to);
   107          });
   108    }
   109  
   110    void addValueIdentityOpByName(StringRef opName) override {
   111      addOpHandlerByName(
   112          opName,
   113          std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
   114    }
   115  
   116    void handleValueIdentity(Operation *op, CAGSlice &cag) const {
   117      assert(op->getNumResults() == 1);
   118      if (!isHandledType(op->getResult(0)->getType()))
   119        return;
   120  
   121      auto resultNode = cag.getResultAnchor(op, 0);
   122      resultNode->setTypeTransformRule(
   123          CAGAnchorNode::TypeTransformRule::DirectStorage);
   124  
   125      for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
   126        if (!isHandledType(op->getOperand(opIdx)->getType()))
   127          continue;
   128        auto operandNode = cag.getOperandAnchor(op, opIdx);
   129        operandNode->setTypeTransformRule(
   130            CAGAnchorNode::TypeTransformRule::DirectStorage);
   131        UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
   132      }
   133    }
   134  
   135    void handleConstant(Operation *op, CAGSlice &cag) const {
   136      if (!isHandledType(op->getResult(0)->getType()))
   137        return;
   138  
   139      auto resultNode = cag.getResultAnchor(op, 0);
   140      resultNode->setTypeTransformRule(
   141          CAGAnchorNode::TypeTransformRule::ExpressedOnly);
   142      Attribute valueAttr;
   143      if (!matchPattern(op, m_Constant(&valueAttr))) {
   144        return;
   145      }
   146  
   147      AttributeTensorStatistics stats(valueAttr);
   148      TensorAxisStatistics layerStats;
   149      if (!stats.get(layerStats)) {
   150        op->emitOpError("could not compute statistics");
   151        return;
   152      }
   153  
   154      UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
   155    }
   156  
   157    void handleTerminal(Operation *op, CAGSlice &cag) const {
   158      if (!isHandledType(op->getOperand(0)->getType()))
   159        return;
   160      auto operandNode = cag.getOperandAnchor(op, 0);
   161      operandNode->setTypeTransformRule(
   162          CAGAnchorNode::TypeTransformRule::ExpressedOnly);
   163    }
   164  
   165    void handleStats(Operation *op, CAGSlice &cag) const {
   166      if (!isHandledType(op->getResult(0)->getType()))
   167        return;
   168  
   169      auto argNode = cag.getOperandAnchor(op, 0);
   170      auto resultNode = cag.getResultAnchor(op, 0);
   171      UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
   172  
   173      TensorAxisStatistics layerStats;
   174      auto statsOp = cast<quant::StatisticsOp>(op);
   175      auto layerStatsAttr = statsOp.layerStats();
   176      layerStats.minValue =
   177          layerStatsAttr.getValue<FloatAttr>(0).getValueAsDouble();
   178      layerStats.maxValue =
   179          layerStatsAttr.getValue<FloatAttr>(1).getValueAsDouble();
   180      UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
   181    }
   182  
   183    void handleAdd(Operation *op, CAGSlice &cag) const {
   184      if (!isHandledType(op->getResult(0)->getType()))
   185        return;
   186  
   187      auto lhs = cag.getOperandAnchor(op, 0);
   188      auto rhs = cag.getOperandAnchor(op, 1);
   189      auto resultNode = cag.getResultAnchor(op, 0);
   190      // Add supports 8/16 bit math.
   191      llvm::SmallBitVector disableMask =
   192          getCandidateTypeDisabledExceptMask({q8, q16});
   193      lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   194      rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   195      resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
   196      // NOTE: We couple the add such that the scale/zeroPoint match between
   197      // both args and the result. This is overly constrained in that it is
   198      // possible to write efficient add kernels with a bit more freedom (i.e.
   199      // zeroPoints can vary, scales can differ by a power of two, etc).
   200      // However, fully coupled yields the simples solutions on the fast path.
   201      // Further efficiency can be had by constraining the zeroPoint to 0, but
   202      // there isn't a constraint for this yet (and there are tradeoffs).
   203      UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
   204      UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
   205      addRealMathOptionalConstraints(op, resultNode, cag);
   206    }
   207  
   208    void handleMul(Operation *op, CAGSlice &cag) const {
   209      if (!isHandledType(op->getResult(0)->getType()))
   210        return;
   211  
   212      auto lhs = cag.getOperandAnchor(op, 0);
   213      auto rhs = cag.getOperandAnchor(op, 1);
   214      auto resultNode = cag.getResultAnchor(op, 0);
   215      // Mul supports 8/16 bit math.
   216      llvm::SmallBitVector disableMask =
   217          getCandidateTypeDisabledExceptMask({q8, q16});
   218      lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   219      rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   220      resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
   221      addRealMathOptionalConstraints(op, resultNode, cag);
   222    }
   223  
   224    void handleMatMul(Operation *op, CAGSlice &cag) const {
   225      if (!isHandledType(op->getResult(0)->getType()))
   226        return;
   227  
   228      auto lhs = cag.getOperandAnchor(op, 0);
   229      auto rhs = cag.getOperandAnchor(op, 1);
   230      auto resultNode = cag.getResultAnchor(op, 0);
   231      // Mul supports 8/16 bit math.
   232      llvm::SmallBitVector disableMask =
   233          getCandidateTypeDisabledExceptMask({q8, q16});
   234      lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   235      rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   236      resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
   237      addRealMathOptionalConstraints(op, resultNode, cag);
   238    }
   239  
   240    void handleMatMulBias(Operation *op, CAGSlice &cag) const {
   241      if (!isHandledType(op->getResult(0)->getType()))
   242        return;
   243  
   244      auto lhs = cag.getOperandAnchor(op, 0);
   245      auto rhs = cag.getOperandAnchor(op, 1);
   246      auto bias = cag.getOperandAnchor(op, 2);
   247      bias->getUniformMetadata().disabledCandidateTypes =
   248          getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
   249  
   250      auto resultNode = cag.getResultAnchor(op, 0);
   251      UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
   252  
   253      // Mul supports 8/16 bit math.
   254      llvm::SmallBitVector disableMask =
   255          getCandidateTypeDisabledExceptMask({q8, q16});
   256      lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   257      rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
   258      resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
   259      addRealMathOptionalConstraints(op, resultNode, cag);
   260    }
   261  
   262    void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
   263                                        CAGSlice &cag) const {
   264      // TODO: It would be nice if these all extended some base trait instead
   265      // of requiring name lookup.
   266      auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
   267      auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
   268  
   269      if (clampMinAttr || clampMaxAttr) {
   270        auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
   271        auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
   272        auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
   273        UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
   274      }
   275    }
   276  
   277    unsigned q8;
   278    unsigned q16;
   279    unsigned q32ExplicitFixedPoint;
   280  };
   281  
   282  } // anonymous namespace
   283  
   284  std::unique_ptr<FxpMathTargetConfig>
   285  FxpMathTargetConfig::create(SolverContext &context) {
   286    return std::make_unique<FxpMathTargetConfigImpl>(context);
   287  }