github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/AsmPrinter.cpp (about)

     1  //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements the MLIR AsmPrinter class, which is used to implement
    19  // the various print() methods on the core IR objects.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/IR/AffineExpr.h"
    24  #include "mlir/IR/AffineMap.h"
    25  #include "mlir/IR/Attributes.h"
    26  #include "mlir/IR/Dialect.h"
    27  #include "mlir/IR/Function.h"
    28  #include "mlir/IR/IntegerSet.h"
    29  #include "mlir/IR/MLIRContext.h"
    30  #include "mlir/IR/Module.h"
    31  #include "mlir/IR/OpImplementation.h"
    32  #include "mlir/IR/Operation.h"
    33  #include "mlir/IR/StandardTypes.h"
    34  #include "mlir/Support/STLExtras.h"
    35  #include "llvm/ADT/APFloat.h"
    36  #include "llvm/ADT/DenseMap.h"
    37  #include "llvm/ADT/MapVector.h"
    38  #include "llvm/ADT/STLExtras.h"
    39  #include "llvm/ADT/ScopedHashTable.h"
    40  #include "llvm/ADT/SetVector.h"
    41  #include "llvm/ADT/SmallString.h"
    42  #include "llvm/ADT/StringExtras.h"
    43  #include "llvm/ADT/StringSet.h"
    44  #include "llvm/Support/CommandLine.h"
    45  #include "llvm/Support/Regex.h"
    46  using namespace mlir;
    47  
    48  void Identifier::print(raw_ostream &os) const { os << str(); }
    49  
    50  void Identifier::dump() const { print(llvm::errs()); }
    51  
    52  void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
    53  
    54  void OperationName::dump() const { print(llvm::errs()); }
    55  
    56  OpAsmPrinter::~OpAsmPrinter() {}
    57  
    58  //===----------------------------------------------------------------------===//
    59  // ModuleState
    60  //===----------------------------------------------------------------------===//
    61  
    62  // TODO(riverriddle) Rethink this flag when we have a pass that can remove debug
    63  // info or when we have a system for printer flags.
    64  static llvm::cl::opt<bool>
    65      shouldPrintDebugInfoOpt("mlir-print-debuginfo",
    66                              llvm::cl::desc("Print debug info in MLIR output"),
    67                              llvm::cl::init(false));
    68  
    69  static llvm::cl::opt<bool> printPrettyDebugInfo(
    70      "mlir-pretty-debuginfo",
    71      llvm::cl::desc("Print pretty debug info in MLIR output"),
    72      llvm::cl::init(false));
    73  
    74  // Use the generic op output form in the operation printer even if the custom
    75  // form is defined.
    76  static llvm::cl::opt<bool>
    77      printGenericOpForm("mlir-print-op-generic",
    78                         llvm::cl::desc("Print the generic op form"),
    79                         llvm::cl::init(false), llvm::cl::Hidden);
    80  
    81  namespace {
    82  /// A special index constant used for non-kind attribute aliases.
    83  static constexpr int kNonAttrKindAlias = -1;
    84  
    85  class ModuleState {
    86  public:
    87    explicit ModuleState(MLIRContext *context) : interfaces(context) {}
    88    void initialize(Operation *op);
    89  
    90    Twine getAttributeAlias(Attribute attr) const {
    91      auto alias = attrToAlias.find(attr);
    92      if (alias == attrToAlias.end())
    93        return Twine();
    94  
    95      // Return the alias for this attribute, along with the index if this was
    96      // generated by a kind alias.
    97      int kindIndex = alias->second.second;
    98      return alias->second.first +
    99             (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
   100    }
   101  
   102    void printAttributeAliases(raw_ostream &os) const {
   103      auto printAlias = [&](StringRef alias, Attribute attr, int index) {
   104        os << '#' << alias;
   105        if (index != kNonAttrKindAlias)
   106          os << index;
   107        os << " = " << attr << '\n';
   108      };
   109  
   110      // Print all of the attribute kind aliases.
   111      for (auto &kindAlias : attrKindToAlias) {
   112        for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
   113          printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
   114        os << "\n";
   115      }
   116  
   117      // In a second pass print all of the remaining attribute aliases that aren't
   118      // kind aliases.
   119      for (Attribute attr : usedAttributes) {
   120        auto alias = attrToAlias.find(attr);
   121        if (alias != attrToAlias.end() &&
   122            alias->second.second == kNonAttrKindAlias)
   123          printAlias(alias->second.first, attr, alias->second.second);
   124      }
   125    }
   126  
   127    StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
   128  
   129    void printTypeAliases(raw_ostream &os) const {
   130      for (Type type : usedTypes) {
   131        auto alias = typeToAlias.find(type);
   132        if (alias != typeToAlias.end())
   133          os << '!' << alias->second << " = type " << type << '\n';
   134      }
   135    }
   136  
   137    /// Get an instance of the OpAsmDialectInterface for the given dialect, or
   138    /// null if one wasn't registered.
   139    const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
   140      return interfaces.getInterfaceFor(dialect);
   141    }
   142  
   143  private:
   144    void recordAttributeReference(Attribute attr) {
   145      // Don't recheck attributes that have already been seen or those that
   146      // already have an alias.
   147      if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
   148        return;
   149  
   150      // If this attribute kind has an alias, then record one for this attribute.
   151      auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
   152      if (alias == attrKindToAlias.end())
   153        return;
   154      std::pair<StringRef, int> attrAlias(alias->second.first,
   155                                          alias->second.second.size());
   156      attrToAlias.insert({attr, attrAlias});
   157      alias->second.second.push_back(attr);
   158    }
   159  
   160    void recordTypeReference(Type ty) { usedTypes.insert(ty); }
   161  
   162    // Visit functions.
   163    void visitOperation(Operation *op);
   164    void visitType(Type type);
   165    void visitAttribute(Attribute attr);
   166  
   167    // Initialize symbol aliases.
   168    void initializeSymbolAliases();
   169  
   170    /// Set of attributes known to be used within the module.
   171    llvm::SetVector<Attribute> usedAttributes;
   172  
   173    /// Mapping between attribute and a pair comprised of a base alias name and a
   174    /// count suffix. If the suffix is set to -1, it is not displayed.
   175    llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias;
   176  
   177    /// Mapping between attribute kind and a pair comprised of a base alias name
   178    /// and a unique list of attributes belonging to this kind sorted by location
   179    /// seen in the module.
   180    llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>>
   181        attrKindToAlias;
   182  
   183    /// Set of types known to be used within the module.
   184    llvm::SetVector<Type> usedTypes;
   185  
   186    /// A mapping between a type and a given alias.
   187    DenseMap<Type, StringRef> typeToAlias;
   188  
   189    /// Collection of OpAsm interfaces implemented in the context.
   190    DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
   191  };
   192  } // end anonymous namespace
   193  
   194  // TODO Support visiting other types/operations when implemented.
   195  void ModuleState::visitType(Type type) {
   196    recordTypeReference(type);
   197    if (auto funcType = type.dyn_cast<FunctionType>()) {
   198      // Visit input and result types for functions.
   199      for (auto input : funcType.getInputs())
   200        visitType(input);
   201      for (auto result : funcType.getResults())
   202        visitType(result);
   203      return;
   204    }
   205    if (auto memref = type.dyn_cast<MemRefType>()) {
   206      // Visit affine maps in memref type.
   207      for (auto map : memref.getAffineMaps())
   208        recordAttributeReference(AffineMapAttr::get(map));
   209    }
   210    if (auto shapedType = type.dyn_cast<ShapedType>()) {
   211      visitType(shapedType.getElementType());
   212    }
   213  }
   214  
   215  void ModuleState::visitAttribute(Attribute attr) {
   216    recordAttributeReference(attr);
   217    if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
   218      for (auto elt : arrayAttr.getValue())
   219        visitAttribute(elt);
   220    } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
   221      visitType(typeAttr.getValue());
   222    }
   223  }
   224  
   225  void ModuleState::visitOperation(Operation *op) {
   226    // Visit all the types used in the operation.
   227    for (auto type : op->getOperandTypes())
   228      visitType(type);
   229    for (auto type : op->getResultTypes())
   230      visitType(type);
   231    for (auto &region : op->getRegions())
   232      for (auto &block : region)
   233        for (auto *arg : block.getArguments())
   234          visitType(arg->getType());
   235  
   236    // Visit each of the attributes.
   237    for (auto elt : op->getAttrs())
   238      visitAttribute(elt.second);
   239  }
   240  
   241  // Utility to generate a function to register a symbol alias.
   242  static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
   243    assert(!name.empty() && "expected alias name to be non-empty");
   244    // TODO(riverriddle) Assert that the provided alias name can be lexed as
   245    // an identifier.
   246  
   247    // Check that the alias doesn't contain a '.' character and the name is not
   248    // already in use.
   249    return !name.contains('.') && usedAliases.insert(name).second;
   250  }
   251  
   252  void ModuleState::initializeSymbolAliases() {
   253    // Track the identifiers in use for each symbol so that the same identifier
   254    // isn't used twice.
   255    llvm::StringSet<> usedAliases;
   256  
   257    // Collect the set of aliases from each dialect.
   258    SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases;
   259    SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases;
   260    SmallVector<std::pair<Type, StringRef>, 16> typeAliases;
   261  
   262    // AffineMap/Integer set have specific kind aliases.
   263    attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map");
   264    attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set");
   265  
   266    for (auto &interface : interfaces) {
   267      interface.getAttributeKindAliases(attributeKindAliases);
   268      interface.getAttributeAliases(attributeAliases);
   269      interface.getTypeAliases(typeAliases);
   270    }
   271  
   272    // Setup the attribute kind aliases.
   273    StringRef alias;
   274    unsigned attrKind;
   275    for (auto &attrAliasPair : attributeKindAliases) {
   276      std::tie(attrKind, alias) = attrAliasPair;
   277      assert(!alias.empty() && "expected non-empty alias string");
   278      if (!usedAliases.count(alias) && !alias.contains('.'))
   279        attrKindToAlias.insert({attrKind, {alias, {}}});
   280    }
   281  
   282    // Clear the set of used identifiers so that the attribute kind aliases are
   283    // just a prefix and not the full alias, i.e. there may be some overlap.
   284    usedAliases.clear();
   285  
   286    // Register the attribute aliases.
   287    // Create a regex for the attribute kind alias names, these have a prefix with
   288    // a counter appended to the end. We prevent normal aliases from having these
   289    // names to avoid collisions.
   290    llvm::Regex reservedAttrNames("[0-9]+$");
   291  
   292    // Attribute value aliases.
   293    Attribute attr;
   294    for (auto &attrAliasPair : attributeAliases) {
   295      std::tie(attr, alias) = attrAliasPair;
   296      if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
   297        attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
   298    }
   299  
   300    // Clear the set of used identifiers as types can have the same identifiers as
   301    // affine structures.
   302    usedAliases.clear();
   303  
   304    // Type aliases.
   305    for (auto &typeAliasPair : typeAliases)
   306      if (canRegisterAlias(typeAliasPair.second, usedAliases))
   307        typeToAlias.insert(typeAliasPair);
   308  }
   309  
   310  void ModuleState::initialize(Operation *op) {
   311    // Initialize the symbol aliases.
   312    initializeSymbolAliases();
   313  
   314    // Visit each of the nested operations.
   315    op->walk([&](Operation *op) { visitOperation(op); });
   316  }
   317  
   318  //===----------------------------------------------------------------------===//
   319  // ModulePrinter
   320  //===----------------------------------------------------------------------===//
   321  
   322  namespace {
   323  class ModulePrinter {
   324  public:
   325    ModulePrinter(raw_ostream &os, ModuleState *state = nullptr)
   326        : os(os), state(state) {}
   327    explicit ModulePrinter(ModulePrinter &printer)
   328        : os(printer.os), state(printer.state) {}
   329  
   330    template <typename Container, typename UnaryFunctor>
   331    inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
   332      interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
   333    }
   334  
   335    void print(ModuleOp module);
   336  
   337    /// Print the given attribute. If 'mayElideType' is true, some attributes are
   338    /// printed without the type when the type matches the default used in the
   339    /// parser (for example i64 is the default for integer attributes).
   340    void printAttribute(Attribute attr, bool mayElideType = false);
   341  
   342    void printType(Type type);
   343    void printLocation(LocationAttr loc);
   344  
   345    void printAffineMap(AffineMap map);
   346    void printAffineExpr(
   347        AffineExpr expr,
   348        llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
   349    void printAffineConstraint(AffineExpr expr, bool isEq);
   350    void printIntegerSet(IntegerSet set);
   351  
   352  protected:
   353    void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
   354                               ArrayRef<StringRef> elidedAttrs = {});
   355    void printTrailingLocation(Location loc);
   356    void printLocationInternal(LocationAttr loc, bool pretty = false);
   357    void printDenseElementsAttr(DenseElementsAttr attr);
   358  
   359    /// This enum is used to represent the binding stength of the enclosing
   360    /// context that an AffineExprStorage is being printed in, so we can
   361    /// intelligently produce parens.
   362    enum class BindingStrength {
   363      Weak,   // + and -
   364      Strong, // All other binary operators.
   365    };
   366    void printAffineExprInternal(
   367        AffineExpr expr, BindingStrength enclosingTightness,
   368        llvm::function_ref<void(unsigned, bool)> printValueName = nullptr);
   369  
   370    /// The output stream for the printer.
   371    raw_ostream &os;
   372  
   373    /// An optional printer state for the module.
   374    ModuleState *state;
   375  };
   376  } // end anonymous namespace
   377  
   378  void ModulePrinter::printTrailingLocation(Location loc) {
   379    // Check to see if we are printing debug information.
   380    if (!shouldPrintDebugInfoOpt)
   381      return;
   382  
   383    os << " ";
   384    printLocation(loc);
   385  }
   386  
   387  void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
   388    switch (loc.getKind()) {
   389    case StandardAttributes::UnknownLocation:
   390      if (pretty)
   391        os << "[unknown]";
   392      else
   393        os << "unknown";
   394      break;
   395    case StandardAttributes::FileLineColLocation: {
   396      auto fileLoc = loc.cast<FileLineColLoc>();
   397      auto mayQuote = pretty ? "" : "\"";
   398      os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
   399         << fileLoc.getLine() << ':' << fileLoc.getColumn();
   400      break;
   401    }
   402    case StandardAttributes::NameLocation: {
   403      auto nameLoc = loc.cast<NameLoc>();
   404      os << '\"' << nameLoc.getName() << '\"';
   405  
   406      // Print the child if it isn't unknown.
   407      auto childLoc = nameLoc.getChildLoc();
   408      if (!childLoc.isa<UnknownLoc>()) {
   409        os << '(';
   410        printLocationInternal(childLoc, pretty);
   411        os << ')';
   412      }
   413      break;
   414    }
   415    case StandardAttributes::CallSiteLocation: {
   416      auto callLocation = loc.cast<CallSiteLoc>();
   417      auto caller = callLocation.getCaller();
   418      auto callee = callLocation.getCallee();
   419      if (!pretty)
   420        os << "callsite(";
   421      printLocationInternal(callee, pretty);
   422      if (pretty) {
   423        if (callee.isa<NameLoc>()) {
   424          if (caller.isa<FileLineColLoc>()) {
   425            os << " at ";
   426          } else {
   427            os << "\n at ";
   428          }
   429        } else {
   430          os << "\n at ";
   431        }
   432      } else {
   433        os << " at ";
   434      }
   435      printLocationInternal(caller, pretty);
   436      if (!pretty)
   437        os << ")";
   438      break;
   439    }
   440    case StandardAttributes::FusedLocation: {
   441      auto fusedLoc = loc.cast<FusedLoc>();
   442      if (!pretty)
   443        os << "fused";
   444      if (auto metadata = fusedLoc.getMetadata())
   445        os << '<' << metadata << '>';
   446      os << '[';
   447      interleave(
   448          fusedLoc.getLocations(),
   449          [&](Location loc) { printLocationInternal(loc, pretty); },
   450          [&]() { os << ", "; });
   451      os << ']';
   452      break;
   453    }
   454    }
   455  }
   456  
   457  /// Print a floating point value in a way that the parser will be able to
   458  /// round-trip losslessly.
   459  static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
   460    // We would like to output the FP constant value in exponential notation,
   461    // but we cannot do this if doing so will lose precision.  Check here to
   462    // make sure that we only output it in exponential format if we can parse
   463    // the value back and get the same value.
   464    bool isInf = apValue.isInfinity();
   465    bool isNaN = apValue.isNaN();
   466    if (!isInf && !isNaN) {
   467      SmallString<128> strValue;
   468      apValue.toString(strValue, 6, 0, false);
   469  
   470      // Check to make sure that the stringized number is not some string like
   471      // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
   472      // that the string matches the "[-+]?[0-9]" regex.
   473      assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
   474              ((strValue[0] == '-' || strValue[0] == '+') &&
   475               (strValue[1] >= '0' && strValue[1] <= '9'))) &&
   476             "[-+]?[0-9] regex does not match!");
   477  
   478      // Parse back the stringized version and check that the value is equal
   479      // (i.e., there is no precision loss). If it is not, use the default format
   480      // of APFloat instead of the exponential notation.
   481      if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
   482        strValue.clear();
   483        apValue.toString(strValue);
   484      }
   485      os << strValue;
   486      return;
   487    }
   488  
   489    // Print special values in hexadecimal format.  The sign bit should be
   490    // included in the literal.
   491    SmallVector<char, 16> str;
   492    APInt apInt = apValue.bitcastToAPInt();
   493    apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
   494                   /*formatAsCLiteral=*/true);
   495    os << str;
   496  }
   497  
   498  void ModulePrinter::printLocation(LocationAttr loc) {
   499    if (printPrettyDebugInfo) {
   500      printLocationInternal(loc, /*pretty=*/true);
   501    } else {
   502      os << "loc(";
   503      printLocationInternal(loc);
   504      os << ')';
   505    }
   506  }
   507  
   508  /// Returns if the given dialect symbol data is simple enough to print in the
   509  /// pretty form, i.e. without the enclosing "".
   510  static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
   511    // The name must start with an identifier.
   512    if (symName.empty() || !isalpha(symName.front()))
   513      return false;
   514  
   515    // Ignore all the characters that are valid in an identifier in the symbol
   516    // name.
   517    symName =
   518        symName.drop_while([](char c) { return llvm::isAlnum(c) || c == '.'; });
   519    if (symName.empty())
   520      return true;
   521  
   522    // If we got to an unexpected character, then it must be a <>.  Check those
   523    // recursively.
   524    if (symName.front() != '<' || symName.back() != '>')
   525      return false;
   526  
   527    SmallVector<char, 8> nestedPunctuation;
   528    do {
   529      // If we ran out of characters, then we had a punctuation mismatch.
   530      if (symName.empty())
   531        return false;
   532  
   533      auto c = symName.front();
   534      symName = symName.drop_front();
   535  
   536      switch (c) {
   537      // We never allow null characters. This is an EOF indicator for the lexer
   538      // which we could handle, but isn't important for any known dialect.
   539      case '\0':
   540        return false;
   541      case '<':
   542      case '[':
   543      case '(':
   544      case '{':
   545        nestedPunctuation.push_back(c);
   546        continue;
   547      case '-':
   548        // Treat `->` as a special token.
   549        if (!symName.empty() && symName.front() == '>') {
   550          symName = symName.drop_front();
   551          continue;
   552        }
   553        break;
   554      // Reject types with mismatched brackets.
   555      case '>':
   556        if (nestedPunctuation.pop_back_val() != '<')
   557          return false;
   558        break;
   559      case ']':
   560        if (nestedPunctuation.pop_back_val() != '[')
   561          return false;
   562        break;
   563      case ')':
   564        if (nestedPunctuation.pop_back_val() != '(')
   565          return false;
   566        break;
   567      case '}':
   568        if (nestedPunctuation.pop_back_val() != '{')
   569          return false;
   570        break;
   571      default:
   572        continue;
   573      }
   574  
   575      // We're done when the punctuation is fully matched.
   576    } while (!nestedPunctuation.empty());
   577  
   578    // If there were extra characters, then we failed.
   579    return symName.empty();
   580  }
   581  
   582  /// Print the given dialect symbol to the stream.
   583  static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
   584                                 StringRef dialectName, StringRef symString) {
   585    os << symPrefix << dialectName;
   586  
   587    // If this symbol name is simple enough, print it directly in pretty form,
   588    // otherwise, we print it as an escaped string.
   589    if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
   590      os << '.' << symString;
   591      return;
   592    }
   593  
   594    // TODO: escape the symbol name, it could contain " characters.
   595    os << "<\"" << symString << "\">";
   596  }
   597  
   598  void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
   599    if (!attr) {
   600      os << "<<NULL ATTRIBUTE>>";
   601      return;
   602    }
   603  
   604    // Check for an alias for this attribute.
   605    if (state) {
   606      Twine alias = state->getAttributeAlias(attr);
   607      if (!alias.isTriviallyEmpty()) {
   608        os << '#' << alias;
   609        return;
   610      }
   611    }
   612  
   613    switch (attr.getKind()) {
   614    default: {
   615      auto &dialect = attr.getDialect();
   616  
   617      // Ask the dialect to serialize the attribute to a string.
   618      std::string attrName;
   619      {
   620        llvm::raw_string_ostream attrNameStr(attrName);
   621        dialect.printAttribute(attr, attrNameStr);
   622      }
   623  
   624      printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
   625      break;
   626    }
   627    case StandardAttributes::Opaque: {
   628      auto opaqueAttr = attr.cast<OpaqueAttr>();
   629      printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
   630                         opaqueAttr.getAttrData());
   631      break;
   632    }
   633    case StandardAttributes::Unit:
   634      os << "unit";
   635      break;
   636    case StandardAttributes::Bool:
   637      os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
   638  
   639      // BoolAttr always elides the type.
   640      return;
   641    case StandardAttributes::Dictionary:
   642      os << '{';
   643      interleaveComma(attr.cast<DictionaryAttr>().getValue(),
   644                      [&](NamedAttribute attr) {
   645                        os << attr.first << " = ";
   646                        printAttribute(attr.second);
   647                      });
   648      os << '}';
   649      break;
   650    case StandardAttributes::Integer: {
   651      auto intAttr = attr.cast<IntegerAttr>();
   652      // Print all integer attributes as signed unless i1.
   653      bool isSigned = intAttr.getType().isIndex() ||
   654                      intAttr.getType().getIntOrFloatBitWidth() != 1;
   655      intAttr.getValue().print(os, isSigned);
   656  
   657      // IntegerAttr elides the type if I64.
   658      if (mayElideType && intAttr.getType().isInteger(64))
   659        return;
   660      break;
   661    }
   662    case StandardAttributes::Float: {
   663      auto floatAttr = attr.cast<FloatAttr>();
   664      printFloatValue(floatAttr.getValue(), os);
   665  
   666      // FloatAttr elides the type if F64.
   667      if (mayElideType && floatAttr.getType().isF64())
   668        return;
   669      break;
   670    }
   671    case StandardAttributes::String:
   672      os << '"';
   673      printEscapedString(attr.cast<StringAttr>().getValue(), os);
   674      os << '"';
   675      break;
   676    case StandardAttributes::Array:
   677      os << '[';
   678      interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
   679        printAttribute(attr, /*mayElideType=*/true);
   680      });
   681      os << ']';
   682      break;
   683    case StandardAttributes::AffineMap:
   684      attr.cast<AffineMapAttr>().getValue().print(os);
   685  
   686      // AffineMap always elides the type.
   687      return;
   688    case StandardAttributes::IntegerSet:
   689      attr.cast<IntegerSetAttr>().getValue().print(os);
   690      break;
   691    case StandardAttributes::Type:
   692      printType(attr.cast<TypeAttr>().getValue());
   693      break;
   694    case StandardAttributes::SymbolRef:
   695      os << '@' << attr.cast<SymbolRefAttr>().getValue();
   696      break;
   697    case StandardAttributes::OpaqueElements: {
   698      auto eltsAttr = attr.cast<OpaqueElementsAttr>();
   699      os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
   700      os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
   701      break;
   702    }
   703    case StandardAttributes::DenseElements: {
   704      auto eltsAttr = attr.cast<DenseElementsAttr>();
   705      os << "dense<";
   706      printDenseElementsAttr(eltsAttr);
   707      os << '>';
   708      break;
   709    }
   710    case StandardAttributes::SparseElements: {
   711      auto elementsAttr = attr.cast<SparseElementsAttr>();
   712      os << "sparse<";
   713      printDenseElementsAttr(elementsAttr.getIndices());
   714      os << ", ";
   715      printDenseElementsAttr(elementsAttr.getValues());
   716      os << '>';
   717      break;
   718    }
   719  
   720    // Location attributes.
   721    case StandardAttributes::CallSiteLocation:
   722    case StandardAttributes::FileLineColLocation:
   723    case StandardAttributes::FusedLocation:
   724    case StandardAttributes::NameLocation:
   725    case StandardAttributes::UnknownLocation:
   726      printLocation(attr.cast<LocationAttr>());
   727      break;
   728    }
   729  
   730    // Print the type if it isn't a 'none' type.
   731    auto attrType = attr.getType();
   732    if (!attrType.isa<NoneType>()) {
   733      os << " : ";
   734      printType(attrType);
   735    }
   736  }
   737  
   738  /// Print the integer element of the given DenseElementsAttr at 'index'.
   739  static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
   740                                   unsigned index) {
   741    APInt value = *std::next(attr.int_value_begin(), index);
   742    if (value.getBitWidth() == 1)
   743      os << (value.getBoolValue() ? "true" : "false");
   744    else
   745      value.print(os, /*isSigned=*/true);
   746  }
   747  
   748  /// Print the float element of the given DenseElementsAttr at 'index'.
   749  static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
   750                                     unsigned index) {
   751    APFloat value = *std::next(attr.float_value_begin(), index);
   752    printFloatValue(value, os);
   753  }
   754  
   755  void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
   756    auto type = attr.getType();
   757    auto shape = type.getShape();
   758    auto rank = type.getRank();
   759  
   760    // The function used to print elements of this attribute.
   761    auto printEltFn = type.getElementType().isa<IntegerType>()
   762                          ? printDenseIntElement
   763                          : printDenseFloatElement;
   764  
   765    // Special case for 0-d and splat tensors.
   766    if (attr.isSplat()) {
   767      printEltFn(attr, os, 0);
   768      return;
   769    }
   770  
   771    // Special case for degenerate tensors.
   772    auto numElements = type.getNumElements();
   773    if (numElements == 0) {
   774      for (int i = 0; i < rank; ++i)
   775        os << '[';
   776      for (int i = 0; i < rank; ++i)
   777        os << ']';
   778      return;
   779    }
   780  
   781    // We use a mixed-radix counter to iterate through the shape. When we bump a
   782    // non-least-significant digit, we emit a close bracket. When we next emit an
   783    // element we re-open all closed brackets.
   784  
   785    // The mixed-radix counter, with radices in 'shape'.
   786    SmallVector<unsigned, 4> counter(rank, 0);
   787    // The number of brackets that have been opened and not closed.
   788    unsigned openBrackets = 0;
   789  
   790    auto bumpCounter = [&]() {
   791      // Bump the least significant digit.
   792      ++counter[rank - 1];
   793      // Iterate backwards bubbling back the increment.
   794      for (unsigned i = rank - 1; i > 0; --i)
   795        if (counter[i] >= shape[i]) {
   796          // Index 'i' is rolled over. Bump (i-1) and close a bracket.
   797          counter[i] = 0;
   798          ++counter[i - 1];
   799          --openBrackets;
   800          os << ']';
   801        }
   802    };
   803  
   804    for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
   805      if (idx != 0)
   806        os << ", ";
   807      while (openBrackets++ < rank)
   808        os << '[';
   809      openBrackets = rank;
   810      printEltFn(attr, os, idx);
   811      bumpCounter();
   812    }
   813    while (openBrackets-- > 0)
   814      os << ']';
   815  }
   816  
   817  void ModulePrinter::printType(Type type) {
   818    // Check for an alias for this type.
   819    if (state) {
   820      StringRef alias = state->getTypeAlias(type);
   821      if (!alias.empty()) {
   822        os << '!' << alias;
   823        return;
   824      }
   825    }
   826  
   827    switch (type.getKind()) {
   828    default: {
   829      auto &dialect = type.getDialect();
   830  
   831      // Ask the dialect to serialize the type to a string.
   832      std::string typeName;
   833      {
   834        llvm::raw_string_ostream typeNameStr(typeName);
   835        dialect.printType(type, typeNameStr);
   836      }
   837  
   838      printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
   839      return;
   840    }
   841    case Type::Kind::Opaque: {
   842      auto opaqueTy = type.cast<OpaqueType>();
   843      printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
   844                         opaqueTy.getTypeData());
   845      return;
   846    }
   847    case StandardTypes::Index:
   848      os << "index";
   849      return;
   850    case StandardTypes::BF16:
   851      os << "bf16";
   852      return;
   853    case StandardTypes::F16:
   854      os << "f16";
   855      return;
   856    case StandardTypes::F32:
   857      os << "f32";
   858      return;
   859    case StandardTypes::F64:
   860      os << "f64";
   861      return;
   862  
   863    case StandardTypes::Integer: {
   864      auto integer = type.cast<IntegerType>();
   865      os << 'i' << integer.getWidth();
   866      return;
   867    }
   868    case Type::Kind::Function: {
   869      auto func = type.cast<FunctionType>();
   870      os << '(';
   871      interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
   872      os << ") -> ";
   873      auto results = func.getResults();
   874      if (results.size() == 1 && !results[0].isa<FunctionType>())
   875        os << results[0];
   876      else {
   877        os << '(';
   878        interleaveComma(results, [&](Type type) { printType(type); });
   879        os << ')';
   880      }
   881      return;
   882    }
   883    case StandardTypes::Vector: {
   884      auto v = type.cast<VectorType>();
   885      os << "vector<";
   886      for (auto dim : v.getShape())
   887        os << dim << 'x';
   888      os << v.getElementType() << '>';
   889      return;
   890    }
   891    case StandardTypes::RankedTensor: {
   892      auto v = type.cast<RankedTensorType>();
   893      os << "tensor<";
   894      for (auto dim : v.getShape()) {
   895        if (dim < 0)
   896          os << '?';
   897        else
   898          os << dim;
   899        os << 'x';
   900      }
   901      os << v.getElementType() << '>';
   902      return;
   903    }
   904    case StandardTypes::UnrankedTensor: {
   905      auto v = type.cast<UnrankedTensorType>();
   906      os << "tensor<*x";
   907      printType(v.getElementType());
   908      os << '>';
   909      return;
   910    }
   911    case StandardTypes::MemRef: {
   912      auto v = type.cast<MemRefType>();
   913      os << "memref<";
   914      for (auto dim : v.getShape()) {
   915        if (dim < 0)
   916          os << '?';
   917        else
   918          os << dim;
   919        os << 'x';
   920      }
   921      printType(v.getElementType());
   922      for (auto map : v.getAffineMaps()) {
   923        os << ", ";
   924        printAttribute(AffineMapAttr::get(map));
   925      }
   926      // Only print the memory space if it is the non-default one.
   927      if (v.getMemorySpace())
   928        os << ", " << v.getMemorySpace();
   929      os << '>';
   930      return;
   931    }
   932    case StandardTypes::Complex:
   933      os << "complex<";
   934      printType(type.cast<ComplexType>().getElementType());
   935      os << '>';
   936      return;
   937    case StandardTypes::Tuple: {
   938      auto tuple = type.cast<TupleType>();
   939      os << "tuple<";
   940      interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); });
   941      os << '>';
   942      return;
   943    }
   944    case StandardTypes::None:
   945      os << "none";
   946      return;
   947    }
   948  }
   949  
   950  //===----------------------------------------------------------------------===//
   951  // Affine expressions and maps
   952  //===----------------------------------------------------------------------===//
   953  
   954  void ModulePrinter::printAffineExpr(
   955      AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) {
   956    printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
   957  }
   958  
   959  void ModulePrinter::printAffineExprInternal(
   960      AffineExpr expr, BindingStrength enclosingTightness,
   961      llvm::function_ref<void(unsigned, bool)> printValueName) {
   962    const char *binopSpelling = nullptr;
   963    switch (expr.getKind()) {
   964    case AffineExprKind::SymbolId: {
   965      unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
   966      if (printValueName)
   967        printValueName(pos, /*isSymbol=*/true);
   968      else
   969        os << 's' << pos;
   970      return;
   971    }
   972    case AffineExprKind::DimId: {
   973      unsigned pos = expr.cast<AffineDimExpr>().getPosition();
   974      if (printValueName)
   975        printValueName(pos, /*isSymbol=*/false);
   976      else
   977        os << 'd' << pos;
   978      return;
   979    }
   980    case AffineExprKind::Constant:
   981      os << expr.cast<AffineConstantExpr>().getValue();
   982      return;
   983    case AffineExprKind::Add:
   984      binopSpelling = " + ";
   985      break;
   986    case AffineExprKind::Mul:
   987      binopSpelling = " * ";
   988      break;
   989    case AffineExprKind::FloorDiv:
   990      binopSpelling = " floordiv ";
   991      break;
   992    case AffineExprKind::CeilDiv:
   993      binopSpelling = " ceildiv ";
   994      break;
   995    case AffineExprKind::Mod:
   996      binopSpelling = " mod ";
   997      break;
   998    }
   999  
  1000    auto binOp = expr.cast<AffineBinaryOpExpr>();
  1001    AffineExpr lhsExpr = binOp.getLHS();
  1002    AffineExpr rhsExpr = binOp.getRHS();
  1003  
  1004    // Handle tightly binding binary operators.
  1005    if (binOp.getKind() != AffineExprKind::Add) {
  1006      if (enclosingTightness == BindingStrength::Strong)
  1007        os << '(';
  1008  
  1009      // Pretty print multiplication with -1.
  1010      auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
  1011      if (rhsConst && rhsConst.getValue() == -1) {
  1012        os << "-";
  1013        printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
  1014        return;
  1015      }
  1016  
  1017      printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
  1018  
  1019      os << binopSpelling;
  1020      printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
  1021  
  1022      if (enclosingTightness == BindingStrength::Strong)
  1023        os << ')';
  1024      return;
  1025    }
  1026  
  1027    // Print out special "pretty" forms for add.
  1028    if (enclosingTightness == BindingStrength::Strong)
  1029      os << '(';
  1030  
  1031    // Pretty print addition to a product that has a negative operand as a
  1032    // subtraction.
  1033    if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
  1034      if (rhs.getKind() == AffineExprKind::Mul) {
  1035        AffineExpr rrhsExpr = rhs.getRHS();
  1036        if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
  1037          if (rrhs.getValue() == -1) {
  1038            printAffineExprInternal(lhsExpr, BindingStrength::Weak,
  1039                                    printValueName);
  1040            os << " - ";
  1041            if (rhs.getLHS().getKind() == AffineExprKind::Add) {
  1042              printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
  1043                                      printValueName);
  1044            } else {
  1045              printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
  1046                                      printValueName);
  1047            }
  1048  
  1049            if (enclosingTightness == BindingStrength::Strong)
  1050              os << ')';
  1051            return;
  1052          }
  1053  
  1054          if (rrhs.getValue() < -1) {
  1055            printAffineExprInternal(lhsExpr, BindingStrength::Weak,
  1056                                    printValueName);
  1057            os << " - ";
  1058            printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
  1059                                    printValueName);
  1060            os << " * " << -rrhs.getValue();
  1061            if (enclosingTightness == BindingStrength::Strong)
  1062              os << ')';
  1063            return;
  1064          }
  1065        }
  1066      }
  1067    }
  1068  
  1069    // Pretty print addition to a negative number as a subtraction.
  1070    if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
  1071      if (rhsConst.getValue() < 0) {
  1072        printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
  1073        os << " - " << -rhsConst.getValue();
  1074        if (enclosingTightness == BindingStrength::Strong)
  1075          os << ')';
  1076        return;
  1077      }
  1078    }
  1079  
  1080    printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
  1081  
  1082    os << " + ";
  1083    printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
  1084  
  1085    if (enclosingTightness == BindingStrength::Strong)
  1086      os << ')';
  1087  }
  1088  
  1089  void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
  1090    printAffineExprInternal(expr, BindingStrength::Weak);
  1091    isEq ? os << " == 0" : os << " >= 0";
  1092  }
  1093  
  1094  void ModulePrinter::printAffineMap(AffineMap map) {
  1095    // Dimension identifiers.
  1096    os << '(';
  1097    for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
  1098      os << 'd' << i << ", ";
  1099    if (map.getNumDims() >= 1)
  1100      os << 'd' << map.getNumDims() - 1;
  1101    os << ')';
  1102  
  1103    // Symbolic identifiers.
  1104    if (map.getNumSymbols() != 0) {
  1105      os << '[';
  1106      for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
  1107        os << 's' << i << ", ";
  1108      if (map.getNumSymbols() >= 1)
  1109        os << 's' << map.getNumSymbols() - 1;
  1110      os << ']';
  1111    }
  1112  
  1113    // Result affine expressions.
  1114    os << " -> (";
  1115    interleaveComma(map.getResults(),
  1116                    [&](AffineExpr expr) { printAffineExpr(expr); });
  1117    os << ')';
  1118  }
  1119  
  1120  void ModulePrinter::printIntegerSet(IntegerSet set) {
  1121    // Dimension identifiers.
  1122    os << '(';
  1123    for (unsigned i = 1; i < set.getNumDims(); ++i)
  1124      os << 'd' << i - 1 << ", ";
  1125    if (set.getNumDims() >= 1)
  1126      os << 'd' << set.getNumDims() - 1;
  1127    os << ')';
  1128  
  1129    // Symbolic identifiers.
  1130    if (set.getNumSymbols() != 0) {
  1131      os << '[';
  1132      for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
  1133        os << 's' << i << ", ";
  1134      if (set.getNumSymbols() >= 1)
  1135        os << 's' << set.getNumSymbols() - 1;
  1136      os << ']';
  1137    }
  1138  
  1139    // Print constraints.
  1140    os << " : (";
  1141    int numConstraints = set.getNumConstraints();
  1142    for (int i = 1; i < numConstraints; ++i) {
  1143      printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
  1144      os << ", ";
  1145    }
  1146    if (numConstraints >= 1)
  1147      printAffineConstraint(set.getConstraint(numConstraints - 1),
  1148                            set.isEq(numConstraints - 1));
  1149    os << ')';
  1150  }
  1151  
  1152  //===----------------------------------------------------------------------===//
  1153  // Operation printing
  1154  //===----------------------------------------------------------------------===//
  1155  
  1156  void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
  1157                                            ArrayRef<StringRef> elidedAttrs) {
  1158    // If there are no attributes, then there is nothing to be done.
  1159    if (attrs.empty())
  1160      return;
  1161  
  1162    // Filter out any attributes that shouldn't be included.
  1163    SmallVector<NamedAttribute, 8> filteredAttrs;
  1164    for (auto attr : attrs) {
  1165      // If the caller has requested that this attribute be ignored, then drop it.
  1166      if (llvm::any_of(elidedAttrs,
  1167                       [&](StringRef elided) { return attr.first.is(elided); }))
  1168        continue;
  1169  
  1170      // Otherwise add it to our filteredAttrs list.
  1171      filteredAttrs.push_back(attr);
  1172    }
  1173  
  1174    // If there are no attributes left to print after filtering, then we're done.
  1175    if (filteredAttrs.empty())
  1176      return;
  1177  
  1178    // Otherwise, print them all out in braces.
  1179    os << " {";
  1180    interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
  1181      os << attr.first;
  1182  
  1183      // Pretty printing elides the attribute value for unit attributes.
  1184      if (attr.second.isa<UnitAttr>())
  1185        return;
  1186  
  1187      os << " = ";
  1188      printAttribute(attr.second);
  1189    });
  1190    os << '}';
  1191  }
  1192  
  1193  namespace {
  1194  
  1195  // OperationPrinter contains common functionality for printing operations.
  1196  class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
  1197  public:
  1198    OperationPrinter(Operation *op, ModulePrinter &other);
  1199    OperationPrinter(Region *region, ModulePrinter &other);
  1200  
  1201    // Methods to print operations.
  1202    void print(Operation *op);
  1203    void print(Block *block, bool printBlockArgs = true,
  1204               bool printBlockTerminator = true);
  1205  
  1206    void printOperation(Operation *op);
  1207    void printGenericOp(Operation *op) override;
  1208  
  1209    // Implement OpAsmPrinter.
  1210    raw_ostream &getStream() const override { return os; }
  1211    void printType(Type type) override { ModulePrinter::printType(type); }
  1212    void printAttribute(Attribute attr) override {
  1213      ModulePrinter::printAttribute(attr);
  1214    }
  1215    void printOperand(Value *value) override { printValueID(value); }
  1216  
  1217    void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
  1218                               ArrayRef<StringRef> elidedAttrs = {}) override {
  1219      return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
  1220    };
  1221  
  1222    enum { nameSentinel = ~0U };
  1223  
  1224    void printBlockName(Block *block) {
  1225      auto id = getBlockID(block);
  1226      if (id != ~0U)
  1227        os << "^bb" << id;
  1228      else
  1229        os << "^INVALIDBLOCK";
  1230    }
  1231  
  1232    unsigned getBlockID(Block *block) {
  1233      auto it = blockIDs.find(block);
  1234      return it != blockIDs.end() ? it->second : ~0U;
  1235    }
  1236  
  1237    void printSuccessorAndUseList(Operation *term, unsigned index) override;
  1238  
  1239    /// Print a region.
  1240    void printRegion(Region &blocks, bool printEntryBlockArgs,
  1241                     bool printBlockTerminators) override {
  1242      os << " {\n";
  1243      if (!blocks.empty()) {
  1244        auto *entryBlock = &blocks.front();
  1245        print(entryBlock,
  1246              printEntryBlockArgs && entryBlock->getNumArguments() != 0,
  1247              printBlockTerminators);
  1248        for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
  1249          print(&b);
  1250      }
  1251      os.indent(currentIndent) << "}";
  1252    }
  1253  
  1254    /// Renumber the arguments for the specified region to the same names as the
  1255    /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
  1256    /// operations.  If any entry in namesToUse is null, the corresponding
  1257    /// argument name is left alone.
  1258    void shadowRegionArgs(Region &region, ArrayRef<Value *> namesToUse) override;
  1259  
  1260    void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
  1261                                ArrayRef<Value *> operands) override {
  1262      AffineMap map = mapAttr.getValue();
  1263      unsigned numDims = map.getNumDims();
  1264      auto printValueName = [&](unsigned pos, bool isSymbol) {
  1265        unsigned index = isSymbol ? numDims + pos : pos;
  1266        assert(index < operands.size());
  1267        if (isSymbol)
  1268          os << "symbol(";
  1269        printValueID(operands[index]);
  1270        if (isSymbol)
  1271          os << ')';
  1272      };
  1273  
  1274      interleaveComma(map.getResults(), [&](AffineExpr expr) {
  1275        printAffineExpr(expr, printValueName);
  1276      });
  1277    }
  1278  
  1279    // Number of spaces used for indenting nested operations.
  1280    const static unsigned indentWidth = 2;
  1281  
  1282  protected:
  1283    void numberValueID(Value *value);
  1284    void numberValuesInRegion(Region &region);
  1285    void numberValuesInBlock(Block &block);
  1286    void printValueID(Value *value, bool printResultNo = true) const {
  1287      printValueIDImpl(value, printResultNo, os);
  1288    }
  1289  
  1290  private:
  1291    void printValueIDImpl(Value *value, bool printResultNo,
  1292                          raw_ostream &stream) const;
  1293  
  1294    /// Uniques the given value name within the printer. If the given name
  1295    /// conflicts, it is automatically renamed.
  1296    StringRef uniqueValueName(StringRef name);
  1297  
  1298    /// This is the value ID for each SSA value. If this returns ~0, then the
  1299    /// valueID has an entry in valueNames.
  1300    DenseMap<Value *, unsigned> valueIDs;
  1301    DenseMap<Value *, StringRef> valueNames;
  1302  
  1303    /// This is the block ID for each block in the current.
  1304    DenseMap<Block *, unsigned> blockIDs;
  1305  
  1306    /// This keeps track of all of the non-numeric names that are in flight,
  1307    /// allowing us to check for duplicates.
  1308    /// Note: the value of the map is unused.
  1309    llvm::ScopedHashTable<StringRef, char> usedNames;
  1310    llvm::BumpPtrAllocator usedNameAllocator;
  1311  
  1312    // This is the current indentation level for nested structures.
  1313    unsigned currentIndent = 0;
  1314  
  1315    /// This is the next value ID to assign in numbering.
  1316    unsigned nextValueID = 0;
  1317    /// This is the next ID to assign to a region entry block argument.
  1318    unsigned nextArgumentID = 0;
  1319    /// This is the next ID to assign when a name conflict is detected.
  1320    unsigned nextConflictID = 0;
  1321  };
  1322  } // end anonymous namespace
  1323  
  1324  OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other)
  1325      : ModulePrinter(other) {
  1326    if (op->getNumResults() != 0)
  1327      numberValueID(op->getResult(0));
  1328    for (auto &region : op->getRegions())
  1329      numberValuesInRegion(region);
  1330  }
  1331  
  1332  OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other)
  1333      : ModulePrinter(other) {
  1334    numberValuesInRegion(*region);
  1335  }
  1336  
  1337  /// Number all of the SSA values in the specified region.
  1338  void OperationPrinter::numberValuesInRegion(Region &region) {
  1339    // Save the current value ids to allow for numbering values in sibling regions
  1340    // the same.
  1341    unsigned curValueID = nextValueID;
  1342    unsigned curArgumentID = nextArgumentID;
  1343    unsigned curConflictID = nextConflictID;
  1344  
  1345    // Push a new used names scope.
  1346    llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
  1347  
  1348    // Number the values within this region in a breadth-first order.
  1349    unsigned nextBlockID = 0;
  1350    for (auto &block : region) {
  1351      // Each block gets a unique ID, and all of the operations within it get
  1352      // numbered as well.
  1353      blockIDs[&block] = nextBlockID++;
  1354      numberValuesInBlock(block);
  1355    }
  1356  
  1357    // After that we traverse the nested regions.
  1358    // TODO: Rework this loop to not use recursion.
  1359    for (auto &block : region) {
  1360      for (auto &op : block)
  1361        for (auto &nestedRegion : op.getRegions())
  1362          numberValuesInRegion(nestedRegion);
  1363    }
  1364  
  1365    // Restore the original value ids.
  1366    nextValueID = curValueID;
  1367    nextArgumentID = curArgumentID;
  1368    nextConflictID = curConflictID;
  1369  }
  1370  
  1371  /// Number all of the SSA values in the specified block, without traversing
  1372  /// nested regions.
  1373  void OperationPrinter::numberValuesInBlock(Block &block) {
  1374    // Number the block arguments.
  1375    for (auto *arg : block.getArguments())
  1376      numberValueID(arg);
  1377  
  1378    // We number operation that have results, and we only number the first result.
  1379    for (auto &op : block)
  1380      if (op.getNumResults() != 0)
  1381        numberValueID(op.getResult(0));
  1382  }
  1383  
  1384  void OperationPrinter::numberValueID(Value *value) {
  1385    assert(!valueIDs.count(value) && "Value numbered multiple times");
  1386  
  1387    SmallString<32> specialNameBuffer;
  1388    llvm::raw_svector_ostream specialName(specialNameBuffer);
  1389  
  1390    // Check to see if this value requested a special name.
  1391    auto *op = value->getDefiningOp();
  1392    if (state && op) {
  1393      if (auto *interface = state->getOpAsmInterface(op->getDialect()))
  1394        interface->getOpResultName(op, specialName);
  1395    }
  1396  
  1397    if (specialNameBuffer.empty()) {
  1398      switch (value->getKind()) {
  1399      case Value::Kind::BlockArgument:
  1400        // If this is an argument to the entry block of a region, give it an 'arg'
  1401        // name.
  1402        if (auto *block = cast<BlockArgument>(value)->getOwner()) {
  1403          auto *parentRegion = block->getParent();
  1404          if (parentRegion && block == &parentRegion->front()) {
  1405            specialName << "arg" << nextArgumentID++;
  1406            break;
  1407          }
  1408        }
  1409        // Otherwise number it normally.
  1410        valueIDs[value] = nextValueID++;
  1411        return;
  1412      case Value::Kind::OpResult:
  1413        // This is an uninteresting result, give it a boring number and be
  1414        // done with it.
  1415        valueIDs[value] = nextValueID++;
  1416        return;
  1417      }
  1418    }
  1419  
  1420    // Ok, this value had an interesting name.  Remember it with a sentinel.
  1421    valueIDs[value] = nameSentinel;
  1422    valueNames[value] = uniqueValueName(specialName.str());
  1423  }
  1424  
  1425  /// Uniques the given value name within the printer. If the given name
  1426  /// conflicts, it is automatically renamed.
  1427  StringRef OperationPrinter::uniqueValueName(StringRef name) {
  1428    // Check to see if this name is already unique.
  1429    if (!usedNames.count(name)) {
  1430      name = name.copy(usedNameAllocator);
  1431    } else {
  1432      // Otherwise, we had a conflict - probe until we find a unique name. This
  1433      // is guaranteed to terminate (and usually in a single iteration) because it
  1434      // generates new names by incrementing nextConflictID.
  1435      SmallString<64> probeName(name);
  1436      probeName.push_back('_');
  1437      while (1) {
  1438        probeName.resize(name.size() + 1);
  1439        probeName += llvm::utostr(nextConflictID++);
  1440        if (!usedNames.count(probeName)) {
  1441          name = StringRef(probeName).copy(usedNameAllocator);
  1442          break;
  1443        }
  1444      }
  1445    }
  1446  
  1447    usedNames.insert(name, char());
  1448    return name;
  1449  }
  1450  
  1451  void OperationPrinter::print(Block *block, bool printBlockArgs,
  1452                               bool printBlockTerminator) {
  1453    // Print the block label and argument list if requested.
  1454    if (printBlockArgs) {
  1455      os.indent(currentIndent);
  1456      printBlockName(block);
  1457  
  1458      // Print the argument list if non-empty.
  1459      if (!block->args_empty()) {
  1460        os << '(';
  1461        interleaveComma(block->getArguments(), [&](BlockArgument *arg) {
  1462          printValueID(arg);
  1463          os << ": ";
  1464          printType(arg->getType());
  1465        });
  1466        os << ')';
  1467      }
  1468      os << ':';
  1469  
  1470      // Print out some context information about the predecessors of this block.
  1471      if (!block->getParent()) {
  1472        os << "\t// block is not in a region!";
  1473      } else if (block->hasNoPredecessors()) {
  1474        os << "\t// no predecessors";
  1475      } else if (auto *pred = block->getSinglePredecessor()) {
  1476        os << "\t// pred: ";
  1477        printBlockName(pred);
  1478      } else {
  1479        // We want to print the predecessors in increasing numeric order, not in
  1480        // whatever order the use-list is in, so gather and sort them.
  1481        SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
  1482        for (auto *pred : block->getPredecessors())
  1483          predIDs.push_back({getBlockID(pred), pred});
  1484        llvm::array_pod_sort(predIDs.begin(), predIDs.end());
  1485  
  1486        os << "\t// " << predIDs.size() << " preds: ";
  1487  
  1488        interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
  1489          printBlockName(pred.second);
  1490        });
  1491      }
  1492      os << '\n';
  1493    }
  1494  
  1495    currentIndent += indentWidth;
  1496    auto range = llvm::make_range(
  1497        block->getOperations().begin(),
  1498        std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1));
  1499    for (auto &op : range) {
  1500      print(&op);
  1501      os << '\n';
  1502    }
  1503    currentIndent -= indentWidth;
  1504  }
  1505  
  1506  void OperationPrinter::print(Operation *op) {
  1507    os.indent(currentIndent);
  1508    printOperation(op);
  1509    printTrailingLocation(op->getLoc());
  1510  }
  1511  
  1512  void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
  1513                                          raw_ostream &stream) const {
  1514    if (!value) {
  1515      stream << "<<NULL>>";
  1516      return;
  1517    }
  1518  
  1519    int resultNo = -1;
  1520    auto lookupValue = value;
  1521  
  1522    // If this is a reference to the result of a multi-result operation or
  1523    // operation, print out the # identifier and make sure to map our lookup
  1524    // to the first result of the operation.
  1525    if (auto *result = dyn_cast<OpResult>(value)) {
  1526      if (result->getOwner()->getNumResults() != 1) {
  1527        resultNo = result->getResultNumber();
  1528        lookupValue = result->getOwner()->getResult(0);
  1529      }
  1530    }
  1531  
  1532    auto it = valueIDs.find(lookupValue);
  1533    if (it == valueIDs.end()) {
  1534      stream << "<<INVALID SSA VALUE>>";
  1535      return;
  1536    }
  1537  
  1538    stream << '%';
  1539    if (it->second != nameSentinel) {
  1540      stream << it->second;
  1541    } else {
  1542      auto nameIt = valueNames.find(lookupValue);
  1543      assert(nameIt != valueNames.end() && "Didn't have a name entry?");
  1544      stream << nameIt->second;
  1545    }
  1546  
  1547    if (resultNo != -1 && printResultNo)
  1548      stream << '#' << resultNo;
  1549  }
  1550  
  1551  /// Renumber the arguments for the specified region to the same names as the
  1552  /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
  1553  /// operations.  If any entry in namesToUse is null, the corresponding
  1554  /// argument name is left alone.
  1555  void OperationPrinter::shadowRegionArgs(Region &region,
  1556                                          ArrayRef<Value *> namesToUse) {
  1557    assert(!region.empty() && "cannot shadow arguments of an empty region");
  1558    assert(region.front().getNumArguments() == namesToUse.size() &&
  1559           "incorrect number of names passed in");
  1560    assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
  1561           "only KnownIsolatedFromAbove ops can shadow names");
  1562  
  1563    SmallVector<char, 16> nameStr;
  1564    for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
  1565      auto *nameToUse = namesToUse[i];
  1566      if (nameToUse == nullptr)
  1567        continue;
  1568  
  1569      auto *nameToReplace = region.front().getArgument(i);
  1570  
  1571      nameStr.clear();
  1572      llvm::raw_svector_ostream nameStream(nameStr);
  1573      printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream);
  1574  
  1575      // Entry block arguments should already have a pretty "arg" name.
  1576      assert(valueIDs[nameToReplace] == nameSentinel);
  1577  
  1578      // Use the name without the leading %.
  1579      auto name = StringRef(nameStream.str()).drop_front();
  1580  
  1581      // Overwrite the name.
  1582      valueNames[nameToReplace] = name.copy(usedNameAllocator);
  1583    }
  1584  }
  1585  
  1586  void OperationPrinter::printOperation(Operation *op) {
  1587    if (size_t numResults = op->getNumResults()) {
  1588      printValueID(op->getResult(0), /*printResultNo=*/false);
  1589      if (numResults > 1)
  1590        os << ':' << numResults;
  1591      os << " = ";
  1592    }
  1593  
  1594    // TODO(riverriddle): FuncOp cannot be round-tripped currently, as
  1595    // FunctionType cannot be used in a TypeAttr.
  1596    if (printGenericOpForm && !isa<FuncOp>(op))
  1597      return printGenericOp(op);
  1598  
  1599    // Check to see if this is a known operation.  If so, use the registered
  1600    // custom printer hook.
  1601    if (auto *opInfo = op->getAbstractOperation()) {
  1602      opInfo->printAssembly(op, this);
  1603      return;
  1604    }
  1605  
  1606    // Otherwise print with the generic assembly form.
  1607    printGenericOp(op);
  1608  }
  1609  
  1610  void OperationPrinter::printGenericOp(Operation *op) {
  1611    os << '"';
  1612    printEscapedString(op->getName().getStringRef(), os);
  1613    os << "\"(";
  1614  
  1615    // Get the list of operands that are not successor operands.
  1616    unsigned totalNumSuccessorOperands = 0;
  1617    unsigned numSuccessors = op->getNumSuccessors();
  1618    for (unsigned i = 0; i < numSuccessors; ++i)
  1619      totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
  1620    unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
  1621    SmallVector<Value *, 8> properOperands(
  1622        op->operand_begin(), std::next(op->operand_begin(), numProperOperands));
  1623  
  1624    interleaveComma(properOperands, [&](Value *value) { printValueID(value); });
  1625  
  1626    os << ')';
  1627  
  1628    // For terminators, print the list of successors and their operands.
  1629    if (numSuccessors != 0) {
  1630      os << '[';
  1631      for (unsigned i = 0; i < numSuccessors; ++i) {
  1632        if (i != 0)
  1633          os << ", ";
  1634        printSuccessorAndUseList(op, i);
  1635      }
  1636      os << ']';
  1637    }
  1638  
  1639    // Print regions.
  1640    if (op->getNumRegions() != 0) {
  1641      os << " (";
  1642      interleaveComma(op->getRegions(), [&](Region &region) {
  1643        printRegion(region, /*printEntryBlockArgs=*/true,
  1644                    /*printBlockTerminators=*/true);
  1645      });
  1646      os << ')';
  1647    }
  1648  
  1649    auto attrs = op->getAttrs();
  1650    printOptionalAttrDict(attrs);
  1651  
  1652    // Print the type signature of the operation.
  1653    os << " : ";
  1654    printFunctionalType(op);
  1655  }
  1656  
  1657  void OperationPrinter::printSuccessorAndUseList(Operation *term,
  1658                                                  unsigned index) {
  1659    printBlockName(term->getSuccessor(index));
  1660  
  1661    auto succOperands = term->getSuccessorOperands(index);
  1662    if (succOperands.begin() == succOperands.end())
  1663      return;
  1664  
  1665    os << '(';
  1666    interleaveComma(succOperands,
  1667                    [this](Value *operand) { printValueID(operand); });
  1668    os << " : ";
  1669    interleaveComma(succOperands,
  1670                    [this](Value *operand) { printType(operand->getType()); });
  1671    os << ')';
  1672  }
  1673  
  1674  void ModulePrinter::print(ModuleOp module) {
  1675    // Output the aliases at the top level.
  1676    if (state) {
  1677      state->printAttributeAliases(os);
  1678      state->printTypeAliases(os);
  1679    }
  1680  
  1681    // Print the module.
  1682    OperationPrinter(module, *this).print(module);
  1683    os << '\n';
  1684  }
  1685  
  1686  //===----------------------------------------------------------------------===//
  1687  // print and dump methods
  1688  //===----------------------------------------------------------------------===//
  1689  
  1690  void Attribute::print(raw_ostream &os) const {
  1691    ModulePrinter(os).printAttribute(*this);
  1692  }
  1693  
  1694  void Attribute::dump() const {
  1695    print(llvm::errs());
  1696    llvm::errs() << "\n";
  1697  }
  1698  
  1699  void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
  1700  
  1701  void Type::dump() { print(llvm::errs()); }
  1702  
  1703  void AffineMap::dump() const {
  1704    print(llvm::errs());
  1705    llvm::errs() << "\n";
  1706  }
  1707  
  1708  void IntegerSet::dump() const {
  1709    print(llvm::errs());
  1710    llvm::errs() << "\n";
  1711  }
  1712  
  1713  void AffineExpr::print(raw_ostream &os) const {
  1714    if (expr == nullptr) {
  1715      os << "null affine expr";
  1716      return;
  1717    }
  1718    ModulePrinter(os).printAffineExpr(*this);
  1719  }
  1720  
  1721  void AffineExpr::dump() const {
  1722    print(llvm::errs());
  1723    llvm::errs() << "\n";
  1724  }
  1725  
  1726  void AffineMap::print(raw_ostream &os) const {
  1727    if (map == nullptr) {
  1728      os << "null affine map";
  1729      return;
  1730    }
  1731    ModulePrinter(os).printAffineMap(*this);
  1732  }
  1733  
  1734  void IntegerSet::print(raw_ostream &os) const {
  1735    ModulePrinter(os).printIntegerSet(*this);
  1736  }
  1737  
  1738  void Value::print(raw_ostream &os) {
  1739    switch (getKind()) {
  1740    case Value::Kind::BlockArgument:
  1741      // TODO: Improve this.
  1742      os << "<block argument>\n";
  1743      return;
  1744    case Value::Kind::OpResult:
  1745      return getDefiningOp()->print(os);
  1746    }
  1747  }
  1748  
  1749  void Value::dump() { print(llvm::errs()); }
  1750  
  1751  void Operation::print(raw_ostream &os) {
  1752    // Handle top-level operations.
  1753    if (!getParent()) {
  1754      ModulePrinter modulePrinter(os);
  1755      OperationPrinter(this, modulePrinter).print(this);
  1756      return;
  1757    }
  1758  
  1759    auto region = getParentRegion();
  1760    if (!region) {
  1761      os << "<<UNLINKED INSTRUCTION>>\n";
  1762      return;
  1763    }
  1764  
  1765    // Get the top-level region.
  1766    while (auto *nextRegion = region->getParentRegion())
  1767      region = nextRegion;
  1768  
  1769    ModuleState state(getContext());
  1770    ModulePrinter modulePrinter(os, &state);
  1771    OperationPrinter(region, modulePrinter).print(this);
  1772  }
  1773  
  1774  void Operation::dump() {
  1775    print(llvm::errs());
  1776    llvm::errs() << "\n";
  1777  }
  1778  
  1779  void Block::print(raw_ostream &os) {
  1780    auto region = getParent();
  1781    if (!region) {
  1782      os << "<<UNLINKED BLOCK>>\n";
  1783      return;
  1784    }
  1785  
  1786    // Get the top-level region.
  1787    while (auto *nextRegion = region->getParentRegion())
  1788      region = nextRegion;
  1789  
  1790    ModuleState state(region->getContext());
  1791    ModulePrinter modulePrinter(os, &state);
  1792    OperationPrinter(region, modulePrinter).print(this);
  1793  }
  1794  
  1795  void Block::dump() { print(llvm::errs()); }
  1796  
  1797  /// Print out the name of the block without printing its body.
  1798  void Block::printAsOperand(raw_ostream &os, bool printType) {
  1799    auto region = getParent();
  1800    if (!region) {
  1801      os << "<<UNLINKED BLOCK>>\n";
  1802      return;
  1803    }
  1804  
  1805    // Get the top-level region.
  1806    while (auto *nextRegion = region->getParentRegion())
  1807      region = nextRegion;
  1808  
  1809    ModulePrinter modulePrinter(os);
  1810    OperationPrinter(region, modulePrinter).printBlockName(this);
  1811  }
  1812  
  1813  void ModuleOp::print(raw_ostream &os) {
  1814    ModuleState state(getContext());
  1815    state.initialize(*this);
  1816    ModulePrinter(os, &state).print(*this);
  1817  }
  1818  
  1819  void ModuleOp::dump() { print(llvm::errs()); }