github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (about) 1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // OpDefinitionsGen uses the description of operations to generate C++ 19 // definitions for ops. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Support/STLExtras.h" 24 #include "mlir/TableGen/Format.h" 25 #include "mlir/TableGen/GenInfo.h" 26 #include "mlir/TableGen/OpTrait.h" 27 #include "mlir/TableGen/Operator.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/Support/Signals.h" 30 #include "llvm/TableGen/Error.h" 31 #include "llvm/TableGen/Record.h" 32 #include "llvm/TableGen/TableGenBackend.h" 33 34 using namespace llvm; 35 using namespace mlir; 36 using namespace mlir::tblgen; 37 38 static const char *const tblgenNamePrefix = "tblgen_"; 39 static const char *const generatedArgName = "tblgen_arg"; 40 static const char *const builderOpState = "tblgen_state"; 41 42 // The logic to calculate the dynamic value range for an static operand/result 43 // of an op with variadic operands/results. Note that this logic is not for 44 // general use; it assumes all variadic operands/results must have the same 45 // number of values. 46 // 47 // {0}: The list of whether each static operand/result is variadic. 48 // {1}: The total number of non-variadic operands/results. 49 // {2}: The total number of variadic operands/results. 50 // {3}: The total number of dynamic values. 51 // {4}: The begin iterator of the dynamic values. 52 // {5}: "operand" or "result" 53 const char *valueRangeCalcCode = R"( 54 bool isVariadic[] = {{{0}}; 55 int prevVariadicCount = 0; 56 for (unsigned i = 0; i < index; ++i) 57 if (isVariadic[i]) ++prevVariadicCount; 58 59 // Calculate how many dynamic values a static variadic {5} corresponds to. 60 // This assumes all static variadic {5}s have the same dynamic value count. 61 int variadicSize = ({3} - {1}) / {2}; 62 // `index` passed in as the parameter is the static index which counts each 63 // {5} (variadic or not) as size 1. So here for each previous static variadic 64 // {5}, we need to offset by (variadicSize - 1) to get where the dynamic 65 // value pack for this static {5} starts. 66 int offset = index + (variadicSize - 1) * prevVariadicCount; 67 int size = isVariadic[index] ? variadicSize : 1; 68 69 return {{std::next({4}, offset), std::next({4}, offset + size)}; 70 )"; 71 72 static const char *const opCommentHeader = R"( 73 //===----------------------------------------------------------------------===// 74 // {0} {1} 75 //===----------------------------------------------------------------------===// 76 77 )"; 78 79 //===----------------------------------------------------------------------===// 80 // Utility structs and functions 81 //===----------------------------------------------------------------------===// 82 83 // Returns whether the record has a value of the given name that can be returned 84 // via getValueAsString. 85 static inline bool hasStringAttribute(const Record &record, 86 StringRef fieldName) { 87 auto valueInit = record.getValueInit(fieldName); 88 return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit); 89 } 90 91 static std::string getArgumentName(const Operator &op, int index) { 92 const auto &operand = op.getOperand(index); 93 if (!operand.name.empty()) 94 return operand.name; 95 else 96 return formatv("{0}_{1}", generatedArgName, index); 97 } 98 99 namespace { 100 // Simple RAII helper for defining ifdef-undef-endif scopes. 101 class IfDefScope { 102 public: 103 IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) { 104 os << "#ifdef " << name << "\n" 105 << "#undef " << name << "\n\n"; 106 } 107 108 ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } 109 110 private: 111 StringRef name; 112 raw_ostream &os; 113 }; 114 } // end anonymous namespace 115 116 //===----------------------------------------------------------------------===// 117 // Classes for C++ code emission 118 //===----------------------------------------------------------------------===// 119 120 // We emit the op declaration and definition into separate files: *Ops.h.inc 121 // and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and 122 // the latter for dialect *Ops.cpp. This way provides a cleaner interface. 123 // 124 // In order to do this split, we need to track method signature and 125 // implementation logic separately. Signature information is used for both 126 // declaration and definition, while implementation logic is only for 127 // definition. So we have the following classes for C++ code emission. 128 129 namespace { 130 // Class for holding the signature of an op's method for C++ code emission 131 class OpMethodSignature { 132 public: 133 OpMethodSignature(StringRef retType, StringRef name, StringRef params); 134 135 // Writes the signature as a method declaration to the given `os`. 136 void writeDeclTo(raw_ostream &os) const; 137 // Writes the signature as the start of a method definition to the given `os`. 138 // `namePrefix` is the prefix to be prepended to the method name (typically 139 // namespaces for qualifying the method definition). 140 void writeDefTo(raw_ostream &os, StringRef namePrefix) const; 141 142 private: 143 // Returns true if the given C++ `type` ends with '&' or '*', or is empty. 144 static bool elideSpaceAfterType(StringRef type); 145 146 std::string returnType; 147 std::string methodName; 148 std::string parameters; 149 }; 150 151 // Class for holding the body of an op's method for C++ code emission 152 class OpMethodBody { 153 public: 154 explicit OpMethodBody(bool declOnly); 155 156 OpMethodBody &operator<<(Twine content); 157 OpMethodBody &operator<<(int content); 158 OpMethodBody &operator<<(const FmtObjectBase &content); 159 160 void writeTo(raw_ostream &os) const; 161 162 private: 163 // Whether this class should record method body. 164 bool isEffective; 165 std::string body; 166 }; 167 168 // Class for holding an op's method for C++ code emission 169 class OpMethod { 170 public: 171 // Properties (qualifiers) of class methods. Bitfield is used here to help 172 // querying properties. 173 enum Property { 174 MP_None = 0x0, 175 MP_Static = 0x1, // Static method 176 MP_Constructor = 0x2, // Constructor 177 MP_Private = 0x4, // Private method 178 }; 179 180 OpMethod(StringRef retType, StringRef name, StringRef params, 181 Property property, bool declOnly); 182 183 OpMethodBody &body(); 184 185 // Returns true if this is a static method. 186 bool isStatic() const; 187 188 // Returns true if this is a private method. 189 bool isPrivate() const; 190 191 // Writes the method as a declaration to the given `os`. 192 void writeDeclTo(raw_ostream &os) const; 193 // Writes the method as a definition to the given `os`. `namePrefix` is the 194 // prefix to be prepended to the method name (typically namespaces for 195 // qualifying the method definition). 196 void writeDefTo(raw_ostream &os, StringRef namePrefix) const; 197 198 private: 199 Property properties; 200 // Whether this method only contains a declaration. 201 bool isDeclOnly; 202 OpMethodSignature methodSignature; 203 OpMethodBody methodBody; 204 }; 205 206 // A class used to emit C++ classes from Tablegen. Contains a list of public 207 // methods and a list of private fields to be emitted. 208 class Class { 209 public: 210 explicit Class(StringRef name); 211 212 // Creates a new method in this class. 213 OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "", 214 OpMethod::Property = OpMethod::MP_None, 215 bool declOnly = false); 216 217 OpMethod &newConstructor(StringRef params = "", bool declOnly = false); 218 219 // Creates a new field in this class. 220 void newField(StringRef type, StringRef name, StringRef defaultValue = ""); 221 222 // Writes this op's class as a declaration to the given `os`. 223 void writeDeclTo(raw_ostream &os) const; 224 // Writes the method definitions in this op's class to the given `os`. 225 void writeDefTo(raw_ostream &os) const; 226 227 // Returns the C++ class name of the op. 228 StringRef getClassName() const { return className; } 229 230 protected: 231 std::string className; 232 SmallVector<OpMethod, 8> methods; 233 SmallVector<std::string, 4> fields; 234 }; 235 236 // Class for holding an op for C++ code emission 237 class OpClass : public Class { 238 public: 239 explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); 240 241 // Adds an op trait. 242 void addTrait(Twine trait); 243 244 // Writes this op's class as a declaration to the given `os`. Redefines 245 // Class::writeDeclTo to also emit traits and extra class declarations. 246 void writeDeclTo(raw_ostream &os) const; 247 248 private: 249 StringRef extraClassDeclaration; 250 SmallVector<std::string, 4> traits; 251 }; 252 } // end anonymous namespace 253 254 OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name, 255 StringRef params) 256 : returnType(retType), methodName(name), parameters(params) {} 257 258 void OpMethodSignature::writeDeclTo(raw_ostream &os) const { 259 os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName 260 << "(" << parameters << ")"; 261 } 262 263 void OpMethodSignature::writeDefTo(raw_ostream &os, 264 StringRef namePrefix) const { 265 // We need to remove the default values for parameters in method definition. 266 // TODO(antiagainst): We are using '=' and ',' as delimiters for parameter 267 // initializers. This is incorrect for initializer list with more than one 268 // element. Change to a more robust approach. 269 auto removeParamDefaultValue = [](StringRef params) { 270 std::string result; 271 std::pair<StringRef, StringRef> parts; 272 while (!params.empty()) { 273 parts = params.split("="); 274 result.append(result.empty() ? "" : ", "); 275 result.append(parts.first); 276 params = parts.second.split(",").second; 277 } 278 return result; 279 }; 280 281 os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix 282 << (namePrefix.empty() ? "" : "::") << methodName << "(" 283 << removeParamDefaultValue(parameters) << ")"; 284 } 285 286 bool OpMethodSignature::elideSpaceAfterType(StringRef type) { 287 return type.empty() || type.endswith("&") || type.endswith("*"); 288 } 289 290 OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} 291 292 OpMethodBody &OpMethodBody::operator<<(Twine content) { 293 if (isEffective) 294 body.append(content.str()); 295 return *this; 296 } 297 298 OpMethodBody &OpMethodBody::operator<<(int content) { 299 if (isEffective) 300 body.append(std::to_string(content)); 301 return *this; 302 } 303 304 OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { 305 if (isEffective) 306 body.append(content.str()); 307 return *this; 308 } 309 310 void OpMethodBody::writeTo(raw_ostream &os) const { 311 auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); 312 os << bodyRef; 313 if (bodyRef.empty() || bodyRef.back() != '\n') 314 os << "\n"; 315 } 316 317 OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params, 318 OpMethod::Property property, bool declOnly) 319 : properties(property), isDeclOnly(declOnly), 320 methodSignature(retType, name, params), methodBody(declOnly) {} 321 322 OpMethodBody &OpMethod::body() { return methodBody; } 323 324 bool OpMethod::isStatic() const { return properties & MP_Static; } 325 326 bool OpMethod::isPrivate() const { return properties & MP_Private; } 327 328 void OpMethod::writeDeclTo(raw_ostream &os) const { 329 os.indent(2); 330 if (isStatic()) 331 os << "static "; 332 methodSignature.writeDeclTo(os); 333 os << ";"; 334 } 335 336 void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { 337 if (isDeclOnly) 338 return; 339 340 methodSignature.writeDefTo(os, namePrefix); 341 os << " {\n"; 342 methodBody.writeTo(os); 343 os << "}"; 344 } 345 346 Class::Class(StringRef name) : className(name) {} 347 348 OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params, 349 OpMethod::Property property, bool declOnly) { 350 methods.emplace_back(retType, name, params, property, declOnly); 351 return methods.back(); 352 } 353 354 OpMethod &Class::newConstructor(StringRef params, bool declOnly) { 355 return newMethod("", getClassName(), params, OpMethod::MP_Constructor, 356 declOnly); 357 } 358 359 void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { 360 std::string varName = formatv("{0} {1}", type, name).str(); 361 std::string field = defaultValue.empty() 362 ? varName 363 : formatv("{0} = {1}", varName, defaultValue).str(); 364 fields.push_back(std::move(field)); 365 } 366 367 void Class::writeDeclTo(raw_ostream &os) const { 368 bool hasPrivateMethod = false; 369 os << "class " << className << " {\n"; 370 os << "public:\n"; 371 for (const auto &method : methods) { 372 if (!method.isPrivate()) { 373 method.writeDeclTo(os); 374 os << '\n'; 375 } else { 376 hasPrivateMethod = true; 377 } 378 } 379 os << '\n'; 380 os << "private:\n"; 381 if (hasPrivateMethod) { 382 for (const auto &method : methods) { 383 if (method.isPrivate()) { 384 method.writeDeclTo(os); 385 os << '\n'; 386 } 387 } 388 os << '\n'; 389 } 390 for (const auto &field : fields) 391 os.indent(2) << field << ";\n"; 392 os << "};\n"; 393 } 394 395 void Class::writeDefTo(raw_ostream &os) const { 396 for (const auto &method : methods) { 397 method.writeDefTo(os, className); 398 os << "\n\n"; 399 } 400 } 401 402 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) 403 : Class(name), extraClassDeclaration(extraClassDeclaration) {} 404 405 // Adds the given trait to this op. 406 void OpClass::addTrait(Twine trait) { traits.push_back(trait.str()); } 407 408 void OpClass::writeDeclTo(raw_ostream &os) const { 409 os << "class " << className << " : public Op<" << className; 410 for (const auto &trait : traits) 411 os << ", " << trait; 412 os << "> {\npublic:\n"; 413 os << " using Op::Op;\n"; 414 os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; 415 416 bool hasPrivateMethod = false; 417 for (const auto &method : methods) { 418 if (!method.isPrivate()) { 419 method.writeDeclTo(os); 420 os << "\n"; 421 } else { 422 hasPrivateMethod = true; 423 } 424 } 425 426 // TODO: Add line control markers to make errors easier to debug. 427 if (!extraClassDeclaration.empty()) 428 os << extraClassDeclaration << "\n"; 429 430 if (hasPrivateMethod) { 431 os << '\n'; 432 os << "private:\n"; 433 for (const auto &method : methods) { 434 if (method.isPrivate()) { 435 method.writeDeclTo(os); 436 os << "\n"; 437 } 438 } 439 } 440 441 os << "};\n"; 442 } 443 444 //===----------------------------------------------------------------------===// 445 // Op emitter 446 //===----------------------------------------------------------------------===// 447 448 namespace { 449 // Helper class to emit a record into the given output stream. 450 class OpEmitter { 451 public: 452 static void emitDecl(const Operator &op, raw_ostream &os); 453 static void emitDef(const Operator &op, raw_ostream &os); 454 455 private: 456 OpEmitter(const Operator &op); 457 458 void emitDecl(raw_ostream &os); 459 void emitDef(raw_ostream &os); 460 461 // Generates the `getOperationName` method for this op. 462 void genOpNameGetter(); 463 464 // Generates getters for the attributes. 465 void genAttrGetters(); 466 467 // Generates getters for named operands. 468 void genNamedOperandGetters(); 469 470 // Generates getters for named results. 471 void genNamedResultGetters(); 472 473 // Generates getters for named regions. 474 void genNamedRegionGetters(); 475 476 // Generates builder methods for the operation. 477 void genBuilder(); 478 479 // Generates the build() method that takes each result-type/operand/attribute 480 // as a stand-alone parameter. This build() method also requires specifying 481 // result types for all results. 482 void genSeparateParamBuilder(); 483 484 // Generates the build() method that takes a single parameter for all the 485 // result types and a separate parameter for each operand/attribute. 486 void genCollectiveTypeParamBuilder(); 487 488 // Generates the build() method that takes each operand/attribute as a 489 // stand-alone parameter. This build() method uses first operand's type 490 // as all result's types. 491 void genUseOperandAsResultTypeBuilder(); 492 493 // Generates the build() method that takes each operand/attribute as a 494 // stand-alone parameter. This build() method uses first attribute's type 495 // as all result's types. 496 void genUseAttrAsResultTypeBuilder(); 497 498 // Generates the build() method that takes all result types collectively as 499 // one parameter. Similarly for operands and attributes. 500 void genCollectiveParamBuilder(); 501 502 enum class TypeParamKind { None, Separate, Collective }; 503 504 // Builds the parameter list for build() method of this op. This method writes 505 // to `paramList` the comma-separated parameter list. If `includeResultTypes` 506 // is true then `paramList` will also contain the parameters for all results 507 // and `resultTypeNames` will be populated with the parameter name for each 508 // result type. 509 void buildParamList(std::string ¶mList, 510 SmallVectorImpl<std::string> &resultTypeNames, 511 TypeParamKind kind); 512 513 // Adds op arguments and regions into operation state for build() methods. 514 void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body); 515 516 // Generates canonicalizer declaration for the operation. 517 void genCanonicalizerDecls(); 518 519 // Generates the folder declaration for the operation. 520 void genFolderDecls(); 521 522 // Generates the parser for the operation. 523 void genParser(); 524 525 // Generates the printer for the operation. 526 void genPrinter(); 527 528 // Generates verify method for the operation. 529 void genVerifier(); 530 531 // Generates verify statements for operands and results in the operation. 532 // The generated code will be attached to `body`. 533 void genOperandResultVerifier(OpMethodBody &body, 534 Operator::value_range values, 535 StringRef valueKind); 536 537 // Generates verify statements for regions in the operation. 538 // The generated code will be attached to `body`. 539 void genRegionVerifier(OpMethodBody &body); 540 541 // Generates the traits used by the object. 542 void genTraits(); 543 544 private: 545 // The TableGen record for this op. 546 // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly, 547 // it should rather go through the Operator for better abstraction. 548 const Record &def; 549 550 // The wrapper operator class for querying information from this op. 551 Operator op; 552 553 // The C++ code builder for this op 554 OpClass opClass; 555 556 // The format context for verification code generation. 557 FmtContext verifyCtx; 558 }; 559 } // end anonymous namespace 560 561 OpEmitter::OpEmitter(const Operator &op) 562 : def(op.getDef()), op(op), 563 opClass(op.getCppClassName(), op.getExtraClassDeclaration()) { 564 verifyCtx.withOp("(*this->getOperation())"); 565 566 genTraits(); 567 // Generate C++ code for various op methods. The order here determines the 568 // methods in the generated file. 569 genOpNameGetter(); 570 genNamedOperandGetters(); 571 genNamedResultGetters(); 572 genNamedRegionGetters(); 573 genAttrGetters(); 574 genBuilder(); 575 genParser(); 576 genPrinter(); 577 genVerifier(); 578 genCanonicalizerDecls(); 579 genFolderDecls(); 580 } 581 582 void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { 583 OpEmitter(op).emitDecl(os); 584 } 585 586 void OpEmitter::emitDef(const Operator &op, raw_ostream &os) { 587 OpEmitter(op).emitDef(os); 588 } 589 590 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); } 591 592 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } 593 594 void OpEmitter::genAttrGetters() { 595 FmtContext fctx; 596 fctx.withBuilder("mlir::Builder(this->getContext())"); 597 for (auto &namedAttr : op.getAttributes()) { 598 const auto &name = namedAttr.name; 599 const auto &attr = namedAttr.attr; 600 601 auto &method = opClass.newMethod(attr.getReturnType(), name); 602 auto &body = method.body(); 603 604 // Emit the derived attribute body. 605 if (attr.isDerivedAttr()) { 606 body << " " << attr.getDerivedCodeBody() << "\n"; 607 continue; 608 } 609 610 // Emit normal emitter. 611 612 // Return the queried attribute with the correct return type. 613 auto attrVal = 614 (attr.hasDefaultValueInitializer() || attr.isOptional()) 615 ? formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()", name, 616 attr.getStorageType()) 617 : formatv("this->getAttr(\"{0}\").cast<{1}>()", name, 618 attr.getStorageType()); 619 body << " auto attr = " << attrVal << ";\n"; 620 if (attr.hasDefaultValueInitializer()) { 621 // Returns the default value if not set. 622 // TODO: this is inefficient, we are recreating the attribute for every 623 // call. This should be set instead. 624 std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx, 625 attr.getDefaultValueInitializer()); 626 body << " if (!attr)\n return " 627 << tgfmt(attr.getConvertFromStorageCall(), 628 &fctx.withSelf(defaultValue)) 629 << ";\n"; 630 } 631 body << " return " 632 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) 633 << ";\n"; 634 } 635 } 636 637 // Generates the named operand getter methods for the given Operator `op` and 638 // puts them in `opClass`. Uses `rangeType` as the return type of getters that 639 // return a range of operands (individual operands are `Value *` and each 640 // element in the range must also be `Value *`); use `rangeBeginCall` to get an 641 // iterator to the beginning of the operand range; use `rangeSizeCall` to obtain 642 // the number of operands. `getOperandCallPattern` contains the code necessary 643 // to obtain a single operand whose position will be substituted instead of 644 // "{0}" marker in the pattern. Note that the pattern should work for any kind 645 // of ops, in particular for one-operand ops that may not have the 646 // `getOperand(unsigned)` method. 647 static void generateNamedOperandGetters(const Operator &op, Class &opClass, 648 StringRef rangeType, 649 StringRef rangeBeginCall, 650 StringRef rangeSizeCall, 651 StringRef getOperandCallPattern) { 652 const int numOperands = op.getNumOperands(); 653 const int numVariadicOperands = op.getNumVariadicOperands(); 654 const int numNormalOperands = numOperands - numVariadicOperands; 655 656 if (numVariadicOperands > 1 && 657 !op.hasTrait("OpTrait::SameVariadicOperandSize")) { 658 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " 659 "specification over their sizes"); 660 } 661 662 // First emit a "sink" getter method upon which we layer all nicer named 663 // getter methods. 664 auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); 665 666 if (numVariadicOperands == 0) { 667 // We still need to match the return type, which is a range. 668 m.body() << "return {std::next(" << rangeBeginCall << ", index), std::next(" 669 << rangeBeginCall << ", index + 1)};"; 670 } else { 671 // Because the op can have arbitrarily interleaved variadic and non-variadic 672 // operands, we need to embed a list in the "sink" getter method for 673 // calculation at run-time. 674 llvm::SmallVector<StringRef, 4> isVariadic; 675 isVariadic.reserve(numOperands); 676 for (int i = 0; i < numOperands; ++i) { 677 isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic())); 678 } 679 std::string isVariadicList = llvm::join(isVariadic, ", "); 680 681 m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalOperands, 682 numVariadicOperands, rangeSizeCall, rangeBeginCall, 683 "operand"); 684 } 685 686 // Then we emit nicer named getter methods by redirecting to the "sink" getter 687 // method. 688 689 for (int i = 0; i != numOperands; ++i) { 690 const auto &operand = op.getOperand(i); 691 if (operand.name.empty()) 692 continue; 693 694 if (operand.isVariadic()) { 695 auto &m = opClass.newMethod(rangeType, operand.name); 696 m.body() << "return getODSOperands(" << i << ");"; 697 } else { 698 auto &m = opClass.newMethod("Value *", operand.name); 699 m.body() << "return *getODSOperands(" << i << ").begin();"; 700 } 701 } 702 } 703 704 void OpEmitter::genNamedOperandGetters() { 705 generateNamedOperandGetters( 706 op, opClass, /*rangeType=*/"Operation::operand_range", 707 /*rangeBeginCall=*/"getOperation()->operand_begin()", 708 /*rangeSizeCall=*/"getOperation()->getNumOperands()", 709 /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); 710 } 711 712 void OpEmitter::genNamedResultGetters() { 713 const int numResults = op.getNumResults(); 714 const int numVariadicResults = op.getNumVariadicResults(); 715 const int numNormalResults = numResults - numVariadicResults; 716 717 // If we have more than one variadic results, we need more complicated logic 718 // to calculate the value range for each result. 719 720 if (numVariadicResults > 1 && 721 !op.hasTrait("OpTrait::SameVariadicResultSize")) { 722 PrintFatalError(op.getLoc(), "op has multiple variadic results but no " 723 "specification over their sizes"); 724 } 725 726 auto &m = opClass.newMethod("Operation::result_range", "getODSResults", 727 "unsigned index"); 728 729 if (numVariadicResults == 0) { 730 m.body() << "return {std::next(getOperation()->result_begin(), index), " 731 "std::next(getOperation()->result_begin(), index + 1)};"; 732 } else { 733 llvm::SmallVector<StringRef, 4> isVariadic; 734 isVariadic.reserve(numResults); 735 for (int i = 0; i < numResults; ++i) { 736 isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic())); 737 } 738 std::string isVariadicList = llvm::join(isVariadic, ", "); 739 740 m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalResults, 741 numVariadicResults, "getOperation()->getNumResults()", 742 "getOperation()->result_begin()", "result"); 743 } 744 745 for (int i = 0; i != numResults; ++i) { 746 const auto &result = op.getResult(i); 747 if (result.name.empty()) 748 continue; 749 750 if (result.isVariadic()) { 751 auto &m = opClass.newMethod("Operation::result_range", result.name); 752 m.body() << "return getODSResults(" << i << ");"; 753 } else { 754 auto &m = opClass.newMethod("Value *", result.name); 755 m.body() << "return *getODSResults(" << i << ").begin();"; 756 } 757 } 758 } 759 760 void OpEmitter::genNamedRegionGetters() { 761 unsigned numRegions = op.getNumRegions(); 762 for (unsigned i = 0; i < numRegions; ++i) { 763 const auto ®ion = op.getRegion(i); 764 if (!region.name.empty()) { 765 auto &m = opClass.newMethod("Region &", region.name); 766 m.body() << formatv("return this->getOperation()->getRegion({0});", i); 767 } 768 } 769 } 770 771 void OpEmitter::genSeparateParamBuilder() { 772 std::string paramList; 773 llvm::SmallVector<std::string, 4> resultNames; 774 buildParamList(paramList, resultNames, TypeParamKind::Separate); 775 776 auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); 777 genCodeForAddingArgAndRegionForBuilder(m.body()); 778 779 // Push all result types to the operation state 780 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 781 m.body() << " " << builderOpState << "->addTypes(" << resultNames[i] 782 << ");\n"; 783 } 784 } 785 786 void OpEmitter::genCollectiveTypeParamBuilder() { 787 auto numResults = op.getNumResults(); 788 789 // If this op has no results, then just skip generating this builder. 790 // Otherwise we are generating the same signature as the separate-parameter 791 // builder. 792 if (numResults == 0) 793 return; 794 795 // Similarly for ops with one single variadic result, which will also have one 796 // `ArrayRef<Type>` parameter for the result type. 797 if (numResults == 1 && op.getResult(0).isVariadic()) 798 return; 799 800 std::string paramList; 801 llvm::SmallVector<std::string, 4> resultNames; 802 buildParamList(paramList, resultNames, TypeParamKind::Collective); 803 804 auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); 805 genCodeForAddingArgAndRegionForBuilder(m.body()); 806 807 // Push all result types to the operation state 808 m.body() << formatv(" {0}->addTypes(resultTypes);\n", builderOpState); 809 } 810 811 void OpEmitter::genUseOperandAsResultTypeBuilder() { 812 std::string paramList; 813 llvm::SmallVector<std::string, 4> resultNames; 814 buildParamList(paramList, resultNames, TypeParamKind::None); 815 816 auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); 817 genCodeForAddingArgAndRegionForBuilder(m.body()); 818 819 auto numResults = op.getNumResults(); 820 if (numResults == 0) 821 return; 822 823 // Push all result types to the operation state 824 const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; 825 std::string resultType = 826 formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str(); 827 m.body() << " " << builderOpState << "->addTypes({" << resultType; 828 for (int i = 1; i != numResults; ++i) 829 m.body() << ", " << resultType; 830 m.body() << "});\n\n"; 831 } 832 833 void OpEmitter::genUseAttrAsResultTypeBuilder() { 834 std::string paramList; 835 llvm::SmallVector<std::string, 4> resultNames; 836 buildParamList(paramList, resultNames, TypeParamKind::None); 837 838 auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); 839 genCodeForAddingArgAndRegionForBuilder(m.body()); 840 841 auto numResults = op.getNumResults(); 842 if (numResults == 0) 843 return; 844 845 // Push all result types to the operation state 846 std::string resultType; 847 const auto &namedAttr = op.getAttribute(0); 848 if (namedAttr.attr.isTypeAttr()) { 849 resultType = formatv("{0}.getValue()", namedAttr.name); 850 } else { 851 resultType = formatv("{0}.getType()", namedAttr.name); 852 } 853 m.body() << " " << builderOpState << "->addTypes({" << resultType; 854 for (int i = 1; i != numResults; ++i) 855 m.body() << ", " << resultType; 856 m.body() << "});\n\n"; 857 } 858 859 void OpEmitter::genBuilder() { 860 // Handle custom builders if provided. 861 // TODO(antiagainst): Create wrapper class for OpBuilder to hide the native 862 // TableGen API calls here. 863 { 864 auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders")); 865 if (listInit) { 866 for (Init *init : listInit->getValues()) { 867 Record *builderDef = cast<DefInit>(init)->getDef(); 868 StringRef params = builderDef->getValueAsString("params"); 869 StringRef body = builderDef->getValueAsString("body"); 870 bool hasBody = !body.empty(); 871 872 auto &method = 873 opClass.newMethod("void", "build", params, OpMethod::MP_Static, 874 /*declOnly=*/!hasBody); 875 if (hasBody) 876 method.body() << body; 877 } 878 } 879 if (op.skipDefaultBuilders()) { 880 if (!listInit || listInit->empty()) 881 PrintFatalError( 882 op.getLoc(), 883 "default builders are skipped and no custom builders provided"); 884 return; 885 } 886 } 887 888 // Generate default builders that requires all result type, operands, and 889 // attributes as parameters. 890 891 // We generate three builders here: 892 // 1. one having a stand-alone parameter for each result type / operand / 893 // attribute, and 894 genSeparateParamBuilder(); 895 // 2. one having a stand-alone parameter for each operand / attribute and 896 // an aggregrated parameter for all result types, and 897 genCollectiveTypeParamBuilder(); 898 // 3. one having an aggregated parameter for all result types / operands / 899 // attributes, and 900 genCollectiveParamBuilder(); 901 // 4. one having a stand-alone prameter for each operand and attribute, 902 // use the first operand or attribute's type as all result types 903 // to facilitate different call patterns. 904 if (op.getNumVariadicResults() == 0) { 905 if (op.hasTrait("OpTrait::SameOperandsAndResultType")) 906 genUseOperandAsResultTypeBuilder(); 907 if (op.hasTrait("OpTrait::FirstAttrDerivedResultType")) 908 genUseAttrAsResultTypeBuilder(); 909 } 910 } 911 912 void OpEmitter::genCollectiveParamBuilder() { 913 int numResults = op.getNumResults(); 914 int numVariadicResults = op.getNumVariadicResults(); 915 int numNonVariadicResults = numResults - numVariadicResults; 916 917 int numOperands = op.getNumOperands(); 918 int numVariadicOperands = op.getNumVariadicOperands(); 919 int numNonVariadicOperands = numOperands - numVariadicOperands; 920 // Signature 921 std::string params = 922 std::string("Builder *, OperationState *") + builderOpState + 923 ", ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, " 924 "ArrayRef<NamedAttribute> attributes"; 925 auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); 926 auto &body = m.body(); 927 928 // Result types 929 if (numVariadicResults == 0 || numNonVariadicResults != 0) 930 body << " assert(resultTypes.size()" 931 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults 932 << "u && \"mismatched number of return types\");\n"; 933 body << " " << builderOpState << "->addTypes(resultTypes);\n"; 934 935 // Operands 936 if (numVariadicOperands == 0 || numNonVariadicOperands != 0) 937 body << " assert(operands.size()" 938 << (numVariadicOperands != 0 ? " >= " : " == ") 939 << numNonVariadicOperands 940 << "u && \"mismatched number of parameters\");\n"; 941 body << " " << builderOpState << "->addOperands(operands);\n\n"; 942 943 // Attributes 944 body << " for (const auto& pair : attributes)\n" 945 << " " << builderOpState 946 << "->addAttribute(pair.first, pair.second);\n"; 947 948 // Create the correct number of regions 949 if (int numRegions = op.getNumRegions()) { 950 for (int i = 0; i < numRegions; ++i) 951 m.body() << " (void)" << builderOpState << "->addRegion();\n"; 952 } 953 } 954 955 void OpEmitter::buildParamList(std::string ¶mList, 956 SmallVectorImpl<std::string> &resultTypeNames, 957 TypeParamKind kind) { 958 resultTypeNames.clear(); 959 auto numResults = op.getNumResults(); 960 resultTypeNames.reserve(numResults); 961 962 paramList = "Builder *, OperationState *"; 963 paramList.append(builderOpState); 964 965 switch (kind) { 966 case TypeParamKind::None: 967 break; 968 case TypeParamKind::Separate: { 969 // Add parameters for all return types 970 for (int i = 0; i < numResults; ++i) { 971 const auto &result = op.getResult(i); 972 std::string resultName = result.name; 973 if (resultName.empty()) 974 resultName = formatv("resultType{0}", i); 975 976 paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type "); 977 paramList.append(resultName); 978 979 resultTypeNames.emplace_back(std::move(resultName)); 980 } 981 } break; 982 case TypeParamKind::Collective: { 983 paramList.append(", ArrayRef<Type> resultTypes"); 984 resultTypeNames.push_back("resultTypes"); 985 } break; 986 } 987 988 int numOperands = 0; 989 int numAttrs = 0; 990 991 // Add parameters for all arguments (operands and attributes). 992 for (int i = 0, e = op.getNumArgs(); i < e; ++i) { 993 auto argument = op.getArg(i); 994 if (argument.is<tblgen::NamedTypeConstraint *>()) { 995 const auto &operand = op.getOperand(numOperands); 996 paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> " 997 : ", Value *"); 998 paramList.append(getArgumentName(op, numOperands)); 999 ++numOperands; 1000 } else { 1001 // TODO(antiagainst): Support default initializer for attributes 1002 const auto &namedAttr = op.getAttribute(numAttrs); 1003 const auto &attr = namedAttr.attr; 1004 paramList.append(", "); 1005 if (attr.isOptional()) 1006 paramList.append("/*optional*/"); 1007 paramList.append(attr.getStorageType()); 1008 paramList.append(" "); 1009 paramList.append(namedAttr.name); 1010 ++numAttrs; 1011 } 1012 } 1013 1014 if (numOperands + numAttrs != op.getNumArgs()) 1015 PrintFatalError("op arguments must be either operands or attributes"); 1016 } 1017 1018 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) { 1019 // Push all operands to the result 1020 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 1021 body << " " << builderOpState << "->addOperands(" << getArgumentName(op, i) 1022 << ");\n"; 1023 } 1024 1025 // Push all attributes to the result 1026 for (const auto &namedAttr : op.getAttributes()) { 1027 if (!namedAttr.attr.isDerivedAttr()) { 1028 bool emitNotNullCheck = namedAttr.attr.isOptional(); 1029 if (emitNotNullCheck) { 1030 body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; 1031 } 1032 body << formatv(" {0}->addAttribute(\"{1}\", {1});\n", builderOpState, 1033 namedAttr.name); 1034 if (emitNotNullCheck) { 1035 body << " }\n"; 1036 } 1037 } 1038 } 1039 1040 // Create the correct number of regions 1041 if (int numRegions = op.getNumRegions()) { 1042 for (int i = 0; i < numRegions; ++i) 1043 body << " (void)" << builderOpState << "->addRegion();\n"; 1044 } 1045 } 1046 1047 void OpEmitter::genCanonicalizerDecls() { 1048 if (!def.getValueAsBit("hasCanonicalizer")) 1049 return; 1050 1051 const char *const params = 1052 "OwningRewritePatternList &results, MLIRContext *context"; 1053 opClass.newMethod("void", "getCanonicalizationPatterns", params, 1054 OpMethod::MP_Static, /*declOnly=*/true); 1055 } 1056 1057 void OpEmitter::genFolderDecls() { 1058 bool hasSingleResult = op.getNumResults() == 1; 1059 1060 if (def.getValueAsBit("hasFolder")) { 1061 if (hasSingleResult) { 1062 const char *const params = "ArrayRef<Attribute> operands"; 1063 opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None, 1064 /*declOnly=*/true); 1065 } else { 1066 const char *const params = "ArrayRef<Attribute> operands, " 1067 "SmallVectorImpl<OpFoldResult> &results"; 1068 opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None, 1069 /*declOnly=*/true); 1070 } 1071 } 1072 } 1073 1074 void OpEmitter::genParser() { 1075 if (!hasStringAttribute(def, "parser")) 1076 return; 1077 1078 auto &method = opClass.newMethod( 1079 "ParseResult", "parse", "OpAsmParser *parser, OperationState *result", 1080 OpMethod::MP_Static); 1081 FmtContext fctx; 1082 fctx.addSubst("cppClass", opClass.getClassName()); 1083 auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); 1084 method.body() << " " << tgfmt(parser, &fctx); 1085 } 1086 1087 void OpEmitter::genPrinter() { 1088 auto valueInit = def.getValueInit("printer"); 1089 CodeInit *codeInit = dyn_cast<CodeInit>(valueInit); 1090 if (!codeInit) 1091 return; 1092 1093 auto &method = opClass.newMethod("void", "print", "OpAsmPrinter *p"); 1094 FmtContext fctx; 1095 fctx.addSubst("cppClass", opClass.getClassName()); 1096 auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r"); 1097 method.body() << " " << tgfmt(printer, &fctx); 1098 } 1099 1100 void OpEmitter::genVerifier() { 1101 auto valueInit = def.getValueInit("verifier"); 1102 CodeInit *codeInit = dyn_cast<CodeInit>(valueInit); 1103 bool hasCustomVerify = codeInit && !codeInit->getValue().empty(); 1104 1105 auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/""); 1106 auto &body = method.body(); 1107 1108 // Populate substitutions for attributes and named operands and results. 1109 for (const auto &namedAttr : op.getAttributes()) 1110 verifyCtx.addSubst(namedAttr.name, 1111 formatv("this->getAttr(\"{0}\")", namedAttr.name)); 1112 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 1113 auto &value = op.getOperand(i); 1114 // Skip from from first variadic operands for now. Else getOperand index 1115 // used below doesn't match. 1116 if (value.isVariadic()) 1117 break; 1118 if (!value.name.empty()) 1119 verifyCtx.addSubst( 1120 value.name, formatv("(*this->getOperation()->getOperand({0}))", i)); 1121 } 1122 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 1123 auto &value = op.getResult(i); 1124 // Skip from from first variadic results for now. Else getResult index used 1125 // below doesn't match. 1126 if (value.isVariadic()) 1127 break; 1128 if (!value.name.empty()) 1129 verifyCtx.addSubst(value.name, 1130 formatv("(*this->getOperation()->getResult({0}))", i)); 1131 } 1132 1133 // Verify the attributes have the correct type. 1134 for (const auto &namedAttr : op.getAttributes()) { 1135 const auto &attr = namedAttr.attr; 1136 if (attr.isDerivedAttr()) 1137 continue; 1138 1139 auto attrName = namedAttr.name; 1140 // Prefix with `tblgen_` to avoid hiding the attribute accessor. 1141 auto varName = tblgenNamePrefix + attrName; 1142 body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName, 1143 attrName); 1144 1145 bool allowMissingAttr = 1146 attr.hasDefaultValueInitializer() || attr.isOptional(); 1147 if (allowMissingAttr) { 1148 // If the attribute has a default value, then only verify the predicate if 1149 // set. This does effectively assume that the default value is valid. 1150 // TODO: verify the debug value is valid (perhaps in debug mode only). 1151 body << " if (" << varName << ") {\n"; 1152 } else { 1153 body << " if (!" << varName 1154 << ") return emitOpError(\"requires attribute '" << attrName 1155 << "'\");\n {\n"; 1156 } 1157 1158 auto attrPred = attr.getPredicate(); 1159 if (!attrPred.isNull()) { 1160 body << tgfmt( 1161 " if (!($0)) return emitOpError(\"attribute '$1' " 1162 "failed to satisfy constraint: $2\");\n", 1163 /*ctx=*/nullptr, 1164 tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)), 1165 attrName, attr.getDescription()); 1166 } 1167 1168 body << " }\n"; 1169 } 1170 1171 genOperandResultVerifier(body, op.getOperands(), "operand"); 1172 genOperandResultVerifier(body, op.getResults(), "result"); 1173 1174 for (auto &trait : op.getTraits()) { 1175 if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) { 1176 body << tgfmt(" if (!($0)) {\n " 1177 "return emitOpError(\"failed to verify that $1\");\n }\n", 1178 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), 1179 t->getDescription()); 1180 } 1181 } 1182 1183 genRegionVerifier(body); 1184 1185 if (hasCustomVerify) { 1186 FmtContext fctx; 1187 fctx.addSubst("cppClass", opClass.getClassName()); 1188 auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r"); 1189 body << " " << tgfmt(printer, &fctx); 1190 } else { 1191 body << " return mlir::success();\n"; 1192 } 1193 } 1194 1195 void OpEmitter::genOperandResultVerifier(OpMethodBody &body, 1196 Operator::value_range values, 1197 StringRef valueKind) { 1198 FmtContext fctx; 1199 1200 body << " {\n"; 1201 body << " unsigned index = 0; (void)index;\n"; 1202 1203 for (auto staticValue : llvm::enumerate(values)) { 1204 if (!staticValue.value().hasPredicate()) 1205 continue; 1206 1207 // Emit a loop to check all the dynamic values in the pack. 1208 body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n", 1209 // Capitalize the first letter to match the function name 1210 valueKind.substr(0, 1).upper(), valueKind.substr(1), 1211 staticValue.index()); 1212 1213 auto constraint = staticValue.value().constraint; 1214 1215 body << " (void)v;\n" 1216 << " if (!(" 1217 << tgfmt(constraint.getConditionTemplate(), 1218 &fctx.withSelf("v->getType()")) 1219 << ")) {\n" 1220 << formatv(" return emitOpError(\"{0} #\") << index " 1221 "<< \" must be {1}\";\n", 1222 valueKind, constraint.getDescription()) 1223 << " }\n" // if 1224 << " ++index;\n" 1225 << " }\n"; // for 1226 } 1227 1228 body << " }\n"; 1229 } 1230 1231 void OpEmitter::genRegionVerifier(OpMethodBody &body) { 1232 unsigned numRegions = op.getNumRegions(); 1233 1234 // Verify this op has the correct number of regions 1235 body << formatv( 1236 " if (this->getOperation()->getNumRegions() != {0}) {\n " 1237 "return emitOpError(\"has incorrect number of regions: expected {0} but " 1238 "found \") << this->getOperation()->getNumRegions();\n }\n", 1239 numRegions); 1240 1241 for (unsigned i = 0; i < numRegions; ++i) { 1242 const auto ®ion = op.getRegion(i); 1243 1244 std::string name = formatv("#{0}", i); 1245 if (!region.name.empty()) { 1246 name += formatv(" ('{0}')", region.name); 1247 } 1248 1249 auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str(); 1250 auto constraint = tgfmt(region.constraint.getConditionTemplate(), 1251 &verifyCtx.withSelf(getRegion)) 1252 .str(); 1253 1254 body << formatv(" if (!({0})) {\n " 1255 "return emitOpError(\"region {1} failed to verify " 1256 "constraint: {2}\");\n }\n", 1257 constraint, name, region.constraint.getDescription()); 1258 } 1259 } 1260 1261 void OpEmitter::genTraits() { 1262 int numResults = op.getNumResults(); 1263 int numVariadicResults = op.getNumVariadicResults(); 1264 1265 // Add return size trait. 1266 if (numVariadicResults != 0) { 1267 if (numResults == numVariadicResults) 1268 opClass.addTrait("OpTrait::VariadicResults"); 1269 else 1270 opClass.addTrait("OpTrait::AtLeastNResults<" + 1271 Twine(numResults - numVariadicResults) + ">::Impl"); 1272 } else { 1273 switch (numResults) { 1274 case 0: 1275 opClass.addTrait("OpTrait::ZeroResult"); 1276 break; 1277 case 1: 1278 opClass.addTrait("OpTrait::OneResult"); 1279 break; 1280 default: 1281 opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl"); 1282 break; 1283 } 1284 } 1285 1286 for (const auto &trait : op.getTraits()) { 1287 if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait)) 1288 opClass.addTrait(opTrait->getTrait()); 1289 } 1290 1291 // Add variadic size trait and normal op traits. 1292 int numOperands = op.getNumOperands(); 1293 int numVariadicOperands = op.getNumVariadicOperands(); 1294 1295 // Add operand size trait. 1296 if (numVariadicOperands != 0) { 1297 if (numOperands == numVariadicOperands) 1298 opClass.addTrait("OpTrait::VariadicOperands"); 1299 else 1300 opClass.addTrait("OpTrait::AtLeastNOperands<" + 1301 Twine(numOperands - numVariadicOperands) + ">::Impl"); 1302 } else { 1303 switch (numOperands) { 1304 case 0: 1305 opClass.addTrait("OpTrait::ZeroOperands"); 1306 break; 1307 case 1: 1308 opClass.addTrait("OpTrait::OneOperand"); 1309 break; 1310 default: 1311 opClass.addTrait("OpTrait::NOperands<" + Twine(numOperands) + ">::Impl"); 1312 break; 1313 } 1314 } 1315 } 1316 1317 void OpEmitter::genOpNameGetter() { 1318 auto &method = opClass.newMethod("StringRef", "getOperationName", 1319 /*params=*/"", OpMethod::MP_Static); 1320 method.body() << " return \"" << op.getOperationName() << "\";\n"; 1321 } 1322 1323 //===----------------------------------------------------------------------===// 1324 // OpOperandAdaptor emitter 1325 //===----------------------------------------------------------------------===// 1326 1327 namespace { 1328 // Helper class to emit Op operand adaptors to an output stream. Operand 1329 // adaptors are wrappers around ArrayRef<Value *> that provide named operand 1330 // getters identical to those defined in the Op. 1331 class OpOperandAdaptorEmitter { 1332 public: 1333 static void emitDecl(const Operator &op, raw_ostream &os); 1334 static void emitDef(const Operator &op, raw_ostream &os); 1335 1336 private: 1337 explicit OpOperandAdaptorEmitter(const Operator &op); 1338 1339 Class adapterClass; 1340 }; 1341 } // end namespace 1342 1343 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) 1344 : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { 1345 adapterClass.newField("ArrayRef<Value *>", "tblgen_operands"); 1346 auto &constructor = adapterClass.newConstructor("ArrayRef<Value *> values"); 1347 constructor.body() << " tblgen_operands = values;\n"; 1348 1349 generateNamedOperandGetters(op, adapterClass, 1350 /*rangeType=*/"ArrayRef<Value *>", 1351 /*rangeBeginCall=*/"tblgen_operands.begin()", 1352 /*rangeSizeCall=*/"tblgen_operands.size()", 1353 /*getOperandCallPattern=*/"tblgen_operands[{0}]"); 1354 } 1355 1356 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { 1357 OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os); 1358 } 1359 1360 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { 1361 OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os); 1362 } 1363 1364 // Emits the opcode enum and op classes. 1365 static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os, 1366 bool emitDecl) { 1367 IfDefScope scope("GET_OP_CLASSES", os); 1368 // First emit forward declaration for each class, this allows them to refer 1369 // to each others in traits for example. 1370 if (emitDecl) { 1371 for (auto *def : defs) { 1372 Operator op(*def); 1373 os << "class " << op.getCppClassName() << ";\n"; 1374 } 1375 } 1376 for (auto *def : defs) { 1377 Operator op(*def); 1378 if (emitDecl) { 1379 os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); 1380 OpOperandAdaptorEmitter::emitDecl(op, os); 1381 OpEmitter::emitDecl(op, os); 1382 } else { 1383 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); 1384 OpOperandAdaptorEmitter::emitDef(op, os); 1385 OpEmitter::emitDef(op, os); 1386 } 1387 } 1388 } 1389 1390 // Emits a comma-separated list of the ops. 1391 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) { 1392 IfDefScope scope("GET_OP_LIST", os); 1393 1394 interleave( 1395 // TODO: We are constructing the Operator wrapper instance just for 1396 // getting it's qualified class name here. Reduce the overhead by having a 1397 // lightweight version of Operator class just for that purpose. 1398 defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, 1399 [&os]() { os << ",\n"; }); 1400 } 1401 1402 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { 1403 emitSourceFileHeader("Op Declarations", os); 1404 1405 const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); 1406 emitOpClasses(defs, os, /*emitDecl=*/true); 1407 1408 return false; 1409 } 1410 1411 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { 1412 emitSourceFileHeader("Op Definitions", os); 1413 1414 const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); 1415 emitOpList(defs, os); 1416 emitOpClasses(defs, os, /*emitDecl=*/false); 1417 1418 return false; 1419 } 1420 1421 static mlir::GenRegistration 1422 genOpDecls("gen-op-decls", "Generate op declarations", 1423 [](const RecordKeeper &records, raw_ostream &os) { 1424 return emitOpDecls(records, os); 1425 }); 1426 1427 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", 1428 [](const RecordKeeper &records, 1429 raw_ostream &os) { 1430 return emitOpDefs(records, os); 1431 });