github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (about)

     1  //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // SPIRVSerializationGen generates common utility functions for SPIR-V
    19  // serialization.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Support/StringExtras.h"
    24  #include "mlir/TableGen/Attribute.h"
    25  #include "mlir/TableGen/GenInfo.h"
    26  #include "mlir/TableGen/Operator.h"
    27  #include "llvm/ADT/Sequence.h"
    28  #include "llvm/ADT/SmallVector.h"
    29  #include "llvm/ADT/StringExtras.h"
    30  #include "llvm/ADT/StringRef.h"
    31  #include "llvm/Support/FormatVariadic.h"
    32  #include "llvm/Support/raw_ostream.h"
    33  #include "llvm/TableGen/Error.h"
    34  #include "llvm/TableGen/Record.h"
    35  #include "llvm/TableGen/TableGenBackend.h"
    36  
    37  using llvm::ArrayRef;
    38  using llvm::formatv;
    39  using llvm::raw_ostream;
    40  using llvm::raw_string_ostream;
    41  using llvm::Record;
    42  using llvm::RecordKeeper;
    43  using llvm::SMLoc;
    44  using llvm::StringRef;
    45  using llvm::Twine;
    46  using mlir::tblgen::Attribute;
    47  using mlir::tblgen::EnumAttr;
    48  using mlir::tblgen::NamedAttribute;
    49  using mlir::tblgen::NamedTypeConstraint;
    50  using mlir::tblgen::Operator;
    51  
    52  // Writes the following function to `os`:
    53  //   inline uint32_t getOpcode(<op-class-name>) { return <opcode>; }
    54  static void emitGetOpcodeFunction(const Record *record, Operator const &op,
    55                                    raw_ostream &os) {
    56    os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
    57                  "getOpcode<{0}>()",
    58                  op.getQualCppClassName())
    59       << " {\n  "
    60       << formatv("return ::mlir::spirv::Opcode::{0};\n}\n",
    61                  record->getValueAsString("spirvOpName"));
    62  }
    63  
    64  static void declareOpcodeFn(raw_ostream &os) {
    65    os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
    66          "getOpcode();\n";
    67  }
    68  
    69  static void emitAttributeSerialization(const Attribute &attr,
    70                                         ArrayRef<SMLoc> loc, llvm::StringRef op,
    71                                         llvm::StringRef operandList,
    72                                         llvm::StringRef attrName,
    73                                         raw_ostream &os) {
    74    os << "    auto attr = " << op << ".getAttr(\"" << attrName << "\");\n";
    75    os << "    if (attr) {\n";
    76    if (attr.getAttrDefName() == "I32ArrayAttr") {
    77      // Serialize all the elements of the array
    78      os << "      for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
    79      os << "        " << operandList
    80         << ".push_back(static_cast<uint32_t>(attrElem.cast<IntegerAttr>()."
    81            "getValue().getZExtValue()));\n";
    82      os << "      }\n";
    83    } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
    84      os << "      " << operandList
    85         << ".push_back(static_cast<uint32_t>(attr.cast<IntegerAttr>().getValue()"
    86            ".getZExtValue()));\n";
    87    } else {
    88      PrintFatalError(
    89          loc,
    90          llvm::Twine(
    91              "unhandled attribute type in SPIR-V serialization generation : '") +
    92              attr.getAttrDefName() + llvm::Twine("'"));
    93    }
    94    os << "    }\n";
    95  }
    96  
    97  static void emitSerializationFunction(const Record *attrClass,
    98                                        const Record *record, const Operator &op,
    99                                        raw_ostream &os) {
   100    // If the record has 'autogenSerialization' set to 0, nothing to do
   101    if (!record->getValueAsBit("autogenSerialization")) {
   102      return;
   103    }
   104    os << formatv("template <> LogicalResult\nSerializer::processOp<{0}>(\n"
   105                  "  {0} op)",
   106                  op.getQualCppClassName())
   107       << " {\n";
   108    os << "  SmallVector<uint32_t, 4> operands;\n";
   109    os << "  SmallVector<StringRef, 2> elidedAttrs;\n";
   110  
   111    // Serialize result information
   112    if (op.getNumResults() == 1) {
   113      os << "  uint32_t resultTypeID = 0;\n";
   114      os << "  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) "
   115            "{\n";
   116      os << "    return failure();\n";
   117      os << "  }\n";
   118      os << "  operands.push_back(resultTypeID);\n";
   119      // Create an SSA result <id> for the op
   120      os << "  auto resultID = getNextID();\n";
   121      os << "  valueIDMap[op.getResult()] = resultID;\n";
   122      os << "  operands.push_back(resultID);\n";
   123    } else if (op.getNumResults() != 0) {
   124      PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
   125    }
   126  
   127    // Process arguments
   128    auto operandNum = 0;
   129    for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
   130      auto argument = op.getArg(i);
   131      os << "  {\n";
   132      if (argument.is<NamedTypeConstraint *>()) {
   133        os << "    for (auto arg : op.getODSOperands(" << operandNum << ")) {\n";
   134        os << "      auto argID = findValueID(arg);\n";
   135        os << "      if (!argID) {\n";
   136        os << "        emitError(op.getLoc(), \"operand " << operandNum
   137           << " has a use before def\");\n";
   138        os << "      }\n";
   139        os << "      operands.push_back(argID);\n";
   140        os << "    }\n";
   141        operandNum++;
   142      } else {
   143        auto attr = argument.get<NamedAttribute *>();
   144        emitAttributeSerialization(
   145            (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
   146            record->getLoc(), "op", "operands", attr->name, os);
   147        os << "    elidedAttrs.push_back(\"" << attr->name << "\");\n";
   148      }
   149      os << "  }\n";
   150    }
   151  
   152    os << formatv("  encodeInstructionInto("
   153                  "functions, spirv::getOpcode<{0}>(), operands);\n",
   154                  op.getQualCppClassName());
   155  
   156    if (op.getNumResults() == 1) {
   157      // All non-argument attributes translated into OpDecorate instruction
   158      os << "  for (auto attr : op.getAttrs()) {\n";
   159      os << "    if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return "
   160            "attr.first.is(elided); })) {\n";
   161      os << "      continue;\n";
   162      os << "    }\n";
   163      os << "    if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n";
   164      os << "      return failure();";
   165      os << "    }\n";
   166      os << "  }\n";
   167    }
   168  
   169    os << "  return success();\n";
   170    os << "}\n\n";
   171  }
   172  
   173  static void initDispatchSerializationFn(raw_ostream &os) {
   174    os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
   175          "*op) {\n ";
   176  }
   177  
   178  static void emitSerializationDispatch(const Operator &op, raw_ostream &os) {
   179    os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n";
   180    os << "    ";
   181    os << formatv("return processOp<{0}>(cast<{0}>(op));\n",
   182                  op.getQualCppClassName());
   183    os << "  } else";
   184  }
   185  
   186  static void finalizeDispatchSerializationFn(raw_ostream &os) {
   187    os << " {\n";
   188    os << "    return op->emitError(\"unhandled operation serialization\");\n";
   189    os << "  }\n";
   190    os << "  return success();\n";
   191    os << "}\n\n";
   192  }
   193  
   194  static void emitAttributeDeserialization(
   195      const Attribute &attr, ArrayRef<SMLoc> loc, llvm::StringRef attrList,
   196      llvm::StringRef attrName, llvm::StringRef operandsList,
   197      llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) {
   198    if (attr.getAttrDefName() == "I32ArrayAttr") {
   199      os << "    SmallVector<Attribute, 4> attrListElems;\n";
   200      os << "    while (" << wordIndex << " < " << wordCount << ") {\n";
   201      os << "      attrListElems.push_back(opBuilder.getI32IntegerAttr("
   202         << operandsList << "[" << wordIndex << "++]));\n";
   203      os << "    }\n";
   204      os << "    " << attrList << ".push_back(opBuilder.getNamedAttr(\""
   205         << attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n";
   206    } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
   207      os << "    " << attrList << ".push_back(opBuilder.getNamedAttr(\""
   208         << attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "["
   209         << wordIndex << "++])));\n";
   210    } else {
   211      PrintFatalError(
   212          loc, llvm::Twine(
   213                   "unhandled attribute type in deserialization generation : '") +
   214                   attr.getAttrDefName() + llvm::Twine("'"));
   215    }
   216  }
   217  
   218  static void emitDeserializationFunction(const Record *attrClass,
   219                                          const Record *record,
   220                                          const Operator &op, raw_ostream &os) {
   221    // If the record has 'autogenSerialization' set to 0, nothing to do
   222    if (!record->getValueAsBit("autogenSerialization")) {
   223      return;
   224    }
   225    os << formatv("template <> "
   226                  "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
   227                  "uint32_t> words)",
   228                  op.getQualCppClassName());
   229    os << " {\n";
   230    os << "  SmallVector<Type, 1> resultTypes;\n";
   231    os << "  size_t wordIndex = 0; (void)wordIndex;\n";
   232  
   233    // Deserialize result information if it exists
   234    bool hasResult = false;
   235    if (op.getNumResults() == 1) {
   236      os << "  {\n";
   237      os << "    if (wordIndex >= words.size()) {\n";
   238      os << "      "
   239         << formatv("return emitError(unknownLoc, \"expected result type <id> "
   240                    "while deserializing {0}\");\n",
   241                    op.getQualCppClassName());
   242      os << "    }\n";
   243      os << "    auto ty = getType(words[wordIndex]);\n";
   244      os << "    if (!ty) {\n";
   245      os << "      return emitError(unknownLoc, \"unknown type result <id> : "
   246            "\") << words[wordIndex];\n";
   247      os << "    }\n";
   248      os << "    resultTypes.push_back(ty);\n";
   249      os << "    wordIndex++;\n";
   250      os << "  }\n";
   251      os << "  if (wordIndex >= words.size()) {\n";
   252      os << "    "
   253         << formatv("return emitError(unknownLoc, \"expected result <id> while "
   254                    "deserializing {0}\");\n",
   255                    op.getQualCppClassName());
   256      os << "  }\n";
   257      os << "  uint32_t valueID = words[wordIndex++];\n";
   258      hasResult = true;
   259    } else if (op.getNumResults() != 0) {
   260      PrintFatalError(record->getLoc(),
   261                      "SPIR-V ops can have only zero or one result");
   262    }
   263  
   264    // Process operands/attributes
   265    os << "  SmallVector<Value *, 4> operands;\n";
   266    os << "  SmallVector<NamedAttribute, 4> attributes;\n";
   267    unsigned operandNum = 0;
   268    for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
   269      auto argument = op.getArg(i);
   270      if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
   271        if (valueArg->isVariadic()) {
   272          if (i != e - 1) {
   273            PrintFatalError(record->getLoc(),
   274                            "SPIR-V ops can have Variadic<..> argument only if "
   275                            "it's the last argument");
   276          }
   277          os << "  for (; wordIndex < words.size(); ++wordIndex)";
   278        } else {
   279          os << "  if (wordIndex < words.size())";
   280        }
   281        os << " {\n";
   282        os << "    auto arg = getValue(words[wordIndex]);\n";
   283        os << "    if (!arg) {\n";
   284        os << "      return emitError(unknownLoc, \"unknown result <id> : \") << "
   285              "words[wordIndex];\n";
   286        os << "    }\n";
   287        os << "    operands.push_back(arg);\n";
   288        if (!valueArg->isVariadic()) {
   289          os << "    wordIndex++;\n";
   290        }
   291        operandNum++;
   292        os << "  }\n";
   293      } else {
   294        os << "  if (wordIndex < words.size()) {\n";
   295        auto attr = argument.get<NamedAttribute *>();
   296        emitAttributeDeserialization(
   297            (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
   298            record->getLoc(), "attributes", attr->name, "words", "wordIndex",
   299            "words.size()", os);
   300        os << "  }\n";
   301      }
   302    }
   303  
   304    os << "  if (wordIndex != words.size()) {\n";
   305    os << "    return emitError(unknownLoc, \"found more operands than expected "
   306          "when deserializing "
   307       << op.getQualCppClassName()
   308       << ", only \") << wordIndex << \" of \" << words.size() << \" "
   309          "processed\";\n";
   310    os << "  }\n\n";
   311  
   312    // Import decorations parsed
   313    if (op.getNumResults() == 1) {
   314      os << "  if (decorations.count(valueID)) {\n"
   315         << "    auto attrs = decorations[valueID].getAttrs();\n"
   316         << "    attributes.append(attrs.begin(), attrs.end());\n"
   317         << "  }\n";
   318    }
   319  
   320    os << formatv("  auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
   321                  "operands, attributes); (void)op;\n",
   322                  op.getQualCppClassName());
   323    if (hasResult) {
   324      os << "  valueMap[valueID] = op.getResult();\n\n";
   325    }
   326  
   327    os << "  return success();\n";
   328    os << "}\n\n";
   329  }
   330  
   331  static void initDispatchDeserializationFn(raw_ostream &os) {
   332    os << "LogicalResult "
   333          "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode "
   334          "opcode, ArrayRef<uint32_t> words) {\n";
   335    os << "  switch (opcode) {\n";
   336  }
   337  
   338  static void emitDeserializationDispatch(const Operator &op, const Record *def,
   339                                          raw_ostream &os) {
   340    os << formatv("  case spirv::Opcode::{0}:\n",
   341                  def->getValueAsString("spirvOpName"));
   342    os << formatv("    return processOp<{0}>(words);\n",
   343                  op.getQualCppClassName());
   344  }
   345  
   346  static void finalizeDispatchDeserializationFn(raw_ostream &os) {
   347    os << "  default:\n";
   348    os << "    ;\n";
   349    os << "  }\n";
   350    os << "  return emitError(unknownLoc, \"unhandled deserialization of \") << "
   351          "spirv::stringifyOpcode(opcode);\n";
   352    os << "}\n";
   353  }
   354  
   355  static bool emitSerializationFns(const RecordKeeper &recordKeeper,
   356                                   raw_ostream &os) {
   357    llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
   358  
   359    std::string dSerFnString, dDesFnString, serFnString, deserFnString,
   360        utilsString;
   361    raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
   362        serFn(serFnString), deserFn(deserFnString), utils(utilsString);
   363    auto attrClass = recordKeeper.getClass("Attr");
   364  
   365    declareOpcodeFn(utils);
   366    initDispatchSerializationFn(dSerFn);
   367    initDispatchDeserializationFn(dDesFn);
   368    auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
   369    for (const auto *def : defs) {
   370      if (!def->getValueAsBit("hasOpcode")) {
   371        continue;
   372      }
   373      Operator op(def);
   374      emitGetOpcodeFunction(def, op, utils);
   375      emitSerializationFunction(attrClass, def, op, serFn);
   376      emitSerializationDispatch(op, dSerFn);
   377      emitDeserializationFunction(attrClass, def, op, deserFn);
   378      emitDeserializationDispatch(op, def, dDesFn);
   379    }
   380    finalizeDispatchSerializationFn(dSerFn);
   381    finalizeDispatchDeserializationFn(dDesFn);
   382  
   383    os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n";
   384    os << utils.str();
   385    os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n";
   386  
   387    os << "#ifdef GET_SERIALIZATION_FNS\n\n";
   388    os << serFn.str();
   389    os << dSerFn.str();
   390    os << "#endif // GET_SERIALIZATION_FNS\n\n";
   391  
   392    os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
   393    os << deserFn.str();
   394    os << dDesFn.str();
   395    os << "#endif // GET_DESERIALIZATION_FNS\n\n";
   396  
   397    return false;
   398  }
   399  
   400  static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
   401    os << formatv("template <typename EnumClass> inline constexpr StringRef "
   402                  "attributeName();\n");
   403  }
   404  
   405  static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) {
   406    os << "template <typename EnumClass> using SymbolizeFnTy = "
   407          "llvm::Optional<EnumClass> (*)(StringRef);\n";
   408    os << "template <typename EnumClass> inline constexpr "
   409          "SymbolizeFnTy<EnumClass> symbolizeEnum();\n";
   410  }
   411  
   412  static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
   413                                        raw_ostream &os) {
   414    auto enumName = enumAttr.getEnumClassName();
   415    os << formatv("template <> inline StringRef attributeName<{0}>()", enumName)
   416       << " {\n";
   417    os << "  "
   418       << formatv("static constexpr const char attrName[] = \"{0}\";\n",
   419                  mlir::convertToSnakeCase(enumName));
   420    os << "  return attrName;\n";
   421    os << "}\n";
   422  }
   423  
   424  static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr,
   425                                         raw_ostream &os) {
   426    auto enumName = enumAttr.getEnumClassName();
   427    auto strToSymFnName = enumAttr.getStringToSymbolFnName();
   428    os << formatv("template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>()",
   429                  enumName)
   430       << " {\n";
   431    os << "  return " << strToSymFnName << ";\n";
   432    os << "}\n";
   433  }
   434  
   435  static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
   436    llvm::emitSourceFileHeader("SPIR-V Op Utilites", os);
   437  
   438    auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr");
   439    os << "#ifndef SPIRV_OP_UTILS_H_\n";
   440    os << "#define SPIRV_OP_UTILS_H_\n";
   441    emitEnumGetAttrNameFnDecl(os);
   442    emitEnumGetSymbolizeFnDecl(os);
   443    for (const auto *def : defs) {
   444      EnumAttr enumAttr(*def);
   445      emitEnumGetAttrNameFnDefn(enumAttr, os);
   446      emitEnumGetSymbolizeFnDefn(enumAttr, os);
   447    }
   448    os << "#endif // SPIRV_OP_UTILS_H\n";
   449    return false;
   450  }
   451  
   452  // Registers the enum utility generator to mlir-tblgen.
   453  static mlir::GenRegistration genSerialization(
   454      "gen-spirv-serialization",
   455      "Generate SPIR-V (de)serialization utilities and functions",
   456      [](const RecordKeeper &records, raw_ostream &os) {
   457        return emitSerializationFns(records, os);
   458      });
   459  
   460  static mlir::GenRegistration
   461      genOpUtils("gen-spirv-op-utils",
   462                 "Generate SPIR-V operation utility definitions",
   463                 [](const RecordKeeper &records, raw_ostream &os) {
   464                   return emitOpUtils(records, os);
   465                 });