github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp (about)

     1  //===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file defines a testing pass to add default statistics nodes to every
    19  // quantization eligible op. Useful for unit testing.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Dialect/QuantOps/QuantOps.h"
    24  #include "mlir/Dialect/QuantOps/QuantTypes.h"
    25  #include "mlir/IR/Attributes.h"
    26  #include "mlir/IR/Builders.h"
    27  #include "mlir/Quantizer/Configurations/FxpMathConfig.h"
    28  #include "mlir/Quantizer/Support/Configuration.h"
    29  #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
    30  #include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
    31  #include "mlir/Quantizer/Transforms/Passes.h"
    32  #include "mlir/Support/LogicalResult.h"
    33  #include "llvm/Support/GraphWriter.h"
    34  #include "llvm/Support/raw_ostream.h"
    35  
    36  using namespace mlir;
    37  using namespace mlir::quantizer;
    38  using namespace mlir::quant;
    39  
    40  namespace {
    41  
    42  class AddDefaultStatsPass : public FunctionPass<AddDefaultStatsPass> {
    43  public:
    44    AddDefaultStatsPass() = default;
    45    AddDefaultStatsPass(SolverContext &solverContext,
    46                        const TargetConfiguration &config)
    47        : explicitSolverContext(&solverContext), explicitConfig(&config) {}
    48  
    49    void runOnFunction() override;
    50    void runWithConfig(SolverContext &solverContext,
    51                       const TargetConfiguration &config);
    52  
    53  private:
    54    SolverContext *explicitSolverContext = nullptr;
    55    const TargetConfiguration *explicitConfig = nullptr;
    56  };
    57  
    58  } // end anonymous namespace
    59  
    60  void AddDefaultStatsPass::runOnFunction() {
    61    if (explicitSolverContext && explicitConfig) {
    62      // If explicitly constructed with a config and context.
    63      runWithConfig(*explicitSolverContext, *explicitConfig);
    64      return;
    65    }
    66    // For global pass registration, use defaults.
    67    SolverContext solverContext(*getFunction().getContext());
    68    auto config = FxpMathTargetConfig::create(solverContext);
    69    runWithConfig(solverContext, *config);
    70  }
    71  
    72  void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
    73                                          const TargetConfiguration &config) {
    74    auto func = getFunction();
    75  
    76    // Insert stats for each argument.
    77    for (auto *arg : func.getArguments()) {
    78      if (!config.isHandledType(arg->getType()))
    79        continue;
    80      OpBuilder b(func.getBody());
    81      APFloat minValue(-1.0f);
    82      APFloat maxValue(1.0f);
    83      ElementsAttr layerStats = DenseFPElementsAttr::get(
    84          b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
    85      auto statsOp =
    86          b.create<StatisticsOp>(func.getLoc(), arg, layerStats, nullptr);
    87      arg->replaceAllUsesWith(statsOp);
    88  
    89      // StatsOp contained a use to 'arg' so make sure to reset it after replacing
    90      // all of the uses of 'arg'.
    91      statsOp.getOperation()->replaceUsesOfWith(statsOp, arg);
    92    }
    93  
    94    // Walk the ops and insert stats.
    95    func.walk([&](Operation *op) {
    96      if (!config.isRequireStatsOp(op)) {
    97        return;
    98      }
    99      assert(op->getNumResults() == 1);
   100  
   101      auto originalResult = op->getResult(0);
   102      if (!config.isHandledType(originalResult->getType()))
   103        return;
   104  
   105      OpBuilder b(op->getBlock(), ++op->getIterator());
   106  
   107      APFloat minValue(-1.0f);
   108      APFloat maxValue(1.0f);
   109      ElementsAttr layerStats = DenseFPElementsAttr::get(
   110          b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
   111      auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
   112                                            layerStats, nullptr);
   113      originalResult->replaceAllUsesWith(statsOp);
   114  
   115      // StatsOp contained a use to 'op' so make sure to reset it after replacing
   116      // all of the uses of 'op'.
   117      statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult);
   118    });
   119  }
   120  
   121  std::unique_ptr<FunctionPassBase> mlir::quantizer::createAddDefaultStatsPass() {
   122    return std::make_unique<AddDefaultStatsPass>();
   123  }
   124  
   125  static PassRegistration<AddDefaultStatsPass> pass(
   126      "quantizer-add-default-stats-test",
   127      "Adds default (dummy) statistics to all ops that can benefit from "
   128      "runtime statistics. This is meant to help in early stage bootstrapping.");