github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp (about)

     1  //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Support/STLExtras.h"
    23  #include "mlir/TableGen/Attribute.h"
    24  #include "mlir/TableGen/Format.h"
    25  #include "mlir/TableGen/GenInfo.h"
    26  #include "mlir/TableGen/Operator.h"
    27  #include "mlir/TableGen/Pattern.h"
    28  #include "mlir/TableGen/Predicate.h"
    29  #include "mlir/TableGen/Type.h"
    30  #include "llvm/ADT/StringExtras.h"
    31  #include "llvm/ADT/StringSet.h"
    32  #include "llvm/Support/CommandLine.h"
    33  #include "llvm/Support/FormatAdapters.h"
    34  #include "llvm/Support/PrettyStackTrace.h"
    35  #include "llvm/Support/Signals.h"
    36  #include "llvm/TableGen/Error.h"
    37  #include "llvm/TableGen/Main.h"
    38  #include "llvm/TableGen/Record.h"
    39  #include "llvm/TableGen/TableGenBackend.h"
    40  
    41  using namespace llvm;
    42  using namespace mlir;
    43  using namespace mlir::tblgen;
    44  
    45  namespace llvm {
    46  template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
    47    static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
    48                       raw_ostream &os, StringRef style) {
    49      os << v.first << ":" << v.second;
    50    }
    51  };
    52  } // end namespace llvm
    53  
    54  //===----------------------------------------------------------------------===//
    55  // PatternEmitter
    56  //===----------------------------------------------------------------------===//
    57  
    58  namespace {
    59  class PatternEmitter {
    60  public:
    61    PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
    62  
    63    // Emits the mlir::RewritePattern struct named `rewriteName`.
    64    void emit(StringRef rewriteName);
    65  
    66  private:
    67    // Emits the code for matching ops.
    68    void emitMatchLogic(DagNode tree);
    69  
    70    // Emits the code for rewriting ops.
    71    void emitRewriteLogic();
    72  
    73    //===--------------------------------------------------------------------===//
    74    // Match utilities
    75    //===--------------------------------------------------------------------===//
    76  
    77    // Emits C++ statements for matching the op constrained by the given DAG
    78    // `tree`.
    79    void emitOpMatch(DagNode tree, int depth);
    80  
    81    // Emits C++ statements for matching the `index`-th argument of the given DAG
    82    // `tree` as an operand.
    83    void emitOperandMatch(DagNode tree, int index, int depth, int indent);
    84  
    85    // Emits C++ statements for matching the `index`-th argument of the given DAG
    86    // `tree` as an attribute.
    87    void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
    88  
    89    //===--------------------------------------------------------------------===//
    90    // Rewrite utilities
    91    //===--------------------------------------------------------------------===//
    92  
    93    // The entry point for handling a result pattern rooted at `resultTree`. This
    94    // method dispatches to concrete handlers according to `resultTree`'s kind and
    95    // returns a symbol representing the whole value pack. Callers are expected to
    96    // further resolve the symbol according to the specific use case.
    97    //
    98    // `depth` is the nesting level of `resultTree`; 0 means top-level result
    99    // pattern. For top-level result pattern, `resultIndex` indicates which result
   100    // of the matched root op this pattern is intended to replace, which can be
   101    // used to deduce the result type of the op generated from this result
   102    // pattern.
   103    std::string handleResultPattern(DagNode resultTree, int resultIndex,
   104                                    int depth);
   105  
   106    // Emits the C++ statement to replace the matched DAG with a value built via
   107    // calling native C++ code.
   108    std::string handleReplaceWithNativeCodeCall(DagNode resultTree);
   109  
   110    // Returns the C++ expression referencing the old value serving as the
   111    // replacement.
   112    std::string handleReplaceWithValue(DagNode tree);
   113  
   114    // Emits the C++ statement to build a new op out of the given DAG `tree` and
   115    // returns the variable name that this op is assigned to. If the root op in
   116    // DAG `tree` has a specified name, the created op will be assigned to a
   117    // variable of the given name. Otherwise, a unique name will be used as the
   118    // result value name.
   119    std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
   120  
   121    // Returns the C++ expression to construct a constant attribute of the given
   122    // `value` for the given attribute kind `attr`.
   123    std::string handleConstantAttr(Attribute attr, StringRef value);
   124  
   125    // Returns the C++ expression to build an argument from the given DAG `leaf`.
   126    // `patArgName` is used to bound the argument to the source pattern.
   127    std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
   128  
   129    //===--------------------------------------------------------------------===//
   130    // General utilities
   131    //===--------------------------------------------------------------------===//
   132  
   133    // Collects all of the operations within the given dag tree.
   134    void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
   135  
   136    // Returns a unique symbol for a local variable of the given `op`.
   137    std::string getUniqueSymbol(const Operator *op);
   138  
   139    //===--------------------------------------------------------------------===//
   140    // Symbol utilities
   141    //===--------------------------------------------------------------------===//
   142  
   143    // Returns how many static values the given DAG `node` correspond to.
   144    int getNodeValueCount(DagNode node);
   145  
   146  private:
   147    // Pattern instantiation location followed by the location of multiclass
   148    // prototypes used. This is intended to be used as a whole to
   149    // PrintFatalError() on errors.
   150    ArrayRef<llvm::SMLoc> loc;
   151  
   152    // Op's TableGen Record to wrapper object.
   153    RecordOperatorMap *opMap;
   154  
   155    // Handy wrapper for pattern being emitted.
   156    Pattern pattern;
   157  
   158    // Map for all bound symbols' info.
   159    SymbolInfoMap symbolInfoMap;
   160  
   161    // The next unused ID for newly created values.
   162    unsigned nextValueId;
   163  
   164    raw_ostream &os;
   165  
   166    // Format contexts containing placeholder substitutations.
   167    FmtContext fmtCtx;
   168  
   169    // Number of op processed.
   170    int opCounter = 0;
   171  };
   172  } // end anonymous namespace
   173  
   174  PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
   175                                 raw_ostream &os)
   176      : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
   177        symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
   178    fmtCtx.withBuilder("rewriter");
   179  }
   180  
   181  std::string PatternEmitter::handleConstantAttr(Attribute attr,
   182                                                 StringRef value) {
   183    if (!attr.isConstBuildable())
   184      PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
   185                               " does not have the 'constBuilderCall' field");
   186  
   187    // TODO(jpienaar): Verify the constants here
   188    return tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value);
   189  }
   190  
   191  // Helper function to match patterns.
   192  void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
   193    Operator &op = tree.getDialectOp(opMap);
   194  
   195    int indent = 4 + 2 * depth;
   196    os.indent(indent) << formatv(
   197        "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
   198        depth, op.getQualCppClassName());
   199    // Skip the operand matching at depth 0 as the pattern rewriter already does.
   200    if (depth != 0) {
   201      // Skip if there is no defining operation (e.g., arguments to function).
   202      os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
   203                                   depth);
   204    }
   205    if (tree.getNumArgs() != op.getNumArgs()) {
   206      PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
   207                                   "pattern vs. {2} in definition",
   208                                   op.getOperationName(), tree.getNumArgs(),
   209                                   op.getNumArgs()));
   210    }
   211  
   212    // If the operand's name is set, set to that variable.
   213    auto name = tree.getSymbol();
   214    if (!name.empty())
   215      os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
   216  
   217    for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
   218      auto opArg = op.getArg(i);
   219  
   220      // Handle nested DAG construct first
   221      if (DagNode argTree = tree.getArgAsNestedDag(i)) {
   222        if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
   223          if (operand->isVariadic()) {
   224            auto error = formatv("use nested DAG construct to match op {0}'s "
   225                                 "variadic operand #{1} unsupported now",
   226                                 op.getOperationName(), i);
   227            PrintFatalError(loc, error);
   228          }
   229        }
   230        os.indent(indent) << "{\n";
   231  
   232        os.indent(indent + 2) << formatv(
   233            "auto *op{0} = "
   234            "(*castedOp{1}.getODSOperands({2}).begin())->getDefiningOp();\n",
   235            depth + 1, depth, i);
   236        emitOpMatch(argTree, depth + 1);
   237        os.indent(indent + 2)
   238            << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
   239        os.indent(indent) << "}\n";
   240        continue;
   241      }
   242  
   243      // Next handle DAG leaf: operand or attribute
   244      if (opArg.is<NamedTypeConstraint *>()) {
   245        emitOperandMatch(tree, i, depth, indent);
   246      } else if (opArg.is<NamedAttribute *>()) {
   247        emitAttributeMatch(tree, i, depth, indent);
   248      } else {
   249        PrintFatalError(loc, "unhandled case when matching op");
   250      }
   251    }
   252  }
   253  
   254  void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
   255                                        int indent) {
   256    Operator &op = tree.getDialectOp(opMap);
   257    auto *operand = op.getArg(index).get<NamedTypeConstraint *>();
   258    auto matcher = tree.getArgAsLeaf(index);
   259  
   260    // If a constraint is specified, we need to generate C++ statements to
   261    // check the constraint.
   262    if (!matcher.isUnspecified()) {
   263      if (!matcher.isOperandMatcher()) {
   264        PrintFatalError(
   265            loc, formatv("the {1}-th argument of op '{0}' should be an operand",
   266                         op.getOperationName(), index + 1));
   267      }
   268  
   269      // Only need to verify if the matcher's type is different from the one
   270      // of op definition.
   271      if (operand->constraint != matcher.getAsConstraint()) {
   272        if (operand->isVariadic()) {
   273          auto error = formatv(
   274              "further constrain op {0}'s variadic operand #{1} unsupported now",
   275              op.getOperationName(), index);
   276          PrintFatalError(loc, error);
   277        }
   278        auto self =
   279            formatv("(*castedOp{0}.getODSOperands({1}).begin())->getType()",
   280                    depth, index);
   281        os.indent(indent) << "if (!("
   282                          << tgfmt(matcher.getConditionTemplate(),
   283                                   &fmtCtx.withSelf(self))
   284                          << ")) return matchFailure();\n";
   285      }
   286    }
   287  
   288    // Capture the value
   289    auto name = tree.getArgName(index);
   290    if (!name.empty()) {
   291      os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
   292                                   name, depth, index);
   293    }
   294  }
   295  
   296  void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
   297                                          int indent) {
   298    Operator &op = tree.getDialectOp(opMap);
   299    auto *namedAttr = op.getArg(index).get<NamedAttribute *>();
   300    const auto &attr = namedAttr->attr;
   301  
   302    os.indent(indent) << "{\n";
   303    indent += 2;
   304    os.indent(indent) << formatv(
   305        "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
   306        attr.getStorageType(), namedAttr->name);
   307  
   308    // TODO(antiagainst): This should use getter method to avoid duplication.
   309    if (attr.hasDefaultValueInitializer()) {
   310      os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
   311                        << tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
   312                                 attr.getDefaultValueInitializer())
   313                        << ";\n";
   314    } else if (attr.isOptional()) {
   315      // For a missing attribute that is optional according to definition, we
   316      // should just capature a mlir::Attribute() to signal the missing state.
   317      // That is precisely what getAttr() returns on missing attributes.
   318    } else {
   319      os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n";
   320    }
   321  
   322    auto matcher = tree.getArgAsLeaf(index);
   323    if (!matcher.isUnspecified()) {
   324      if (!matcher.isAttrMatcher()) {
   325        PrintFatalError(
   326            loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
   327                         op.getOperationName(), index + 1));
   328      }
   329  
   330      // If a constraint is specified, we need to generate C++ statements to
   331      // check the constraint.
   332      os.indent(indent) << "if (!("
   333                        << tgfmt(matcher.getConditionTemplate(),
   334                                 &fmtCtx.withSelf("tblgen_attr"))
   335                        << ")) return matchFailure();\n";
   336    }
   337  
   338    // Capture the value
   339    auto name = tree.getArgName(index);
   340    if (!name.empty()) {
   341      os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
   342    }
   343  
   344    indent -= 2;
   345    os.indent(indent) << "}\n";
   346  }
   347  
   348  void PatternEmitter::emitMatchLogic(DagNode tree) {
   349    emitOpMatch(tree, 0);
   350  
   351    for (auto &appliedConstraint : pattern.getConstraints()) {
   352      auto &constraint = appliedConstraint.constraint;
   353      auto &entities = appliedConstraint.entities;
   354  
   355      auto condition = constraint.getConditionTemplate();
   356      auto cmd = "if (!({0})) return matchFailure();\n";
   357  
   358      if (isa<TypeConstraint>(constraint)) {
   359        auto self = formatv("({0}->getType())",
   360                            symbolInfoMap.getValueAndRangeUse(entities.front()));
   361        os.indent(4) << formatv(cmd,
   362                                tgfmt(condition, &fmtCtx.withSelf(self.str())));
   363      } else if (isa<AttrConstraint>(constraint)) {
   364        PrintFatalError(
   365            loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
   366      } else {
   367        // TODO(b/138794486): replace formatv arguments with the exact specified
   368        // args.
   369        if (entities.size() > 4) {
   370          PrintFatalError(loc, "only support up to 4-entity constraints now");
   371        }
   372        SmallVector<std::string, 4> names;
   373        int i = 0;
   374        for (int e = entities.size(); i < e; ++i)
   375          names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
   376        std::string self = appliedConstraint.self;
   377        if (!self.empty())
   378          self = symbolInfoMap.getValueAndRangeUse(self);
   379        for (; i < 4; ++i)
   380          names.push_back("<unused>");
   381        os.indent(4) << formatv(cmd,
   382                                tgfmt(condition, &fmtCtx.withSelf(self), names[0],
   383                                      names[1], names[2], names[3]));
   384      }
   385    }
   386  }
   387  
   388  void PatternEmitter::collectOps(DagNode tree,
   389                                  llvm::SmallPtrSetImpl<const Operator *> &ops) {
   390    // Check if this tree is an operation.
   391    if (tree.isOperation())
   392      ops.insert(&tree.getDialectOp(opMap));
   393  
   394    // Recurse the arguments of the tree.
   395    for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
   396      if (auto child = tree.getArgAsNestedDag(i))
   397        collectOps(child, ops);
   398  }
   399  
   400  void PatternEmitter::emit(StringRef rewriteName) {
   401    // Get the DAG tree for the source pattern.
   402    DagNode sourceTree = pattern.getSourcePattern();
   403  
   404    const Operator &rootOp = pattern.getSourceRootOp();
   405    auto rootName = rootOp.getOperationName();
   406  
   407    // Collect the set of result operations.
   408    llvm::SmallPtrSet<const Operator *, 4> resultOps;
   409    for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i)
   410      collectOps(pattern.getResultPattern(i), resultOps);
   411  
   412    // Emit RewritePattern for Pattern.
   413    auto locs = pattern.getLocation();
   414    os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
   415                  make_range(locs.rbegin(), locs.rend()));
   416    os << formatv(R"(struct {0} : public RewritePattern {
   417    {0}(MLIRContext *context)
   418        : RewritePattern("{1}", {{)",
   419                  rewriteName, rootName);
   420    interleaveComma(resultOps, os, [&](const Operator *op) {
   421      os << '"' << op->getOperationName() << '"';
   422    });
   423    os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
   424  
   425    // Emit matchAndRewrite() function.
   426    os << R"(
   427    PatternMatchResult matchAndRewrite(Operation *op0,
   428                                       PatternRewriter &rewriter) const override {
   429  )";
   430  
   431    // Register all symbols bound in the source pattern.
   432    pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
   433  
   434    os.indent(4) << "// Variables for capturing values and attributes used for "
   435                    "creating ops\n";
   436    // Create local variables for storing the arguments and results bound
   437    // to symbols.
   438    for (const auto &symbolInfoPair : symbolInfoMap) {
   439      StringRef symbol = symbolInfoPair.getKey();
   440      auto &info = symbolInfoPair.getValue();
   441      os.indent(4) << info.getVarDecl(symbol);
   442    }
   443    // TODO(jpienaar): capture ops with consistent numbering so that it can be
   444    // reused for fused loc.
   445    os.indent(4) << formatv("Operation *tblgen_ops[{0}];\n\n",
   446                            pattern.getSourcePattern().getNumOps());
   447  
   448    os.indent(4) << "// Match\n";
   449    os.indent(4) << "tblgen_ops[0] = op0;\n";
   450    emitMatchLogic(sourceTree);
   451    os << "\n";
   452  
   453    os.indent(4) << "// Rewrite\n";
   454    emitRewriteLogic();
   455  
   456    os.indent(4) << "return matchSuccess();\n";
   457    os << "  };\n";
   458    os << "};\n";
   459  }
   460  
   461  void PatternEmitter::emitRewriteLogic() {
   462    const Operator &rootOp = pattern.getSourceRootOp();
   463    int numExpectedResults = rootOp.getNumResults();
   464    int numResultPatterns = pattern.getNumResultPatterns();
   465  
   466    // First register all symbols bound to ops generated in result patterns.
   467    pattern.collectResultPatternBoundSymbols(symbolInfoMap);
   468  
   469    // Only the last N static values generated are used to replace the matched
   470    // root N-result op. We need to calculate the starting index (of the results
   471    // of the matched op) each result pattern is to replace.
   472    SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
   473    // If we don't need to replace any value at all, set the replacement starting
   474    // index as the number of result patterns so we skip all of them when trying
   475    // to replace the matched op's results.
   476    int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
   477    for (int i = numResultPatterns - 1; i >= 0; --i) {
   478      auto numValues = getNodeValueCount(pattern.getResultPattern(i));
   479      offsets[i] = offsets[i + 1] - numValues;
   480      if (offsets[i] == 0) {
   481        if (replStartIndex == -1)
   482          replStartIndex = i;
   483      } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
   484        auto error = formatv(
   485            "cannot use the same multi-result op '{0}' to generate both "
   486            "auxiliary values and values to be used for replacing the matched op",
   487            pattern.getResultPattern(i).getSymbol());
   488        PrintFatalError(loc, error);
   489      }
   490    }
   491  
   492    if (offsets.front() > 0) {
   493      const char error[] = "no enough values generated to replace the matched op";
   494      PrintFatalError(loc, error);
   495    }
   496  
   497    os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
   498    os.indent(4) << "auto loc = rewriter.getFusedLoc({";
   499    for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
   500      os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
   501    }
   502    os << "}); (void)loc;\n";
   503  
   504    // Process each result pattern and record the result symbol.
   505    llvm::SmallVector<std::string, 2> resultValues;
   506    for (int i = 0; i < numResultPatterns; ++i) {
   507      DagNode resultTree = pattern.getResultPattern(i);
   508      resultValues.push_back(handleResultPattern(resultTree, offsets[i], 0));
   509    }
   510  
   511    os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
   512    // Only use the last portion for replacing the matched root op's results.
   513    auto range = llvm::makeArrayRef(resultValues).drop_front(replStartIndex);
   514    for (const auto &val : range) {
   515      os.indent(4) << "\n";
   516      // Resolve each symbol for all range use so that we can loop over them.
   517      os << symbolInfoMap.getAllRangeUse(
   518          val, "    for (auto *v : {0}) tblgen_values.push_back(v);", "\n");
   519    }
   520    os.indent(4) << "\n";
   521    os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
   522  }
   523  
   524  std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
   525    return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
   526  }
   527  
   528  std::string PatternEmitter::handleResultPattern(DagNode resultTree,
   529                                                  int resultIndex, int depth) {
   530    if (resultTree.isNativeCodeCall()) {
   531      auto symbol = handleReplaceWithNativeCodeCall(resultTree);
   532      symbolInfoMap.bindValue(symbol);
   533      return symbol;
   534    }
   535  
   536    if (resultTree.isReplaceWithValue()) {
   537      return handleReplaceWithValue(resultTree);
   538    }
   539  
   540    // Normal op creation.
   541    auto symbol = handleOpCreation(resultTree, resultIndex, depth);
   542    if (resultTree.getSymbol().empty()) {
   543      // This is an op not explicitly bound to a symbol in the rewrite rule.
   544      // Register the auto-generated symbol for it.
   545      symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
   546    }
   547    return symbol;
   548  }
   549  
   550  std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
   551    assert(tree.isReplaceWithValue());
   552  
   553    if (tree.getNumArgs() != 1) {
   554      PrintFatalError(
   555          loc, "replaceWithValue directive must take exactly one argument");
   556    }
   557  
   558    if (!tree.getSymbol().empty()) {
   559      PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
   560    }
   561  
   562    return tree.getArgName(0);
   563  }
   564  
   565  std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
   566                                               StringRef patArgName) {
   567    if (leaf.isConstantAttr()) {
   568      auto constAttr = leaf.getAsConstantAttr();
   569      return handleConstantAttr(constAttr.getAttribute(),
   570                                constAttr.getConstantValue());
   571    }
   572    if (leaf.isEnumAttrCase()) {
   573      auto enumCase = leaf.getAsEnumAttrCase();
   574      if (enumCase.isStrCase())
   575        return handleConstantAttr(enumCase, enumCase.getSymbol());
   576      // This is an enum case backed by an IntegerAttr. We need to get its value
   577      // to build the constant.
   578      std::string val = std::to_string(enumCase.getValue());
   579      return handleConstantAttr(enumCase, val);
   580    }
   581  
   582    auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
   583    if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
   584      return argName;
   585    }
   586    if (leaf.isNativeCodeCall()) {
   587      return tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
   588    }
   589    PrintFatalError(loc, "unhandled case when rewriting op");
   590  }
   591  
   592  std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
   593    auto fmt = tree.getNativeCodeTemplate();
   594    // TODO(b/138794486): replace formatv arguments with the exact specified args.
   595    SmallVector<std::string, 8> attrs(8);
   596    if (tree.getNumArgs() > 8) {
   597      PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " +
   598                               Twine(tree.getNumArgs()));
   599    }
   600    for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
   601      attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
   602    }
   603    return tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], attrs[4],
   604                 attrs[5], attrs[6], attrs[7]);
   605  }
   606  
   607  int PatternEmitter::getNodeValueCount(DagNode node) {
   608    if (node.isOperation()) {
   609      // If the op is bound to a symbol in the rewrite rule, query its result
   610      // count from the symbol info map.
   611      auto symbol = node.getSymbol();
   612      if (!symbol.empty()) {
   613        return symbolInfoMap.getStaticValueCount(symbol);
   614      }
   615      // Otherwise this is an unbound op; we will use all its results.
   616      return pattern.getDialectOp(node).getNumResults();
   617    }
   618    // TODO(antiagainst): This considers all NativeCodeCall as returning one
   619    // value. Enhance if multi-value ones are needed.
   620    return 1;
   621  }
   622  
   623  std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
   624                                               int depth) {
   625    Operator &resultOp = tree.getDialectOp(opMap);
   626    auto numOpArgs = resultOp.getNumArgs();
   627  
   628    if (numOpArgs != tree.getNumArgs()) {
   629      PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: "
   630                                   "{1} in pattern vs. {2} in definition",
   631                                   resultOp.getOperationName(), tree.getNumArgs(),
   632                                   numOpArgs));
   633    }
   634  
   635    // A map to collect all nested DAG child nodes' names, with operand index as
   636    // the key. This includes both bound and unbound child nodes.
   637    llvm::DenseMap<unsigned, std::string> childNodeNames;
   638  
   639    // First go through all the child nodes who are nested DAG constructs to
   640    // create ops for them and remember the symbol names for them, so that we can
   641    // use the results in the current node. This happens in a recursive manner.
   642    for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
   643      if (auto child = tree.getArgAsNestedDag(i)) {
   644        childNodeNames[i] = handleResultPattern(child, i, depth + 1);
   645      }
   646    }
   647  
   648    // The name of the local variable holding this op.
   649    std::string valuePackName;
   650    // The symbol for holding the result of this pattern. Note that the result of
   651    // this pattern is not necessarily the same as the variable created by this
   652    // pattern because we can use `__N` suffix to refer only a specific result if
   653    // the generated op is a multi-result op.
   654    std::string resultValue;
   655    if (tree.getSymbol().empty()) {
   656      // No symbol is explicitly bound to this op in the pattern. Generate a
   657      // unique name.
   658      valuePackName = resultValue = getUniqueSymbol(&resultOp);
   659    } else {
   660      resultValue = tree.getSymbol();
   661      // Strip the index to get the name for the value pack and use it to name the
   662      // local variable for the op.
   663      valuePackName = SymbolInfoMap::getValuePackName(resultValue);
   664    }
   665  
   666    // Create the local variable for this op.
   667    os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
   668                            valuePackName);
   669    os.indent(4) << "{\n";
   670  
   671    // Now prepare operands used for building this op:
   672    // * If the operand is non-variadic, we create a `Value*` local variable.
   673    // * If the operand is variadic, we create a `SmallVector<Value*>` local
   674    //   variable.
   675  
   676    int argIndex = 0;   // The current index to this op's ODS argument
   677    int valueIndex = 0; // An index for uniquing local variable names.
   678    for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
   679      const auto &operand = resultOp.getOperand(argIndex);
   680      std::string varName;
   681      if (operand.isVariadic()) {
   682        varName = formatv("tblgen_values_{0}", valueIndex++);
   683        os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
   684        std::string range;
   685        if (tree.isNestedDagArg(argIndex)) {
   686          range = childNodeNames[argIndex];
   687        } else {
   688          range = tree.getArgName(argIndex);
   689        }
   690        // Resolve the symbol for all range use so that we have a uniform way of
   691        // capturing the values.
   692        range = symbolInfoMap.getValueAndRangeUse(range);
   693        os.indent(6) << formatv("for (auto *v : {0}) {1}.push_back(v);\n", range,
   694                                varName);
   695      } else {
   696        varName = formatv("tblgen_value_{0}", valueIndex++);
   697        os.indent(6) << formatv("Value *{0} = ", varName);
   698        if (tree.isNestedDagArg(argIndex)) {
   699          os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
   700        } else {
   701          DagLeaf leaf = tree.getArgAsLeaf(argIndex);
   702          auto symbol =
   703              symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
   704          if (leaf.isNativeCodeCall()) {
   705            os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
   706          } else {
   707            os << symbol;
   708          }
   709        }
   710        os << ";\n";
   711      }
   712  
   713      // Update to use the newly created local variable for building the op later.
   714      childNodeNames[argIndex] = varName;
   715    }
   716  
   717    // Then we create the builder call.
   718  
   719    // Right now we don't have general type inference in MLIR. Except a few
   720    // special cases listed below, we need to supply types for all results
   721    // when building an op.
   722    bool isSameOperandsAndResultType =
   723        resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
   724    bool isBroadcastable =
   725        resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
   726    bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
   727    bool usePartialResults = valuePackName != resultValue;
   728  
   729    if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
   730        usePartialResults || depth > 0 || resultIndex < 0) {
   731      os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
   732                              resultOp.getQualCppClassName());
   733    } else {
   734      // If depth == 0 and resultIndex >= 0, it means we are replacing the values
   735      // generated from the source pattern root op. Then we can use the source
   736      // pattern's value types to determine the value type of the generated op
   737      // here.
   738  
   739      // We need to specify the types for all results.
   740      int numResults = resultOp.getNumResults();
   741      if (numResults != 0) {
   742        os.indent(6) << "tblgen_types.clear();\n";
   743        for (int i = 0; i < numResults; ++i) {
   744          os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
   745                                  "tblgen_types.push_back(v->getType());\n",
   746                                  resultIndex + i);
   747        }
   748      }
   749  
   750      os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
   751                              resultOp.getQualCppClassName());
   752      if (numResults != 0)
   753        os.indent(6) << ", tblgen_types";
   754    }
   755  
   756    // Add operands for the builder all.
   757    for (int i = 0; i < argIndex; ++i) {
   758      const auto &operand = resultOp.getOperand(i);
   759      // Start each operand on its own line.
   760      (os << ",\n").indent(8);
   761      if (!operand.name.empty()) {
   762        os << "/*" << operand.name << "=*/";
   763      }
   764      os << childNodeNames[i];
   765      // TODO(jpienaar): verify types
   766    }
   767  
   768    // Add attributes for the builder call.
   769    for (; argIndex != numOpArgs; ++argIndex) {
   770      // Start each attribute on its own line.
   771      (os << ",\n").indent(8);
   772      // The argument in the op definition.
   773      auto opArgName = resultOp.getArgName(argIndex);
   774      if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
   775        if (!subTree.isNativeCodeCall())
   776          PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
   777                               "for creating attribute");
   778        os << formatv("/*{0}=*/{1}", opArgName,
   779                      handleReplaceWithNativeCodeCall(subTree));
   780      } else {
   781        auto leaf = tree.getArgAsLeaf(argIndex);
   782        // The argument in the result DAG pattern.
   783        auto patArgName = tree.getArgName(argIndex);
   784        if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
   785          // TODO(jpienaar): Refactor out into map to avoid recomputing these.
   786          auto argument = resultOp.getArg(argIndex);
   787          if (!argument.is<NamedAttribute *>())
   788            PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
   789          if (!patArgName.empty())
   790            os << "/*" << patArgName << "=*/";
   791        } else {
   792          os << "/*" << opArgName << "=*/";
   793        }
   794        os << handleOpArgument(leaf, patArgName);
   795      }
   796    }
   797    os << "\n      );\n";
   798    os.indent(4) << "}\n";
   799  
   800    return resultValue;
   801  }
   802  
   803  static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
   804    emitSourceFileHeader("Rewriters", os);
   805  
   806    const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
   807    auto numPatterns = patterns.size();
   808  
   809    // We put the map here because it can be shared among multiple patterns.
   810    RecordOperatorMap recordOpMap;
   811  
   812    std::vector<std::string> rewriterNames;
   813    rewriterNames.reserve(numPatterns);
   814  
   815    std::string baseRewriterName = "GeneratedConvert";
   816    int rewriterIndex = 0;
   817  
   818    for (Record *p : patterns) {
   819      std::string name;
   820      if (p->isAnonymous()) {
   821        // If no name is provided, ensure unique rewriter names simply by
   822        // appending unique suffix.
   823        name = baseRewriterName + llvm::utostr(rewriterIndex++);
   824      } else {
   825        name = p->getName();
   826      }
   827      PatternEmitter(p, &recordOpMap, os).emit(name);
   828      rewriterNames.push_back(std::move(name));
   829    }
   830  
   831    // Emit function to add the generated matchers to the pattern list.
   832    os << "void populateWithGenerated(MLIRContext *context, "
   833       << "OwningRewritePatternList *patterns) {\n";
   834    for (const auto &name : rewriterNames) {
   835      os << "  patterns->insert<" << name << ">(context);\n";
   836    }
   837    os << "}\n";
   838  }
   839  
   840  static mlir::GenRegistration
   841      genRewriters("gen-rewriters", "Generate pattern rewriters",
   842                   [](const RecordKeeper &records, raw_ostream &os) {
   843                     emitRewriters(records, os);
   844                     return false;
   845                   });