github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (about) 1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // OpInterfacesGen generates definitions for operation interfaces. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Support/STLExtras.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "llvm/ADT/SmallVector.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Record.h" 30 #include "llvm/TableGen/TableGenBackend.h" 31 32 using namespace llvm; 33 using namespace mlir; 34 35 namespace { 36 // This struct represents a single method argument. 37 struct MethodArgument { 38 StringRef type, name; 39 }; 40 41 // Wrapper class around a single interface method. 42 class OpInterfaceMethod { 43 public: 44 explicit OpInterfaceMethod(const llvm::Record *def) : def(def) { 45 llvm::DagInit *args = def->getValueAsDag("arguments"); 46 for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) { 47 arguments.push_back( 48 {llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(), 49 args->getArgNameStr(i)}); 50 } 51 } 52 53 // Return the return type of this method. 54 StringRef getReturnType() const { 55 return def->getValueAsString("returnType"); 56 } 57 58 // Return the name of this method. 59 StringRef getName() const { return def->getValueAsString("name"); } 60 61 // Return if this method is static. 62 bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); } 63 64 // Return the body for this method if it has one. 65 llvm::Optional<StringRef> getBody() const { 66 auto value = def->getValueAsString("body"); 67 return value.empty() ? llvm::Optional<StringRef>() : value; 68 } 69 70 // Arguments. 71 ArrayRef<MethodArgument> getArguments() const { return arguments; } 72 bool arg_empty() const { return arguments.empty(); } 73 74 protected: 75 // The TableGen definition of this method. 76 const llvm::Record *def; 77 78 // The arguments of this method. 79 SmallVector<MethodArgument, 2> arguments; 80 }; 81 82 // Wrapper class with helper methods for accessing OpInterfaces defined in 83 // TableGen. 84 class OpInterface { 85 public: 86 explicit OpInterface(const llvm::Record *def) : def(def) { 87 auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods")); 88 for (llvm::Init *init : listInit->getValues()) 89 methods.emplace_back(cast<llvm::DefInit>(init)->getDef()); 90 } 91 92 // Return the name of this interface. 93 StringRef getName() const { return def->getValueAsString("cppClassName"); } 94 95 // Return the methods of this interface. 96 ArrayRef<OpInterfaceMethod> getMethods() const { return methods; } 97 98 protected: 99 // The TableGen definition of this interface. 100 const llvm::Record *def; 101 102 // The methods of this interface. 103 SmallVector<OpInterfaceMethod, 8> methods; 104 }; 105 } // end anonymous namespace 106 107 // Emit the method name and argument list for the given method. If 108 // 'addOperationArg' is true, then an Operation* argument is added to the 109 // beginning of the argument list. 110 static void emitMethodNameAndArgs(const OpInterfaceMethod &method, 111 raw_ostream &os, bool addOperationArg) { 112 os << method.getName() << '('; 113 if (addOperationArg) 114 os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", "); 115 interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) { 116 os << arg.type << " " << arg.name; 117 }); 118 os << ')'; 119 } 120 121 static void emitInterfaceDef(const Record &interfaceDef, raw_ostream &os) { 122 OpInterface interface(&interfaceDef); 123 StringRef interfaceName = interface.getName(); 124 125 // Insert the method definitions. 126 for (auto &method : interface.getMethods()) { 127 os << method.getReturnType() << " " << interfaceName << "::"; 128 emitMethodNameAndArgs(method, os, /*addOperationArg=*/false); 129 130 // Forward to the method on the concrete operation type. 131 os << " {\n return getImpl()->" << method.getName() << '('; 132 if (!method.isStatic()) 133 os << "getOperation()" << (method.arg_empty() ? "" : ", "); 134 interleaveComma(method.getArguments(), os, 135 [&](const MethodArgument &arg) { os << arg.name; }); 136 os << ");\n }\n"; 137 } 138 } 139 140 static bool emitInterfaceDefs(const RecordKeeper &recordKeeper, 141 raw_ostream &os) { 142 llvm::emitSourceFileHeader("Operation Interface Definitions", os); 143 144 auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface"); 145 for (const auto *def : defs) 146 emitInterfaceDef(*def, os); 147 return false; 148 } 149 150 static void emitConceptDecl(OpInterface &interface, raw_ostream &os) { 151 os << " class Concept {\n" 152 << " public:\n" 153 << " virtual ~Concept() = default;\n"; 154 155 // Insert each of the pure virtual concept methods. 156 for (auto &method : interface.getMethods()) { 157 os << " virtual " << method.getReturnType() << " "; 158 emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic()); 159 os << " = 0;\n"; 160 } 161 os << " };\n"; 162 } 163 164 static void emitModelDecl(OpInterface &interface, raw_ostream &os) { 165 os << " template<typename ConcreteOp>\n"; 166 os << " class Model : public Concept {\npublic:\n"; 167 168 // Insert each of the virtual method overrides. 169 for (auto &method : interface.getMethods()) { 170 os << " " << method.getReturnType() << " "; 171 emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic()); 172 os << " final {\n"; 173 174 // Provide a definition of the concrete op if this is non static. 175 if (!method.isStatic()) { 176 os << " auto op = llvm::cast<ConcreteOp>(tablegen_opaque_op);\n" 177 << " (void)op;\n"; 178 } 179 180 // Check for a provided body to the function. 181 if (auto body = method.getBody()) { 182 os << body << "\n }\n"; 183 continue; 184 } 185 186 // Forward to the method on the concrete operation type. 187 os << " return " << (method.isStatic() ? "ConcreteOp::" : "op."); 188 189 // Add the arguments to the call. 190 os << method.getName() << '('; 191 interleaveComma(method.getArguments(), os, 192 [&](const MethodArgument &arg) { os << arg.name; }); 193 os << ");\n }\n"; 194 } 195 os << " };\n"; 196 } 197 198 static void emitInterfaceDecl(const Record &interfaceDef, raw_ostream &os) { 199 OpInterface interface(&interfaceDef); 200 StringRef interfaceName = interface.getName(); 201 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); 202 203 // Emit the traits struct containing the concept and model declarations. 204 os << "namespace detail {\n" 205 << "struct " << interfaceTraitsName << " {\n"; 206 emitConceptDecl(interface, os); 207 emitModelDecl(interface, os); 208 os << "};\n} // end namespace detail\n"; 209 210 // Emit the main interface class declaration. 211 os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n" 212 "public:\n" 213 " using OpInterface<{1}, detail::{2}>::OpInterface;\n", 214 interfaceName, interfaceName, interfaceTraitsName); 215 216 // Insert the method declarations. 217 for (auto &method : interface.getMethods()) { 218 os << " " << method.getReturnType() << " "; 219 emitMethodNameAndArgs(method, os, /*addOperationArg=*/false); 220 os << ";\n"; 221 } 222 os << "};\n"; 223 } 224 225 static bool emitInterfaceDecls(const RecordKeeper &recordKeeper, 226 raw_ostream &os) { 227 llvm::emitSourceFileHeader("Operation Interface Declarations", os); 228 229 auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface"); 230 for (const auto *def : defs) 231 emitInterfaceDecl(*def, os); 232 return false; 233 } 234 235 // Registers the operation interface generator to mlir-tblgen. 236 static mlir::GenRegistration 237 genInterfaceDecls("gen-op-interface-decls", 238 "Generate op interface declarations", 239 [](const RecordKeeper &records, raw_ostream &os) { 240 return emitInterfaceDecls(records, os); 241 }); 242 243 // Registers the operation interface generator to mlir-tblgen. 244 static mlir::GenRegistration 245 genInterfaceDefs("gen-op-interface-defs", 246 "Generate op interface definitions", 247 [](const RecordKeeper &records, raw_ostream &os) { 248 return emitInterfaceDefs(records, os); 249 });