github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp (about)

     1  //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/QuantOps/QuantOps.h"
    19  #include "TypeDetail.h"
    20  
    21  #include "mlir/Dialect/QuantOps/QuantTypes.h"
    22  #include "mlir/IR/MLIRContext.h"
    23  #include "mlir/IR/Matchers.h"
    24  #include "mlir/IR/PatternMatch.h"
    25  #include "mlir/IR/StandardTypes.h"
    26  #include "llvm/ADT/StringRef.h"
    27  #include "llvm/ADT/Twine.h"
    28  #include "llvm/Support/MathExtras.h"
    29  
    30  using namespace mlir;
    31  using namespace mlir::quant;
    32  using namespace mlir::quant::detail;
    33  
    34  #define GET_OP_CLASSES
    35  #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
    36  
    37  namespace {
    38  
    39  /// Matches x -> [scast -> scast] -> y, replacing the second scast with the
    40  /// value of x if the casts invert each other.
    41  class RemoveRedundantStorageCastsRewrite
    42      : public OpRewritePattern<StorageCastOp> {
    43  public:
    44    using OpRewritePattern<StorageCastOp>::OpRewritePattern;
    45  
    46    PatternMatchResult matchAndRewrite(StorageCastOp op,
    47                                       PatternRewriter &rewriter) const override {
    48      if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
    49        return matchFailure();
    50      auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
    51      if (srcScastOp.arg()->getType() != op.getType())
    52        return matchFailure();
    53  
    54      rewriter.replaceOp(op, srcScastOp.arg());
    55      return matchSuccess();
    56    }
    57  };
    58  
    59  } // end anonymous namespace
    60  
    61  void StorageCastOp::getCanonicalizationPatterns(
    62      OwningRewritePatternList &patterns, MLIRContext *context) {
    63    patterns.insert<RemoveRedundantStorageCastsRewrite>(context);
    64  }
    65  
    66  QuantizationDialect::QuantizationDialect(MLIRContext *context)
    67      : Dialect(/*name=*/"quant", context) {
    68    addTypes<AnyQuantizedType, UniformQuantizedType,
    69             UniformQuantizedPerAxisType>();
    70    addOperations<
    71  #define GET_OP_LIST
    72  #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc"
    73        >();
    74  }