github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp (about) 1 //===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines a testing pass to add default statistics nodes to every 19 // quantization eligible op. Useful for unit testing. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Dialect/QuantOps/QuantOps.h" 24 #include "mlir/Dialect/QuantOps/QuantTypes.h" 25 #include "mlir/IR/Attributes.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/Quantizer/Configurations/FxpMathConfig.h" 28 #include "mlir/Quantizer/Support/Configuration.h" 29 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" 30 #include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h" 31 #include "mlir/Quantizer/Transforms/Passes.h" 32 #include "mlir/Support/LogicalResult.h" 33 #include "llvm/Support/GraphWriter.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 using namespace mlir; 37 using namespace mlir::quantizer; 38 using namespace mlir::quant; 39 40 namespace { 41 42 class AddDefaultStatsPass : public FunctionPass<AddDefaultStatsPass> { 43 public: 44 AddDefaultStatsPass() = default; 45 AddDefaultStatsPass(SolverContext &solverContext, 46 const TargetConfiguration &config) 47 : explicitSolverContext(&solverContext), explicitConfig(&config) {} 48 49 void runOnFunction() override; 50 void runWithConfig(SolverContext &solverContext, 51 const TargetConfiguration &config); 52 53 private: 54 SolverContext *explicitSolverContext = nullptr; 55 const TargetConfiguration *explicitConfig = nullptr; 56 }; 57 58 } // end anonymous namespace 59 60 void AddDefaultStatsPass::runOnFunction() { 61 if (explicitSolverContext && explicitConfig) { 62 // If explicitly constructed with a config and context. 63 runWithConfig(*explicitSolverContext, *explicitConfig); 64 return; 65 } 66 // For global pass registration, use defaults. 67 SolverContext solverContext(*getFunction().getContext()); 68 auto config = FxpMathTargetConfig::create(solverContext); 69 runWithConfig(solverContext, *config); 70 } 71 72 void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, 73 const TargetConfiguration &config) { 74 auto func = getFunction(); 75 76 // Insert stats for each argument. 77 for (auto *arg : func.getArguments()) { 78 if (!config.isHandledType(arg->getType())) 79 continue; 80 OpBuilder b(func.getBody()); 81 APFloat minValue(-1.0f); 82 APFloat maxValue(1.0f); 83 ElementsAttr layerStats = DenseFPElementsAttr::get( 84 b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); 85 auto statsOp = 86 b.create<StatisticsOp>(func.getLoc(), arg, layerStats, nullptr); 87 arg->replaceAllUsesWith(statsOp); 88 89 // StatsOp contained a use to 'arg' so make sure to reset it after replacing 90 // all of the uses of 'arg'. 91 statsOp.getOperation()->replaceUsesOfWith(statsOp, arg); 92 } 93 94 // Walk the ops and insert stats. 95 func.walk([&](Operation *op) { 96 if (!config.isRequireStatsOp(op)) { 97 return; 98 } 99 assert(op->getNumResults() == 1); 100 101 auto originalResult = op->getResult(0); 102 if (!config.isHandledType(originalResult->getType())) 103 return; 104 105 OpBuilder b(op->getBlock(), ++op->getIterator()); 106 107 APFloat minValue(-1.0f); 108 APFloat maxValue(1.0f); 109 ElementsAttr layerStats = DenseFPElementsAttr::get( 110 b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); 111 auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0), 112 layerStats, nullptr); 113 originalResult->replaceAllUsesWith(statsOp); 114 115 // StatsOp contained a use to 'op' so make sure to reset it after replacing 116 // all of the uses of 'op'. 117 statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult); 118 }); 119 } 120 121 std::unique_ptr<FunctionPassBase> mlir::quantizer::createAddDefaultStatsPass() { 122 return std::make_unique<AddDefaultStatsPass>(); 123 } 124 125 static PassRegistration<AddDefaultStatsPass> pass( 126 "quantizer-add-default-stats-test", 127 "Adds default (dummy) statistics to all ops that can benefit from " 128 "runtime statistics. This is meant to help in early stage bootstrapping.");