github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp (about)

     1  //===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
    19  
    20  #include "mlir/IR/MLIRContext.h"
    21  #include "mlir/Quantizer/Support/Configuration.h"
    22  #include "llvm/Support/raw_ostream.h"
    23  
    24  using namespace mlir;
    25  using namespace mlir::quantizer;
    26  
    27  void CAGNode::replaceIncoming(CAGNode *otherNode) {
    28    if (this == otherNode)
    29      return;
    30    for (CAGNode *parentNode : incoming) {
    31      for (CAGNode *&it : parentNode->outgoing) {
    32        if (it == this) {
    33          it = otherNode;
    34          otherNode->incoming.push_back(parentNode);
    35        }
    36      }
    37    }
    38    incoming.clear();
    39  }
    40  
    41  void CAGNode::addOutgoing(CAGNode *toNode) {
    42    if (!llvm::is_contained(outgoing, toNode)) {
    43      outgoing.push_back(toNode);
    44      toNode->incoming.push_back(this);
    45    }
    46  }
    47  
    48  CAGOperandAnchor::CAGOperandAnchor(Operation *op, unsigned operandIdx)
    49      : CAGAnchorNode(Kind::OperandAnchor, op->getOperand(operandIdx)->getType()),
    50        op(op), operandIdx(operandIdx) {}
    51  
    52  CAGResultAnchor::CAGResultAnchor(Operation *op, unsigned resultIdx)
    53      : CAGAnchorNode(Kind::ResultAnchor, op->getResult(resultIdx)->getType()),
    54        resultValue(op->getResult(resultIdx)) {}
    55  
    56  CAGSlice::CAGSlice(SolverContext &context) : context(context) {}
    57  CAGSlice::~CAGSlice() { llvm::DeleteContainerPointers(allNodes); }
    58  
    59  CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op,
    60                                               unsigned operandIdx) {
    61    assert(operandIdx < op->getNumOperands() && "illegal operand index");
    62  
    63    // Dedup.
    64    auto key = std::make_pair(op, operandIdx);
    65    auto foundIt = operandAnchors.find(key);
    66    if (foundIt != operandAnchors.end()) {
    67      return foundIt->second;
    68    }
    69  
    70    // Create.
    71    auto anchor = std::make_unique<CAGOperandAnchor>(op, operandIdx);
    72    auto *unowned = anchor.release();
    73    unowned->nodeId = allNodes.size();
    74    allNodes.push_back(unowned);
    75    operandAnchors.insert(std::make_pair(key, unowned));
    76    return unowned;
    77  }
    78  
    79  CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) {
    80    assert(resultIdx < op->getNumResults() && "illegal result index");
    81  
    82    // Dedup.
    83    auto key = std::make_pair(op, resultIdx);
    84    auto foundIt = resultAnchors.find(key);
    85    if (foundIt != resultAnchors.end()) {
    86      return foundIt->second;
    87    }
    88  
    89    // Create.
    90    auto anchor = std::make_unique<CAGResultAnchor>(op, resultIdx);
    91    auto *unowned = anchor.release();
    92    unowned->nodeId = allNodes.size();
    93    allNodes.push_back(unowned);
    94    resultAnchors.insert(std::make_pair(key, unowned));
    95    return unowned;
    96  }
    97  
    98  void CAGSlice::enumerateImpliedConnections(
    99      std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback) {
   100    // Discover peer identity pairs (i.e. implied edges from Result->Operand and
   101    // Arg->Call). Use an intermediate vector so that the callback can modify.
   102    std::vector<std::pair<CAGAnchorNode *, CAGAnchorNode *>> impliedPairs;
   103    for (auto &resultAnchorPair : resultAnchors) {
   104      CAGResultAnchor *resultAnchor = resultAnchorPair.second;
   105      Value *resultValue = resultAnchor->getValue();
   106      for (auto &use : resultValue->getUses()) {
   107        Operation *operandOp = use.getOwner();
   108        unsigned operandIdx = use.getOperandNumber();
   109        auto foundIt = operandAnchors.find(std::make_pair(operandOp, operandIdx));
   110        if (foundIt != operandAnchors.end()) {
   111          impliedPairs.push_back(std::make_pair(resultAnchor, foundIt->second));
   112        }
   113      }
   114    }
   115  
   116    // Callback for each pair.
   117    for (auto &impliedPair : impliedPairs) {
   118      callback(impliedPair.first, impliedPair.second);
   119    }
   120  }
   121  
   122  unsigned CAGSlice::propagate(const TargetConfiguration &config) {
   123    std::vector<CAGNode *> dirtyNodes;
   124    dirtyNodes.reserve(allNodes.size());
   125    // Note that because iteration happens in nodeId order, there is no need
   126    // to sort in order to make deterministic. If the selection method changes,
   127    // a sort should be explicitly done.
   128    for (CAGNode *child : *this) {
   129      if (child->isDirty()) {
   130        dirtyNodes.push_back(child);
   131      }
   132    }
   133  
   134    if (dirtyNodes.empty()) {
   135      return 0;
   136    }
   137    for (auto dirtyNode : dirtyNodes) {
   138      dirtyNode->clearDirty();
   139      dirtyNode->propagate(context, config);
   140    }
   141  
   142    return dirtyNodes.size();
   143  }
   144  
   145  void CAGAnchorNode::propagate(SolverContext &solverContext,
   146                                const TargetConfiguration &config) {
   147    for (CAGNode *child : *this) {
   148      child->markDirty();
   149    }
   150  }
   151  
   152  Type CAGAnchorNode::getTransformedType() {
   153    if (!getUniformMetadata().selectedType) {
   154      return nullptr;
   155    }
   156    return getUniformMetadata().selectedType.castFromExpressedType(
   157        getOriginalType());
   158  }
   159  
   160  void CAGNode::printLabel(llvm::raw_ostream &os) const {
   161    os << "Node<" << static_cast<const void *>(this) << ">";
   162  }
   163  
   164  void CAGAnchorNode::printLabel(llvm::raw_ostream &os) const {
   165    getUniformMetadata().printSummary(os);
   166  }
   167  
   168  void CAGOperandAnchor::printLabel(llvm::raw_ostream &os) const {
   169    os << "Operand<";
   170    op->getName().print(os);
   171    os << "," << operandIdx;
   172    os << ">";
   173    CAGAnchorNode::printLabel(os);
   174  }
   175  
   176  void CAGResultAnchor::printLabel(llvm::raw_ostream &os) const {
   177    os << "Result<";
   178    getOp()->getName().print(os);
   179    os << ">";
   180    CAGAnchorNode::printLabel(os);
   181  }