github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/TableGen/Pattern.cpp (about) 1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 19 // Pattern. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/TableGen/Pattern.h" 24 #include "llvm/ADT/Twine.h" 25 #include "llvm/Support/FormatVariadic.h" 26 #include "llvm/TableGen/Error.h" 27 #include "llvm/TableGen/Record.h" 28 29 using namespace mlir; 30 31 using llvm::formatv; 32 using mlir::tblgen::Operator; 33 34 //===----------------------------------------------------------------------===// 35 // DagLeaf 36 //===----------------------------------------------------------------------===// 37 38 bool tblgen::DagLeaf::isUnspecified() const { 39 return dyn_cast_or_null<llvm::UnsetInit>(def); 40 } 41 42 bool tblgen::DagLeaf::isOperandMatcher() const { 43 // Operand matchers specify a type constraint. 44 return isSubClassOf("TypeConstraint"); 45 } 46 47 bool tblgen::DagLeaf::isAttrMatcher() const { 48 // Attribute matchers specify an attribute constraint. 49 return isSubClassOf("AttrConstraint"); 50 } 51 52 bool tblgen::DagLeaf::isNativeCodeCall() const { 53 return isSubClassOf("NativeCodeCall"); 54 } 55 56 bool tblgen::DagLeaf::isConstantAttr() const { 57 return isSubClassOf("ConstantAttr"); 58 } 59 60 bool tblgen::DagLeaf::isEnumAttrCase() const { 61 return isSubClassOf("EnumAttrCaseInfo"); 62 } 63 64 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { 65 assert((isOperandMatcher() || isAttrMatcher()) && 66 "the DAG leaf must be operand or attribute"); 67 return Constraint(cast<llvm::DefInit>(def)->getDef()); 68 } 69 70 tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { 71 assert(isConstantAttr() && "the DAG leaf must be constant attribute"); 72 return ConstantAttr(cast<llvm::DefInit>(def)); 73 } 74 75 tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const { 76 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); 77 return EnumAttrCase(cast<llvm::DefInit>(def)); 78 } 79 80 std::string tblgen::DagLeaf::getConditionTemplate() const { 81 return getAsConstraint().getConditionTemplate(); 82 } 83 84 llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const { 85 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 86 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression"); 87 } 88 89 bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { 90 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def)) 91 return defInit->getDef()->isSubClassOf(superclass); 92 return false; 93 } 94 95 //===----------------------------------------------------------------------===// 96 // DagNode 97 //===----------------------------------------------------------------------===// 98 99 bool tblgen::DagNode::isNativeCodeCall() const { 100 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) 101 return defInit->getDef()->isSubClassOf("NativeCodeCall"); 102 return false; 103 } 104 105 bool tblgen::DagNode::isOperation() const { 106 return !(isNativeCodeCall() || isReplaceWithValue()); 107 } 108 109 llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { 110 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); 111 return cast<llvm::DefInit>(node->getOperator()) 112 ->getDef() 113 ->getValueAsString("expression"); 114 } 115 116 llvm::StringRef tblgen::DagNode::getSymbol() const { 117 return node->getNameStr(); 118 } 119 120 Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { 121 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 122 auto it = mapper->find(opDef); 123 if (it != mapper->end()) 124 return *it->second; 125 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef)) 126 .first->second; 127 } 128 129 int tblgen::DagNode::getNumOps() const { 130 int count = isReplaceWithValue() ? 0 : 1; 131 for (int i = 0, e = getNumArgs(); i != e; ++i) { 132 if (auto child = getArgAsNestedDag(i)) 133 count += child.getNumOps(); 134 } 135 return count; 136 } 137 138 int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); } 139 140 bool tblgen::DagNode::isNestedDagArg(unsigned index) const { 141 return isa<llvm::DagInit>(node->getArg(index)); 142 } 143 144 tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { 145 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index))); 146 } 147 148 tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { 149 assert(!isNestedDagArg(index)); 150 return DagLeaf(node->getArg(index)); 151 } 152 153 StringRef tblgen::DagNode::getArgName(unsigned index) const { 154 return node->getArgNameStr(index); 155 } 156 157 bool tblgen::DagNode::isReplaceWithValue() const { 158 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef(); 159 return dagOpDef->getName() == "replaceWithValue"; 160 } 161 162 //===----------------------------------------------------------------------===// 163 // SymbolInfoMap 164 //===----------------------------------------------------------------------===// 165 166 StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, 167 int *index) { 168 StringRef name, indexStr; 169 int idx = -1; 170 std::tie(name, indexStr) = symbol.rsplit("__"); 171 172 if (indexStr.consumeInteger(10, idx)) { 173 // The second part is not an index; we return the whole symbol as-is. 174 return symbol; 175 } 176 if (index) { 177 *index = idx; 178 } 179 return name; 180 } 181 182 tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, 183 SymbolInfo::Kind kind, 184 Optional<int> index) 185 : op(op), kind(kind), argIndex(index) {} 186 187 int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { 188 switch (kind) { 189 case Kind::Attr: 190 case Kind::Operand: 191 case Kind::Value: 192 return 1; 193 case Kind::Result: 194 return op->getNumResults(); 195 } 196 llvm_unreachable("unknown kind"); 197 } 198 199 std::string 200 tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { 201 switch (kind) { 202 case Kind::Attr: { 203 auto type = 204 op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); 205 return formatv("{0} {1};\n", type, name); 206 } 207 case Kind::Operand: { 208 // Use operand range for captured operands (to support potential variadic 209 // operands). 210 return formatv("Operation::operand_range {0}(op0->getOperands());\n", name); 211 } 212 case Kind::Value: { 213 return formatv("ArrayRef<Value *> {0};\n", name); 214 } 215 case Kind::Result: { 216 // Use the op itself for captured results. 217 return formatv("{0} {1};\n", op->getQualCppClassName(), name); 218 } 219 } 220 llvm_unreachable("unknown kind"); 221 } 222 223 std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse( 224 StringRef name, int index, const char *fmt, const char *separator) const { 225 switch (kind) { 226 case Kind::Attr: { 227 assert(index < 0); 228 return formatv(fmt, name); 229 } 230 case Kind::Operand: { 231 assert(index < 0); 232 auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>(); 233 // If this operand is variadic, then return a range. Otherwise, return the 234 // value itself. 235 if (operand->isVariadic()) { 236 return formatv(fmt, name); 237 } 238 return formatv(fmt, formatv("(*{0}.begin())", name)); 239 } 240 case Kind::Result: { 241 // If `index` is greater than zero, then we are referencing a specific 242 // result of a multi-result op. The result can still be variadic. 243 if (index >= 0) { 244 std::string v = formatv("{0}.getODSResults({1})", name, index); 245 if (!op->getResult(index).isVariadic()) 246 v = formatv("(*{0}.begin())", v); 247 return formatv(fmt, v); 248 } 249 250 // We are referencing all results of the multi-result op. A specific result 251 // can either be a value or a range. Then join them with `separator`. 252 SmallVector<std::string, 4> values; 253 values.reserve(op->getNumResults()); 254 255 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 256 std::string v = formatv("{0}.getODSResults({1})", name, i); 257 if (!op->getResult(i).isVariadic()) { 258 v = formatv("(*{0}.begin())", v); 259 } 260 values.push_back(formatv(fmt, v)); 261 } 262 return llvm::join(values, separator); 263 } 264 case Kind::Value: { 265 assert(index < 0); 266 assert(op == nullptr); 267 return formatv(fmt, name); 268 } 269 } 270 } 271 272 std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse( 273 StringRef name, int index, const char *fmt, const char *separator) const { 274 switch (kind) { 275 case Kind::Attr: 276 case Kind::Operand: { 277 assert(index < 0 && "only allowed for symbol bound to result"); 278 return formatv(fmt, name); 279 } 280 case Kind::Result: { 281 if (index >= 0) { 282 return formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); 283 } 284 285 // We are referencing all results of the multi-result op. Each result should 286 // have a value range, and then join them with `separator`. 287 SmallVector<std::string, 4> values; 288 values.reserve(op->getNumResults()); 289 290 for (int i = 0, e = op->getNumResults(); i < e; ++i) { 291 values.push_back( 292 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))); 293 } 294 return llvm::join(values, separator); 295 } 296 case Kind::Value: { 297 assert(index < 0 && "only allowed for symbol bound to result"); 298 assert(op == nullptr); 299 return formatv(fmt, formatv("{{{0}}", name)); 300 } 301 } 302 llvm_unreachable("unknown kind"); 303 } 304 305 bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, 306 int argIndex) { 307 StringRef name = getValuePackName(symbol); 308 if (name != symbol) { 309 auto error = formatv( 310 "symbol '{0}' with trailing index cannot bind to op argument", symbol); 311 PrintFatalError(loc, error); 312 } 313 314 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() 315 ? SymbolInfo::getAttr(&op, argIndex) 316 : SymbolInfo::getOperand(&op, argIndex); 317 318 return symbolInfoMap.insert({symbol, symInfo}).second; 319 } 320 321 bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { 322 StringRef name = getValuePackName(symbol); 323 return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; 324 } 325 326 bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { 327 return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; 328 } 329 330 bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { 331 return find(symbol) != symbolInfoMap.end(); 332 } 333 334 tblgen::SymbolInfoMap::const_iterator 335 tblgen::SymbolInfoMap::find(StringRef key) const { 336 StringRef name = getValuePackName(key); 337 return symbolInfoMap.find(name); 338 } 339 340 int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { 341 StringRef name = getValuePackName(symbol); 342 if (name != symbol) { 343 // If there is a trailing index inside symbol, it references just one 344 // static value. 345 return 1; 346 } 347 // Otherwise, find how many it represents by querying the symbol's info. 348 return find(name)->getValue().getStaticValueCount(); 349 } 350 351 std::string 352 tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, 353 const char *separator) const { 354 int index = -1; 355 StringRef name = getValuePackName(symbol, &index); 356 357 auto it = symbolInfoMap.find(name); 358 if (it == symbolInfoMap.end()) { 359 auto error = formatv("referencing unbound symbol '{0}'", symbol); 360 PrintFatalError(loc, error); 361 } 362 363 return it->getValue().getValueAndRangeUse(name, index, fmt, separator); 364 } 365 366 std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol, 367 const char *fmt, 368 const char *separator) const { 369 int index = -1; 370 StringRef name = getValuePackName(symbol, &index); 371 372 auto it = symbolInfoMap.find(name); 373 if (it == symbolInfoMap.end()) { 374 auto error = formatv("referencing unbound symbol '{0}'", symbol); 375 PrintFatalError(loc, error); 376 } 377 378 return it->getValue().getAllRangeUse(name, index, fmt, separator); 379 } 380 381 //===----------------------------------------------------------------------===// 382 // Pattern 383 //==----------------------------------------------------------------------===// 384 385 tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) 386 : def(*def), recordOpMap(mapper) {} 387 388 tblgen::DagNode tblgen::Pattern::getSourcePattern() const { 389 return tblgen::DagNode(def.getValueAsDag("sourcePattern")); 390 } 391 392 int tblgen::Pattern::getNumResultPatterns() const { 393 auto *results = def.getValueAsListInit("resultPatterns"); 394 return results->size(); 395 } 396 397 tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { 398 auto *results = def.getValueAsListInit("resultPatterns"); 399 return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); 400 } 401 402 void tblgen::Pattern::collectSourcePatternBoundSymbols( 403 tblgen::SymbolInfoMap &infoMap) { 404 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); 405 } 406 407 void tblgen::Pattern::collectResultPatternBoundSymbols( 408 tblgen::SymbolInfoMap &infoMap) { 409 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { 410 auto pattern = getResultPattern(i); 411 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); 412 } 413 } 414 415 const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { 416 return getSourcePattern().getDialectOp(recordOpMap); 417 } 418 419 tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) { 420 return node.getDialectOp(recordOpMap); 421 } 422 423 std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const { 424 auto *listInit = def.getValueAsListInit("constraints"); 425 std::vector<tblgen::AppliedConstraint> ret; 426 ret.reserve(listInit->size()); 427 428 for (auto it : *listInit) { 429 auto *dagInit = dyn_cast<llvm::DagInit>(it); 430 if (!dagInit) 431 PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity " 432 "constraints should be DAG nodes"); 433 434 std::vector<std::string> entities; 435 entities.reserve(dagInit->arg_size()); 436 for (auto *argName : dagInit->getArgNames()) 437 entities.push_back(argName->getValue()); 438 439 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(), 440 dagInit->getNameStr(), std::move(entities)); 441 } 442 return ret; 443 } 444 445 int tblgen::Pattern::getBenefit() const { 446 // The initial benefit value is a heuristic with number of ops in the source 447 // pattern. 448 int initBenefit = getSourcePattern().getNumOps(); 449 llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); 450 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) { 451 PrintFatalError(def.getLoc(), 452 "The 'addBenefit' takes and only takes one integer value"); 453 } 454 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue(); 455 } 456 457 std::vector<tblgen::Pattern::IdentifierLine> 458 tblgen::Pattern::getLocation() const { 459 std::vector<std::pair<StringRef, unsigned>> result; 460 result.reserve(def.getLoc().size()); 461 for (auto loc : def.getLoc()) { 462 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); 463 assert(buf && "invalid source location"); 464 result.emplace_back( 465 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), 466 llvm::SrcMgr.getLineAndColumn(loc, buf).first); 467 } 468 return result; 469 } 470 471 void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 472 bool isSrcPattern) { 473 auto treeName = tree.getSymbol(); 474 if (!tree.isOperation()) { 475 if (!treeName.empty()) { 476 PrintFatalError( 477 def.getLoc(), 478 formatv("binding symbol '{0}' to non-operation unsupported right now", 479 treeName)); 480 } 481 return; 482 } 483 484 auto &op = getDialectOp(tree); 485 auto numOpArgs = op.getNumArgs(); 486 auto numTreeArgs = tree.getNumArgs(); 487 488 if (numOpArgs != numTreeArgs) { 489 auto err = formatv("op '{0}' argument number mismatch: " 490 "{1} in pattern vs. {2} in definition", 491 op.getOperationName(), numTreeArgs, numOpArgs); 492 PrintFatalError(def.getLoc(), err); 493 } 494 495 // The name attached to the DAG node's operator is for representing the 496 // results generated from this op. It should be remembered as bound results. 497 if (!treeName.empty()) { 498 if (!infoMap.bindOpResult(treeName, op)) 499 PrintFatalError(def.getLoc(), 500 formatv("symbol '{0}' bound more than once", treeName)); 501 } 502 503 for (int i = 0; i != numTreeArgs; ++i) { 504 if (auto treeArg = tree.getArgAsNestedDag(i)) { 505 // This DAG node argument is a DAG node itself. Go inside recursively. 506 collectBoundSymbols(treeArg, infoMap, isSrcPattern); 507 } else if (isSrcPattern) { 508 // We can only bind symbols to op arguments in source pattern. Those 509 // symbols are referenced in result patterns. 510 auto treeArgName = tree.getArgName(i); 511 if (!treeArgName.empty()) { 512 if (!infoMap.bindOpArgument(treeArgName, op, i)) { 513 auto err = formatv("symbol '{0}' bound more than once", treeArgName); 514 PrintFatalError(def.getLoc(), err); 515 } 516 } 517 } 518 } 519 }