github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp (about) 1 //===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" 19 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/Quantizer/Support/Configuration.h" 22 #include "llvm/Support/raw_ostream.h" 23 24 using namespace mlir; 25 using namespace mlir::quantizer; 26 27 void CAGNode::replaceIncoming(CAGNode *otherNode) { 28 if (this == otherNode) 29 return; 30 for (CAGNode *parentNode : incoming) { 31 for (CAGNode *&it : parentNode->outgoing) { 32 if (it == this) { 33 it = otherNode; 34 otherNode->incoming.push_back(parentNode); 35 } 36 } 37 } 38 incoming.clear(); 39 } 40 41 void CAGNode::addOutgoing(CAGNode *toNode) { 42 if (!llvm::is_contained(outgoing, toNode)) { 43 outgoing.push_back(toNode); 44 toNode->incoming.push_back(this); 45 } 46 } 47 48 CAGOperandAnchor::CAGOperandAnchor(Operation *op, unsigned operandIdx) 49 : CAGAnchorNode(Kind::OperandAnchor, op->getOperand(operandIdx)->getType()), 50 op(op), operandIdx(operandIdx) {} 51 52 CAGResultAnchor::CAGResultAnchor(Operation *op, unsigned resultIdx) 53 : CAGAnchorNode(Kind::ResultAnchor, op->getResult(resultIdx)->getType()), 54 resultValue(op->getResult(resultIdx)) {} 55 56 CAGSlice::CAGSlice(SolverContext &context) : context(context) {} 57 CAGSlice::~CAGSlice() { llvm::DeleteContainerPointers(allNodes); } 58 59 CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op, 60 unsigned operandIdx) { 61 assert(operandIdx < op->getNumOperands() && "illegal operand index"); 62 63 // Dedup. 64 auto key = std::make_pair(op, operandIdx); 65 auto foundIt = operandAnchors.find(key); 66 if (foundIt != operandAnchors.end()) { 67 return foundIt->second; 68 } 69 70 // Create. 71 auto anchor = std::make_unique<CAGOperandAnchor>(op, operandIdx); 72 auto *unowned = anchor.release(); 73 unowned->nodeId = allNodes.size(); 74 allNodes.push_back(unowned); 75 operandAnchors.insert(std::make_pair(key, unowned)); 76 return unowned; 77 } 78 79 CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) { 80 assert(resultIdx < op->getNumResults() && "illegal result index"); 81 82 // Dedup. 83 auto key = std::make_pair(op, resultIdx); 84 auto foundIt = resultAnchors.find(key); 85 if (foundIt != resultAnchors.end()) { 86 return foundIt->second; 87 } 88 89 // Create. 90 auto anchor = std::make_unique<CAGResultAnchor>(op, resultIdx); 91 auto *unowned = anchor.release(); 92 unowned->nodeId = allNodes.size(); 93 allNodes.push_back(unowned); 94 resultAnchors.insert(std::make_pair(key, unowned)); 95 return unowned; 96 } 97 98 void CAGSlice::enumerateImpliedConnections( 99 std::function<void(CAGAnchorNode *from, CAGAnchorNode *to)> callback) { 100 // Discover peer identity pairs (i.e. implied edges from Result->Operand and 101 // Arg->Call). Use an intermediate vector so that the callback can modify. 102 std::vector<std::pair<CAGAnchorNode *, CAGAnchorNode *>> impliedPairs; 103 for (auto &resultAnchorPair : resultAnchors) { 104 CAGResultAnchor *resultAnchor = resultAnchorPair.second; 105 Value *resultValue = resultAnchor->getValue(); 106 for (auto &use : resultValue->getUses()) { 107 Operation *operandOp = use.getOwner(); 108 unsigned operandIdx = use.getOperandNumber(); 109 auto foundIt = operandAnchors.find(std::make_pair(operandOp, operandIdx)); 110 if (foundIt != operandAnchors.end()) { 111 impliedPairs.push_back(std::make_pair(resultAnchor, foundIt->second)); 112 } 113 } 114 } 115 116 // Callback for each pair. 117 for (auto &impliedPair : impliedPairs) { 118 callback(impliedPair.first, impliedPair.second); 119 } 120 } 121 122 unsigned CAGSlice::propagate(const TargetConfiguration &config) { 123 std::vector<CAGNode *> dirtyNodes; 124 dirtyNodes.reserve(allNodes.size()); 125 // Note that because iteration happens in nodeId order, there is no need 126 // to sort in order to make deterministic. If the selection method changes, 127 // a sort should be explicitly done. 128 for (CAGNode *child : *this) { 129 if (child->isDirty()) { 130 dirtyNodes.push_back(child); 131 } 132 } 133 134 if (dirtyNodes.empty()) { 135 return 0; 136 } 137 for (auto dirtyNode : dirtyNodes) { 138 dirtyNode->clearDirty(); 139 dirtyNode->propagate(context, config); 140 } 141 142 return dirtyNodes.size(); 143 } 144 145 void CAGAnchorNode::propagate(SolverContext &solverContext, 146 const TargetConfiguration &config) { 147 for (CAGNode *child : *this) { 148 child->markDirty(); 149 } 150 } 151 152 Type CAGAnchorNode::getTransformedType() { 153 if (!getUniformMetadata().selectedType) { 154 return nullptr; 155 } 156 return getUniformMetadata().selectedType.castFromExpressedType( 157 getOriginalType()); 158 } 159 160 void CAGNode::printLabel(llvm::raw_ostream &os) const { 161 os << "Node<" << static_cast<const void *>(this) << ">"; 162 } 163 164 void CAGAnchorNode::printLabel(llvm::raw_ostream &os) const { 165 getUniformMetadata().printSummary(os); 166 } 167 168 void CAGOperandAnchor::printLabel(llvm::raw_ostream &os) const { 169 os << "Operand<"; 170 op->getName().print(os); 171 os << "," << operandIdx; 172 os << ">"; 173 CAGAnchorNode::printLabel(os); 174 } 175 176 void CAGResultAnchor::printLabel(llvm::raw_ostream &os) const { 177 os << "Result<"; 178 getOp()->getName().print(os); 179 os << ">"; 180 CAGAnchorNode::printLabel(os); 181 }