github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp (about)

     1  //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // EnumsGen generates common utility functions for enums.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/TableGen/Attribute.h"
    23  #include "mlir/TableGen/GenInfo.h"
    24  #include "llvm/ADT/SmallVector.h"
    25  #include "llvm/ADT/StringExtras.h"
    26  #include "llvm/Support/FormatVariadic.h"
    27  #include "llvm/Support/raw_ostream.h"
    28  #include "llvm/TableGen/Error.h"
    29  #include "llvm/TableGen/Record.h"
    30  #include "llvm/TableGen/TableGenBackend.h"
    31  
    32  using llvm::formatv;
    33  using llvm::isDigit;
    34  using llvm::raw_ostream;
    35  using llvm::Record;
    36  using llvm::RecordKeeper;
    37  using llvm::StringRef;
    38  using mlir::tblgen::EnumAttr;
    39  using mlir::tblgen::EnumAttrCase;
    40  
    41  static std::string makeIdentifier(StringRef str) {
    42    if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
    43      std::string newStr = std::string("_") + str.str();
    44      return newStr;
    45    }
    46    return str.str();
    47  }
    48  
    49  static void emitEnumClass(const Record &enumDef, StringRef enumName,
    50                            StringRef underlyingType, StringRef description,
    51                            const std::vector<EnumAttrCase> &enumerants,
    52                            raw_ostream &os) {
    53    os << "// " << description << "\n";
    54    os << "enum class " << enumName;
    55  
    56    if (!underlyingType.empty())
    57      os << " : " << underlyingType;
    58    os << " {\n";
    59  
    60    for (const auto &enumerant : enumerants) {
    61      auto symbol = makeIdentifier(enumerant.getSymbol());
    62      auto value = enumerant.getValue();
    63      if (value >= 0) {
    64        os << formatv("  {0} = {1},\n", symbol, value);
    65      } else {
    66        os << formatv("  {0},\n", symbol);
    67      }
    68    }
    69    os << "};\n\n";
    70  }
    71  
    72  static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
    73                               StringRef cppNamespace, raw_ostream &os) {
    74    std::string qualName = formatv("{0}::{1}", cppNamespace, enumName);
    75    if (underlyingType.empty())
    76      underlyingType = formatv("std::underlying_type<{0}>::type", qualName);
    77  
    78    const char *const mapInfo = R"(
    79  namespace llvm {
    80  template<> struct DenseMapInfo<{0}> {{
    81    using StorageInfo = llvm::DenseMapInfo<{1}>;
    82  
    83    static inline {0} getEmptyKey() {{
    84      return static_cast<{0}>(StorageInfo::getEmptyKey());
    85    }
    86  
    87    static inline {0} getTombstoneKey() {{
    88      return static_cast<{0}>(StorageInfo::getTombstoneKey());
    89    }
    90  
    91    static unsigned getHashValue(const {0} &val) {{
    92      return StorageInfo::getHashValue(static_cast<{1}>(val));
    93    }
    94  
    95    static bool isEqual(const {0} &lhs, const {0} &rhs) {{
    96      return lhs == rhs;
    97    }
    98  };
    99  })";
   100    os << formatv(mapInfo, qualName, underlyingType);
   101    os << "\n\n";
   102  }
   103  
   104  static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
   105    EnumAttr enumAttr(enumDef);
   106    StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
   107    auto enumerants = enumAttr.getAllCases();
   108  
   109    unsigned maxEnumVal = 0;
   110    for (const auto &enumerant : enumerants) {
   111      int64_t value = enumerant.getValue();
   112      // Avoid generating the max value function if there is an enumerant without
   113      // explicit value.
   114      if (value < 0)
   115        return;
   116  
   117      maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
   118    }
   119  
   120    // Emit the function to return the max enum value
   121    os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
   122    os << formatv("  return {0};\n", maxEnumVal);
   123    os << "}\n\n";
   124  }
   125  
   126  static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
   127    EnumAttr enumAttr(enumDef);
   128    StringRef enumName = enumAttr.getEnumClassName();
   129    StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
   130    auto enumerants = enumAttr.getAllCases();
   131  
   132    os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
   133    os << "  switch (val) {\n";
   134    for (const auto &enumerant : enumerants) {
   135      auto symbol = enumerant.getSymbol();
   136      os << formatv("    case {0}::{1}: return \"{2}\";\n", enumName,
   137                    makeIdentifier(symbol), symbol);
   138    }
   139    os << "  }\n";
   140    os << "  return \"\";\n";
   141    os << "}\n\n";
   142  }
   143  
   144  static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) {
   145    EnumAttr enumAttr(enumDef);
   146    StringRef enumName = enumAttr.getEnumClassName();
   147    StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
   148    auto enumerants = enumAttr.getAllCases();
   149  
   150    os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
   151                  strToSymFnName);
   152    os << formatv("  return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
   153                  enumName);
   154    for (const auto &enumerant : enumerants) {
   155      auto symbol = enumerant.getSymbol();
   156      os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
   157                    makeIdentifier(symbol));
   158    }
   159    os << "      .Default(llvm::None);\n";
   160    os << "}\n";
   161  }
   162  
   163  static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
   164    EnumAttr enumAttr(enumDef);
   165    StringRef enumName = enumAttr.getEnumClassName();
   166    std::string underlyingType = enumAttr.getUnderlyingType();
   167    StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
   168    auto enumerants = enumAttr.getAllCases();
   169  
   170    // Avoid generating the underlying value to symbol conversion function if
   171    // there is an enumerant without explicit value.
   172    if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
   173          return enumerant.getValue() < 0;
   174        }))
   175      return;
   176  
   177    os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
   178                  underlyingToSymFnName,
   179                  underlyingType.empty() ? std::string("unsigned")
   180                                         : underlyingType)
   181       << "  switch (value) {\n";
   182    for (const auto &enumerant : enumerants) {
   183      auto symbol = enumerant.getSymbol();
   184      auto value = enumerant.getValue();
   185      os << formatv("  case {0}: return {1}::{2};\n", value, enumName,
   186                    makeIdentifier(symbol));
   187    }
   188    os << "  default: return llvm::None;\n"
   189       << "  }\n"
   190       << "}\n\n";
   191  }
   192  
   193  static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
   194    EnumAttr enumAttr(enumDef);
   195    StringRef enumName = enumAttr.getEnumClassName();
   196    StringRef cppNamespace = enumAttr.getCppNamespace();
   197    std::string underlyingType = enumAttr.getUnderlyingType();
   198    StringRef description = enumAttr.getDescription();
   199    StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
   200    StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
   201    StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
   202    auto enumerants = enumAttr.getAllCases();
   203  
   204    llvm::SmallVector<StringRef, 2> namespaces;
   205    llvm::SplitString(cppNamespace, namespaces, "::");
   206  
   207    for (auto ns : namespaces)
   208      os << "namespace " << ns << " {\n";
   209  
   210    // Emit the enum class definition
   211    emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
   212  
   213    // Emit coversion function declarations
   214    if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
   215          return enumerant.getValue() >= 0;
   216        })) {
   217      os << formatv(
   218          "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
   219          underlyingType.empty() ? std::string("unsigned") : underlyingType);
   220    }
   221    os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName);
   222    os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
   223                  strToSymFnName);
   224  
   225    emitMaxValueFn(enumDef, os);
   226  
   227    for (auto ns : llvm::reverse(namespaces))
   228      os << "} // namespace " << ns << "\n";
   229  
   230    // Emit DenseMapInfo for this enum class
   231    emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
   232  }
   233  
   234  static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
   235    llvm::emitSourceFileHeader("Enum Utility Declarations", os);
   236  
   237    auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
   238    for (const auto *def : defs)
   239      emitEnumDecl(*def, os);
   240  
   241    return false;
   242  }
   243  
   244  static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
   245    EnumAttr enumAttr(enumDef);
   246    StringRef cppNamespace = enumAttr.getCppNamespace();
   247  
   248    llvm::SmallVector<StringRef, 2> namespaces;
   249    llvm::SplitString(cppNamespace, namespaces, "::");
   250  
   251    for (auto ns : namespaces)
   252      os << "namespace " << ns << " {\n";
   253  
   254    emitSymToStrFn(enumDef, os);
   255    emitStrToSymFn(enumDef, os);
   256    emitUnderlyingToSymFn(enumDef, os);
   257  
   258    for (auto ns : llvm::reverse(namespaces))
   259      os << "} // namespace " << ns << "\n";
   260    os << "\n";
   261  }
   262  
   263  static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
   264    llvm::emitSourceFileHeader("Enum Utility Definitions", os);
   265  
   266    auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
   267    for (const auto *def : defs)
   268      emitEnumDef(*def, os);
   269  
   270    return false;
   271  }
   272  
   273  // Registers the enum utility generator to mlir-tblgen.
   274  static mlir::GenRegistration
   275      genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
   276                   [](const RecordKeeper &records, raw_ostream &os) {
   277                     return emitEnumDecls(records, os);
   278                   });
   279  
   280  // Registers the enum utility generator to mlir-tblgen.
   281  static mlir::GenRegistration
   282      genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
   283                  [](const RecordKeeper &records, raw_ostream &os) {
   284                    return emitEnumDefs(records, os);
   285                  });