github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp (about) 1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/QuantOps/QuantOps.h" 19 #include "TypeDetail.h" 20 21 #include "mlir/Dialect/QuantOps/QuantTypes.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/IR/StandardTypes.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/ADT/Twine.h" 28 #include "llvm/Support/MathExtras.h" 29 30 using namespace mlir; 31 using namespace mlir::quant; 32 using namespace mlir::quant::detail; 33 34 #define GET_OP_CLASSES 35 #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" 36 37 namespace { 38 39 /// Matches x -> [scast -> scast] -> y, replacing the second scast with the 40 /// value of x if the casts invert each other. 41 class RemoveRedundantStorageCastsRewrite 42 : public OpRewritePattern<StorageCastOp> { 43 public: 44 using OpRewritePattern<StorageCastOp>::OpRewritePattern; 45 46 PatternMatchResult matchAndRewrite(StorageCastOp op, 47 PatternRewriter &rewriter) const override { 48 if (!matchPattern(op.arg(), m_Op<StorageCastOp>())) 49 return matchFailure(); 50 auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp()); 51 if (srcScastOp.arg()->getType() != op.getType()) 52 return matchFailure(); 53 54 rewriter.replaceOp(op, srcScastOp.arg()); 55 return matchSuccess(); 56 } 57 }; 58 59 } // end anonymous namespace 60 61 void StorageCastOp::getCanonicalizationPatterns( 62 OwningRewritePatternList &patterns, MLIRContext *context) { 63 patterns.insert<RemoveRedundantStorageCastsRewrite>(context); 64 } 65 66 QuantizationDialect::QuantizationDialect(MLIRContext *context) 67 : Dialect(/*name=*/"quant", context) { 68 addTypes<AnyQuantizedType, UniformQuantizedType, 69 UniformQuantizedPerAxisType>(); 70 addOperations< 71 #define GET_OP_LIST 72 #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" 73 >(); 74 }