github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/TableGen/Pattern.cpp (about)

     1  //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // Pattern wrapper class to simplify using TableGen Record defining a MLIR
    19  // Pattern.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/TableGen/Pattern.h"
    24  #include "llvm/ADT/Twine.h"
    25  #include "llvm/Support/FormatVariadic.h"
    26  #include "llvm/TableGen/Error.h"
    27  #include "llvm/TableGen/Record.h"
    28  
    29  using namespace mlir;
    30  
    31  using llvm::formatv;
    32  using mlir::tblgen::Operator;
    33  
    34  //===----------------------------------------------------------------------===//
    35  // DagLeaf
    36  //===----------------------------------------------------------------------===//
    37  
    38  bool tblgen::DagLeaf::isUnspecified() const {
    39    return dyn_cast_or_null<llvm::UnsetInit>(def);
    40  }
    41  
    42  bool tblgen::DagLeaf::isOperandMatcher() const {
    43    // Operand matchers specify a type constraint.
    44    return isSubClassOf("TypeConstraint");
    45  }
    46  
    47  bool tblgen::DagLeaf::isAttrMatcher() const {
    48    // Attribute matchers specify an attribute constraint.
    49    return isSubClassOf("AttrConstraint");
    50  }
    51  
    52  bool tblgen::DagLeaf::isNativeCodeCall() const {
    53    return isSubClassOf("NativeCodeCall");
    54  }
    55  
    56  bool tblgen::DagLeaf::isConstantAttr() const {
    57    return isSubClassOf("ConstantAttr");
    58  }
    59  
    60  bool tblgen::DagLeaf::isEnumAttrCase() const {
    61    return isSubClassOf("EnumAttrCaseInfo");
    62  }
    63  
    64  tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
    65    assert((isOperandMatcher() || isAttrMatcher()) &&
    66           "the DAG leaf must be operand or attribute");
    67    return Constraint(cast<llvm::DefInit>(def)->getDef());
    68  }
    69  
    70  tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
    71    assert(isConstantAttr() && "the DAG leaf must be constant attribute");
    72    return ConstantAttr(cast<llvm::DefInit>(def));
    73  }
    74  
    75  tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
    76    assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
    77    return EnumAttrCase(cast<llvm::DefInit>(def));
    78  }
    79  
    80  std::string tblgen::DagLeaf::getConditionTemplate() const {
    81    return getAsConstraint().getConditionTemplate();
    82  }
    83  
    84  llvm::StringRef tblgen::DagLeaf::getNativeCodeTemplate() const {
    85    assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
    86    return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
    87  }
    88  
    89  bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
    90    if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
    91      return defInit->getDef()->isSubClassOf(superclass);
    92    return false;
    93  }
    94  
    95  //===----------------------------------------------------------------------===//
    96  // DagNode
    97  //===----------------------------------------------------------------------===//
    98  
    99  bool tblgen::DagNode::isNativeCodeCall() const {
   100    if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
   101      return defInit->getDef()->isSubClassOf("NativeCodeCall");
   102    return false;
   103  }
   104  
   105  bool tblgen::DagNode::isOperation() const {
   106    return !(isNativeCodeCall() || isReplaceWithValue());
   107  }
   108  
   109  llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const {
   110    assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
   111    return cast<llvm::DefInit>(node->getOperator())
   112        ->getDef()
   113        ->getValueAsString("expression");
   114  }
   115  
   116  llvm::StringRef tblgen::DagNode::getSymbol() const {
   117    return node->getNameStr();
   118  }
   119  
   120  Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const {
   121    llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
   122    auto it = mapper->find(opDef);
   123    if (it != mapper->end())
   124      return *it->second;
   125    return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
   126                .first->second;
   127  }
   128  
   129  int tblgen::DagNode::getNumOps() const {
   130    int count = isReplaceWithValue() ? 0 : 1;
   131    for (int i = 0, e = getNumArgs(); i != e; ++i) {
   132      if (auto child = getArgAsNestedDag(i))
   133        count += child.getNumOps();
   134    }
   135    return count;
   136  }
   137  
   138  int tblgen::DagNode::getNumArgs() const { return node->getNumArgs(); }
   139  
   140  bool tblgen::DagNode::isNestedDagArg(unsigned index) const {
   141    return isa<llvm::DagInit>(node->getArg(index));
   142  }
   143  
   144  tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const {
   145    return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
   146  }
   147  
   148  tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const {
   149    assert(!isNestedDagArg(index));
   150    return DagLeaf(node->getArg(index));
   151  }
   152  
   153  StringRef tblgen::DagNode::getArgName(unsigned index) const {
   154    return node->getArgNameStr(index);
   155  }
   156  
   157  bool tblgen::DagNode::isReplaceWithValue() const {
   158    auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
   159    return dagOpDef->getName() == "replaceWithValue";
   160  }
   161  
   162  //===----------------------------------------------------------------------===//
   163  // SymbolInfoMap
   164  //===----------------------------------------------------------------------===//
   165  
   166  StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
   167                                                    int *index) {
   168    StringRef name, indexStr;
   169    int idx = -1;
   170    std::tie(name, indexStr) = symbol.rsplit("__");
   171  
   172    if (indexStr.consumeInteger(10, idx)) {
   173      // The second part is not an index; we return the whole symbol as-is.
   174      return symbol;
   175    }
   176    if (index) {
   177      *index = idx;
   178    }
   179    return name;
   180  }
   181  
   182  tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
   183                                                SymbolInfo::Kind kind,
   184                                                Optional<int> index)
   185      : op(op), kind(kind), argIndex(index) {}
   186  
   187  int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
   188    switch (kind) {
   189    case Kind::Attr:
   190    case Kind::Operand:
   191    case Kind::Value:
   192      return 1;
   193    case Kind::Result:
   194      return op->getNumResults();
   195    }
   196    llvm_unreachable("unknown kind");
   197  }
   198  
   199  std::string
   200  tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
   201    switch (kind) {
   202    case Kind::Attr: {
   203      auto type =
   204          op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
   205      return formatv("{0} {1};\n", type, name);
   206    }
   207    case Kind::Operand: {
   208      // Use operand range for captured operands (to support potential variadic
   209      // operands).
   210      return formatv("Operation::operand_range {0}(op0->getOperands());\n", name);
   211    }
   212    case Kind::Value: {
   213      return formatv("ArrayRef<Value *> {0};\n", name);
   214    }
   215    case Kind::Result: {
   216      // Use the op itself for captured results.
   217      return formatv("{0} {1};\n", op->getQualCppClassName(), name);
   218    }
   219    }
   220    llvm_unreachable("unknown kind");
   221  }
   222  
   223  std::string tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
   224      StringRef name, int index, const char *fmt, const char *separator) const {
   225    switch (kind) {
   226    case Kind::Attr: {
   227      assert(index < 0);
   228      return formatv(fmt, name);
   229    }
   230    case Kind::Operand: {
   231      assert(index < 0);
   232      auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
   233      // If this operand is variadic, then return a range. Otherwise, return the
   234      // value itself.
   235      if (operand->isVariadic()) {
   236        return formatv(fmt, name);
   237      }
   238      return formatv(fmt, formatv("(*{0}.begin())", name));
   239    }
   240    case Kind::Result: {
   241      // If `index` is greater than zero, then we are referencing a specific
   242      // result of a multi-result op. The result can still be variadic.
   243      if (index >= 0) {
   244        std::string v = formatv("{0}.getODSResults({1})", name, index);
   245        if (!op->getResult(index).isVariadic())
   246          v = formatv("(*{0}.begin())", v);
   247        return formatv(fmt, v);
   248      }
   249  
   250      // We are referencing all results of the multi-result op. A specific result
   251      // can either be a value or a range. Then join them with `separator`.
   252      SmallVector<std::string, 4> values;
   253      values.reserve(op->getNumResults());
   254  
   255      for (int i = 0, e = op->getNumResults(); i < e; ++i) {
   256        std::string v = formatv("{0}.getODSResults({1})", name, i);
   257        if (!op->getResult(i).isVariadic()) {
   258          v = formatv("(*{0}.begin())", v);
   259        }
   260        values.push_back(formatv(fmt, v));
   261      }
   262      return llvm::join(values, separator);
   263    }
   264    case Kind::Value: {
   265      assert(index < 0);
   266      assert(op == nullptr);
   267      return formatv(fmt, name);
   268    }
   269    }
   270  }
   271  
   272  std::string tblgen::SymbolInfoMap::SymbolInfo::getAllRangeUse(
   273      StringRef name, int index, const char *fmt, const char *separator) const {
   274    switch (kind) {
   275    case Kind::Attr:
   276    case Kind::Operand: {
   277      assert(index < 0 && "only allowed for symbol bound to result");
   278      return formatv(fmt, name);
   279    }
   280    case Kind::Result: {
   281      if (index >= 0) {
   282        return formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
   283      }
   284  
   285      // We are referencing all results of the multi-result op. Each result should
   286      // have a value range, and then join them with `separator`.
   287      SmallVector<std::string, 4> values;
   288      values.reserve(op->getNumResults());
   289  
   290      for (int i = 0, e = op->getNumResults(); i < e; ++i) {
   291        values.push_back(
   292            formatv(fmt, formatv("{0}.getODSResults({1})", name, i)));
   293      }
   294      return llvm::join(values, separator);
   295    }
   296    case Kind::Value: {
   297      assert(index < 0 && "only allowed for symbol bound to result");
   298      assert(op == nullptr);
   299      return formatv(fmt, formatv("{{{0}}", name));
   300    }
   301    }
   302    llvm_unreachable("unknown kind");
   303  }
   304  
   305  bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
   306                                             int argIndex) {
   307    StringRef name = getValuePackName(symbol);
   308    if (name != symbol) {
   309      auto error = formatv(
   310          "symbol '{0}' with trailing index cannot bind to op argument", symbol);
   311      PrintFatalError(loc, error);
   312    }
   313  
   314    auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
   315                       ? SymbolInfo::getAttr(&op, argIndex)
   316                       : SymbolInfo::getOperand(&op, argIndex);
   317  
   318    return symbolInfoMap.insert({symbol, symInfo}).second;
   319  }
   320  
   321  bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
   322    StringRef name = getValuePackName(symbol);
   323    return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
   324  }
   325  
   326  bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
   327    return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
   328  }
   329  
   330  bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
   331    return find(symbol) != symbolInfoMap.end();
   332  }
   333  
   334  tblgen::SymbolInfoMap::const_iterator
   335  tblgen::SymbolInfoMap::find(StringRef key) const {
   336    StringRef name = getValuePackName(key);
   337    return symbolInfoMap.find(name);
   338  }
   339  
   340  int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
   341    StringRef name = getValuePackName(symbol);
   342    if (name != symbol) {
   343      // If there is a trailing index inside symbol, it references just one
   344      // static value.
   345      return 1;
   346    }
   347    // Otherwise, find how many it represents by querying the symbol's info.
   348    return find(name)->getValue().getStaticValueCount();
   349  }
   350  
   351  std::string
   352  tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt,
   353                                             const char *separator) const {
   354    int index = -1;
   355    StringRef name = getValuePackName(symbol, &index);
   356  
   357    auto it = symbolInfoMap.find(name);
   358    if (it == symbolInfoMap.end()) {
   359      auto error = formatv("referencing unbound symbol '{0}'", symbol);
   360      PrintFatalError(loc, error);
   361    }
   362  
   363    return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
   364  }
   365  
   366  std::string tblgen::SymbolInfoMap::getAllRangeUse(StringRef symbol,
   367                                                    const char *fmt,
   368                                                    const char *separator) const {
   369    int index = -1;
   370    StringRef name = getValuePackName(symbol, &index);
   371  
   372    auto it = symbolInfoMap.find(name);
   373    if (it == symbolInfoMap.end()) {
   374      auto error = formatv("referencing unbound symbol '{0}'", symbol);
   375      PrintFatalError(loc, error);
   376    }
   377  
   378    return it->getValue().getAllRangeUse(name, index, fmt, separator);
   379  }
   380  
   381  //===----------------------------------------------------------------------===//
   382  // Pattern
   383  //==----------------------------------------------------------------------===//
   384  
   385  tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
   386      : def(*def), recordOpMap(mapper) {}
   387  
   388  tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
   389    return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
   390  }
   391  
   392  int tblgen::Pattern::getNumResultPatterns() const {
   393    auto *results = def.getValueAsListInit("resultPatterns");
   394    return results->size();
   395  }
   396  
   397  tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
   398    auto *results = def.getValueAsListInit("resultPatterns");
   399    return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
   400  }
   401  
   402  void tblgen::Pattern::collectSourcePatternBoundSymbols(
   403      tblgen::SymbolInfoMap &infoMap) {
   404    collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
   405  }
   406  
   407  void tblgen::Pattern::collectResultPatternBoundSymbols(
   408      tblgen::SymbolInfoMap &infoMap) {
   409    for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
   410      auto pattern = getResultPattern(i);
   411      collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
   412    }
   413  }
   414  
   415  const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
   416    return getSourcePattern().getDialectOp(recordOpMap);
   417  }
   418  
   419  tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
   420    return node.getDialectOp(recordOpMap);
   421  }
   422  
   423  std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
   424    auto *listInit = def.getValueAsListInit("constraints");
   425    std::vector<tblgen::AppliedConstraint> ret;
   426    ret.reserve(listInit->size());
   427  
   428    for (auto it : *listInit) {
   429      auto *dagInit = dyn_cast<llvm::DagInit>(it);
   430      if (!dagInit)
   431        PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity "
   432                                      "constraints should be DAG nodes");
   433  
   434      std::vector<std::string> entities;
   435      entities.reserve(dagInit->arg_size());
   436      for (auto *argName : dagInit->getArgNames())
   437        entities.push_back(argName->getValue());
   438  
   439      ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
   440                       dagInit->getNameStr(), std::move(entities));
   441    }
   442    return ret;
   443  }
   444  
   445  int tblgen::Pattern::getBenefit() const {
   446    // The initial benefit value is a heuristic with number of ops in the source
   447    // pattern.
   448    int initBenefit = getSourcePattern().getNumOps();
   449    llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
   450    if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
   451      PrintFatalError(def.getLoc(),
   452                      "The 'addBenefit' takes and only takes one integer value");
   453    }
   454    return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
   455  }
   456  
   457  std::vector<tblgen::Pattern::IdentifierLine>
   458  tblgen::Pattern::getLocation() const {
   459    std::vector<std::pair<StringRef, unsigned>> result;
   460    result.reserve(def.getLoc().size());
   461    for (auto loc : def.getLoc()) {
   462      unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
   463      assert(buf && "invalid source location");
   464      result.emplace_back(
   465          llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
   466          llvm::SrcMgr.getLineAndColumn(loc, buf).first);
   467    }
   468    return result;
   469  }
   470  
   471  void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
   472                                            bool isSrcPattern) {
   473    auto treeName = tree.getSymbol();
   474    if (!tree.isOperation()) {
   475      if (!treeName.empty()) {
   476        PrintFatalError(
   477            def.getLoc(),
   478            formatv("binding symbol '{0}' to non-operation unsupported right now",
   479                    treeName));
   480      }
   481      return;
   482    }
   483  
   484    auto &op = getDialectOp(tree);
   485    auto numOpArgs = op.getNumArgs();
   486    auto numTreeArgs = tree.getNumArgs();
   487  
   488    if (numOpArgs != numTreeArgs) {
   489      auto err = formatv("op '{0}' argument number mismatch: "
   490                         "{1} in pattern vs. {2} in definition",
   491                         op.getOperationName(), numTreeArgs, numOpArgs);
   492      PrintFatalError(def.getLoc(), err);
   493    }
   494  
   495    // The name attached to the DAG node's operator is for representing the
   496    // results generated from this op. It should be remembered as bound results.
   497    if (!treeName.empty()) {
   498      if (!infoMap.bindOpResult(treeName, op))
   499        PrintFatalError(def.getLoc(),
   500                        formatv("symbol '{0}' bound more than once", treeName));
   501    }
   502  
   503    for (int i = 0; i != numTreeArgs; ++i) {
   504      if (auto treeArg = tree.getArgAsNestedDag(i)) {
   505        // This DAG node argument is a DAG node itself. Go inside recursively.
   506        collectBoundSymbols(treeArg, infoMap, isSrcPattern);
   507      } else if (isSrcPattern) {
   508        // We can only bind symbols to op arguments in source pattern. Those
   509        // symbols are referenced in result patterns.
   510        auto treeArgName = tree.getArgName(i);
   511        if (!treeArgName.empty()) {
   512          if (!infoMap.bindOpArgument(treeArgName, op, i)) {
   513            auto err = formatv("symbol '{0}' bound more than once", treeArgName);
   514            PrintFatalError(def.getLoc(), err);
   515          }
   516        }
   517      }
   518    }
   519  }