github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/EnumsGen.cpp (about) 1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // EnumsGen generates common utility functions for enums. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Record.h" 30 #include "llvm/TableGen/TableGenBackend.h" 31 32 using llvm::formatv; 33 using llvm::isDigit; 34 using llvm::raw_ostream; 35 using llvm::Record; 36 using llvm::RecordKeeper; 37 using llvm::StringRef; 38 using mlir::tblgen::EnumAttr; 39 using mlir::tblgen::EnumAttrCase; 40 41 static std::string makeIdentifier(StringRef str) { 42 if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) { 43 std::string newStr = std::string("_") + str.str(); 44 return newStr; 45 } 46 return str.str(); 47 } 48 49 static void emitEnumClass(const Record &enumDef, StringRef enumName, 50 StringRef underlyingType, StringRef description, 51 const std::vector<EnumAttrCase> &enumerants, 52 raw_ostream &os) { 53 os << "// " << description << "\n"; 54 os << "enum class " << enumName; 55 56 if (!underlyingType.empty()) 57 os << " : " << underlyingType; 58 os << " {\n"; 59 60 for (const auto &enumerant : enumerants) { 61 auto symbol = makeIdentifier(enumerant.getSymbol()); 62 auto value = enumerant.getValue(); 63 if (value >= 0) { 64 os << formatv(" {0} = {1},\n", symbol, value); 65 } else { 66 os << formatv(" {0},\n", symbol); 67 } 68 } 69 os << "};\n\n"; 70 } 71 72 static void emitDenseMapInfo(StringRef enumName, std::string underlyingType, 73 StringRef cppNamespace, raw_ostream &os) { 74 std::string qualName = formatv("{0}::{1}", cppNamespace, enumName); 75 if (underlyingType.empty()) 76 underlyingType = formatv("std::underlying_type<{0}>::type", qualName); 77 78 const char *const mapInfo = R"( 79 namespace llvm { 80 template<> struct DenseMapInfo<{0}> {{ 81 using StorageInfo = llvm::DenseMapInfo<{1}>; 82 83 static inline {0} getEmptyKey() {{ 84 return static_cast<{0}>(StorageInfo::getEmptyKey()); 85 } 86 87 static inline {0} getTombstoneKey() {{ 88 return static_cast<{0}>(StorageInfo::getTombstoneKey()); 89 } 90 91 static unsigned getHashValue(const {0} &val) {{ 92 return StorageInfo::getHashValue(static_cast<{1}>(val)); 93 } 94 95 static bool isEqual(const {0} &lhs, const {0} &rhs) {{ 96 return lhs == rhs; 97 } 98 }; 99 })"; 100 os << formatv(mapInfo, qualName, underlyingType); 101 os << "\n\n"; 102 } 103 104 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) { 105 EnumAttr enumAttr(enumDef); 106 StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); 107 auto enumerants = enumAttr.getAllCases(); 108 109 unsigned maxEnumVal = 0; 110 for (const auto &enumerant : enumerants) { 111 int64_t value = enumerant.getValue(); 112 // Avoid generating the max value function if there is an enumerant without 113 // explicit value. 114 if (value < 0) 115 return; 116 117 maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value)); 118 } 119 120 // Emit the function to return the max enum value 121 os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName); 122 os << formatv(" return {0};\n", maxEnumVal); 123 os << "}\n\n"; 124 } 125 126 static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) { 127 EnumAttr enumAttr(enumDef); 128 StringRef enumName = enumAttr.getEnumClassName(); 129 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); 130 auto enumerants = enumAttr.getAllCases(); 131 132 os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); 133 os << " switch (val) {\n"; 134 for (const auto &enumerant : enumerants) { 135 auto symbol = enumerant.getSymbol(); 136 os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, 137 makeIdentifier(symbol), symbol); 138 } 139 os << " }\n"; 140 os << " return \"\";\n"; 141 os << "}\n\n"; 142 } 143 144 static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) { 145 EnumAttr enumAttr(enumDef); 146 StringRef enumName = enumAttr.getEnumClassName(); 147 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); 148 auto enumerants = enumAttr.getAllCases(); 149 150 os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, 151 strToSymFnName); 152 os << formatv(" return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n", 153 enumName); 154 for (const auto &enumerant : enumerants) { 155 auto symbol = enumerant.getSymbol(); 156 os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, 157 makeIdentifier(symbol)); 158 } 159 os << " .Default(llvm::None);\n"; 160 os << "}\n"; 161 } 162 163 static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) { 164 EnumAttr enumAttr(enumDef); 165 StringRef enumName = enumAttr.getEnumClassName(); 166 std::string underlyingType = enumAttr.getUnderlyingType(); 167 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); 168 auto enumerants = enumAttr.getAllCases(); 169 170 // Avoid generating the underlying value to symbol conversion function if 171 // there is an enumerant without explicit value. 172 if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) { 173 return enumerant.getValue() < 0; 174 })) 175 return; 176 177 os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName, 178 underlyingToSymFnName, 179 underlyingType.empty() ? std::string("unsigned") 180 : underlyingType) 181 << " switch (value) {\n"; 182 for (const auto &enumerant : enumerants) { 183 auto symbol = enumerant.getSymbol(); 184 auto value = enumerant.getValue(); 185 os << formatv(" case {0}: return {1}::{2};\n", value, enumName, 186 makeIdentifier(symbol)); 187 } 188 os << " default: return llvm::None;\n" 189 << " }\n" 190 << "}\n\n"; 191 } 192 193 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { 194 EnumAttr enumAttr(enumDef); 195 StringRef enumName = enumAttr.getEnumClassName(); 196 StringRef cppNamespace = enumAttr.getCppNamespace(); 197 std::string underlyingType = enumAttr.getUnderlyingType(); 198 StringRef description = enumAttr.getDescription(); 199 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); 200 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); 201 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); 202 auto enumerants = enumAttr.getAllCases(); 203 204 llvm::SmallVector<StringRef, 2> namespaces; 205 llvm::SplitString(cppNamespace, namespaces, "::"); 206 207 for (auto ns : namespaces) 208 os << "namespace " << ns << " {\n"; 209 210 // Emit the enum class definition 211 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); 212 213 // Emit coversion function declarations 214 if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) { 215 return enumerant.getValue() >= 0; 216 })) { 217 os << formatv( 218 "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, 219 underlyingType.empty() ? std::string("unsigned") : underlyingType); 220 } 221 os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName); 222 os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName, 223 strToSymFnName); 224 225 emitMaxValueFn(enumDef, os); 226 227 for (auto ns : llvm::reverse(namespaces)) 228 os << "} // namespace " << ns << "\n"; 229 230 // Emit DenseMapInfo for this enum class 231 emitDenseMapInfo(enumName, underlyingType, cppNamespace, os); 232 } 233 234 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { 235 llvm::emitSourceFileHeader("Enum Utility Declarations", os); 236 237 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); 238 for (const auto *def : defs) 239 emitEnumDecl(*def, os); 240 241 return false; 242 } 243 244 static void emitEnumDef(const Record &enumDef, raw_ostream &os) { 245 EnumAttr enumAttr(enumDef); 246 StringRef cppNamespace = enumAttr.getCppNamespace(); 247 248 llvm::SmallVector<StringRef, 2> namespaces; 249 llvm::SplitString(cppNamespace, namespaces, "::"); 250 251 for (auto ns : namespaces) 252 os << "namespace " << ns << " {\n"; 253 254 emitSymToStrFn(enumDef, os); 255 emitStrToSymFn(enumDef, os); 256 emitUnderlyingToSymFn(enumDef, os); 257 258 for (auto ns : llvm::reverse(namespaces)) 259 os << "} // namespace " << ns << "\n"; 260 os << "\n"; 261 } 262 263 static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { 264 llvm::emitSourceFileHeader("Enum Utility Definitions", os); 265 266 auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); 267 for (const auto *def : defs) 268 emitEnumDef(*def, os); 269 270 return false; 271 } 272 273 // Registers the enum utility generator to mlir-tblgen. 274 static mlir::GenRegistration 275 genEnumDecls("gen-enum-decls", "Generate enum utility declarations", 276 [](const RecordKeeper &records, raw_ostream &os) { 277 return emitEnumDecls(records, os); 278 }); 279 280 // Registers the enum utility generator to mlir-tblgen. 281 static mlir::GenRegistration 282 genEnumDefs("gen-enum-defs", "Generate enum utility definitions", 283 [](const RecordKeeper &records, raw_ostream &os) { 284 return emitEnumDefs(records, os); 285 });