github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp (about)

     1  //===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/QuantOps/Passes.h"
    19  #include "mlir/Dialect/QuantOps/QuantOps.h"
    20  #include "mlir/Dialect/QuantOps/QuantizeUtils.h"
    21  #include "mlir/Dialect/QuantOps/UniformSupport.h"
    22  #include "mlir/Dialect/StandardOps/Ops.h"
    23  #include "mlir/IR/Attributes.h"
    24  #include "mlir/IR/Matchers.h"
    25  #include "mlir/IR/PatternMatch.h"
    26  #include "mlir/IR/StandardTypes.h"
    27  #include "mlir/Pass/Pass.h"
    28  
    29  using namespace mlir;
    30  using namespace mlir::quant;
    31  
    32  namespace {
    33  
    34  class ConvertConstPass : public FunctionPass<ConvertConstPass> {
    35  public:
    36    void runOnFunction() override;
    37  };
    38  
    39  struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
    40    using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
    41  
    42    PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
    43                                       PatternRewriter &rewriter) const override;
    44  };
    45  
    46  } // end anonymous namespace
    47  
    48  /// Matches a [constant] -> [qbarrier] where the qbarrier results type is
    49  /// quantized and the operand type is quantizable.
    50  
    51  PatternMatchResult
    52  QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
    53                                         PatternRewriter &rewriter) const {
    54    Attribute value;
    55  
    56    // Is the operand a constant?
    57    if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
    58      return matchFailure();
    59    }
    60  
    61    // Does the qbarrier convert to a quantized type. This will not be true
    62    // if a quantized type has not yet been chosen or if the cast to an equivalent
    63    // storage type is not supported.
    64    Type qbarrierResultType = qbarrier.getResult()->getType();
    65    QuantizedType quantizedElementType =
    66        QuantizedType::getQuantizedElementType(qbarrierResultType);
    67    if (!quantizedElementType) {
    68      return matchFailure();
    69    }
    70    if (!QuantizedType::castToStorageType(qbarrierResultType)) {
    71      return matchFailure();
    72    }
    73  
    74    // Is the operand type compatible with the expressed type of the quantized
    75    // type? This will not be true if the qbarrier is superfluous (converts
    76    // from and to a quantized type).
    77    if (!quantizedElementType.isCompatibleExpressedType(
    78            qbarrier.arg()->getType())) {
    79      return matchFailure();
    80    }
    81  
    82    // Is the constant value a type expressed in a way that we support?
    83    if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
    84        !value.isa<SparseElementsAttr>()) {
    85      return matchFailure();
    86    }
    87  
    88    Type newConstValueType;
    89    auto newConstValue =
    90        quantizeAttr(value, quantizedElementType, newConstValueType);
    91    if (!newConstValue) {
    92      return matchFailure();
    93    }
    94  
    95    // When creating the new const op, use a fused location that combines the
    96    // original const and the qbarrier that led to the quantization.
    97    auto fusedLoc = FusedLoc::get(
    98        {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
    99        rewriter.getContext());
   100    auto newConstOp =
   101        rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
   102    rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
   103                                               qbarrier.getType(), newConstOp);
   104    return matchSuccess();
   105  }
   106  
   107  void ConvertConstPass::runOnFunction() {
   108    OwningRewritePatternList patterns;
   109    auto func = getFunction();
   110    auto *context = &getContext();
   111    patterns.insert<QuantizedConstRewrite>(context);
   112    applyPatternsGreedily(func, patterns);
   113  }
   114  
   115  std::unique_ptr<FunctionPassBase> mlir::quant::createConvertConstPass() {
   116    return std::make_unique<ConvertConstPass>();
   117  }
   118  
   119  static PassRegistration<ConvertConstPass>
   120      pass("quant-convert-const",
   121           "Converts constants followed by qbarrier to actual quantized values");