github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (about)

     1  //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // OpInterfacesGen generates definitions for operation interfaces.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Support/STLExtras.h"
    23  #include "mlir/TableGen/GenInfo.h"
    24  #include "llvm/ADT/SmallVector.h"
    25  #include "llvm/ADT/StringExtras.h"
    26  #include "llvm/Support/FormatVariadic.h"
    27  #include "llvm/Support/raw_ostream.h"
    28  #include "llvm/TableGen/Error.h"
    29  #include "llvm/TableGen/Record.h"
    30  #include "llvm/TableGen/TableGenBackend.h"
    31  
    32  using namespace llvm;
    33  using namespace mlir;
    34  
    35  namespace {
    36  // This struct represents a single method argument.
    37  struct MethodArgument {
    38    StringRef type, name;
    39  };
    40  
    41  // Wrapper class around a single interface method.
    42  class OpInterfaceMethod {
    43  public:
    44    explicit OpInterfaceMethod(const llvm::Record *def) : def(def) {
    45      llvm::DagInit *args = def->getValueAsDag("arguments");
    46      for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
    47        arguments.push_back(
    48            {llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
    49             args->getArgNameStr(i)});
    50      }
    51    }
    52  
    53    // Return the return type of this method.
    54    StringRef getReturnType() const {
    55      return def->getValueAsString("returnType");
    56    }
    57  
    58    // Return the name of this method.
    59    StringRef getName() const { return def->getValueAsString("name"); }
    60  
    61    // Return if this method is static.
    62    bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); }
    63  
    64    // Return the body for this method if it has one.
    65    llvm::Optional<StringRef> getBody() const {
    66      auto value = def->getValueAsString("body");
    67      return value.empty() ? llvm::Optional<StringRef>() : value;
    68    }
    69  
    70    // Arguments.
    71    ArrayRef<MethodArgument> getArguments() const { return arguments; }
    72    bool arg_empty() const { return arguments.empty(); }
    73  
    74  protected:
    75    // The TableGen definition of this method.
    76    const llvm::Record *def;
    77  
    78    // The arguments of this method.
    79    SmallVector<MethodArgument, 2> arguments;
    80  };
    81  
    82  // Wrapper class with helper methods for accessing OpInterfaces defined in
    83  // TableGen.
    84  class OpInterface {
    85  public:
    86    explicit OpInterface(const llvm::Record *def) : def(def) {
    87      auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
    88      for (llvm::Init *init : listInit->getValues())
    89        methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
    90    }
    91  
    92    // Return the name of this interface.
    93    StringRef getName() const { return def->getValueAsString("cppClassName"); }
    94  
    95    // Return the methods of this interface.
    96    ArrayRef<OpInterfaceMethod> getMethods() const { return methods; }
    97  
    98  protected:
    99    // The TableGen definition of this interface.
   100    const llvm::Record *def;
   101  
   102    // The methods of this interface.
   103    SmallVector<OpInterfaceMethod, 8> methods;
   104  };
   105  } // end anonymous namespace
   106  
   107  // Emit the method name and argument list for the given method. If
   108  // 'addOperationArg' is true, then an Operation* argument is added to the
   109  // beginning of the argument list.
   110  static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
   111                                    raw_ostream &os, bool addOperationArg) {
   112    os << method.getName() << '(';
   113    if (addOperationArg)
   114      os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
   115    interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
   116      os << arg.type << " " << arg.name;
   117    });
   118    os << ')';
   119  }
   120  
   121  static void emitInterfaceDef(const Record &interfaceDef, raw_ostream &os) {
   122    OpInterface interface(&interfaceDef);
   123    StringRef interfaceName = interface.getName();
   124  
   125    // Insert the method definitions.
   126    for (auto &method : interface.getMethods()) {
   127      os << method.getReturnType() << " " << interfaceName << "::";
   128      emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
   129  
   130      // Forward to the method on the concrete operation type.
   131      os << " {\n      return getImpl()->" << method.getName() << '(';
   132      if (!method.isStatic())
   133        os << "getOperation()" << (method.arg_empty() ? "" : ", ");
   134      interleaveComma(method.getArguments(), os,
   135                      [&](const MethodArgument &arg) { os << arg.name; });
   136      os << ");\n  }\n";
   137    }
   138  }
   139  
   140  static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
   141                                raw_ostream &os) {
   142    llvm::emitSourceFileHeader("Operation Interface Definitions", os);
   143  
   144    auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
   145    for (const auto *def : defs)
   146      emitInterfaceDef(*def, os);
   147    return false;
   148  }
   149  
   150  static void emitConceptDecl(OpInterface &interface, raw_ostream &os) {
   151    os << "  class Concept {\n"
   152       << "  public:\n"
   153       << "    virtual ~Concept() = default;\n";
   154  
   155    // Insert each of the pure virtual concept methods.
   156    for (auto &method : interface.getMethods()) {
   157      os << "    virtual " << method.getReturnType() << " ";
   158      emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
   159      os << " = 0;\n";
   160    }
   161    os << "  };\n";
   162  }
   163  
   164  static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
   165    os << "  template<typename ConcreteOp>\n";
   166    os << "  class Model : public Concept {\npublic:\n";
   167  
   168    // Insert each of the virtual method overrides.
   169    for (auto &method : interface.getMethods()) {
   170      os << "    " << method.getReturnType() << " ";
   171      emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
   172      os << " final {\n";
   173  
   174      // Provide a definition of the concrete op if this is non static.
   175      if (!method.isStatic()) {
   176        os << "      auto op = llvm::cast<ConcreteOp>(tablegen_opaque_op);\n"
   177           << "      (void)op;\n";
   178      }
   179  
   180      // Check for a provided body to the function.
   181      if (auto body = method.getBody()) {
   182        os << body << "\n    }\n";
   183        continue;
   184      }
   185  
   186      // Forward to the method on the concrete operation type.
   187      os << "      return " << (method.isStatic() ? "ConcreteOp::" : "op.");
   188  
   189      // Add the arguments to the call.
   190      os << method.getName() << '(';
   191      interleaveComma(method.getArguments(), os,
   192                      [&](const MethodArgument &arg) { os << arg.name; });
   193      os << ");\n    }\n";
   194    }
   195    os << "  };\n";
   196  }
   197  
   198  static void emitInterfaceDecl(const Record &interfaceDef, raw_ostream &os) {
   199    OpInterface interface(&interfaceDef);
   200    StringRef interfaceName = interface.getName();
   201    auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
   202  
   203    // Emit the traits struct containing the concept and model declarations.
   204    os << "namespace detail {\n"
   205       << "struct " << interfaceTraitsName << " {\n";
   206    emitConceptDecl(interface, os);
   207    emitModelDecl(interface, os);
   208    os << "};\n} // end namespace detail\n";
   209  
   210    // Emit the main interface class declaration.
   211    os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
   212                        "public:\n"
   213                        "  using OpInterface<{1}, detail::{2}>::OpInterface;\n",
   214                        interfaceName, interfaceName, interfaceTraitsName);
   215  
   216    // Insert the method declarations.
   217    for (auto &method : interface.getMethods()) {
   218      os << "  " << method.getReturnType() << " ";
   219      emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
   220      os << ";\n";
   221    }
   222    os << "};\n";
   223  }
   224  
   225  static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
   226                                 raw_ostream &os) {
   227    llvm::emitSourceFileHeader("Operation Interface Declarations", os);
   228  
   229    auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
   230    for (const auto *def : defs)
   231      emitInterfaceDecl(*def, os);
   232    return false;
   233  }
   234  
   235  // Registers the operation interface generator to mlir-tblgen.
   236  static mlir::GenRegistration
   237      genInterfaceDecls("gen-op-interface-decls",
   238                        "Generate op interface declarations",
   239                        [](const RecordKeeper &records, raw_ostream &os) {
   240                          return emitInterfaceDecls(records, os);
   241                        });
   242  
   243  // Registers the operation interface generator to mlir-tblgen.
   244  static mlir::GenRegistration
   245      genInterfaceDefs("gen-op-interface-defs",
   246                       "Generate op interface definitions",
   247                       [](const RecordKeeper &records, raw_ostream &os) {
   248                         return emitInterfaceDefs(records, os);
   249                       });