github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp (about)

     1  //===- LowerUniformRealMath.cpp  ------------------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "UniformKernelUtils.h"
    19  
    20  #include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
    21  #include "mlir/Dialect/FxpMathOps/Passes.h"
    22  #include "mlir/Dialect/StandardOps/Ops.h"
    23  #include "mlir/IR/Diagnostics.h"
    24  #include "mlir/IR/PatternMatch.h"
    25  #include "mlir/Pass/Pass.h"
    26  
    27  using namespace mlir;
    28  using namespace mlir::fxpmath;
    29  using namespace mlir::fxpmath::detail;
    30  using namespace mlir::quant;
    31  
    32  namespace {
    33  
    34  struct LowerUniformRealMathPass
    35      : public FunctionPass<LowerUniformRealMathPass> {
    36    void runOnFunction() override;
    37  };
    38  
    39  struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
    40    void runOnFunction() override;
    41  };
    42  
    43  } // end anonymous namespace
    44  
    45  //===----------------------------------------------------------------------===//
    46  // Dequantize
    47  //===----------------------------------------------------------------------===//
    48  
    49  static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
    50                                              UniformQuantizedType elementType,
    51                                              PatternRewriter &rewriter) {
    52    // Pre-conditions.
    53    if (!elementType.isSigned()) {
    54      // TODO: Support unsigned storage type.
    55      emitWarning(loc, "unimplemented: dequantize signed uniform");
    56      return nullptr;
    57    }
    58  
    59    Type storageType = elementType.castToStorageType(input->getType());
    60    Type realType = elementType.castToExpressedType(input->getType());
    61    Type intermediateType =
    62        castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
    63    assert(storageType && "cannot cast to storage type");
    64    assert(realType && "cannot cast to expressed type");
    65  
    66    // Cast to storage type.
    67    input = rewriter.create<StorageCastOp>(loc, storageType, input);
    68  
    69    // Promote to intermediate type.
    70    input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
    71  
    72    // Apply zero-point offset.
    73    if (elementType.getZeroPoint() != 0) {
    74      Value *negZeroPointConst = rewriter.create<ConstantOp>(
    75          loc, broadcastScalarConstIntValue(intermediateType,
    76                                            -elementType.getZeroPoint()));
    77      input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
    78    }
    79  
    80    // Convert to float.
    81    input = rewriter.create<ConvertISToFOp>(loc, realType, input);
    82  
    83    // Mul by scale.
    84    Value *scaleConst = rewriter.create<ConstantOp>(
    85        loc, broadcastScalarConstFloatValue(realType,
    86                                            APFloat(elementType.getScale())));
    87    return rewriter.create<MulFOp>(loc, input, scaleConst);
    88  }
    89  
    90  static Value *
    91  emitUniformPerAxisDequantize(Location loc, Value *input,
    92                               UniformQuantizedPerAxisType elementType,
    93                               PatternRewriter &rewriter) {
    94    // TODO: Support per-axis dequantize.
    95    rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
    96        << "unimplemented: per-axis uniform dequantization";
    97    return nullptr;
    98  }
    99  
   100  static Value *emitDequantize(Location loc, Value *input,
   101                               PatternRewriter &rewriter) {
   102    Type inputType = input->getType();
   103    QuantizedType qElementType =
   104        QuantizedType::getQuantizedElementType(inputType);
   105    if (auto uperLayerElementType =
   106            qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
   107      return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
   108                                           rewriter);
   109    } else if (auto uperAxisElementType =
   110                   qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
   111      return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
   112                                          rewriter);
   113    } else {
   114      return nullptr;
   115    }
   116  }
   117  
   118  namespace {
   119  
   120  struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
   121    using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
   122  
   123    PatternMatchResult matchAndRewrite(DequantizeCastOp op,
   124                                       PatternRewriter &rewriter) const {
   125      Type inputType = op.arg()->getType();
   126      Type outputType = op.getResult()->getType();
   127  
   128      QuantizedType inputElementType =
   129          QuantizedType::getQuantizedElementType(inputType);
   130      Type expressedOutputType = inputElementType.castToExpressedType(inputType);
   131      if (expressedOutputType != outputType) {
   132        // Not a valid uniform cast.
   133        return matchFailure();
   134      }
   135  
   136      Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
   137      if (!dequantizedValue) {
   138        return matchFailure();
   139      }
   140  
   141      rewriter.replaceOp(op, dequantizedValue);
   142      return matchSuccess();
   143    }
   144  };
   145  
   146  } // end anonymous namespace
   147  
   148  //===----------------------------------------------------------------------===//
   149  // Elementwise add
   150  //===----------------------------------------------------------------------===//
   151  
   152  static LogicalResult
   153  tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
   154                                        PatternRewriter &rewriter) {
   155    if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
   156        info.rhsType != info.resultType) {
   157      return failure();
   158    }
   159  
   160    // Choose a byte aligned intermediate width big enough to perform the
   161    // calculation without overflow.
   162    // TODO: This should probably be made just big enough to avoid overflow and
   163    // leave the downstream tooling to decide how to align that to machine
   164    // word sizes.
   165    unsigned intermediateWidth =
   166        info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
   167    IntegerType intermediateElementType =
   168        IntegerType::get(intermediateWidth, rewriter.getContext());
   169    Type intermediateType =
   170        castElementType(info.resultStorageType, intermediateElementType);
   171  
   172    // Cast operands to storage type.
   173    Value *lhsValue = rewriter
   174                          .create<StorageCastOp>(info.op->getLoc(),
   175                                                 info.lhsStorageType, info.lhs)
   176                          .getResult();
   177    Value *rhsValue = rewriter
   178                          .create<StorageCastOp>(info.op->getLoc(),
   179                                                 info.rhsStorageType, info.rhs)
   180                          .getResult();
   181  
   182    // Cast to the intermediate sized type.
   183    lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
   184                                            lhsValue);
   185    rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
   186                                            rhsValue);
   187  
   188    // Add.
   189    Value *resultValue =
   190        rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
   191  
   192    // Zero point offset adjustment.
   193    // result = (lhs - zp) + (rhs - zp) + zp
   194    // zpOffset = -zp
   195    int zpOffset = -1 * info.resultType.getZeroPoint();
   196    if (zpOffset != 0) {
   197      Value *zpOffsetConst = rewriter.create<ConstantOp>(
   198          info.op->getLoc(),
   199          broadcastScalarConstIntValue(intermediateType, zpOffset));
   200      resultValue =
   201          rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
   202    }
   203  
   204    // Clamp.
   205    auto clampMinMax = info.getClampMinMax(intermediateElementType);
   206    resultValue = rewriter.create<ClampISOp>(
   207        info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
   208  
   209    // Convert back to original type.
   210    resultValue = rewriter.create<ConvertISOp>(
   211        info.op->getLoc(), info.resultStorageType, resultValue);
   212  
   213    // Cast back for new result.
   214    rewriter.replaceOpWithNewOp<StorageCastOp>(
   215        info.op, info.getQuantizedResultType(), resultValue);
   216  
   217    return success();
   218  }
   219  
   220  //===----------------------------------------------------------------------===//
   221  // Elementwise mul
   222  //===----------------------------------------------------------------------===//
   223  
   224  static LogicalResult
   225  tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
   226                              PatternRewriter &rewriter) {
   227    if (!info.resultType.isSigned()) {
   228      return failure();
   229    }
   230  
   231    double outputMultiplierReal = info.lhsType.getScale() *
   232                                  info.rhsType.getScale() /
   233                                  info.resultType.getScale();
   234    if (outputMultiplierReal > 1.0) {
   235      info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
   236      return failure();
   237    }
   238  
   239    // TODO: Choose an appropriate intermediate width for muls > 8 bits to
   240    // avoid overflow.
   241    unsigned intermediateWidth = 32;
   242    IntegerType intermediateElementType =
   243        IntegerType::get(intermediateWidth, rewriter.getContext());
   244    Type intermediateType =
   245        castElementType(info.resultStorageType, intermediateElementType);
   246  
   247    // Cast operands to storage type.
   248    Value *lhsValue = rewriter
   249                          .create<StorageCastOp>(info.op->getLoc(),
   250                                                 info.lhsStorageType, info.lhs)
   251                          .getResult();
   252    Value *rhsValue = rewriter
   253                          .create<StorageCastOp>(info.op->getLoc(),
   254                                                 info.rhsStorageType, info.rhs)
   255                          .getResult();
   256  
   257    // Cast to the intermediate sized type.
   258    lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
   259                                            lhsValue);
   260    rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
   261                                            rhsValue);
   262  
   263    // Apply argument zeroPoints.
   264    if (info.lhsType.getZeroPoint() != 0) {
   265      Value *zpOffsetConst = rewriter.create<ConstantOp>(
   266          info.op->getLoc(), broadcastScalarConstIntValue(
   267                                 intermediateType, -info.lhsType.getZeroPoint()));
   268      lhsValue =
   269          rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
   270    }
   271  
   272    if (info.rhsType.getZeroPoint() != 0) {
   273      Value *zpOffsetConst = rewriter.create<ConstantOp>(
   274          info.op->getLoc(), broadcastScalarConstIntValue(
   275                                 intermediateType, -info.rhsType.getZeroPoint()));
   276      rhsValue =
   277          rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
   278    }
   279  
   280    // Mul.
   281    Value *resultValue =
   282        rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
   283  
   284    // Scale output.
   285    QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
   286    resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
   287        info.op->getLoc(), resultValue,
   288        IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
   289    resultValue = rewriter.create<RoundingDivideByPotISOp>(
   290        info.op->getLoc(), resultValue,
   291        IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
   292  
   293    // Zero point offset adjustment.
   294    if (info.resultType.getZeroPoint() != 0) {
   295      Value *zpOffsetConst = rewriter.create<ConstantOp>(
   296          info.op->getLoc(),
   297          broadcastScalarConstIntValue(intermediateType,
   298                                       info.resultType.getZeroPoint()));
   299      resultValue =
   300          rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
   301    }
   302  
   303    // Clamp.
   304    auto clampMinMax = info.getClampMinMax(intermediateElementType);
   305    resultValue = rewriter.create<ClampISOp>(
   306        info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
   307  
   308    // Convert back to original type.
   309    resultValue = rewriter.create<ConvertISOp>(
   310        info.op->getLoc(), info.resultStorageType, resultValue);
   311  
   312    // Cast back for new result.
   313    rewriter.replaceOpWithNewOp<StorageCastOp>(
   314        info.op, info.getQuantizedResultType(), resultValue);
   315  
   316    return success();
   317  }
   318  
   319  namespace {
   320  
   321  struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
   322    using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
   323  
   324    PatternMatchResult matchAndRewrite(RealAddEwOp op,
   325                                       PatternRewriter &rewriter) const {
   326      const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
   327                                     op.clamp_max());
   328      if (!info.isValid()) {
   329        return matchFailure();
   330      }
   331  
   332      // Try all of the permutations we support.
   333      if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
   334        return matchSuccess();
   335      }
   336  
   337      return matchFailure();
   338    }
   339  };
   340  
   341  struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
   342    using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
   343  
   344    PatternMatchResult matchAndRewrite(RealMulEwOp op,
   345                                       PatternRewriter &rewriter) const {
   346      const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
   347                                     op.clamp_max());
   348      if (!info.isValid()) {
   349        return matchFailure();
   350      }
   351  
   352      // Try all of the permutations we support.
   353      if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
   354        return matchSuccess();
   355      }
   356  
   357      return matchFailure();
   358    }
   359  };
   360  
   361  } // end anonymous namespace
   362  
   363  //===----------------------------------------------------------------------===//
   364  // LowerUniformRealMath pass
   365  //===----------------------------------------------------------------------===//
   366  
   367  void LowerUniformRealMathPass::runOnFunction() {
   368    auto fn = getFunction();
   369    OwningRewritePatternList patterns;
   370    auto *context = &getContext();
   371    patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
   372    applyPatternsGreedily(fn, patterns);
   373  }
   374  
   375  FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
   376    return new LowerUniformRealMathPass();
   377  }
   378  
   379  static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
   380      "fxpmath-lower-uniform-real-math",
   381      "Lowers uniform-quantized real math ops to integer arithmetic.");
   382  
   383  //===----------------------------------------------------------------------===//
   384  // LowerUniformCasts pass
   385  //===----------------------------------------------------------------------===//
   386  
   387  void LowerUniformCastsPass::runOnFunction() {
   388    auto fn = getFunction();
   389    OwningRewritePatternList patterns;
   390    auto *context = &getContext();
   391    patterns.insert<UniformDequantizePattern>(context);
   392    applyPatternsGreedily(fn, patterns);
   393  }
   394  
   395  FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
   396    return new LowerUniformCastsPass();
   397  }
   398  
   399  static PassRegistration<LowerUniformCastsPass>
   400      lowerUniformCastsPass("fxpmath-lower-uniform-casts",
   401                            "Lowers uniform-quantized casts.");