github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp (about)

     1  //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
    19  #include "mlir/Dialect/QuantOps/Passes.h"
    20  #include "mlir/Dialect/QuantOps/QuantOps.h"
    21  #include "mlir/Dialect/QuantOps/UniformSupport.h"
    22  #include "mlir/IR/Attributes.h"
    23  #include "mlir/IR/PatternMatch.h"
    24  #include "mlir/IR/StandardTypes.h"
    25  #include "mlir/Pass/Pass.h"
    26  
    27  using namespace mlir;
    28  using namespace mlir::quant;
    29  
    30  namespace {
    31  
    32  class ConvertSimulatedQuantPass
    33      : public FunctionPass<ConvertSimulatedQuantPass> {
    34  public:
    35    void runOnFunction() override;
    36  };
    37  
    38  } // end anonymous namespace
    39  
    40  /// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
    41  class ConstFakeQuantRewrite : public RewritePattern {
    42  public:
    43    bool *hadFailure;
    44  
    45    ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
    46        : RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
    47          hadFailure(hadFailure) {}
    48  
    49    PatternMatchResult matchAndRewrite(Operation *op,
    50                                       PatternRewriter &rewriter) const override {
    51      // TODO: If this pattern comes up more frequently, consider adding core
    52      // support for failable rewrites.
    53      if (failableRewrite(op, rewriter)) {
    54        *hadFailure = true;
    55        return matchFailure();
    56      }
    57  
    58      return matchSuccess();
    59    }
    60  
    61    bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
    62      auto fqOp = cast<ConstFakeQuant>(op);
    63  
    64      auto converter =
    65          ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
    66      if (!converter) {
    67        return (op->emitError("unsupported quantized type conversion"), true);
    68      }
    69  
    70      UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
    71          fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
    72          fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
    73          fqOp.narrow_range(), converter.expressedType, fqOp.is_signed());
    74  
    75      if (!uniformElementType) {
    76        // Note that the fakeQuantAttrsToType will have emitted the error.
    77        return true;
    78      }
    79  
    80      Type quantizedType = converter.convert(uniformElementType);
    81      assert(quantizedType &&
    82             "Converter accepted a type that it did not convert");
    83  
    84      // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
    85      // this is a forced/hard-coded constraint.
    86      auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
    87                                                      fqOp.inputs());
    88      rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
    89                                                    qbarrier.getResult());
    90  
    91      return false;
    92    }
    93  };
    94  
    95  void ConvertSimulatedQuantPass::runOnFunction() {
    96    bool hadFailure = false;
    97    OwningRewritePatternList patterns;
    98    auto func = getFunction();
    99    auto *context = &getContext();
   100    patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
   101    applyPatternsGreedily(func, patterns);
   102    if (hadFailure)
   103      signalPassFailure();
   104  }
   105  
   106  std::unique_ptr<FunctionPassBase>
   107  mlir::quant::createConvertSimulatedQuantPass() {
   108    return std::make_unique<ConvertSimulatedQuantPass>();
   109  }
   110  
   111  static PassRegistration<ConvertSimulatedQuantPass>
   112      pass("quant-convert-simulated-quantization",
   113           "Converts training-time simulated quantization ops to corresponding "
   114           "quantize/dequantize casts.");