github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/TableGen/Predicate.cpp (about) 1 //===- Predicate.cpp - Predicate class ------------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // Wrapper around predicates defined in TableGen. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Predicate.h" 23 #include "llvm/ADT/SetVector.h" 24 #include "llvm/ADT/SmallPtrSet.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/TableGen/Error.h" 28 #include "llvm/TableGen/Record.h" 29 30 using namespace mlir; 31 32 // Construct a Predicate from a record. 33 tblgen::Pred::Pred(const llvm::Record *record) : def(record) { 34 assert(def->isSubClassOf("Pred") && 35 "must be a subclass of TableGen 'Pred' class"); 36 } 37 38 // Construct a Predicate from an initializer. 39 tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) { 40 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init)) 41 def = defInit->getDef(); 42 } 43 44 std::string tblgen::Pred::getCondition() const { 45 // Static dispatch to subclasses. 46 if (def->isSubClassOf("CombinedPred")) 47 return static_cast<const CombinedPred *>(this)->getConditionImpl(); 48 if (def->isSubClassOf("CPred")) 49 return static_cast<const CPred *>(this)->getConditionImpl(); 50 llvm_unreachable("Pred::getCondition must be overridden in subclasses"); 51 } 52 53 bool tblgen::Pred::isCombined() const { 54 return def && def->isSubClassOf("CombinedPred"); 55 } 56 57 ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); } 58 59 tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) { 60 assert(def->isSubClassOf("CPred") && 61 "must be a subclass of Tablegen 'CPred' class"); 62 } 63 64 tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) { 65 assert((!def || def->isSubClassOf("CPred")) && 66 "must be a subclass of Tablegen 'CPred' class"); 67 } 68 69 // Get condition of the C Predicate. 70 std::string tblgen::CPred::getConditionImpl() const { 71 assert(!isNull() && "null predicate does not have a condition"); 72 return def->getValueAsString("predExpr"); 73 } 74 75 tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { 76 assert(def->isSubClassOf("CombinedPred") && 77 "must be a subclass of Tablegen 'CombinedPred' class"); 78 } 79 80 tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { 81 assert((!def || def->isSubClassOf("CombinedPred")) && 82 "must be a subclass of Tablegen 'CombinedPred' class"); 83 } 84 85 const llvm::Record *tblgen::CombinedPred::getCombinerDef() const { 86 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); 87 return def->getValueAsDef("kind"); 88 } 89 90 const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const { 91 assert(def->getValue("children") && 92 "CombinedPred must have a value 'children'"); 93 return def->getValueAsListOfDefs("children"); 94 } 95 96 namespace { 97 // Kinds of nodes in a logical predicate tree. 98 enum class PredCombinerKind { 99 Leaf, 100 And, 101 Or, 102 Not, 103 SubstLeaves, 104 Concat, 105 // Special kinds that are used in simplification. 106 False, 107 True 108 }; 109 110 // A node in a logical predicate tree. 111 struct PredNode { 112 PredCombinerKind kind; 113 const tblgen::Pred *predicate; 114 SmallVector<PredNode *, 4> children; 115 std::string expr; 116 117 // Prefix and suffix are used by ConcatPred. 118 std::string prefix; 119 std::string suffix; 120 }; 121 } // end anonymous namespace 122 123 // Get a predicate tree node kind based on the kind used in the predicate 124 // TableGen record. 125 static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) { 126 if (!pred.isCombined()) 127 return PredCombinerKind::Leaf; 128 129 const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred); 130 return llvm::StringSwitch<PredCombinerKind>( 131 combinedPred.getCombinerDef()->getName()) 132 .Case("PredCombinerAnd", PredCombinerKind::And) 133 .Case("PredCombinerOr", PredCombinerKind::Or) 134 .Case("PredCombinerNot", PredCombinerKind::Not) 135 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves) 136 .Case("PredCombinerConcat", PredCombinerKind::Concat); 137 } 138 139 namespace { 140 // Substitution<pattern, replacement>. 141 using Subst = std::pair<StringRef, StringRef>; 142 } // end anonymous namespace 143 144 // Build the predicate tree starting from the top-level predicate, which may 145 // have children, and perform leaf substitutions inplace. Note that after 146 // substitution, nodes are still pointing to the original TableGen record. 147 // All nodes are created within "allocator". 148 static PredNode *buildPredicateTree(const tblgen::Pred &root, 149 llvm::BumpPtrAllocator &allocator, 150 ArrayRef<Subst> substitutions) { 151 auto *rootNode = allocator.Allocate<PredNode>(); 152 new (rootNode) PredNode; 153 rootNode->kind = getPredCombinerKind(root); 154 rootNode->predicate = &root; 155 if (!root.isCombined()) { 156 rootNode->expr = root.getCondition(); 157 // Apply all parent substitutions from innermost to outermost. 158 for (const auto &subst : llvm::reverse(substitutions)) { 159 auto pos = rootNode->expr.find(subst.first); 160 while (pos != std::string::npos) { 161 rootNode->expr.replace(pos, subst.first.size(), subst.second); 162 // Skip the newly inserted substring, which itself may consider the 163 // pattern to match. 164 pos += subst.second.size(); 165 // Find the next possible match position. 166 pos = rootNode->expr.find(subst.first, pos); 167 } 168 } 169 return rootNode; 170 } 171 172 // If the current combined predicate is a leaf substitution, append it to the 173 // list before contiuing. 174 auto allSubstitutions = llvm::to_vector<4>(substitutions); 175 if (rootNode->kind == PredCombinerKind::SubstLeaves) { 176 const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root); 177 allSubstitutions.push_back( 178 {substPred.getPattern(), substPred.getReplacement()}); 179 } 180 // If the current predicate is a ConcatPred, record the prefix and suffix. 181 else if (rootNode->kind == PredCombinerKind::Concat) { 182 const auto &concatPred = static_cast<const tblgen::ConcatPred &>(root); 183 rootNode->prefix = concatPred.getPrefix(); 184 rootNode->suffix = concatPred.getSuffix(); 185 } 186 187 // Build child subtrees. 188 auto combined = static_cast<const tblgen::CombinedPred &>(root); 189 for (const auto *record : combined.getChildren()) { 190 auto childTree = 191 buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions); 192 rootNode->children.push_back(childTree); 193 } 194 return rootNode; 195 } 196 197 // Simplify a predicate tree rooted at "node" using the predicates that are 198 // known to be true(false). For AND(OR) combined predicates, if any of the 199 // children is known to be false(true), the result is also false(true). 200 // Furthermore, for AND(OR) combined predicates, children that are known to be 201 // true(false) don't have to be checked dynamically. 202 static PredNode *propagateGroundTruth( 203 PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds, 204 const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) { 205 // If the current predicate is known to be true or false, change the kind of 206 // the node and return immediately. 207 if (knownTruePreds.count(node->predicate) != 0) { 208 node->kind = PredCombinerKind::True; 209 node->children.clear(); 210 return node; 211 } 212 if (knownFalsePreds.count(node->predicate) != 0) { 213 node->kind = PredCombinerKind::False; 214 node->children.clear(); 215 return node; 216 } 217 218 // If the current node is a substitution, stop recursion now. 219 // The expressions in the leaves below this node were rewritten, but the nodes 220 // still point to the original predicate records. While the original 221 // predicate may be known to be true or false, it is not necessarily the case 222 // after rewriting. 223 // TODO(zinenko,jpienaar): we can support ground truth for rewritten 224 // predicates by either (a) having our own unique'ing of the predicates 225 // instead of relying on TableGen record pointers or (b) taking ground truth 226 // values optinally prefixed with a list of substitutions to apply, e.g. 227 // "predX is true by itself as well as predSubY leaf substitution had been 228 // applied to it". 229 if (node->kind == PredCombinerKind::SubstLeaves) { 230 return node; 231 } 232 233 // Otherwise, look at child nodes. 234 235 // Move child nodes into some local variable so that they can be optimized 236 // separately and re-added if necessary. 237 llvm::SmallVector<PredNode *, 4> children; 238 std::swap(node->children, children); 239 240 for (auto &child : children) { 241 // First, simplify the child. This maintains the predicate as it was. 242 auto simplifiedChild = 243 propagateGroundTruth(child, knownTruePreds, knownFalsePreds); 244 245 // Just add the child if we don't know how to simplify the current node. 246 if (node->kind != PredCombinerKind::And && 247 node->kind != PredCombinerKind::Or) { 248 node->children.push_back(simplifiedChild); 249 continue; 250 } 251 252 // Second, based on the type define which known values of child predicates 253 // immediately collapse this predicate to a known value, and which others 254 // may be safely ignored. 255 // OR(..., True, ...) = True 256 // OR(..., False, ...) = OR(..., ...) 257 // AND(..., False, ...) = False 258 // AND(..., True, ...) = AND(..., ...) 259 auto collapseKind = node->kind == PredCombinerKind::And 260 ? PredCombinerKind::False 261 : PredCombinerKind::True; 262 auto eraseKind = node->kind == PredCombinerKind::And 263 ? PredCombinerKind::True 264 : PredCombinerKind::False; 265 const auto &collapseList = 266 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; 267 const auto &eraseList = 268 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; 269 if (simplifiedChild->kind == collapseKind || 270 collapseList.count(simplifiedChild->predicate) != 0) { 271 node->kind = collapseKind; 272 node->children.clear(); 273 return node; 274 } else if (simplifiedChild->kind == eraseKind || 275 eraseList.count(simplifiedChild->predicate) != 0) { 276 continue; 277 } 278 node->children.push_back(simplifiedChild); 279 } 280 return node; 281 } 282 283 // Combine a list of predicate expressions using a binary combiner. If a list 284 // is empty, return "init". 285 static std::string combineBinary(ArrayRef<std::string> children, 286 std::string combiner, std::string init) { 287 if (children.empty()) 288 return init; 289 290 auto size = children.size(); 291 if (size == 1) 292 return children.front(); 293 294 std::string str; 295 llvm::raw_string_ostream os(str); 296 os << '(' << children.front() << ')'; 297 for (unsigned i = 1; i < size; ++i) { 298 os << ' ' << combiner << " (" << children[i] << ')'; 299 } 300 return os.str(); 301 } 302 303 // Prepend negation to the only condition in the predicate expression list. 304 static std::string combineNot(ArrayRef<std::string> children) { 305 assert(children.size() == 1 && "expected exactly one child predicate of Neg"); 306 return (Twine("!(") + children.front() + Twine(')')).str(); 307 } 308 309 // Recursively traverse the predicate tree in depth-first post-order and build 310 // the final expression. 311 static std::string getCombinedCondition(const PredNode &root) { 312 // Immediately return for non-combiner predicates that don't have children. 313 if (root.kind == PredCombinerKind::Leaf) 314 return root.expr; 315 if (root.kind == PredCombinerKind::True) 316 return "true"; 317 if (root.kind == PredCombinerKind::False) 318 return "false"; 319 320 // Recurse into children. 321 llvm::SmallVector<std::string, 4> childExpressions; 322 childExpressions.reserve(root.children.size()); 323 for (const auto &child : root.children) 324 childExpressions.push_back(getCombinedCondition(*child)); 325 326 // Combine the expressions based on the predicate node kind. 327 if (root.kind == PredCombinerKind::And) 328 return combineBinary(childExpressions, "&&", "true"); 329 if (root.kind == PredCombinerKind::Or) 330 return combineBinary(childExpressions, "||", "false"); 331 if (root.kind == PredCombinerKind::Not) 332 return combineNot(childExpressions); 333 if (root.kind == PredCombinerKind::Concat) { 334 assert(childExpressions.size() == 1 && 335 "ConcatPred should only have one child"); 336 return root.prefix + childExpressions.front() + root.suffix; 337 } 338 339 // Substitutions were applied before so just ignore them. 340 if (root.kind == PredCombinerKind::SubstLeaves) { 341 assert(childExpressions.size() == 1 && 342 "substitution predicate must have one child"); 343 return childExpressions[0]; 344 } 345 346 llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); 347 } 348 349 std::string tblgen::CombinedPred::getConditionImpl() const { 350 llvm::BumpPtrAllocator allocator; 351 auto predicateTree = buildPredicateTree(*this, allocator, {}); 352 predicateTree = propagateGroundTruth( 353 predicateTree, 354 /*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(), 355 /*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>()); 356 357 return getCombinedCondition(*predicateTree); 358 } 359 360 StringRef tblgen::SubstLeavesPred::getPattern() const { 361 return def->getValueAsString("pattern"); 362 } 363 364 StringRef tblgen::SubstLeavesPred::getReplacement() const { 365 return def->getValueAsString("replacement"); 366 } 367 368 StringRef tblgen::ConcatPred::getPrefix() const { 369 return def->getValueAsString("prefix"); 370 } 371 372 StringRef tblgen::ConcatPred::getSuffix() const { 373 return def->getValueAsString("suffix"); 374 }