github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp (about) 1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/Attribute.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "mlir/TableGen/Pattern.h" 28 #include "mlir/TableGen/Predicate.h" 29 #include "mlir/TableGen/Type.h" 30 #include "llvm/ADT/StringExtras.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FormatAdapters.h" 34 #include "llvm/Support/PrettyStackTrace.h" 35 #include "llvm/Support/Signals.h" 36 #include "llvm/TableGen/Error.h" 37 #include "llvm/TableGen/Main.h" 38 #include "llvm/TableGen/Record.h" 39 #include "llvm/TableGen/TableGenBackend.h" 40 41 using namespace llvm; 42 using namespace mlir; 43 using namespace mlir::tblgen; 44 45 namespace llvm { 46 template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 47 static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 48 raw_ostream &os, StringRef style) { 49 os << v.first << ":" << v.second; 50 } 51 }; 52 } // end namespace llvm 53 54 //===----------------------------------------------------------------------===// 55 // PatternEmitter 56 //===----------------------------------------------------------------------===// 57 58 namespace { 59 class PatternEmitter { 60 public: 61 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); 62 63 // Emits the mlir::RewritePattern struct named `rewriteName`. 64 void emit(StringRef rewriteName); 65 66 private: 67 // Emits the code for matching ops. 68 void emitMatchLogic(DagNode tree); 69 70 // Emits the code for rewriting ops. 71 void emitRewriteLogic(); 72 73 //===--------------------------------------------------------------------===// 74 // Match utilities 75 //===--------------------------------------------------------------------===// 76 77 // Emits C++ statements for matching the op constrained by the given DAG 78 // `tree`. 79 void emitOpMatch(DagNode tree, int depth); 80 81 // Emits C++ statements for matching the `index`-th argument of the given DAG 82 // `tree` as an operand. 83 void emitOperandMatch(DagNode tree, int index, int depth, int indent); 84 85 // Emits C++ statements for matching the `index`-th argument of the given DAG 86 // `tree` as an attribute. 87 void emitAttributeMatch(DagNode tree, int index, int depth, int indent); 88 89 //===--------------------------------------------------------------------===// 90 // Rewrite utilities 91 //===--------------------------------------------------------------------===// 92 93 // The entry point for handling a result pattern rooted at `resultTree`. This 94 // method dispatches to concrete handlers according to `resultTree`'s kind and 95 // returns a symbol representing the whole value pack. Callers are expected to 96 // further resolve the symbol according to the specific use case. 97 // 98 // `depth` is the nesting level of `resultTree`; 0 means top-level result 99 // pattern. For top-level result pattern, `resultIndex` indicates which result 100 // of the matched root op this pattern is intended to replace, which can be 101 // used to deduce the result type of the op generated from this result 102 // pattern. 103 std::string handleResultPattern(DagNode resultTree, int resultIndex, 104 int depth); 105 106 // Emits the C++ statement to replace the matched DAG with a value built via 107 // calling native C++ code. 108 std::string handleReplaceWithNativeCodeCall(DagNode resultTree); 109 110 // Returns the C++ expression referencing the old value serving as the 111 // replacement. 112 std::string handleReplaceWithValue(DagNode tree); 113 114 // Emits the C++ statement to build a new op out of the given DAG `tree` and 115 // returns the variable name that this op is assigned to. If the root op in 116 // DAG `tree` has a specified name, the created op will be assigned to a 117 // variable of the given name. Otherwise, a unique name will be used as the 118 // result value name. 119 std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 120 121 // Returns the C++ expression to construct a constant attribute of the given 122 // `value` for the given attribute kind `attr`. 123 std::string handleConstantAttr(Attribute attr, StringRef value); 124 125 // Returns the C++ expression to build an argument from the given DAG `leaf`. 126 // `patArgName` is used to bound the argument to the source pattern. 127 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 128 129 //===--------------------------------------------------------------------===// 130 // General utilities 131 //===--------------------------------------------------------------------===// 132 133 // Collects all of the operations within the given dag tree. 134 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 135 136 // Returns a unique symbol for a local variable of the given `op`. 137 std::string getUniqueSymbol(const Operator *op); 138 139 //===--------------------------------------------------------------------===// 140 // Symbol utilities 141 //===--------------------------------------------------------------------===// 142 143 // Returns how many static values the given DAG `node` correspond to. 144 int getNodeValueCount(DagNode node); 145 146 private: 147 // Pattern instantiation location followed by the location of multiclass 148 // prototypes used. This is intended to be used as a whole to 149 // PrintFatalError() on errors. 150 ArrayRef<llvm::SMLoc> loc; 151 152 // Op's TableGen Record to wrapper object. 153 RecordOperatorMap *opMap; 154 155 // Handy wrapper for pattern being emitted. 156 Pattern pattern; 157 158 // Map for all bound symbols' info. 159 SymbolInfoMap symbolInfoMap; 160 161 // The next unused ID for newly created values. 162 unsigned nextValueId; 163 164 raw_ostream &os; 165 166 // Format contexts containing placeholder substitutations. 167 FmtContext fmtCtx; 168 169 // Number of op processed. 170 int opCounter = 0; 171 }; 172 } // end anonymous namespace 173 174 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, 175 raw_ostream &os) 176 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 177 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { 178 fmtCtx.withBuilder("rewriter"); 179 } 180 181 std::string PatternEmitter::handleConstantAttr(Attribute attr, 182 StringRef value) { 183 if (!attr.isConstBuildable()) 184 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 185 " does not have the 'constBuilderCall' field"); 186 187 // TODO(jpienaar): Verify the constants here 188 return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value); 189 } 190 191 // Helper function to match patterns. 192 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 193 Operator &op = tree.getDialectOp(opMap); 194 195 int indent = 4 + 2 * depth; 196 os.indent(indent) << formatv( 197 "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", 198 depth, op.getQualCppClassName()); 199 // Skip the operand matching at depth 0 as the pattern rewriter already does. 200 if (depth != 0) { 201 // Skip if there is no defining operation (e.g., arguments to function). 202 os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", 203 depth); 204 } 205 if (tree.getNumArgs() != op.getNumArgs()) { 206 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " 207 "pattern vs. {2} in definition", 208 op.getOperationName(), tree.getNumArgs(), 209 op.getNumArgs())); 210 } 211 212 // If the operand's name is set, set to that variable. 213 auto name = tree.getSymbol(); 214 if (!name.empty()) 215 os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); 216 217 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 218 auto opArg = op.getArg(i); 219 220 // Handle nested DAG construct first 221 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 222 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) { 223 if (operand->isVariadic()) { 224 auto error = formatv("use nested DAG construct to match op {0}'s " 225 "variadic operand #{1} unsupported now", 226 op.getOperationName(), i); 227 PrintFatalError(loc, error); 228 } 229 } 230 os.indent(indent) << "{\n"; 231 232 os.indent(indent + 2) << formatv( 233 "auto *op{0} = " 234 "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n", 235 depth + 1, depth, i); 236 emitOpMatch(argTree, depth + 1); 237 os.indent(indent + 2) 238 << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); 239 os.indent(indent) << "}\n"; 240 continue; 241 } 242 243 // Next handle DAG leaf: operand or attribute 244 if (opArg.is<NamedTypeConstraint *>()) { 245 emitOperandMatch(tree, i, depth, indent); 246 } else if (opArg.is<NamedAttribute *>()) { 247 emitAttributeMatch(tree, i, depth, indent); 248 } else { 249 PrintFatalError(loc, "unhandled case when matching op"); 250 } 251 } 252 } 253 254 void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, 255 int indent) { 256 Operator &op = tree.getDialectOp(opMap); 257 auto *operand = op.getArg(index).get<NamedTypeConstraint *>(); 258 auto matcher = tree.getArgAsLeaf(index); 259 260 // If a constraint is specified, we need to generate C++ statements to 261 // check the constraint. 262 if (!matcher.isUnspecified()) { 263 if (!matcher.isOperandMatcher()) { 264 PrintFatalError( 265 loc, formatv("the {1}-th argument of op '{0}' should be an operand", 266 op.getOperationName(), index + 1)); 267 } 268 269 // Only need to verify if the matcher's type is different from the one 270 // of op definition. 271 if (operand->constraint != matcher.getAsConstraint()) { 272 if (operand->isVariadic()) { 273 auto error = formatv( 274 "further constrain op {0}'s variadic operand #{1} unsupported now", 275 op.getOperationName(), index); 276 PrintFatalError(loc, error); 277 } 278 auto self = 279 formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()", 280 depth, index); 281 os.indent(indent) << "if (!(" 282 << tgfmt(matcher.getConditionTemplate(), 283 &fmtCtx.withSelf(self)) 284 << ")) return matchFailure();\n"; 285 } 286 } 287 288 // Capture the value 289 auto name = tree.getArgName(index); 290 if (!name.empty()) { 291 os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n", 292 name, depth, index); 293 } 294 } 295 296 void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, 297 int indent) { 298 Operator &op = tree.getDialectOp(opMap); 299 auto *namedAttr = op.getArg(index).get<NamedAttribute *>(); 300 const auto &attr = namedAttr->attr; 301 302 os.indent(indent) << "{\n"; 303 indent += 2; 304 os.indent(indent) << formatv( 305 "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, 306 attr.getStorageType(), namedAttr->name); 307 308 // TODO(antiagainst): This should use getter method to avoid duplication. 309 if (attr.hasDefaultValueInitializer()) { 310 os.indent(indent) << "if (!tblgen_attr) tblgen_attr = " 311 << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 312 attr.getDefaultValueInitializer()) 313 << ";\n"; 314 } else if (attr.isOptional()) { 315 // For a missing attribute that is optional according to definition, we 316 // should just capature a mlir::Attribute() to signal the missing state. 317 // That is precisely what getAttr() returns on missing attributes. 318 } else { 319 os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; 320 } 321 322 auto matcher = tree.getArgAsLeaf(index); 323 if (!matcher.isUnspecified()) { 324 if (!matcher.isAttrMatcher()) { 325 PrintFatalError( 326 loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 327 op.getOperationName(), index + 1)); 328 } 329 330 // If a constraint is specified, we need to generate C++ statements to 331 // check the constraint. 332 os.indent(indent) << "if (!(" 333 << tgfmt(matcher.getConditionTemplate(), 334 &fmtCtx.withSelf("tblgen_attr")) 335 << ")) return matchFailure();\n"; 336 } 337 338 // Capture the value 339 auto name = tree.getArgName(index); 340 if (!name.empty()) { 341 os.indent(indent) << formatv("{0} = tblgen_attr;\n", name); 342 } 343 344 indent -= 2; 345 os.indent(indent) << "}\n"; 346 } 347 348 void PatternEmitter::emitMatchLogic(DagNode tree) { 349 emitOpMatch(tree, 0); 350 351 for (auto &appliedConstraint : pattern.getConstraints()) { 352 auto &constraint = appliedConstraint.constraint; 353 auto &entities = appliedConstraint.entities; 354 355 auto condition = constraint.getConditionTemplate(); 356 auto cmd = "if (!({0})) return matchFailure();\n"; 357 358 if (isa<TypeConstraint>(constraint)) { 359 auto self = formatv("({0}->getType())", 360 symbolInfoMap.getValueAndRangeUse(entities.front())); 361 os.indent(4) << formatv(cmd, 362 tgfmt(condition, &fmtCtx.withSelf(self.str()))); 363 } else if (isa<AttrConstraint>(constraint)) { 364 PrintFatalError( 365 loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 366 } else { 367 // TODO(b/138794486): replace formatv arguments with the exact specified 368 // args. 369 if (entities.size() > 4) { 370 PrintFatalError(loc, "only support up to 4-entity constraints now"); 371 } 372 SmallVector<std::string, 4> names; 373 int i = 0; 374 for (int e = entities.size(); i < e; ++i) 375 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 376 std::string self = appliedConstraint.self; 377 if (!self.empty()) 378 self = symbolInfoMap.getValueAndRangeUse(self); 379 for (; i < 4; ++i) 380 names.push_back("<unused>"); 381 os.indent(4) << formatv(cmd, 382 tgfmt(condition, &fmtCtx.withSelf(self), names[0], 383 names[1], names[2], names[3])); 384 } 385 } 386 } 387 388 void PatternEmitter::collectOps(DagNode tree, 389 llvm::SmallPtrSetImpl<const Operator *> &ops) { 390 // Check if this tree is an operation. 391 if (tree.isOperation()) 392 ops.insert(&tree.getDialectOp(opMap)); 393 394 // Recurse the arguments of the tree. 395 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 396 if (auto child = tree.getArgAsNestedDag(i)) 397 collectOps(child, ops); 398 } 399 400 void PatternEmitter::emit(StringRef rewriteName) { 401 // Get the DAG tree for the source pattern. 402 DagNode sourceTree = pattern.getSourcePattern(); 403 404 const Operator &rootOp = pattern.getSourceRootOp(); 405 auto rootName = rootOp.getOperationName(); 406 407 // Collect the set of result operations. 408 llvm::SmallPtrSet<const Operator *, 4> resultOps; 409 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) 410 collectOps(pattern.getResultPattern(i), resultOps); 411 412 // Emit RewritePattern for Pattern. 413 auto locs = pattern.getLocation(); 414 os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n", 415 make_range(locs.rbegin(), locs.rend())); 416 os << formatv(R"(struct {0} : public RewritePattern { 417 {0}(MLIRContext *context) 418 : RewritePattern("{1}", {{)", 419 rewriteName, rootName); 420 interleaveComma(resultOps, os, [&](const Operator *op) { 421 os << '"' << op->getOperationName() << '"'; 422 }); 423 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; 424 425 // Emit matchAndRewrite() function. 426 os << R"( 427 PatternMatchResult matchAndRewrite(Operation *op0, 428 PatternRewriter &rewriter) const override { 429 )"; 430 431 // Register all symbols bound in the source pattern. 432 pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 433 434 os.indent(4) << "// Variables for capturing values and attributes used for " 435 "creating ops\n"; 436 // Create local variables for storing the arguments and results bound 437 // to symbols. 438 for (const auto &symbolInfoPair : symbolInfoMap) { 439 StringRef symbol = symbolInfoPair.getKey(); 440 auto &info = symbolInfoPair.getValue(); 441 os.indent(4) << info.getVarDecl(symbol); 442 } 443 // TODO(jpienaar): capture ops with consistent numbering so that it can be 444 // reused for fused loc. 445 os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n", 446 pattern.getSourcePattern().getNumOps()); 447 448 os.indent(4) << "// Match\n"; 449 os.indent(4) << "tblgen_ops[0] = op0;\n"; 450 emitMatchLogic(sourceTree); 451 os << "\n"; 452 453 os.indent(4) << "// Rewrite\n"; 454 emitRewriteLogic(); 455 456 os.indent(4) << "return matchSuccess();\n"; 457 os << " };\n"; 458 os << "};\n"; 459 } 460 461 void PatternEmitter::emitRewriteLogic() { 462 const Operator &rootOp = pattern.getSourceRootOp(); 463 int numExpectedResults = rootOp.getNumResults(); 464 int numResultPatterns = pattern.getNumResultPatterns(); 465 466 // First register all symbols bound to ops generated in result patterns. 467 pattern.collectResultPatternBoundSymbols(symbolInfoMap); 468 469 // Only the last N static values generated are used to replace the matched 470 // root N-result op. We need to calculate the starting index (of the results 471 // of the matched op) each result pattern is to replace. 472 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 473 // If we don't need to replace any value at all, set the replacement starting 474 // index as the number of result patterns so we skip all of them when trying 475 // to replace the matched op's results. 476 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 477 for (int i = numResultPatterns - 1; i >= 0; --i) { 478 auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 479 offsets[i] = offsets[i + 1] - numValues; 480 if (offsets[i] == 0) { 481 if (replStartIndex == -1) 482 replStartIndex = i; 483 } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 484 auto error = formatv( 485 "cannot use the same multi-result op '{0}' to generate both " 486 "auxiliary values and values to be used for replacing the matched op", 487 pattern.getResultPattern(i).getSymbol()); 488 PrintFatalError(loc, error); 489 } 490 } 491 492 if (offsets.front() > 0) { 493 const char error[] = "no enough values generated to replace the matched op"; 494 PrintFatalError(loc, error); 495 } 496 497 os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n"; 498 os.indent(4) << "auto loc = rewriter.getFusedLoc({"; 499 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 500 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 501 } 502 os << "}); (void)loc;\n"; 503 504 // Process each result pattern and record the result symbol. 505 llvm::SmallVector<std::string, 2> resultValues; 506 for (int i = 0; i < numResultPatterns; ++i) { 507 DagNode resultTree = pattern.getResultPattern(i); 508 resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0)); 509 } 510 511 os.indent(4) << "SmallVector<Value *, 4> tblgen_values;"; 512 // Only use the last portion for replacing the matched root op's results. 513 auto range = llvm::makeArrayRef(resultValues).drop_front(replStartIndex); 514 for (const auto &val : range) { 515 os.indent(4) << "\n"; 516 // Resolve each symbol for all range use so that we can loop over them. 517 os << symbolInfoMap.getAllRangeUse( 518 val, " for (auto *v : {0}) tblgen_values.push_back(v);", "\n"); 519 } 520 os.indent(4) << "\n"; 521 os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n"; 522 } 523 524 std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 525 return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); 526 } 527 528 std::string PatternEmitter::handleResultPattern(DagNode resultTree, 529 int resultIndex, int depth) { 530 if (resultTree.isNativeCodeCall()) { 531 auto symbol = handleReplaceWithNativeCodeCall(resultTree); 532 symbolInfoMap.bindValue(symbol); 533 return symbol; 534 } 535 536 if (resultTree.isReplaceWithValue()) { 537 return handleReplaceWithValue(resultTree); 538 } 539 540 // Normal op creation. 541 auto symbol = handleOpCreation(resultTree, resultIndex, depth); 542 if (resultTree.getSymbol().empty()) { 543 // This is an op not explicitly bound to a symbol in the rewrite rule. 544 // Register the auto-generated symbol for it. 545 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 546 } 547 return symbol; 548 } 549 550 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { 551 assert(tree.isReplaceWithValue()); 552 553 if (tree.getNumArgs() != 1) { 554 PrintFatalError( 555 loc, "replaceWithValue directive must take exactly one argument"); 556 } 557 558 if (!tree.getSymbol().empty()) { 559 PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 560 } 561 562 return tree.getArgName(0); 563 } 564 565 std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 566 StringRef patArgName) { 567 if (leaf.isConstantAttr()) { 568 auto constAttr = leaf.getAsConstantAttr(); 569 return handleConstantAttr(constAttr.getAttribute(), 570 constAttr.getConstantValue()); 571 } 572 if (leaf.isEnumAttrCase()) { 573 auto enumCase = leaf.getAsEnumAttrCase(); 574 if (enumCase.isStrCase()) 575 return handleConstantAttr(enumCase, enumCase.getSymbol()); 576 // This is an enum case backed by an IntegerAttr. We need to get its value 577 // to build the constant. 578 std::string val = std::to_string(enumCase.getValue()); 579 return handleConstantAttr(enumCase, val); 580 } 581 582 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 583 if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 584 return argName; 585 } 586 if (leaf.isNativeCodeCall()) { 587 return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 588 } 589 PrintFatalError(loc, "unhandled case when rewriting op"); 590 } 591 592 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { 593 auto fmt = tree.getNativeCodeTemplate(); 594 // TODO(b/138794486): replace formatv arguments with the exact specified args. 595 SmallVector<std::string, 8> attrs(8); 596 if (tree.getNumArgs() > 8) { 597 PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + 598 Twine(tree.getNumArgs())); 599 } 600 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 601 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); 602 } 603 return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], 604 attrs[5], attrs[6], attrs[7]); 605 } 606 607 int PatternEmitter::getNodeValueCount(DagNode node) { 608 if (node.isOperation()) { 609 // If the op is bound to a symbol in the rewrite rule, query its result 610 // count from the symbol info map. 611 auto symbol = node.getSymbol(); 612 if (!symbol.empty()) { 613 return symbolInfoMap.getStaticValueCount(symbol); 614 } 615 // Otherwise this is an unbound op; we will use all its results. 616 return pattern.getDialectOp(node).getNumResults(); 617 } 618 // TODO(antiagainst): This considers all NativeCodeCall as returning one 619 // value. Enhance if multi-value ones are needed. 620 return 1; 621 } 622 623 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 624 int depth) { 625 Operator &resultOp = tree.getDialectOp(opMap); 626 auto numOpArgs = resultOp.getNumArgs(); 627 628 if (numOpArgs != tree.getNumArgs()) { 629 PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " 630 "{1} in pattern vs. {2} in definition", 631 resultOp.getOperationName(), tree.getNumArgs(), 632 numOpArgs)); 633 } 634 635 // A map to collect all nested DAG child nodes' names, with operand index as 636 // the key. This includes both bound and unbound child nodes. 637 llvm::DenseMap<unsigned, std::string> childNodeNames; 638 639 // First go through all the child nodes who are nested DAG constructs to 640 // create ops for them and remember the symbol names for them, so that we can 641 // use the results in the current node. This happens in a recursive manner. 642 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { 643 if (auto child = tree.getArgAsNestedDag(i)) { 644 childNodeNames[i] = handleResultPattern(child, i, depth + 1); 645 } 646 } 647 648 // The name of the local variable holding this op. 649 std::string valuePackName; 650 // The symbol for holding the result of this pattern. Note that the result of 651 // this pattern is not necessarily the same as the variable created by this 652 // pattern because we can use `__N` suffix to refer only a specific result if 653 // the generated op is a multi-result op. 654 std::string resultValue; 655 if (tree.getSymbol().empty()) { 656 // No symbol is explicitly bound to this op in the pattern. Generate a 657 // unique name. 658 valuePackName = resultValue = getUniqueSymbol(&resultOp); 659 } else { 660 resultValue = tree.getSymbol(); 661 // Strip the index to get the name for the value pack and use it to name the 662 // local variable for the op. 663 valuePackName = SymbolInfoMap::getValuePackName(resultValue); 664 } 665 666 // Create the local variable for this op. 667 os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(), 668 valuePackName); 669 os.indent(4) << "{\n"; 670 671 // Now prepare operands used for building this op: 672 // * If the operand is non-variadic, we create a `Value*` local variable. 673 // * If the operand is variadic, we create a `SmallVector<Value*>` local 674 // variable. 675 676 int argIndex = 0; // The current index to this op's ODS argument 677 int valueIndex = 0; // An index for uniquing local variable names. 678 for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) { 679 const auto &operand = resultOp.getOperand(argIndex); 680 std::string varName; 681 if (operand.isVariadic()) { 682 varName = formatv("tblgen_values_{0}", valueIndex++); 683 os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName); 684 std::string range; 685 if (tree.isNestedDagArg(argIndex)) { 686 range = childNodeNames[argIndex]; 687 } else { 688 range = tree.getArgName(argIndex); 689 } 690 // Resolve the symbol for all range use so that we have a uniform way of 691 // capturing the values. 692 range = symbolInfoMap.getValueAndRangeUse(range); 693 os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range, 694 varName); 695 } else { 696 varName = formatv("tblgen_value_{0}", valueIndex++); 697 os.indent(6) << formatv("Value *{0} = ", varName); 698 if (tree.isNestedDagArg(argIndex)) { 699 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 700 } else { 701 DagLeaf leaf = tree.getArgAsLeaf(argIndex); 702 auto symbol = 703 symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex)); 704 if (leaf.isNativeCodeCall()) { 705 os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); 706 } else { 707 os << symbol; 708 } 709 } 710 os << ";\n"; 711 } 712 713 // Update to use the newly created local variable for building the op later. 714 childNodeNames[argIndex] = varName; 715 } 716 717 // Then we create the builder call. 718 719 // Right now we don't have general type inference in MLIR. Except a few 720 // special cases listed below, we need to supply types for all results 721 // when building an op. 722 bool isSameOperandsAndResultType = 723 resultOp.hasTrait("OpTrait::SameOperandsAndResultType"); 724 bool isBroadcastable = 725 resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult"); 726 bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType"); 727 bool usePartialResults = valuePackName != resultValue; 728 729 if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr || 730 usePartialResults || depth > 0 || resultIndex < 0) { 731 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 732 resultOp.getQualCppClassName()); 733 } else { 734 // If depth == 0 and resultIndex >= 0, it means we are replacing the values 735 // generated from the source pattern root op. Then we can use the source 736 // pattern's value types to determine the value type of the generated op 737 // here. 738 739 // We need to specify the types for all results. 740 int numResults = resultOp.getNumResults(); 741 if (numResults != 0) { 742 os.indent(6) << "tblgen_types.clear();\n"; 743 for (int i = 0; i < numResults; ++i) { 744 os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) " 745 "tblgen_types.push_back(v->getType());\n", 746 resultIndex + i); 747 } 748 } 749 750 os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, 751 resultOp.getQualCppClassName()); 752 if (numResults != 0) 753 os.indent(6) << ", tblgen_types"; 754 } 755 756 // Add operands for the builder all. 757 for (int i = 0; i < argIndex; ++i) { 758 const auto &operand = resultOp.getOperand(i); 759 // Start each operand on its own line. 760 (os << ",\n").indent(8); 761 if (!operand.name.empty()) { 762 os << "/*" << operand.name << "=*/"; 763 } 764 os << childNodeNames[i]; 765 // TODO(jpienaar): verify types 766 } 767 768 // Add attributes for the builder call. 769 for (; argIndex != numOpArgs; ++argIndex) { 770 // Start each attribute on its own line. 771 (os << ",\n").indent(8); 772 // The argument in the op definition. 773 auto opArgName = resultOp.getArgName(argIndex); 774 if (auto subTree = tree.getArgAsNestedDag(argIndex)) { 775 if (!subTree.isNativeCodeCall()) 776 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 777 "for creating attribute"); 778 os << formatv("/*{0}=*/{1}", opArgName, 779 handleReplaceWithNativeCodeCall(subTree)); 780 } else { 781 auto leaf = tree.getArgAsLeaf(argIndex); 782 // The argument in the result DAG pattern. 783 auto patArgName = tree.getArgName(argIndex); 784 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 785 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 786 auto argument = resultOp.getArg(argIndex); 787 if (!argument.is<NamedAttribute *>()) 788 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 789 if (!patArgName.empty()) 790 os << "/*" << patArgName << "=*/"; 791 } else { 792 os << "/*" << opArgName << "=*/"; 793 } 794 os << handleOpArgument(leaf, patArgName); 795 } 796 } 797 os << "\n );\n"; 798 os.indent(4) << "}\n"; 799 800 return resultValue; 801 } 802 803 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 804 emitSourceFileHeader("Rewriters", os); 805 806 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 807 auto numPatterns = patterns.size(); 808 809 // We put the map here because it can be shared among multiple patterns. 810 RecordOperatorMap recordOpMap; 811 812 std::vector<std::string> rewriterNames; 813 rewriterNames.reserve(numPatterns); 814 815 std::string baseRewriterName = "GeneratedConvert"; 816 int rewriterIndex = 0; 817 818 for (Record *p : patterns) { 819 std::string name; 820 if (p->isAnonymous()) { 821 // If no name is provided, ensure unique rewriter names simply by 822 // appending unique suffix. 823 name = baseRewriterName + llvm::utostr(rewriterIndex++); 824 } else { 825 name = p->getName(); 826 } 827 PatternEmitter(p, &recordOpMap, os).emit(name); 828 rewriterNames.push_back(std::move(name)); 829 } 830 831 // Emit function to add the generated matchers to the pattern list. 832 os << "void populateWithGenerated(MLIRContext *context, " 833 << "OwningRewritePatternList *patterns) {\n"; 834 for (const auto &name : rewriterNames) { 835 os << " patterns->insert<" << name << ">(context);\n"; 836 } 837 os << "}\n"; 838 } 839 840 static mlir::GenRegistration 841 genRewriters("gen-rewriters", "Generate pattern rewriters", 842 [](const RecordKeeper &records, raw_ostream &os) { 843 emitRewriters(records, os); 844 return false; 845 });