github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/StructsGen.cpp (about)

     1  //===- StructsGen.cpp - MLIR struct utility generator ---------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // StructsGen generates common utility functions for grouping attributes into a
    19  // set of structured data.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/TableGen/Attribute.h"
    24  #include "mlir/TableGen/Format.h"
    25  #include "mlir/TableGen/GenInfo.h"
    26  #include "mlir/TableGen/Operator.h"
    27  #include "llvm/ADT/SmallVector.h"
    28  #include "llvm/ADT/StringExtras.h"
    29  #include "llvm/Support/FormatVariadic.h"
    30  #include "llvm/Support/raw_ostream.h"
    31  #include "llvm/TableGen/Error.h"
    32  #include "llvm/TableGen/Record.h"
    33  #include "llvm/TableGen/TableGenBackend.h"
    34  
    35  using llvm::raw_ostream;
    36  using llvm::Record;
    37  using llvm::RecordKeeper;
    38  using llvm::StringRef;
    39  using mlir::tblgen::StructAttr;
    40  
    41  static void
    42  emitStructClass(const Record &structDef, StringRef structName,
    43                  llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
    44                  StringRef description, raw_ostream &os) {
    45    const char *structInfo = R"(
    46  // {0}
    47  class {1} : public mlir::DictionaryAttr)";
    48    const char *structInfoEnd = R"( {
    49  public:
    50    using DictionaryAttr::DictionaryAttr;
    51    static bool classof(mlir::Attribute attr);
    52  )";
    53    os << formatv(structInfo, description, structName) << structInfoEnd;
    54  
    55    // Declares a constructor function for the tablegen structure.
    56    //   TblgenStruct::get(MLIRContext context, Type1 Field1, Type2 Field2, ...);
    57    const char *getInfoDecl = "  static {0} get(\n";
    58    const char *getInfoDeclArg = "      {0} {1},\n";
    59    const char *getInfoDeclEnd = "      mlir::MLIRContext* context);\n\n";
    60  
    61    os << llvm::formatv(getInfoDecl, structName);
    62  
    63    for (auto field : fields) {
    64      auto name = field.getName();
    65      auto type = field.getType();
    66      auto storage = type.getStorageType();
    67      os << llvm::formatv(getInfoDeclArg, storage, name);
    68    }
    69    os << getInfoDeclEnd;
    70  
    71    // Declares an accessor for the fields owned by the tablegen structure.
    72    //   namespace::storage TblgenStruct::field1() const;
    73    const char *fieldInfo = R"(  {0} {1}() const;
    74  )";
    75    for (const auto field : fields) {
    76      auto name = field.getName();
    77      auto type = field.getType();
    78      auto storage = type.getStorageType();
    79      os << formatv(fieldInfo, storage, name);
    80    }
    81  
    82    os << "};\n\n";
    83  }
    84  
    85  static void emitStructDecl(const Record &structDef, raw_ostream &os) {
    86    StructAttr structAttr(&structDef);
    87    StringRef structName = structAttr.getStructClassName();
    88    StringRef cppNamespace = structAttr.getCppNamespace();
    89    StringRef description = structAttr.getDescription();
    90    auto fields = structAttr.getAllFields();
    91  
    92    // Wrap in the appropriate namespace.
    93    llvm::SmallVector<StringRef, 2> namespaces;
    94    llvm::SplitString(cppNamespace, namespaces, "::");
    95  
    96    for (auto ns : namespaces)
    97      os << "namespace " << ns << " {\n";
    98  
    99    // Emit the struct class definition
   100    emitStructClass(structDef, structName, fields, description, os);
   101  
   102    // Close the declared namespace.
   103    for (auto ns : namespaces)
   104      os << "} // namespace " << ns << "\n";
   105  }
   106  
   107  static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
   108    llvm::emitSourceFileHeader("Struct Utility Declarations", os);
   109  
   110    auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
   111    for (const auto *def : defs) {
   112      emitStructDecl(*def, os);
   113    }
   114  
   115    return false;
   116  }
   117  
   118  static void emitFactoryDef(llvm::StringRef structName,
   119                             llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
   120                             raw_ostream &os) {
   121    const char *getInfoDecl = "{0} {0}::get(\n";
   122    const char *getInfoDeclArg = "    {0} {1},\n";
   123    const char *getInfoDeclEnd = "    mlir::MLIRContext* context) {";
   124  
   125    os << llvm::formatv(getInfoDecl, structName);
   126  
   127    for (auto field : fields) {
   128      auto name = field.getName();
   129      auto type = field.getType();
   130      auto storage = type.getStorageType();
   131      os << llvm::formatv(getInfoDeclArg, storage, name);
   132    }
   133    os << getInfoDeclEnd;
   134  
   135    const char *fieldStart = R"(
   136    llvm::SmallVector<mlir::NamedAttribute, {0}> fields;
   137  )";
   138    os << llvm::formatv(fieldStart, fields.size());
   139  
   140    const char *getFieldInfo = R"(
   141    assert({0});
   142    auto {0}_id = mlir::Identifier::get("{0}", context);
   143    fields.emplace_back({0}_id, {0});
   144  )";
   145  
   146    for (auto field : fields) {
   147      os << llvm::formatv(getFieldInfo, field.getName());
   148    }
   149  
   150    const char *getEndInfo = R"(
   151    Attribute dict = mlir::DictionaryAttr::get(fields, context);
   152    return dict.dyn_cast<{0}>();
   153  }
   154  )";
   155    os << llvm::formatv(getEndInfo, structName);
   156  }
   157  
   158  static void emitClassofDef(llvm::StringRef structName,
   159                             llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
   160                             raw_ostream &os) {
   161    const char *classofInfo = R"(
   162  bool {0}::classof(mlir::Attribute attr))";
   163  
   164    const char *classofInfoHeader = R"(
   165     auto derived = attr.dyn_cast<mlir::DictionaryAttr>();
   166     if (!derived)
   167       return false;
   168     if (derived.size() != {0})
   169       return false;
   170  )";
   171  
   172    os << llvm::formatv(classofInfo, structName) << " {";
   173    os << llvm::formatv(classofInfoHeader, fields.size());
   174  
   175    const char *classofArgInfo = R"(
   176    auto {0} = derived.get("{0}");
   177    if (!{0} || !{0}.isa<{1}>())
   178      return false;
   179  )";
   180    for (auto field : fields) {
   181      auto name = field.getName();
   182      auto type = field.getType();
   183      auto storage = type.getStorageType();
   184      os << llvm::formatv(classofArgInfo, name, storage);
   185    }
   186  
   187    const char *classofEndInfo = R"(
   188    return true;
   189  }
   190  )";
   191    os << classofEndInfo;
   192  }
   193  
   194  static void
   195  emitAccessorDef(llvm::StringRef structName,
   196                  llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
   197                  raw_ostream &os) {
   198    const char *fieldInfo = R"(
   199  {0} {2}::{1}() const {
   200    auto derived = this->cast<mlir::DictionaryAttr>();
   201    auto {1} = derived.get("{1}");
   202    assert({1} && "attribute not found.");
   203    assert({1}.isa<{0}>() && "incorrect Attribute type found.");
   204    return {1}.cast<{0}>();
   205  }
   206  )";
   207    for (auto field : fields) {
   208      auto name = field.getName();
   209      auto type = field.getType();
   210      auto storage = type.getStorageType();
   211      os << llvm::formatv(fieldInfo, storage, name, structName);
   212    }
   213  }
   214  
   215  static void emitStructDef(const Record &structDef, raw_ostream &os) {
   216    StructAttr structAttr(&structDef);
   217    StringRef cppNamespace = structAttr.getCppNamespace();
   218    StringRef structName = structAttr.getStructClassName();
   219    mlir::tblgen::FmtContext ctx;
   220    auto fields = structAttr.getAllFields();
   221  
   222    llvm::SmallVector<StringRef, 2> namespaces;
   223    llvm::SplitString(cppNamespace, namespaces, "::");
   224  
   225    for (auto ns : namespaces)
   226      os << "namespace " << ns << " {\n";
   227  
   228    emitFactoryDef(structName, fields, os);
   229    emitClassofDef(structName, fields, os);
   230    emitAccessorDef(structName, fields, os);
   231  
   232    for (auto ns : llvm::reverse(namespaces))
   233      os << "} // namespace " << ns << "\n";
   234  }
   235  
   236  static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
   237    llvm::emitSourceFileHeader("Struct Utility Definitions", os);
   238  
   239    auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
   240    for (const auto *def : defs)
   241      emitStructDef(*def, os);
   242  
   243    return false;
   244  }
   245  
   246  // Registers the struct utility generator to mlir-tblgen.
   247  static mlir::GenRegistration
   248      genStructDecls("gen-struct-attr-decls",
   249                     "Generate struct utility declarations",
   250                     [](const RecordKeeper &records, raw_ostream &os) {
   251                       return emitStructDecls(records, os);
   252                     });
   253  
   254  // Registers the struct utility generator to mlir-tblgen.
   255  static mlir::GenRegistration
   256      genStructDefs("gen-struct-attr-defs", "Generate struct utility definitions",
   257                    [](const RecordKeeper &records, raw_ostream &os) {
   258                      return emitStructDefs(records, os);
   259                    });