github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (about)

     1  //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // OpDefinitionsGen uses the description of operations to generate C++
    19  // definitions for ops.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/Support/STLExtras.h"
    24  #include "mlir/TableGen/Format.h"
    25  #include "mlir/TableGen/GenInfo.h"
    26  #include "mlir/TableGen/OpTrait.h"
    27  #include "mlir/TableGen/Operator.h"
    28  #include "llvm/ADT/StringExtras.h"
    29  #include "llvm/Support/Signals.h"
    30  #include "llvm/TableGen/Error.h"
    31  #include "llvm/TableGen/Record.h"
    32  #include "llvm/TableGen/TableGenBackend.h"
    33  
    34  using namespace llvm;
    35  using namespace mlir;
    36  using namespace mlir::tblgen;
    37  
    38  static const char *const tblgenNamePrefix = "tblgen_";
    39  static const char *const generatedArgName = "tblgen_arg";
    40  static const char *const builderOpState = "tblgen_state";
    41  
    42  // The logic to calculate the dynamic value range for an static operand/result
    43  // of an op with variadic operands/results. Note that this logic is not for
    44  // general use; it assumes all variadic operands/results must have the same
    45  // number of values.
    46  //
    47  // {0}: The list of whether each static operand/result is variadic.
    48  // {1}: The total number of non-variadic operands/results.
    49  // {2}: The total number of variadic operands/results.
    50  // {3}: The total number of dynamic values.
    51  // {4}: The begin iterator of the dynamic values.
    52  // {5}: "operand" or "result"
    53  const char *valueRangeCalcCode = R"(
    54    bool isVariadic[] = {{{0}};
    55    int prevVariadicCount = 0;
    56    for (unsigned i = 0; i < index; ++i)
    57      if (isVariadic[i]) ++prevVariadicCount;
    58  
    59    // Calculate how many dynamic values a static variadic {5} corresponds to.
    60    // This assumes all static variadic {5}s have the same dynamic value count.
    61    int variadicSize = ({3} - {1}) / {2};
    62    // `index` passed in as the parameter is the static index which counts each
    63    // {5} (variadic or not) as size 1. So here for each previous static variadic
    64    // {5}, we need to offset by (variadicSize - 1) to get where the dynamic
    65    // value pack for this static {5} starts.
    66    int offset = index + (variadicSize - 1) * prevVariadicCount;
    67    int size = isVariadic[index] ? variadicSize : 1;
    68  
    69    return {{std::next({4}, offset), std::next({4}, offset + size)};
    70  )";
    71  
    72  static const char *const opCommentHeader = R"(
    73  //===----------------------------------------------------------------------===//
    74  // {0} {1}
    75  //===----------------------------------------------------------------------===//
    76  
    77  )";
    78  
    79  //===----------------------------------------------------------------------===//
    80  // Utility structs and functions
    81  //===----------------------------------------------------------------------===//
    82  
    83  // Returns whether the record has a value of the given name that can be returned
    84  // via getValueAsString.
    85  static inline bool hasStringAttribute(const Record &record,
    86                                        StringRef fieldName) {
    87    auto valueInit = record.getValueInit(fieldName);
    88    return isa<CodeInit>(valueInit) || isa<StringInit>(valueInit);
    89  }
    90  
    91  static std::string getArgumentName(const Operator &op, int index) {
    92    const auto &operand = op.getOperand(index);
    93    if (!operand.name.empty())
    94      return operand.name;
    95    else
    96      return formatv("{0}_{1}", generatedArgName, index);
    97  }
    98  
    99  namespace {
   100  // Simple RAII helper for defining ifdef-undef-endif scopes.
   101  class IfDefScope {
   102  public:
   103    IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) {
   104      os << "#ifdef " << name << "\n"
   105         << "#undef " << name << "\n\n";
   106    }
   107  
   108    ~IfDefScope() { os << "\n#endif  // " << name << "\n\n"; }
   109  
   110  private:
   111    StringRef name;
   112    raw_ostream &os;
   113  };
   114  } // end anonymous namespace
   115  
   116  //===----------------------------------------------------------------------===//
   117  // Classes for C++ code emission
   118  //===----------------------------------------------------------------------===//
   119  
   120  // We emit the op declaration and definition into separate files: *Ops.h.inc
   121  // and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and
   122  // the latter for dialect *Ops.cpp. This way provides a cleaner interface.
   123  //
   124  // In order to do this split, we need to track method signature and
   125  // implementation logic separately. Signature information is used for both
   126  // declaration and definition, while implementation logic is only for
   127  // definition. So we have the following classes for C++ code emission.
   128  
   129  namespace {
   130  // Class for holding the signature of an op's method for C++ code emission
   131  class OpMethodSignature {
   132  public:
   133    OpMethodSignature(StringRef retType, StringRef name, StringRef params);
   134  
   135    // Writes the signature as a method declaration to the given `os`.
   136    void writeDeclTo(raw_ostream &os) const;
   137    // Writes the signature as the start of a method definition to the given `os`.
   138    // `namePrefix` is the prefix to be prepended to the method name (typically
   139    // namespaces for qualifying the method definition).
   140    void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
   141  
   142  private:
   143    // Returns true if the given C++ `type` ends with '&' or '*', or is empty.
   144    static bool elideSpaceAfterType(StringRef type);
   145  
   146    std::string returnType;
   147    std::string methodName;
   148    std::string parameters;
   149  };
   150  
   151  // Class for holding the body of an op's method for C++ code emission
   152  class OpMethodBody {
   153  public:
   154    explicit OpMethodBody(bool declOnly);
   155  
   156    OpMethodBody &operator<<(Twine content);
   157    OpMethodBody &operator<<(int content);
   158    OpMethodBody &operator<<(const FmtObjectBase &content);
   159  
   160    void writeTo(raw_ostream &os) const;
   161  
   162  private:
   163    // Whether this class should record method body.
   164    bool isEffective;
   165    std::string body;
   166  };
   167  
   168  // Class for holding an op's method for C++ code emission
   169  class OpMethod {
   170  public:
   171    // Properties (qualifiers) of class methods. Bitfield is used here to help
   172    // querying properties.
   173    enum Property {
   174      MP_None = 0x0,
   175      MP_Static = 0x1,      // Static method
   176      MP_Constructor = 0x2, // Constructor
   177      MP_Private = 0x4,     // Private method
   178    };
   179  
   180    OpMethod(StringRef retType, StringRef name, StringRef params,
   181             Property property, bool declOnly);
   182  
   183    OpMethodBody &body();
   184  
   185    // Returns true if this is a static method.
   186    bool isStatic() const;
   187  
   188    // Returns true if this is a private method.
   189    bool isPrivate() const;
   190  
   191    // Writes the method as a declaration to the given `os`.
   192    void writeDeclTo(raw_ostream &os) const;
   193    // Writes the method as a definition to the given `os`. `namePrefix` is the
   194    // prefix to be prepended to the method name (typically namespaces for
   195    // qualifying the method definition).
   196    void writeDefTo(raw_ostream &os, StringRef namePrefix) const;
   197  
   198  private:
   199    Property properties;
   200    // Whether this method only contains a declaration.
   201    bool isDeclOnly;
   202    OpMethodSignature methodSignature;
   203    OpMethodBody methodBody;
   204  };
   205  
   206  // A class used to emit C++ classes from Tablegen.  Contains a list of public
   207  // methods and a list of private fields to be emitted.
   208  class Class {
   209  public:
   210    explicit Class(StringRef name);
   211  
   212    // Creates a new method in this class.
   213    OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "",
   214                        OpMethod::Property = OpMethod::MP_None,
   215                        bool declOnly = false);
   216  
   217    OpMethod &newConstructor(StringRef params = "", bool declOnly = false);
   218  
   219    // Creates a new field in this class.
   220    void newField(StringRef type, StringRef name, StringRef defaultValue = "");
   221  
   222    // Writes this op's class as a declaration to the given `os`.
   223    void writeDeclTo(raw_ostream &os) const;
   224    // Writes the method definitions in this op's class to the given `os`.
   225    void writeDefTo(raw_ostream &os) const;
   226  
   227    // Returns the C++ class name of the op.
   228    StringRef getClassName() const { return className; }
   229  
   230  protected:
   231    std::string className;
   232    SmallVector<OpMethod, 8> methods;
   233    SmallVector<std::string, 4> fields;
   234  };
   235  
   236  // Class for holding an op for C++ code emission
   237  class OpClass : public Class {
   238  public:
   239    explicit OpClass(StringRef name, StringRef extraClassDeclaration = "");
   240  
   241    // Adds an op trait.
   242    void addTrait(Twine trait);
   243  
   244    // Writes this op's class as a declaration to the given `os`.  Redefines
   245    // Class::writeDeclTo to also emit traits and extra class declarations.
   246    void writeDeclTo(raw_ostream &os) const;
   247  
   248  private:
   249    StringRef extraClassDeclaration;
   250    SmallVector<std::string, 4> traits;
   251  };
   252  } // end anonymous namespace
   253  
   254  OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name,
   255                                       StringRef params)
   256      : returnType(retType), methodName(name), parameters(params) {}
   257  
   258  void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
   259    os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName
   260       << "(" << parameters << ")";
   261  }
   262  
   263  void OpMethodSignature::writeDefTo(raw_ostream &os,
   264                                     StringRef namePrefix) const {
   265    // We need to remove the default values for parameters in method definition.
   266    // TODO(antiagainst): We are using '=' and ',' as delimiters for parameter
   267    // initializers. This is incorrect for initializer list with more than one
   268    // element. Change to a more robust approach.
   269    auto removeParamDefaultValue = [](StringRef params) {
   270      std::string result;
   271      std::pair<StringRef, StringRef> parts;
   272      while (!params.empty()) {
   273        parts = params.split("=");
   274        result.append(result.empty() ? "" : ", ");
   275        result.append(parts.first);
   276        params = parts.second.split(",").second;
   277      }
   278      return result;
   279    };
   280  
   281    os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix
   282       << (namePrefix.empty() ? "" : "::") << methodName << "("
   283       << removeParamDefaultValue(parameters) << ")";
   284  }
   285  
   286  bool OpMethodSignature::elideSpaceAfterType(StringRef type) {
   287    return type.empty() || type.endswith("&") || type.endswith("*");
   288  }
   289  
   290  OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
   291  
   292  OpMethodBody &OpMethodBody::operator<<(Twine content) {
   293    if (isEffective)
   294      body.append(content.str());
   295    return *this;
   296  }
   297  
   298  OpMethodBody &OpMethodBody::operator<<(int content) {
   299    if (isEffective)
   300      body.append(std::to_string(content));
   301    return *this;
   302  }
   303  
   304  OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
   305    if (isEffective)
   306      body.append(content.str());
   307    return *this;
   308  }
   309  
   310  void OpMethodBody::writeTo(raw_ostream &os) const {
   311    auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
   312    os << bodyRef;
   313    if (bodyRef.empty() || bodyRef.back() != '\n')
   314      os << "\n";
   315  }
   316  
   317  OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params,
   318                     OpMethod::Property property, bool declOnly)
   319      : properties(property), isDeclOnly(declOnly),
   320        methodSignature(retType, name, params), methodBody(declOnly) {}
   321  
   322  OpMethodBody &OpMethod::body() { return methodBody; }
   323  
   324  bool OpMethod::isStatic() const { return properties & MP_Static; }
   325  
   326  bool OpMethod::isPrivate() const { return properties & MP_Private; }
   327  
   328  void OpMethod::writeDeclTo(raw_ostream &os) const {
   329    os.indent(2);
   330    if (isStatic())
   331      os << "static ";
   332    methodSignature.writeDeclTo(os);
   333    os << ";";
   334  }
   335  
   336  void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
   337    if (isDeclOnly)
   338      return;
   339  
   340    methodSignature.writeDefTo(os, namePrefix);
   341    os << " {\n";
   342    methodBody.writeTo(os);
   343    os << "}";
   344  }
   345  
   346  Class::Class(StringRef name) : className(name) {}
   347  
   348  OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params,
   349                             OpMethod::Property property, bool declOnly) {
   350    methods.emplace_back(retType, name, params, property, declOnly);
   351    return methods.back();
   352  }
   353  
   354  OpMethod &Class::newConstructor(StringRef params, bool declOnly) {
   355    return newMethod("", getClassName(), params, OpMethod::MP_Constructor,
   356                     declOnly);
   357  }
   358  
   359  void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
   360    std::string varName = formatv("{0} {1}", type, name).str();
   361    std::string field = defaultValue.empty()
   362                            ? varName
   363                            : formatv("{0} = {1}", varName, defaultValue).str();
   364    fields.push_back(std::move(field));
   365  }
   366  
   367  void Class::writeDeclTo(raw_ostream &os) const {
   368    bool hasPrivateMethod = false;
   369    os << "class " << className << " {\n";
   370    os << "public:\n";
   371    for (const auto &method : methods) {
   372      if (!method.isPrivate()) {
   373        method.writeDeclTo(os);
   374        os << '\n';
   375      } else {
   376        hasPrivateMethod = true;
   377      }
   378    }
   379    os << '\n';
   380    os << "private:\n";
   381    if (hasPrivateMethod) {
   382      for (const auto &method : methods) {
   383        if (method.isPrivate()) {
   384          method.writeDeclTo(os);
   385          os << '\n';
   386        }
   387      }
   388      os << '\n';
   389    }
   390    for (const auto &field : fields)
   391      os.indent(2) << field << ";\n";
   392    os << "};\n";
   393  }
   394  
   395  void Class::writeDefTo(raw_ostream &os) const {
   396    for (const auto &method : methods) {
   397      method.writeDefTo(os, className);
   398      os << "\n\n";
   399    }
   400  }
   401  
   402  OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
   403      : Class(name), extraClassDeclaration(extraClassDeclaration) {}
   404  
   405  // Adds the given trait to this op.
   406  void OpClass::addTrait(Twine trait) { traits.push_back(trait.str()); }
   407  
   408  void OpClass::writeDeclTo(raw_ostream &os) const {
   409    os << "class " << className << " : public Op<" << className;
   410    for (const auto &trait : traits)
   411      os << ", " << trait;
   412    os << "> {\npublic:\n";
   413    os << "  using Op::Op;\n";
   414    os << "  using OperandAdaptor = " << className << "OperandAdaptor;\n";
   415  
   416    bool hasPrivateMethod = false;
   417    for (const auto &method : methods) {
   418      if (!method.isPrivate()) {
   419        method.writeDeclTo(os);
   420        os << "\n";
   421      } else {
   422        hasPrivateMethod = true;
   423      }
   424    }
   425  
   426    // TODO: Add line control markers to make errors easier to debug.
   427    if (!extraClassDeclaration.empty())
   428      os << extraClassDeclaration << "\n";
   429  
   430    if (hasPrivateMethod) {
   431      os << '\n';
   432      os << "private:\n";
   433      for (const auto &method : methods) {
   434        if (method.isPrivate()) {
   435          method.writeDeclTo(os);
   436          os << "\n";
   437        }
   438      }
   439    }
   440  
   441    os << "};\n";
   442  }
   443  
   444  //===----------------------------------------------------------------------===//
   445  // Op emitter
   446  //===----------------------------------------------------------------------===//
   447  
   448  namespace {
   449  // Helper class to emit a record into the given output stream.
   450  class OpEmitter {
   451  public:
   452    static void emitDecl(const Operator &op, raw_ostream &os);
   453    static void emitDef(const Operator &op, raw_ostream &os);
   454  
   455  private:
   456    OpEmitter(const Operator &op);
   457  
   458    void emitDecl(raw_ostream &os);
   459    void emitDef(raw_ostream &os);
   460  
   461    // Generates the `getOperationName` method for this op.
   462    void genOpNameGetter();
   463  
   464    // Generates getters for the attributes.
   465    void genAttrGetters();
   466  
   467    // Generates getters for named operands.
   468    void genNamedOperandGetters();
   469  
   470    // Generates getters for named results.
   471    void genNamedResultGetters();
   472  
   473    // Generates getters for named regions.
   474    void genNamedRegionGetters();
   475  
   476    // Generates builder methods for the operation.
   477    void genBuilder();
   478  
   479    // Generates the build() method that takes each result-type/operand/attribute
   480    // as a stand-alone parameter. This build() method also requires specifying
   481    // result types for all results.
   482    void genSeparateParamBuilder();
   483  
   484    // Generates the build() method that takes a single parameter for all the
   485    // result types and a separate parameter for each operand/attribute.
   486    void genCollectiveTypeParamBuilder();
   487  
   488    // Generates the build() method that takes each operand/attribute as a
   489    // stand-alone parameter. This build() method uses first operand's type
   490    // as all result's types.
   491    void genUseOperandAsResultTypeBuilder();
   492  
   493    // Generates the build() method that takes each operand/attribute as a
   494    // stand-alone parameter. This build() method uses first attribute's type
   495    // as all result's types.
   496    void genUseAttrAsResultTypeBuilder();
   497  
   498    // Generates the build() method that takes all result types collectively as
   499    // one parameter. Similarly for operands and attributes.
   500    void genCollectiveParamBuilder();
   501  
   502    enum class TypeParamKind { None, Separate, Collective };
   503  
   504    // Builds the parameter list for build() method of this op. This method writes
   505    // to `paramList` the comma-separated parameter list. If `includeResultTypes`
   506    // is true then `paramList` will also contain the parameters for all results
   507    // and `resultTypeNames` will be populated with the parameter name for each
   508    // result type.
   509    void buildParamList(std::string &paramList,
   510                        SmallVectorImpl<std::string> &resultTypeNames,
   511                        TypeParamKind kind);
   512  
   513    // Adds op arguments and regions into operation state for build() methods.
   514    void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body);
   515  
   516    // Generates canonicalizer declaration for the operation.
   517    void genCanonicalizerDecls();
   518  
   519    // Generates the folder declaration for the operation.
   520    void genFolderDecls();
   521  
   522    // Generates the parser for the operation.
   523    void genParser();
   524  
   525    // Generates the printer for the operation.
   526    void genPrinter();
   527  
   528    // Generates verify method for the operation.
   529    void genVerifier();
   530  
   531    // Generates verify statements for operands and results in the operation.
   532    // The generated code will be attached to `body`.
   533    void genOperandResultVerifier(OpMethodBody &body,
   534                                  Operator::value_range values,
   535                                  StringRef valueKind);
   536  
   537    // Generates verify statements for regions in the operation.
   538    // The generated code will be attached to `body`.
   539    void genRegionVerifier(OpMethodBody &body);
   540  
   541    // Generates the traits used by the object.
   542    void genTraits();
   543  
   544  private:
   545    // The TableGen record for this op.
   546    // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
   547    // it should rather go through the Operator for better abstraction.
   548    const Record &def;
   549  
   550    // The wrapper operator class for querying information from this op.
   551    Operator op;
   552  
   553    // The C++ code builder for this op
   554    OpClass opClass;
   555  
   556    // The format context for verification code generation.
   557    FmtContext verifyCtx;
   558  };
   559  } // end anonymous namespace
   560  
   561  OpEmitter::OpEmitter(const Operator &op)
   562      : def(op.getDef()), op(op),
   563        opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
   564    verifyCtx.withOp("(*this->getOperation())");
   565  
   566    genTraits();
   567    // Generate C++ code for various op methods. The order here determines the
   568    // methods in the generated file.
   569    genOpNameGetter();
   570    genNamedOperandGetters();
   571    genNamedResultGetters();
   572    genNamedRegionGetters();
   573    genAttrGetters();
   574    genBuilder();
   575    genParser();
   576    genPrinter();
   577    genVerifier();
   578    genCanonicalizerDecls();
   579    genFolderDecls();
   580  }
   581  
   582  void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
   583    OpEmitter(op).emitDecl(os);
   584  }
   585  
   586  void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
   587    OpEmitter(op).emitDef(os);
   588  }
   589  
   590  void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
   591  
   592  void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
   593  
   594  void OpEmitter::genAttrGetters() {
   595    FmtContext fctx;
   596    fctx.withBuilder("mlir::Builder(this->getContext())");
   597    for (auto &namedAttr : op.getAttributes()) {
   598      const auto &name = namedAttr.name;
   599      const auto &attr = namedAttr.attr;
   600  
   601      auto &method = opClass.newMethod(attr.getReturnType(), name);
   602      auto &body = method.body();
   603  
   604      // Emit the derived attribute body.
   605      if (attr.isDerivedAttr()) {
   606        body << "  " << attr.getDerivedCodeBody() << "\n";
   607        continue;
   608      }
   609  
   610      // Emit normal emitter.
   611  
   612      // Return the queried attribute with the correct return type.
   613      auto attrVal =
   614          (attr.hasDefaultValueInitializer() || attr.isOptional())
   615              ? formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()", name,
   616                        attr.getStorageType())
   617              : formatv("this->getAttr(\"{0}\").cast<{1}>()", name,
   618                        attr.getStorageType());
   619      body << "  auto attr = " << attrVal << ";\n";
   620      if (attr.hasDefaultValueInitializer()) {
   621        // Returns the default value if not set.
   622        // TODO: this is inefficient, we are recreating the attribute for every
   623        // call. This should be set instead.
   624        std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx,
   625                                         attr.getDefaultValueInitializer());
   626        body << "    if (!attr)\n      return "
   627             << tgfmt(attr.getConvertFromStorageCall(),
   628                      &fctx.withSelf(defaultValue))
   629             << ";\n";
   630      }
   631      body << "  return "
   632           << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
   633           << ";\n";
   634    }
   635  }
   636  
   637  // Generates the named operand getter methods for the given Operator `op` and
   638  // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
   639  // return a range of operands (individual operands are `Value *` and each
   640  // element in the range must also be `Value *`); use `rangeBeginCall` to get an
   641  // iterator to the beginning of the operand range; use `rangeSizeCall` to obtain
   642  // the number of operands. `getOperandCallPattern` contains the code necessary
   643  // to obtain a single operand whose position will be substituted instead of
   644  // "{0}" marker in the pattern.  Note that the pattern should work for any kind
   645  // of ops, in particular for one-operand ops that may not have the
   646  // `getOperand(unsigned)` method.
   647  static void generateNamedOperandGetters(const Operator &op, Class &opClass,
   648                                          StringRef rangeType,
   649                                          StringRef rangeBeginCall,
   650                                          StringRef rangeSizeCall,
   651                                          StringRef getOperandCallPattern) {
   652    const int numOperands = op.getNumOperands();
   653    const int numVariadicOperands = op.getNumVariadicOperands();
   654    const int numNormalOperands = numOperands - numVariadicOperands;
   655  
   656    if (numVariadicOperands > 1 &&
   657        !op.hasTrait("OpTrait::SameVariadicOperandSize")) {
   658      PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
   659                                   "specification over their sizes");
   660    }
   661  
   662    // First emit a "sink" getter method upon which we layer all nicer named
   663    // getter methods.
   664    auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index");
   665  
   666    if (numVariadicOperands == 0) {
   667      // We still need to match the return type, which is a range.
   668      m.body() << "return {std::next(" << rangeBeginCall << ", index), std::next("
   669               << rangeBeginCall << ", index + 1)};";
   670    } else {
   671      // Because the op can have arbitrarily interleaved variadic and non-variadic
   672      // operands, we need to embed a list in the "sink" getter method for
   673      // calculation at run-time.
   674      llvm::SmallVector<StringRef, 4> isVariadic;
   675      isVariadic.reserve(numOperands);
   676      for (int i = 0; i < numOperands; ++i) {
   677        isVariadic.push_back(llvm::toStringRef(op.getOperand(i).isVariadic()));
   678      }
   679      std::string isVariadicList = llvm::join(isVariadic, ", ");
   680  
   681      m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalOperands,
   682                          numVariadicOperands, rangeSizeCall, rangeBeginCall,
   683                          "operand");
   684    }
   685  
   686    // Then we emit nicer named getter methods by redirecting to the "sink" getter
   687    // method.
   688  
   689    for (int i = 0; i != numOperands; ++i) {
   690      const auto &operand = op.getOperand(i);
   691      if (operand.name.empty())
   692        continue;
   693  
   694      if (operand.isVariadic()) {
   695        auto &m = opClass.newMethod(rangeType, operand.name);
   696        m.body() << "return getODSOperands(" << i << ");";
   697      } else {
   698        auto &m = opClass.newMethod("Value *", operand.name);
   699        m.body() << "return *getODSOperands(" << i << ").begin();";
   700      }
   701    }
   702  }
   703  
   704  void OpEmitter::genNamedOperandGetters() {
   705    generateNamedOperandGetters(
   706        op, opClass, /*rangeType=*/"Operation::operand_range",
   707        /*rangeBeginCall=*/"getOperation()->operand_begin()",
   708        /*rangeSizeCall=*/"getOperation()->getNumOperands()",
   709        /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
   710  }
   711  
   712  void OpEmitter::genNamedResultGetters() {
   713    const int numResults = op.getNumResults();
   714    const int numVariadicResults = op.getNumVariadicResults();
   715    const int numNormalResults = numResults - numVariadicResults;
   716  
   717    // If we have more than one variadic results, we need more complicated logic
   718    // to calculate the value range for each result.
   719  
   720    if (numVariadicResults > 1 &&
   721        !op.hasTrait("OpTrait::SameVariadicResultSize")) {
   722      PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
   723                                   "specification over their sizes");
   724    }
   725  
   726    auto &m = opClass.newMethod("Operation::result_range", "getODSResults",
   727                                "unsigned index");
   728  
   729    if (numVariadicResults == 0) {
   730      m.body() << "return {std::next(getOperation()->result_begin(), index), "
   731                  "std::next(getOperation()->result_begin(), index + 1)};";
   732    } else {
   733      llvm::SmallVector<StringRef, 4> isVariadic;
   734      isVariadic.reserve(numResults);
   735      for (int i = 0; i < numResults; ++i) {
   736        isVariadic.push_back(llvm::toStringRef(op.getResult(i).isVariadic()));
   737      }
   738      std::string isVariadicList = llvm::join(isVariadic, ", ");
   739  
   740      m.body() << formatv(valueRangeCalcCode, isVariadicList, numNormalResults,
   741                          numVariadicResults, "getOperation()->getNumResults()",
   742                          "getOperation()->result_begin()", "result");
   743    }
   744  
   745    for (int i = 0; i != numResults; ++i) {
   746      const auto &result = op.getResult(i);
   747      if (result.name.empty())
   748        continue;
   749  
   750      if (result.isVariadic()) {
   751        auto &m = opClass.newMethod("Operation::result_range", result.name);
   752        m.body() << "return getODSResults(" << i << ");";
   753      } else {
   754        auto &m = opClass.newMethod("Value *", result.name);
   755        m.body() << "return *getODSResults(" << i << ").begin();";
   756      }
   757    }
   758  }
   759  
   760  void OpEmitter::genNamedRegionGetters() {
   761    unsigned numRegions = op.getNumRegions();
   762    for (unsigned i = 0; i < numRegions; ++i) {
   763      const auto &region = op.getRegion(i);
   764      if (!region.name.empty()) {
   765        auto &m = opClass.newMethod("Region &", region.name);
   766        m.body() << formatv("return this->getOperation()->getRegion({0});", i);
   767      }
   768    }
   769  }
   770  
   771  void OpEmitter::genSeparateParamBuilder() {
   772    std::string paramList;
   773    llvm::SmallVector<std::string, 4> resultNames;
   774    buildParamList(paramList, resultNames, TypeParamKind::Separate);
   775  
   776    auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
   777    genCodeForAddingArgAndRegionForBuilder(m.body());
   778  
   779    // Push all result types to the operation state
   780    for (int i = 0, e = op.getNumResults(); i < e; ++i) {
   781      m.body() << "  " << builderOpState << "->addTypes(" << resultNames[i]
   782               << ");\n";
   783    }
   784  }
   785  
   786  void OpEmitter::genCollectiveTypeParamBuilder() {
   787    auto numResults = op.getNumResults();
   788  
   789    // If this op has no results, then just skip generating this builder.
   790    // Otherwise we are generating the same signature as the separate-parameter
   791    // builder.
   792    if (numResults == 0)
   793      return;
   794  
   795    // Similarly for ops with one single variadic result, which will also have one
   796    // `ArrayRef<Type>` parameter for the result type.
   797    if (numResults == 1 && op.getResult(0).isVariadic())
   798      return;
   799  
   800    std::string paramList;
   801    llvm::SmallVector<std::string, 4> resultNames;
   802    buildParamList(paramList, resultNames, TypeParamKind::Collective);
   803  
   804    auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
   805    genCodeForAddingArgAndRegionForBuilder(m.body());
   806  
   807    // Push all result types to the operation state
   808    m.body() << formatv("  {0}->addTypes(resultTypes);\n", builderOpState);
   809  }
   810  
   811  void OpEmitter::genUseOperandAsResultTypeBuilder() {
   812    std::string paramList;
   813    llvm::SmallVector<std::string, 4> resultNames;
   814    buildParamList(paramList, resultNames, TypeParamKind::None);
   815  
   816    auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
   817    genCodeForAddingArgAndRegionForBuilder(m.body());
   818  
   819    auto numResults = op.getNumResults();
   820    if (numResults == 0)
   821      return;
   822  
   823    // Push all result types to the operation state
   824    const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
   825    std::string resultType =
   826        formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
   827    m.body() << "  " << builderOpState << "->addTypes({" << resultType;
   828    for (int i = 1; i != numResults; ++i)
   829      m.body() << ", " << resultType;
   830    m.body() << "});\n\n";
   831  }
   832  
   833  void OpEmitter::genUseAttrAsResultTypeBuilder() {
   834    std::string paramList;
   835    llvm::SmallVector<std::string, 4> resultNames;
   836    buildParamList(paramList, resultNames, TypeParamKind::None);
   837  
   838    auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
   839    genCodeForAddingArgAndRegionForBuilder(m.body());
   840  
   841    auto numResults = op.getNumResults();
   842    if (numResults == 0)
   843      return;
   844  
   845    // Push all result types to the operation state
   846    std::string resultType;
   847    const auto &namedAttr = op.getAttribute(0);
   848    if (namedAttr.attr.isTypeAttr()) {
   849      resultType = formatv("{0}.getValue()", namedAttr.name);
   850    } else {
   851      resultType = formatv("{0}.getType()", namedAttr.name);
   852    }
   853    m.body() << "  " << builderOpState << "->addTypes({" << resultType;
   854    for (int i = 1; i != numResults; ++i)
   855      m.body() << ", " << resultType;
   856    m.body() << "});\n\n";
   857  }
   858  
   859  void OpEmitter::genBuilder() {
   860    // Handle custom builders if provided.
   861    // TODO(antiagainst): Create wrapper class for OpBuilder to hide the native
   862    // TableGen API calls here.
   863    {
   864      auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
   865      if (listInit) {
   866        for (Init *init : listInit->getValues()) {
   867          Record *builderDef = cast<DefInit>(init)->getDef();
   868          StringRef params = builderDef->getValueAsString("params");
   869          StringRef body = builderDef->getValueAsString("body");
   870          bool hasBody = !body.empty();
   871  
   872          auto &method =
   873              opClass.newMethod("void", "build", params, OpMethod::MP_Static,
   874                                /*declOnly=*/!hasBody);
   875          if (hasBody)
   876            method.body() << body;
   877        }
   878      }
   879      if (op.skipDefaultBuilders()) {
   880        if (!listInit || listInit->empty())
   881          PrintFatalError(
   882              op.getLoc(),
   883              "default builders are skipped and no custom builders provided");
   884        return;
   885      }
   886    }
   887  
   888    // Generate default builders that requires all result type, operands, and
   889    // attributes as parameters.
   890  
   891    // We generate three builders here:
   892    // 1. one having a stand-alone parameter for each result type / operand /
   893    //    attribute, and
   894    genSeparateParamBuilder();
   895    // 2. one having a stand-alone parameter for each operand / attribute and
   896    //    an aggregrated parameter for all result types, and
   897    genCollectiveTypeParamBuilder();
   898    // 3. one having an aggregated parameter for all result types / operands /
   899    //    attributes, and
   900    genCollectiveParamBuilder();
   901    // 4. one having a stand-alone prameter for each operand and attribute,
   902    //    use the first operand or attribute's type as all result types
   903    // to facilitate different call patterns.
   904    if (op.getNumVariadicResults() == 0) {
   905      if (op.hasTrait("OpTrait::SameOperandsAndResultType"))
   906        genUseOperandAsResultTypeBuilder();
   907      if (op.hasTrait("OpTrait::FirstAttrDerivedResultType"))
   908        genUseAttrAsResultTypeBuilder();
   909    }
   910  }
   911  
   912  void OpEmitter::genCollectiveParamBuilder() {
   913    int numResults = op.getNumResults();
   914    int numVariadicResults = op.getNumVariadicResults();
   915    int numNonVariadicResults = numResults - numVariadicResults;
   916  
   917    int numOperands = op.getNumOperands();
   918    int numVariadicOperands = op.getNumVariadicOperands();
   919    int numNonVariadicOperands = numOperands - numVariadicOperands;
   920    // Signature
   921    std::string params =
   922        std::string("Builder *, OperationState *") + builderOpState +
   923        ", ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, "
   924        "ArrayRef<NamedAttribute> attributes";
   925    auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
   926    auto &body = m.body();
   927  
   928    // Result types
   929    if (numVariadicResults == 0 || numNonVariadicResults != 0)
   930      body << "  assert(resultTypes.size()"
   931           << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
   932           << "u && \"mismatched number of return types\");\n";
   933    body << "  " << builderOpState << "->addTypes(resultTypes);\n";
   934  
   935    // Operands
   936    if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
   937      body << "  assert(operands.size()"
   938           << (numVariadicOperands != 0 ? " >= " : " == ")
   939           << numNonVariadicOperands
   940           << "u && \"mismatched number of parameters\");\n";
   941    body << "  " << builderOpState << "->addOperands(operands);\n\n";
   942  
   943    // Attributes
   944    body << "  for (const auto& pair : attributes)\n"
   945         << "    " << builderOpState
   946         << "->addAttribute(pair.first, pair.second);\n";
   947  
   948    // Create the correct number of regions
   949    if (int numRegions = op.getNumRegions()) {
   950      for (int i = 0; i < numRegions; ++i)
   951        m.body() << "  (void)" << builderOpState << "->addRegion();\n";
   952    }
   953  }
   954  
   955  void OpEmitter::buildParamList(std::string &paramList,
   956                                 SmallVectorImpl<std::string> &resultTypeNames,
   957                                 TypeParamKind kind) {
   958    resultTypeNames.clear();
   959    auto numResults = op.getNumResults();
   960    resultTypeNames.reserve(numResults);
   961  
   962    paramList = "Builder *, OperationState *";
   963    paramList.append(builderOpState);
   964  
   965    switch (kind) {
   966    case TypeParamKind::None:
   967      break;
   968    case TypeParamKind::Separate: {
   969      // Add parameters for all return types
   970      for (int i = 0; i < numResults; ++i) {
   971        const auto &result = op.getResult(i);
   972        std::string resultName = result.name;
   973        if (resultName.empty())
   974          resultName = formatv("resultType{0}", i);
   975  
   976        paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
   977        paramList.append(resultName);
   978  
   979        resultTypeNames.emplace_back(std::move(resultName));
   980      }
   981    } break;
   982    case TypeParamKind::Collective: {
   983      paramList.append(", ArrayRef<Type> resultTypes");
   984      resultTypeNames.push_back("resultTypes");
   985    } break;
   986    }
   987  
   988    int numOperands = 0;
   989    int numAttrs = 0;
   990  
   991    // Add parameters for all arguments (operands and attributes).
   992    for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
   993      auto argument = op.getArg(i);
   994      if (argument.is<tblgen::NamedTypeConstraint *>()) {
   995        const auto &operand = op.getOperand(numOperands);
   996        paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
   997                                              : ", Value *");
   998        paramList.append(getArgumentName(op, numOperands));
   999        ++numOperands;
  1000      } else {
  1001        // TODO(antiagainst): Support default initializer for attributes
  1002        const auto &namedAttr = op.getAttribute(numAttrs);
  1003        const auto &attr = namedAttr.attr;
  1004        paramList.append(", ");
  1005        if (attr.isOptional())
  1006          paramList.append("/*optional*/");
  1007        paramList.append(attr.getStorageType());
  1008        paramList.append(" ");
  1009        paramList.append(namedAttr.name);
  1010        ++numAttrs;
  1011      }
  1012    }
  1013  
  1014    if (numOperands + numAttrs != op.getNumArgs())
  1015      PrintFatalError("op arguments must be either operands or attributes");
  1016  }
  1017  
  1018  void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) {
  1019    // Push all operands to the result
  1020    for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
  1021      body << "  " << builderOpState << "->addOperands(" << getArgumentName(op, i)
  1022           << ");\n";
  1023    }
  1024  
  1025    // Push all attributes to the result
  1026    for (const auto &namedAttr : op.getAttributes()) {
  1027      if (!namedAttr.attr.isDerivedAttr()) {
  1028        bool emitNotNullCheck = namedAttr.attr.isOptional();
  1029        if (emitNotNullCheck) {
  1030          body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
  1031        }
  1032        body << formatv("  {0}->addAttribute(\"{1}\", {1});\n", builderOpState,
  1033                        namedAttr.name);
  1034        if (emitNotNullCheck) {
  1035          body << "  }\n";
  1036        }
  1037      }
  1038    }
  1039  
  1040    // Create the correct number of regions
  1041    if (int numRegions = op.getNumRegions()) {
  1042      for (int i = 0; i < numRegions; ++i)
  1043        body << "  (void)" << builderOpState << "->addRegion();\n";
  1044    }
  1045  }
  1046  
  1047  void OpEmitter::genCanonicalizerDecls() {
  1048    if (!def.getValueAsBit("hasCanonicalizer"))
  1049      return;
  1050  
  1051    const char *const params =
  1052        "OwningRewritePatternList &results, MLIRContext *context";
  1053    opClass.newMethod("void", "getCanonicalizationPatterns", params,
  1054                      OpMethod::MP_Static, /*declOnly=*/true);
  1055  }
  1056  
  1057  void OpEmitter::genFolderDecls() {
  1058    bool hasSingleResult = op.getNumResults() == 1;
  1059  
  1060    if (def.getValueAsBit("hasFolder")) {
  1061      if (hasSingleResult) {
  1062        const char *const params = "ArrayRef<Attribute> operands";
  1063        opClass.newMethod("OpFoldResult", "fold", params, OpMethod::MP_None,
  1064                          /*declOnly=*/true);
  1065      } else {
  1066        const char *const params = "ArrayRef<Attribute> operands, "
  1067                                   "SmallVectorImpl<OpFoldResult> &results";
  1068        opClass.newMethod("LogicalResult", "fold", params, OpMethod::MP_None,
  1069                          /*declOnly=*/true);
  1070      }
  1071    }
  1072  }
  1073  
  1074  void OpEmitter::genParser() {
  1075    if (!hasStringAttribute(def, "parser"))
  1076      return;
  1077  
  1078    auto &method = opClass.newMethod(
  1079        "ParseResult", "parse", "OpAsmParser *parser, OperationState *result",
  1080        OpMethod::MP_Static);
  1081    FmtContext fctx;
  1082    fctx.addSubst("cppClass", opClass.getClassName());
  1083    auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
  1084    method.body() << "  " << tgfmt(parser, &fctx);
  1085  }
  1086  
  1087  void OpEmitter::genPrinter() {
  1088    auto valueInit = def.getValueInit("printer");
  1089    CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
  1090    if (!codeInit)
  1091      return;
  1092  
  1093    auto &method = opClass.newMethod("void", "print", "OpAsmPrinter *p");
  1094    FmtContext fctx;
  1095    fctx.addSubst("cppClass", opClass.getClassName());
  1096    auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
  1097    method.body() << "  " << tgfmt(printer, &fctx);
  1098  }
  1099  
  1100  void OpEmitter::genVerifier() {
  1101    auto valueInit = def.getValueInit("verifier");
  1102    CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
  1103    bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
  1104  
  1105    auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
  1106    auto &body = method.body();
  1107  
  1108    // Populate substitutions for attributes and named operands and results.
  1109    for (const auto &namedAttr : op.getAttributes())
  1110      verifyCtx.addSubst(namedAttr.name,
  1111                         formatv("this->getAttr(\"{0}\")", namedAttr.name));
  1112    for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
  1113      auto &value = op.getOperand(i);
  1114      // Skip from from first variadic operands for now. Else getOperand index
  1115      // used below doesn't match.
  1116      if (value.isVariadic())
  1117        break;
  1118      if (!value.name.empty())
  1119        verifyCtx.addSubst(
  1120            value.name, formatv("(*this->getOperation()->getOperand({0}))", i));
  1121    }
  1122    for (int i = 0, e = op.getNumResults(); i < e; ++i) {
  1123      auto &value = op.getResult(i);
  1124      // Skip from from first variadic results for now. Else getResult index used
  1125      // below doesn't match.
  1126      if (value.isVariadic())
  1127        break;
  1128      if (!value.name.empty())
  1129        verifyCtx.addSubst(value.name,
  1130                           formatv("(*this->getOperation()->getResult({0}))", i));
  1131    }
  1132  
  1133    // Verify the attributes have the correct type.
  1134    for (const auto &namedAttr : op.getAttributes()) {
  1135      const auto &attr = namedAttr.attr;
  1136      if (attr.isDerivedAttr())
  1137        continue;
  1138  
  1139      auto attrName = namedAttr.name;
  1140      // Prefix with `tblgen_` to avoid hiding the attribute accessor.
  1141      auto varName = tblgenNamePrefix + attrName;
  1142      body << formatv("  auto {0} = this->getAttr(\"{1}\");\n", varName,
  1143                      attrName);
  1144  
  1145      bool allowMissingAttr =
  1146          attr.hasDefaultValueInitializer() || attr.isOptional();
  1147      if (allowMissingAttr) {
  1148        // If the attribute has a default value, then only verify the predicate if
  1149        // set. This does effectively assume that the default value is valid.
  1150        // TODO: verify the debug value is valid (perhaps in debug mode only).
  1151        body << "  if (" << varName << ") {\n";
  1152      } else {
  1153        body << "  if (!" << varName
  1154             << ") return emitOpError(\"requires attribute '" << attrName
  1155             << "'\");\n  {\n";
  1156      }
  1157  
  1158      auto attrPred = attr.getPredicate();
  1159      if (!attrPred.isNull()) {
  1160        body << tgfmt(
  1161            "    if (!($0)) return emitOpError(\"attribute '$1' "
  1162            "failed to satisfy constraint: $2\");\n",
  1163            /*ctx=*/nullptr,
  1164            tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
  1165            attrName, attr.getDescription());
  1166      }
  1167  
  1168      body << "  }\n";
  1169    }
  1170  
  1171    genOperandResultVerifier(body, op.getOperands(), "operand");
  1172    genOperandResultVerifier(body, op.getResults(), "result");
  1173  
  1174    for (auto &trait : op.getTraits()) {
  1175      if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
  1176        body << tgfmt("  if (!($0)) {\n    "
  1177                      "return emitOpError(\"failed to verify that $1\");\n  }\n",
  1178                      &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
  1179                      t->getDescription());
  1180      }
  1181    }
  1182  
  1183    genRegionVerifier(body);
  1184  
  1185    if (hasCustomVerify) {
  1186      FmtContext fctx;
  1187      fctx.addSubst("cppClass", opClass.getClassName());
  1188      auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r");
  1189      body << "  " << tgfmt(printer, &fctx);
  1190    } else {
  1191      body << "  return mlir::success();\n";
  1192    }
  1193  }
  1194  
  1195  void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
  1196                                           Operator::value_range values,
  1197                                           StringRef valueKind) {
  1198    FmtContext fctx;
  1199  
  1200    body << "  {\n";
  1201    body << "    unsigned index = 0; (void)index;\n";
  1202  
  1203    for (auto staticValue : llvm::enumerate(values)) {
  1204      if (!staticValue.value().hasPredicate())
  1205        continue;
  1206  
  1207      // Emit a loop to check all the dynamic values in the pack.
  1208      body << formatv("    for (Value *v : getODS{0}{1}s({2})) {{\n",
  1209                      // Capitalize the first letter to match the function name
  1210                      valueKind.substr(0, 1).upper(), valueKind.substr(1),
  1211                      staticValue.index());
  1212  
  1213      auto constraint = staticValue.value().constraint;
  1214  
  1215      body << "      (void)v;\n"
  1216           << "      if (!("
  1217           << tgfmt(constraint.getConditionTemplate(),
  1218                    &fctx.withSelf("v->getType()"))
  1219           << ")) {\n"
  1220           << formatv("        return emitOpError(\"{0} #\") << index "
  1221                      "<< \" must be {1}\";\n",
  1222                      valueKind, constraint.getDescription())
  1223           << "      }\n" // if
  1224           << "      ++index;\n"
  1225           << "    }\n"; // for
  1226    }
  1227  
  1228    body << "  }\n";
  1229  }
  1230  
  1231  void OpEmitter::genRegionVerifier(OpMethodBody &body) {
  1232    unsigned numRegions = op.getNumRegions();
  1233  
  1234    // Verify this op has the correct number of regions
  1235    body << formatv(
  1236        "  if (this->getOperation()->getNumRegions() != {0}) {\n    "
  1237        "return emitOpError(\"has incorrect number of regions: expected {0} but "
  1238        "found \") << this->getOperation()->getNumRegions();\n  }\n",
  1239        numRegions);
  1240  
  1241    for (unsigned i = 0; i < numRegions; ++i) {
  1242      const auto &region = op.getRegion(i);
  1243  
  1244      std::string name = formatv("#{0}", i);
  1245      if (!region.name.empty()) {
  1246        name += formatv(" ('{0}')", region.name);
  1247      }
  1248  
  1249      auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str();
  1250      auto constraint = tgfmt(region.constraint.getConditionTemplate(),
  1251                              &verifyCtx.withSelf(getRegion))
  1252                            .str();
  1253  
  1254      body << formatv("  if (!({0})) {\n    "
  1255                      "return emitOpError(\"region {1} failed to verify "
  1256                      "constraint: {2}\");\n  }\n",
  1257                      constraint, name, region.constraint.getDescription());
  1258    }
  1259  }
  1260  
  1261  void OpEmitter::genTraits() {
  1262    int numResults = op.getNumResults();
  1263    int numVariadicResults = op.getNumVariadicResults();
  1264  
  1265    // Add return size trait.
  1266    if (numVariadicResults != 0) {
  1267      if (numResults == numVariadicResults)
  1268        opClass.addTrait("OpTrait::VariadicResults");
  1269      else
  1270        opClass.addTrait("OpTrait::AtLeastNResults<" +
  1271                         Twine(numResults - numVariadicResults) + ">::Impl");
  1272    } else {
  1273      switch (numResults) {
  1274      case 0:
  1275        opClass.addTrait("OpTrait::ZeroResult");
  1276        break;
  1277      case 1:
  1278        opClass.addTrait("OpTrait::OneResult");
  1279        break;
  1280      default:
  1281        opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl");
  1282        break;
  1283      }
  1284    }
  1285  
  1286    for (const auto &trait : op.getTraits()) {
  1287      if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
  1288        opClass.addTrait(opTrait->getTrait());
  1289    }
  1290  
  1291    // Add variadic size trait and normal op traits.
  1292    int numOperands = op.getNumOperands();
  1293    int numVariadicOperands = op.getNumVariadicOperands();
  1294  
  1295    // Add operand size trait.
  1296    if (numVariadicOperands != 0) {
  1297      if (numOperands == numVariadicOperands)
  1298        opClass.addTrait("OpTrait::VariadicOperands");
  1299      else
  1300        opClass.addTrait("OpTrait::AtLeastNOperands<" +
  1301                         Twine(numOperands - numVariadicOperands) + ">::Impl");
  1302    } else {
  1303      switch (numOperands) {
  1304      case 0:
  1305        opClass.addTrait("OpTrait::ZeroOperands");
  1306        break;
  1307      case 1:
  1308        opClass.addTrait("OpTrait::OneOperand");
  1309        break;
  1310      default:
  1311        opClass.addTrait("OpTrait::NOperands<" + Twine(numOperands) + ">::Impl");
  1312        break;
  1313      }
  1314    }
  1315  }
  1316  
  1317  void OpEmitter::genOpNameGetter() {
  1318    auto &method = opClass.newMethod("StringRef", "getOperationName",
  1319                                     /*params=*/"", OpMethod::MP_Static);
  1320    method.body() << "  return \"" << op.getOperationName() << "\";\n";
  1321  }
  1322  
  1323  //===----------------------------------------------------------------------===//
  1324  // OpOperandAdaptor emitter
  1325  //===----------------------------------------------------------------------===//
  1326  
  1327  namespace {
  1328  // Helper class to emit Op operand adaptors to an output stream.  Operand
  1329  // adaptors are wrappers around ArrayRef<Value *> that provide named operand
  1330  // getters identical to those defined in the Op.
  1331  class OpOperandAdaptorEmitter {
  1332  public:
  1333    static void emitDecl(const Operator &op, raw_ostream &os);
  1334    static void emitDef(const Operator &op, raw_ostream &os);
  1335  
  1336  private:
  1337    explicit OpOperandAdaptorEmitter(const Operator &op);
  1338  
  1339    Class adapterClass;
  1340  };
  1341  } // end namespace
  1342  
  1343  OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
  1344      : adapterClass(op.getCppClassName().str() + "OperandAdaptor") {
  1345    adapterClass.newField("ArrayRef<Value *>", "tblgen_operands");
  1346    auto &constructor = adapterClass.newConstructor("ArrayRef<Value *> values");
  1347    constructor.body() << "  tblgen_operands = values;\n";
  1348  
  1349    generateNamedOperandGetters(op, adapterClass,
  1350                                /*rangeType=*/"ArrayRef<Value *>",
  1351                                /*rangeBeginCall=*/"tblgen_operands.begin()",
  1352                                /*rangeSizeCall=*/"tblgen_operands.size()",
  1353                                /*getOperandCallPattern=*/"tblgen_operands[{0}]");
  1354  }
  1355  
  1356  void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
  1357    OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os);
  1358  }
  1359  
  1360  void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
  1361    OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os);
  1362  }
  1363  
  1364  // Emits the opcode enum and op classes.
  1365  static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
  1366                            bool emitDecl) {
  1367    IfDefScope scope("GET_OP_CLASSES", os);
  1368    // First emit forward declaration for each class, this allows them to refer
  1369    // to each others in traits for example.
  1370    if (emitDecl) {
  1371      for (auto *def : defs) {
  1372        Operator op(*def);
  1373        os << "class " << op.getCppClassName() << ";\n";
  1374      }
  1375    }
  1376    for (auto *def : defs) {
  1377      Operator op(*def);
  1378      if (emitDecl) {
  1379        os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
  1380        OpOperandAdaptorEmitter::emitDecl(op, os);
  1381        OpEmitter::emitDecl(op, os);
  1382      } else {
  1383        os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
  1384        OpOperandAdaptorEmitter::emitDef(op, os);
  1385        OpEmitter::emitDef(op, os);
  1386      }
  1387    }
  1388  }
  1389  
  1390  // Emits a comma-separated list of the ops.
  1391  static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
  1392    IfDefScope scope("GET_OP_LIST", os);
  1393  
  1394    interleave(
  1395        // TODO: We are constructing the Operator wrapper instance just for
  1396        // getting it's qualified class name here. Reduce the overhead by having a
  1397        // lightweight version of Operator class just for that purpose.
  1398        defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
  1399        [&os]() { os << ",\n"; });
  1400  }
  1401  
  1402  static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
  1403    emitSourceFileHeader("Op Declarations", os);
  1404  
  1405    const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
  1406    emitOpClasses(defs, os, /*emitDecl=*/true);
  1407  
  1408    return false;
  1409  }
  1410  
  1411  static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
  1412    emitSourceFileHeader("Op Definitions", os);
  1413  
  1414    const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
  1415    emitOpList(defs, os);
  1416    emitOpClasses(defs, os, /*emitDecl=*/false);
  1417  
  1418    return false;
  1419  }
  1420  
  1421  static mlir::GenRegistration
  1422      genOpDecls("gen-op-decls", "Generate op declarations",
  1423                 [](const RecordKeeper &records, raw_ostream &os) {
  1424                   return emitOpDecls(records, os);
  1425                 });
  1426  
  1427  static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
  1428                                         [](const RecordKeeper &records,
  1429                                            raw_ostream &os) {
  1430                                           return emitOpDefs(records, os);
  1431                                         });