github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/TableGen/Operator.cpp (about)

     1  //===- Operator.cpp - Operator class --------------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/TableGen/Operator.h"
    23  #include "mlir/TableGen/OpTrait.h"
    24  #include "mlir/TableGen/Predicate.h"
    25  #include "mlir/TableGen/Type.h"
    26  #include "llvm/Support/FormatVariadic.h"
    27  #include "llvm/TableGen/Error.h"
    28  #include "llvm/TableGen/Record.h"
    29  
    30  using namespace mlir;
    31  
    32  using llvm::DagInit;
    33  using llvm::DefInit;
    34  using llvm::Record;
    35  
    36  tblgen::Operator::Operator(const llvm::Record &def)
    37      : dialect(def.getValueAsDef("opDialect")), def(def) {
    38    // The first `_` in the op's TableGen def name is treated as separating the
    39    // dialect prefix and the op class name. The dialect prefix will be ignored if
    40    // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
    41    // as part of the class name.
    42    StringRef prefix;
    43    std::tie(prefix, cppClassName) = def.getName().split('_');
    44    if (prefix.empty()) {
    45      // Class name with a leading underscore and without dialect prefix
    46      cppClassName = def.getName();
    47    } else if (cppClassName.empty()) {
    48      // Class name without dialect prefix
    49      cppClassName = prefix;
    50    }
    51  
    52    populateOpStructure();
    53  }
    54  
    55  std::string tblgen::Operator::getOperationName() const {
    56    auto prefix = dialect.getName();
    57    auto opName = def.getValueAsString("opName");
    58    if (prefix.empty())
    59      return opName;
    60    return llvm::formatv("{0}.{1}", prefix, opName);
    61  }
    62  
    63  StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); }
    64  
    65  StringRef tblgen::Operator::getCppClassName() const { return cppClassName; }
    66  
    67  std::string tblgen::Operator::getQualCppClassName() const {
    68    auto prefix = dialect.getCppNamespace();
    69    if (prefix.empty())
    70      return cppClassName;
    71    return llvm::formatv("{0}::{1}", prefix, cppClassName);
    72  }
    73  
    74  int tblgen::Operator::getNumResults() const {
    75    DagInit *results = def.getValueAsDag("results");
    76    return results->getNumArgs();
    77  }
    78  
    79  StringRef tblgen::Operator::getExtraClassDeclaration() const {
    80    constexpr auto attr = "extraClassDeclaration";
    81    if (def.isValueUnset(attr))
    82      return {};
    83    return def.getValueAsString(attr);
    84  }
    85  
    86  const llvm::Record &tblgen::Operator::getDef() const { return def; }
    87  
    88  bool tblgen::Operator::isVariadic() const {
    89    return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0;
    90  }
    91  
    92  bool tblgen::Operator::skipDefaultBuilders() const {
    93    return def.getValueAsBit("skipDefaultBuilders");
    94  }
    95  
    96  auto tblgen::Operator::result_begin() -> value_iterator {
    97    return results.begin();
    98  }
    99  
   100  auto tblgen::Operator::result_end() -> value_iterator { return results.end(); }
   101  
   102  auto tblgen::Operator::getResults() -> value_range {
   103    return {result_begin(), result_end()};
   104  }
   105  
   106  tblgen::TypeConstraint
   107  tblgen::Operator::getResultTypeConstraint(int index) const {
   108    DagInit *results = def.getValueAsDag("results");
   109    return TypeConstraint(cast<DefInit>(results->getArg(index)));
   110  }
   111  
   112  StringRef tblgen::Operator::getResultName(int index) const {
   113    DagInit *results = def.getValueAsDag("results");
   114    return results->getArgNameStr(index);
   115  }
   116  
   117  unsigned tblgen::Operator::getNumVariadicResults() const {
   118    return std::count_if(
   119        results.begin(), results.end(),
   120        [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
   121  }
   122  
   123  unsigned tblgen::Operator::getNumVariadicOperands() const {
   124    return std::count_if(
   125        operands.begin(), operands.end(),
   126        [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
   127  }
   128  
   129  StringRef tblgen::Operator::getArgName(int index) const {
   130    DagInit *argumentValues = def.getValueAsDag("arguments");
   131    return argumentValues->getArgName(index)->getValue();
   132  }
   133  
   134  bool tblgen::Operator::hasTrait(StringRef trait) const {
   135    for (auto t : getTraits()) {
   136      if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
   137        if (opTrait->getTrait() == trait)
   138          return true;
   139      } else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
   140        if (opTrait->getTrait() == trait)
   141          return true;
   142      }
   143    }
   144    return false;
   145  }
   146  
   147  unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
   148  
   149  const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
   150    return regions[index];
   151  }
   152  
   153  auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
   154    return traits.begin();
   155  }
   156  auto tblgen::Operator::trait_end() const -> const_trait_iterator {
   157    return traits.end();
   158  }
   159  auto tblgen::Operator::getTraits() const
   160      -> llvm::iterator_range<const_trait_iterator> {
   161    return {trait_begin(), trait_end()};
   162  }
   163  
   164  auto tblgen::Operator::attribute_begin() const -> attribute_iterator {
   165    return attributes.begin();
   166  }
   167  auto tblgen::Operator::attribute_end() const -> attribute_iterator {
   168    return attributes.end();
   169  }
   170  auto tblgen::Operator::getAttributes() const
   171      -> llvm::iterator_range<attribute_iterator> {
   172    return {attribute_begin(), attribute_end()};
   173  }
   174  
   175  auto tblgen::Operator::operand_begin() -> value_iterator {
   176    return operands.begin();
   177  }
   178  auto tblgen::Operator::operand_end() -> value_iterator {
   179    return operands.end();
   180  }
   181  auto tblgen::Operator::getOperands() -> value_range {
   182    return {operand_begin(), operand_end()};
   183  }
   184  
   185  auto tblgen::Operator::getArg(int index) const -> Argument {
   186    return arguments[index];
   187  }
   188  
   189  void tblgen::Operator::populateOpStructure() {
   190    auto &recordKeeper = def.getRecords();
   191    auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
   192    auto attrClass = recordKeeper.getClass("Attr");
   193    auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
   194    numNativeAttributes = 0;
   195  
   196    // The argument ordering is operands, native attributes, derived
   197    // attributes.
   198    DagInit *argumentValues = def.getValueAsDag("arguments");
   199    unsigned i = 0;
   200    // Handle operands and native attributes.
   201    for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) {
   202      auto arg = argumentValues->getArg(i);
   203      auto givenName = argumentValues->getArgNameStr(i);
   204      auto argDefInit = dyn_cast<DefInit>(arg);
   205      if (!argDefInit)
   206        PrintFatalError(def.getLoc(),
   207                        Twine("undefined type for argument #") + Twine(i));
   208      Record *argDef = argDefInit->getDef();
   209  
   210      if (argDef->isSubClassOf(typeConstraintClass)) {
   211        operands.push_back(
   212            NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
   213        arguments.emplace_back(&operands.back());
   214      } else if (argDef->isSubClassOf(attrClass)) {
   215        if (givenName.empty())
   216          PrintFatalError(argDef->getLoc(), "attributes must be named");
   217        if (argDef->isSubClassOf(derivedAttrClass))
   218          PrintFatalError(argDef->getLoc(),
   219                          "derived attributes not allowed in argument list");
   220        attributes.push_back({givenName, Attribute(argDef)});
   221        arguments.emplace_back(&attributes.back());
   222        ++numNativeAttributes;
   223      } else {
   224        PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
   225                                      "from TypeConstraint or Attr are allowed");
   226      }
   227    }
   228  
   229    // Handle derived attributes.
   230    for (const auto &val : def.getValues()) {
   231      if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
   232        if (!record->isSubClassOf(attrClass))
   233          continue;
   234        if (!record->isSubClassOf(derivedAttrClass))
   235          PrintFatalError(def.getLoc(),
   236                          "unexpected Attr where only DerivedAttr is allowed");
   237  
   238        if (record->getClasses().size() != 1) {
   239          PrintFatalError(
   240              def.getLoc(),
   241              "unsupported attribute modelling, only single class expected");
   242        }
   243        attributes.push_back(
   244            {cast<llvm::StringInit>(val.getNameInit())->getValue(),
   245             Attribute(cast<DefInit>(val.getValue()))});
   246      }
   247    }
   248  
   249    auto *resultsDag = def.getValueAsDag("results");
   250    auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
   251    if (!outsOp || outsOp->getDef()->getName() != "outs") {
   252      PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
   253    }
   254  
   255    // Handle results.
   256    for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
   257      auto name = resultsDag->getArgNameStr(i);
   258      auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i));
   259      if (!resultDef) {
   260        PrintFatalError(def.getLoc(),
   261                        Twine("undefined type for result #") + Twine(i));
   262      }
   263      results.push_back({name, TypeConstraint(resultDef)});
   264    }
   265  
   266    auto traitListInit = def.getValueAsListInit("traits");
   267    if (!traitListInit)
   268      return;
   269    traits.reserve(traitListInit->size());
   270    for (auto traitInit : *traitListInit)
   271      traits.push_back(OpTrait::create(traitInit));
   272  
   273    // Handle regions
   274    auto *regionsDag = def.getValueAsDag("regions");
   275    auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
   276    if (!regionsOp || regionsOp->getDef()->getName() != "region") {
   277      PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
   278    }
   279  
   280    for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
   281      auto name = regionsDag->getArgNameStr(i);
   282      auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
   283      if (!regionInit) {
   284        PrintFatalError(def.getLoc(),
   285                        Twine("undefined kind for region #") + Twine(i));
   286      }
   287      regions.push_back({name, Region(regionInit->getDef())});
   288    }
   289  }
   290  
   291  ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
   292  
   293  bool tblgen::Operator::hasDescription() const {
   294    return def.getValue("description") != nullptr;
   295  }
   296  
   297  StringRef tblgen::Operator::getDescription() const {
   298    return def.getValueAsString("description");
   299  }
   300  
   301  bool tblgen::Operator::hasSummary() const {
   302    return def.getValue("summary") != nullptr;
   303  }
   304  
   305  StringRef tblgen::Operator::getSummary() const {
   306    return def.getValueAsString("summary");
   307  }