github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Support/Statistics.cpp (about)

     1  //===- Statistics.cpp - Collects statistics over tensors ------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Quantizer/Support/Statistics.h"
    19  
    20  #include "mlir/IR/Attributes.h"
    21  #include "mlir/IR/StandardTypes.h"
    22  #include "llvm/Support/raw_ostream.h"
    23  
    24  using namespace mlir;
    25  using namespace mlir::quantizer;
    26  
    27  //===----------------------------------------------------------------------===//
    28  // AttributeTensorStatistics implementation
    29  //===----------------------------------------------------------------------===//
    30  
    31  static void
    32  collectElementsStatisticsDim(ElementsAttr attr, unsigned numElements,
    33                               ArrayRef<int64_t> shape,
    34                               llvm::SmallVectorImpl<uint64_t> &indices,
    35                               uint64_t dim, TensorAxisStatistics &statistics) {
    36    // Recursive terminating condition.
    37    if (dim >= shape.size())
    38      return;
    39  
    40    if (dim < (shape.size() - 1)) {
    41      // Recurse past dim.
    42      for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
    43        indices[dim] = i;
    44        collectElementsStatisticsDim(attr, numElements, shape, indices, dim + 1,
    45                                     statistics);
    46      }
    47      return;
    48    }
    49  
    50    // Collection dim.
    51    for (uint64_t i = 0, s = shape[dim]; i < s; ++i) {
    52      indices[dim] = i;
    53      double value = attr.getValue<FloatAttr>(indices).getValueAsDouble();
    54      statistics.minValue = std::min(statistics.minValue, value);
    55      statistics.maxValue = std::max(statistics.maxValue, value);
    56      statistics.mean += value / numElements;
    57      // TODO: Calculate a running variance.
    58    }
    59  }
    60  
    61  static bool getElementsStatistics(ElementsAttr attr,
    62                                    TensorAxisStatistics &statistics) {
    63    statistics.clear();
    64    statistics.minValue = std::numeric_limits<double>::infinity();
    65    statistics.maxValue = -std::numeric_limits<double>::infinity();
    66  
    67    ShapedType sType = attr.getType();
    68    if (!sType.hasStaticShape())
    69      return false;
    70    Type elementTy = sType.getElementType();
    71    if (!elementTy.isa<FloatType>())
    72      return false;
    73  
    74    llvm::SmallVector<uint64_t, 4> indices;
    75    indices.resize(sType.getRank());
    76    ArrayRef<int64_t> shape = sType.getShape();
    77  
    78    auto numElements = sType.getNumElements();
    79    collectElementsStatisticsDim(attr, numElements, shape, indices, 0,
    80                                 statistics);
    81    statistics.sampleSize = numElements;
    82  
    83    return true;
    84  }
    85  
    86  bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const {
    87    if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
    88      double value = floatAttr.getValueAsDouble();
    89      stats = TensorAxisStatistics(1, value, value, value, 0);
    90      return true;
    91    } else if (auto eltAttr = attr.dyn_cast<ElementsAttr>()) {
    92      return getElementsStatistics(eltAttr, stats);
    93    }
    94    return false;
    95  }
    96  
    97  namespace mlir {
    98  namespace quantizer {
    99  
   100  llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
   101                                const TensorAxisStatistics &stats) {
   102    os << "STATS[sampleSize=" << stats.sampleSize << ", min=" << stats.minValue
   103       << ", maxValue=" << stats.maxValue << ", mean=" << stats.mean
   104       << ", variance=" << stats.variance << "]";
   105    return os;
   106  }
   107  
   108  } // end namespace quantizer
   109  } // end namespace mlir