github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (about) 1 //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // SPIRVSerializationGen generates common utility functions for SPIR-V 19 // serialization. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Support/StringExtras.h" 24 #include "mlir/TableGen/Attribute.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/Operator.h" 27 #include "llvm/ADT/Sequence.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/ADT/StringExtras.h" 30 #include "llvm/ADT/StringRef.h" 31 #include "llvm/Support/FormatVariadic.h" 32 #include "llvm/Support/raw_ostream.h" 33 #include "llvm/TableGen/Error.h" 34 #include "llvm/TableGen/Record.h" 35 #include "llvm/TableGen/TableGenBackend.h" 36 37 using llvm::ArrayRef; 38 using llvm::formatv; 39 using llvm::raw_ostream; 40 using llvm::raw_string_ostream; 41 using llvm::Record; 42 using llvm::RecordKeeper; 43 using llvm::SMLoc; 44 using llvm::StringRef; 45 using llvm::Twine; 46 using mlir::tblgen::Attribute; 47 using mlir::tblgen::EnumAttr; 48 using mlir::tblgen::NamedAttribute; 49 using mlir::tblgen::NamedTypeConstraint; 50 using mlir::tblgen::Operator; 51 52 // Writes the following function to `os`: 53 // inline uint32_t getOpcode(<op-class-name>) { return <opcode>; } 54 static void emitGetOpcodeFunction(const Record *record, Operator const &op, 55 raw_ostream &os) { 56 os << formatv("template <> constexpr inline ::mlir::spirv::Opcode " 57 "getOpcode<{0}>()", 58 op.getQualCppClassName()) 59 << " {\n " 60 << formatv("return ::mlir::spirv::Opcode::{0};\n}\n", 61 record->getValueAsString("spirvOpName")); 62 } 63 64 static void declareOpcodeFn(raw_ostream &os) { 65 os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode " 66 "getOpcode();\n"; 67 } 68 69 static void emitAttributeSerialization(const Attribute &attr, 70 ArrayRef<SMLoc> loc, llvm::StringRef op, 71 llvm::StringRef operandList, 72 llvm::StringRef attrName, 73 raw_ostream &os) { 74 os << " auto attr = " << op << ".getAttr(\"" << attrName << "\");\n"; 75 os << " if (attr) {\n"; 76 if (attr.getAttrDefName() == "I32ArrayAttr") { 77 // Serialize all the elements of the array 78 os << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n"; 79 os << " " << operandList 80 << ".push_back(static_cast<uint32_t>(attrElem.cast<IntegerAttr>()." 81 "getValue().getZExtValue()));\n"; 82 os << " }\n"; 83 } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { 84 os << " " << operandList 85 << ".push_back(static_cast<uint32_t>(attr.cast<IntegerAttr>().getValue()" 86 ".getZExtValue()));\n"; 87 } else { 88 PrintFatalError( 89 loc, 90 llvm::Twine( 91 "unhandled attribute type in SPIR-V serialization generation : '") + 92 attr.getAttrDefName() + llvm::Twine("'")); 93 } 94 os << " }\n"; 95 } 96 97 static void emitSerializationFunction(const Record *attrClass, 98 const Record *record, const Operator &op, 99 raw_ostream &os) { 100 // If the record has 'autogenSerialization' set to 0, nothing to do 101 if (!record->getValueAsBit("autogenSerialization")) { 102 return; 103 } 104 os << formatv("template <> LogicalResult\nSerializer::processOp<{0}>(\n" 105 " {0} op)", 106 op.getQualCppClassName()) 107 << " {\n"; 108 os << " SmallVector<uint32_t, 4> operands;\n"; 109 os << " SmallVector<StringRef, 2> elidedAttrs;\n"; 110 111 // Serialize result information 112 if (op.getNumResults() == 1) { 113 os << " uint32_t resultTypeID = 0;\n"; 114 os << " if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) " 115 "{\n"; 116 os << " return failure();\n"; 117 os << " }\n"; 118 os << " operands.push_back(resultTypeID);\n"; 119 // Create an SSA result <id> for the op 120 os << " auto resultID = getNextID();\n"; 121 os << " valueIDMap[op.getResult()] = resultID;\n"; 122 os << " operands.push_back(resultID);\n"; 123 } else if (op.getNumResults() != 0) { 124 PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result"); 125 } 126 127 // Process arguments 128 auto operandNum = 0; 129 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { 130 auto argument = op.getArg(i); 131 os << " {\n"; 132 if (argument.is<NamedTypeConstraint *>()) { 133 os << " for (auto arg : op.getODSOperands(" << operandNum << ")) {\n"; 134 os << " auto argID = findValueID(arg);\n"; 135 os << " if (!argID) {\n"; 136 os << " emitError(op.getLoc(), \"operand " << operandNum 137 << " has a use before def\");\n"; 138 os << " }\n"; 139 os << " operands.push_back(argID);\n"; 140 os << " }\n"; 141 operandNum++; 142 } else { 143 auto attr = argument.get<NamedAttribute *>(); 144 emitAttributeSerialization( 145 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), 146 record->getLoc(), "op", "operands", attr->name, os); 147 os << " elidedAttrs.push_back(\"" << attr->name << "\");\n"; 148 } 149 os << " }\n"; 150 } 151 152 os << formatv(" encodeInstructionInto(" 153 "functions, spirv::getOpcode<{0}>(), operands);\n", 154 op.getQualCppClassName()); 155 156 if (op.getNumResults() == 1) { 157 // All non-argument attributes translated into OpDecorate instruction 158 os << " for (auto attr : op.getAttrs()) {\n"; 159 os << " if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return " 160 "attr.first.is(elided); })) {\n"; 161 os << " continue;\n"; 162 os << " }\n"; 163 os << " if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n"; 164 os << " return failure();"; 165 os << " }\n"; 166 os << " }\n"; 167 } 168 169 os << " return success();\n"; 170 os << "}\n\n"; 171 } 172 173 static void initDispatchSerializationFn(raw_ostream &os) { 174 os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " 175 "*op) {\n "; 176 } 177 178 static void emitSerializationDispatch(const Operator &op, raw_ostream &os) { 179 os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n"; 180 os << " "; 181 os << formatv("return processOp<{0}>(cast<{0}>(op));\n", 182 op.getQualCppClassName()); 183 os << " } else"; 184 } 185 186 static void finalizeDispatchSerializationFn(raw_ostream &os) { 187 os << " {\n"; 188 os << " return op->emitError(\"unhandled operation serialization\");\n"; 189 os << " }\n"; 190 os << " return success();\n"; 191 os << "}\n\n"; 192 } 193 194 static void emitAttributeDeserialization( 195 const Attribute &attr, ArrayRef<SMLoc> loc, llvm::StringRef attrList, 196 llvm::StringRef attrName, llvm::StringRef operandsList, 197 llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) { 198 if (attr.getAttrDefName() == "I32ArrayAttr") { 199 os << " SmallVector<Attribute, 4> attrListElems;\n"; 200 os << " while (" << wordIndex << " < " << wordCount << ") {\n"; 201 os << " attrListElems.push_back(opBuilder.getI32IntegerAttr(" 202 << operandsList << "[" << wordIndex << "++]));\n"; 203 os << " }\n"; 204 os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\"" 205 << attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n"; 206 } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { 207 os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\"" 208 << attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "[" 209 << wordIndex << "++])));\n"; 210 } else { 211 PrintFatalError( 212 loc, llvm::Twine( 213 "unhandled attribute type in deserialization generation : '") + 214 attr.getAttrDefName() + llvm::Twine("'")); 215 } 216 } 217 218 static void emitDeserializationFunction(const Record *attrClass, 219 const Record *record, 220 const Operator &op, raw_ostream &os) { 221 // If the record has 'autogenSerialization' set to 0, nothing to do 222 if (!record->getValueAsBit("autogenSerialization")) { 223 return; 224 } 225 os << formatv("template <> " 226 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<" 227 "uint32_t> words)", 228 op.getQualCppClassName()); 229 os << " {\n"; 230 os << " SmallVector<Type, 1> resultTypes;\n"; 231 os << " size_t wordIndex = 0; (void)wordIndex;\n"; 232 233 // Deserialize result information if it exists 234 bool hasResult = false; 235 if (op.getNumResults() == 1) { 236 os << " {\n"; 237 os << " if (wordIndex >= words.size()) {\n"; 238 os << " " 239 << formatv("return emitError(unknownLoc, \"expected result type <id> " 240 "while deserializing {0}\");\n", 241 op.getQualCppClassName()); 242 os << " }\n"; 243 os << " auto ty = getType(words[wordIndex]);\n"; 244 os << " if (!ty) {\n"; 245 os << " return emitError(unknownLoc, \"unknown type result <id> : " 246 "\") << words[wordIndex];\n"; 247 os << " }\n"; 248 os << " resultTypes.push_back(ty);\n"; 249 os << " wordIndex++;\n"; 250 os << " }\n"; 251 os << " if (wordIndex >= words.size()) {\n"; 252 os << " " 253 << formatv("return emitError(unknownLoc, \"expected result <id> while " 254 "deserializing {0}\");\n", 255 op.getQualCppClassName()); 256 os << " }\n"; 257 os << " uint32_t valueID = words[wordIndex++];\n"; 258 hasResult = true; 259 } else if (op.getNumResults() != 0) { 260 PrintFatalError(record->getLoc(), 261 "SPIR-V ops can have only zero or one result"); 262 } 263 264 // Process operands/attributes 265 os << " SmallVector<Value *, 4> operands;\n"; 266 os << " SmallVector<NamedAttribute, 4> attributes;\n"; 267 unsigned operandNum = 0; 268 for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { 269 auto argument = op.getArg(i); 270 if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) { 271 if (valueArg->isVariadic()) { 272 if (i != e - 1) { 273 PrintFatalError(record->getLoc(), 274 "SPIR-V ops can have Variadic<..> argument only if " 275 "it's the last argument"); 276 } 277 os << " for (; wordIndex < words.size(); ++wordIndex)"; 278 } else { 279 os << " if (wordIndex < words.size())"; 280 } 281 os << " {\n"; 282 os << " auto arg = getValue(words[wordIndex]);\n"; 283 os << " if (!arg) {\n"; 284 os << " return emitError(unknownLoc, \"unknown result <id> : \") << " 285 "words[wordIndex];\n"; 286 os << " }\n"; 287 os << " operands.push_back(arg);\n"; 288 if (!valueArg->isVariadic()) { 289 os << " wordIndex++;\n"; 290 } 291 operandNum++; 292 os << " }\n"; 293 } else { 294 os << " if (wordIndex < words.size()) {\n"; 295 auto attr = argument.get<NamedAttribute *>(); 296 emitAttributeDeserialization( 297 (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), 298 record->getLoc(), "attributes", attr->name, "words", "wordIndex", 299 "words.size()", os); 300 os << " }\n"; 301 } 302 } 303 304 os << " if (wordIndex != words.size()) {\n"; 305 os << " return emitError(unknownLoc, \"found more operands than expected " 306 "when deserializing " 307 << op.getQualCppClassName() 308 << ", only \") << wordIndex << \" of \" << words.size() << \" " 309 "processed\";\n"; 310 os << " }\n\n"; 311 312 // Import decorations parsed 313 if (op.getNumResults() == 1) { 314 os << " if (decorations.count(valueID)) {\n" 315 << " auto attrs = decorations[valueID].getAttrs();\n" 316 << " attributes.append(attrs.begin(), attrs.end());\n" 317 << " }\n"; 318 } 319 320 os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, " 321 "operands, attributes); (void)op;\n", 322 op.getQualCppClassName()); 323 if (hasResult) { 324 os << " valueMap[valueID] = op.getResult();\n\n"; 325 } 326 327 os << " return success();\n"; 328 os << "}\n\n"; 329 } 330 331 static void initDispatchDeserializationFn(raw_ostream &os) { 332 os << "LogicalResult " 333 "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode " 334 "opcode, ArrayRef<uint32_t> words) {\n"; 335 os << " switch (opcode) {\n"; 336 } 337 338 static void emitDeserializationDispatch(const Operator &op, const Record *def, 339 raw_ostream &os) { 340 os << formatv(" case spirv::Opcode::{0}:\n", 341 def->getValueAsString("spirvOpName")); 342 os << formatv(" return processOp<{0}>(words);\n", 343 op.getQualCppClassName()); 344 } 345 346 static void finalizeDispatchDeserializationFn(raw_ostream &os) { 347 os << " default:\n"; 348 os << " ;\n"; 349 os << " }\n"; 350 os << " return emitError(unknownLoc, \"unhandled deserialization of \") << " 351 "spirv::stringifyOpcode(opcode);\n"; 352 os << "}\n"; 353 } 354 355 static bool emitSerializationFns(const RecordKeeper &recordKeeper, 356 raw_ostream &os) { 357 llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os); 358 359 std::string dSerFnString, dDesFnString, serFnString, deserFnString, 360 utilsString; 361 raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), 362 serFn(serFnString), deserFn(deserFnString), utils(utilsString); 363 auto attrClass = recordKeeper.getClass("Attr"); 364 365 declareOpcodeFn(utils); 366 initDispatchSerializationFn(dSerFn); 367 initDispatchDeserializationFn(dDesFn); 368 auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); 369 for (const auto *def : defs) { 370 if (!def->getValueAsBit("hasOpcode")) { 371 continue; 372 } 373 Operator op(def); 374 emitGetOpcodeFunction(def, op, utils); 375 emitSerializationFunction(attrClass, def, op, serFn); 376 emitSerializationDispatch(op, dSerFn); 377 emitDeserializationFunction(attrClass, def, op, deserFn); 378 emitDeserializationDispatch(op, def, dDesFn); 379 } 380 finalizeDispatchSerializationFn(dSerFn); 381 finalizeDispatchDeserializationFn(dDesFn); 382 383 os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n"; 384 os << utils.str(); 385 os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n"; 386 387 os << "#ifdef GET_SERIALIZATION_FNS\n\n"; 388 os << serFn.str(); 389 os << dSerFn.str(); 390 os << "#endif // GET_SERIALIZATION_FNS\n\n"; 391 392 os << "#ifdef GET_DESERIALIZATION_FNS\n\n"; 393 os << deserFn.str(); 394 os << dDesFn.str(); 395 os << "#endif // GET_DESERIALIZATION_FNS\n\n"; 396 397 return false; 398 } 399 400 static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { 401 os << formatv("template <typename EnumClass> inline constexpr StringRef " 402 "attributeName();\n"); 403 } 404 405 static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) { 406 os << "template <typename EnumClass> using SymbolizeFnTy = " 407 "llvm::Optional<EnumClass> (*)(StringRef);\n"; 408 os << "template <typename EnumClass> inline constexpr " 409 "SymbolizeFnTy<EnumClass> symbolizeEnum();\n"; 410 } 411 412 static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr, 413 raw_ostream &os) { 414 auto enumName = enumAttr.getEnumClassName(); 415 os << formatv("template <> inline StringRef attributeName<{0}>()", enumName) 416 << " {\n"; 417 os << " " 418 << formatv("static constexpr const char attrName[] = \"{0}\";\n", 419 mlir::convertToSnakeCase(enumName)); 420 os << " return attrName;\n"; 421 os << "}\n"; 422 } 423 424 static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr, 425 raw_ostream &os) { 426 auto enumName = enumAttr.getEnumClassName(); 427 auto strToSymFnName = enumAttr.getStringToSymbolFnName(); 428 os << formatv("template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>()", 429 enumName) 430 << " {\n"; 431 os << " return " << strToSymFnName << ";\n"; 432 os << "}\n"; 433 } 434 435 static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { 436 llvm::emitSourceFileHeader("SPIR-V Op Utilites", os); 437 438 auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr"); 439 os << "#ifndef SPIRV_OP_UTILS_H_\n"; 440 os << "#define SPIRV_OP_UTILS_H_\n"; 441 emitEnumGetAttrNameFnDecl(os); 442 emitEnumGetSymbolizeFnDecl(os); 443 for (const auto *def : defs) { 444 EnumAttr enumAttr(*def); 445 emitEnumGetAttrNameFnDefn(enumAttr, os); 446 emitEnumGetSymbolizeFnDefn(enumAttr, os); 447 } 448 os << "#endif // SPIRV_OP_UTILS_H\n"; 449 return false; 450 } 451 452 // Registers the enum utility generator to mlir-tblgen. 453 static mlir::GenRegistration genSerialization( 454 "gen-spirv-serialization", 455 "Generate SPIR-V (de)serialization utilities and functions", 456 [](const RecordKeeper &records, raw_ostream &os) { 457 return emitSerializationFns(records, os); 458 }); 459 460 static mlir::GenRegistration 461 genOpUtils("gen-spirv-op-utils", 462 "Generate SPIR-V operation utility definitions", 463 [](const RecordKeeper &records, raw_ostream &os) { 464 return emitOpUtils(records, os); 465 });