github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/AsmPrinter.cpp (about) 1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the MLIR AsmPrinter class, which is used to implement 19 // the various print() methods on the core IR objects. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Attributes.h" 26 #include "mlir/IR/Dialect.h" 27 #include "mlir/IR/Function.h" 28 #include "mlir/IR/IntegerSet.h" 29 #include "mlir/IR/MLIRContext.h" 30 #include "mlir/IR/Module.h" 31 #include "mlir/IR/OpImplementation.h" 32 #include "mlir/IR/Operation.h" 33 #include "mlir/IR/StandardTypes.h" 34 #include "mlir/Support/STLExtras.h" 35 #include "llvm/ADT/APFloat.h" 36 #include "llvm/ADT/DenseMap.h" 37 #include "llvm/ADT/MapVector.h" 38 #include "llvm/ADT/STLExtras.h" 39 #include "llvm/ADT/ScopedHashTable.h" 40 #include "llvm/ADT/SetVector.h" 41 #include "llvm/ADT/SmallString.h" 42 #include "llvm/ADT/StringExtras.h" 43 #include "llvm/ADT/StringSet.h" 44 #include "llvm/Support/CommandLine.h" 45 #include "llvm/Support/Regex.h" 46 using namespace mlir; 47 48 void Identifier::print(raw_ostream &os) const { os << str(); } 49 50 void Identifier::dump() const { print(llvm::errs()); } 51 52 void OperationName::print(raw_ostream &os) const { os << getStringRef(); } 53 54 void OperationName::dump() const { print(llvm::errs()); } 55 56 OpAsmPrinter::~OpAsmPrinter() {} 57 58 //===----------------------------------------------------------------------===// 59 // ModuleState 60 //===----------------------------------------------------------------------===// 61 62 // TODO(riverriddle) Rethink this flag when we have a pass that can remove debug 63 // info or when we have a system for printer flags. 64 static llvm::cl::opt<bool> 65 shouldPrintDebugInfoOpt("mlir-print-debuginfo", 66 llvm::cl::desc("Print debug info in MLIR output"), 67 llvm::cl::init(false)); 68 69 static llvm::cl::opt<bool> printPrettyDebugInfo( 70 "mlir-pretty-debuginfo", 71 llvm::cl::desc("Print pretty debug info in MLIR output"), 72 llvm::cl::init(false)); 73 74 // Use the generic op output form in the operation printer even if the custom 75 // form is defined. 76 static llvm::cl::opt<bool> 77 printGenericOpForm("mlir-print-op-generic", 78 llvm::cl::desc("Print the generic op form"), 79 llvm::cl::init(false), llvm::cl::Hidden); 80 81 namespace { 82 /// A special index constant used for non-kind attribute aliases. 83 static constexpr int kNonAttrKindAlias = -1; 84 85 class ModuleState { 86 public: 87 explicit ModuleState(MLIRContext *context) : interfaces(context) {} 88 void initialize(Operation *op); 89 90 Twine getAttributeAlias(Attribute attr) const { 91 auto alias = attrToAlias.find(attr); 92 if (alias == attrToAlias.end()) 93 return Twine(); 94 95 // Return the alias for this attribute, along with the index if this was 96 // generated by a kind alias. 97 int kindIndex = alias->second.second; 98 return alias->second.first + 99 (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex)); 100 } 101 102 void printAttributeAliases(raw_ostream &os) const { 103 auto printAlias = [&](StringRef alias, Attribute attr, int index) { 104 os << '#' << alias; 105 if (index != kNonAttrKindAlias) 106 os << index; 107 os << " = " << attr << '\n'; 108 }; 109 110 // Print all of the attribute kind aliases. 111 for (auto &kindAlias : attrKindToAlias) { 112 for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i) 113 printAlias(kindAlias.second.first, kindAlias.second.second[i], i); 114 os << "\n"; 115 } 116 117 // In a second pass print all of the remaining attribute aliases that aren't 118 // kind aliases. 119 for (Attribute attr : usedAttributes) { 120 auto alias = attrToAlias.find(attr); 121 if (alias != attrToAlias.end() && 122 alias->second.second == kNonAttrKindAlias) 123 printAlias(alias->second.first, attr, alias->second.second); 124 } 125 } 126 127 StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); } 128 129 void printTypeAliases(raw_ostream &os) const { 130 for (Type type : usedTypes) { 131 auto alias = typeToAlias.find(type); 132 if (alias != typeToAlias.end()) 133 os << '!' << alias->second << " = type " << type << '\n'; 134 } 135 } 136 137 /// Get an instance of the OpAsmDialectInterface for the given dialect, or 138 /// null if one wasn't registered. 139 const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) { 140 return interfaces.getInterfaceFor(dialect); 141 } 142 143 private: 144 void recordAttributeReference(Attribute attr) { 145 // Don't recheck attributes that have already been seen or those that 146 // already have an alias. 147 if (!usedAttributes.insert(attr) || attrToAlias.count(attr)) 148 return; 149 150 // If this attribute kind has an alias, then record one for this attribute. 151 auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind())); 152 if (alias == attrKindToAlias.end()) 153 return; 154 std::pair<StringRef, int> attrAlias(alias->second.first, 155 alias->second.second.size()); 156 attrToAlias.insert({attr, attrAlias}); 157 alias->second.second.push_back(attr); 158 } 159 160 void recordTypeReference(Type ty) { usedTypes.insert(ty); } 161 162 // Visit functions. 163 void visitOperation(Operation *op); 164 void visitType(Type type); 165 void visitAttribute(Attribute attr); 166 167 // Initialize symbol aliases. 168 void initializeSymbolAliases(); 169 170 /// Set of attributes known to be used within the module. 171 llvm::SetVector<Attribute> usedAttributes; 172 173 /// Mapping between attribute and a pair comprised of a base alias name and a 174 /// count suffix. If the suffix is set to -1, it is not displayed. 175 llvm::MapVector<Attribute, std::pair<StringRef, int>> attrToAlias; 176 177 /// Mapping between attribute kind and a pair comprised of a base alias name 178 /// and a unique list of attributes belonging to this kind sorted by location 179 /// seen in the module. 180 llvm::MapVector<unsigned, std::pair<StringRef, std::vector<Attribute>>> 181 attrKindToAlias; 182 183 /// Set of types known to be used within the module. 184 llvm::SetVector<Type> usedTypes; 185 186 /// A mapping between a type and a given alias. 187 DenseMap<Type, StringRef> typeToAlias; 188 189 /// Collection of OpAsm interfaces implemented in the context. 190 DialectInterfaceCollection<OpAsmDialectInterface> interfaces; 191 }; 192 } // end anonymous namespace 193 194 // TODO Support visiting other types/operations when implemented. 195 void ModuleState::visitType(Type type) { 196 recordTypeReference(type); 197 if (auto funcType = type.dyn_cast<FunctionType>()) { 198 // Visit input and result types for functions. 199 for (auto input : funcType.getInputs()) 200 visitType(input); 201 for (auto result : funcType.getResults()) 202 visitType(result); 203 return; 204 } 205 if (auto memref = type.dyn_cast<MemRefType>()) { 206 // Visit affine maps in memref type. 207 for (auto map : memref.getAffineMaps()) 208 recordAttributeReference(AffineMapAttr::get(map)); 209 } 210 if (auto shapedType = type.dyn_cast<ShapedType>()) { 211 visitType(shapedType.getElementType()); 212 } 213 } 214 215 void ModuleState::visitAttribute(Attribute attr) { 216 recordAttributeReference(attr); 217 if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 218 for (auto elt : arrayAttr.getValue()) 219 visitAttribute(elt); 220 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) { 221 visitType(typeAttr.getValue()); 222 } 223 } 224 225 void ModuleState::visitOperation(Operation *op) { 226 // Visit all the types used in the operation. 227 for (auto type : op->getOperandTypes()) 228 visitType(type); 229 for (auto type : op->getResultTypes()) 230 visitType(type); 231 for (auto ®ion : op->getRegions()) 232 for (auto &block : region) 233 for (auto *arg : block.getArguments()) 234 visitType(arg->getType()); 235 236 // Visit each of the attributes. 237 for (auto elt : op->getAttrs()) 238 visitAttribute(elt.second); 239 } 240 241 // Utility to generate a function to register a symbol alias. 242 static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) { 243 assert(!name.empty() && "expected alias name to be non-empty"); 244 // TODO(riverriddle) Assert that the provided alias name can be lexed as 245 // an identifier. 246 247 // Check that the alias doesn't contain a '.' character and the name is not 248 // already in use. 249 return !name.contains('.') && usedAliases.insert(name).second; 250 } 251 252 void ModuleState::initializeSymbolAliases() { 253 // Track the identifiers in use for each symbol so that the same identifier 254 // isn't used twice. 255 llvm::StringSet<> usedAliases; 256 257 // Collect the set of aliases from each dialect. 258 SmallVector<std::pair<unsigned, StringRef>, 8> attributeKindAliases; 259 SmallVector<std::pair<Attribute, StringRef>, 8> attributeAliases; 260 SmallVector<std::pair<Type, StringRef>, 16> typeAliases; 261 262 // AffineMap/Integer set have specific kind aliases. 263 attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map"); 264 attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set"); 265 266 for (auto &interface : interfaces) { 267 interface.getAttributeKindAliases(attributeKindAliases); 268 interface.getAttributeAliases(attributeAliases); 269 interface.getTypeAliases(typeAliases); 270 } 271 272 // Setup the attribute kind aliases. 273 StringRef alias; 274 unsigned attrKind; 275 for (auto &attrAliasPair : attributeKindAliases) { 276 std::tie(attrKind, alias) = attrAliasPair; 277 assert(!alias.empty() && "expected non-empty alias string"); 278 if (!usedAliases.count(alias) && !alias.contains('.')) 279 attrKindToAlias.insert({attrKind, {alias, {}}}); 280 } 281 282 // Clear the set of used identifiers so that the attribute kind aliases are 283 // just a prefix and not the full alias, i.e. there may be some overlap. 284 usedAliases.clear(); 285 286 // Register the attribute aliases. 287 // Create a regex for the attribute kind alias names, these have a prefix with 288 // a counter appended to the end. We prevent normal aliases from having these 289 // names to avoid collisions. 290 llvm::Regex reservedAttrNames("[0-9]+$"); 291 292 // Attribute value aliases. 293 Attribute attr; 294 for (auto &attrAliasPair : attributeAliases) { 295 std::tie(attr, alias) = attrAliasPair; 296 if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases)) 297 attrToAlias.insert({attr, {alias, kNonAttrKindAlias}}); 298 } 299 300 // Clear the set of used identifiers as types can have the same identifiers as 301 // affine structures. 302 usedAliases.clear(); 303 304 // Type aliases. 305 for (auto &typeAliasPair : typeAliases) 306 if (canRegisterAlias(typeAliasPair.second, usedAliases)) 307 typeToAlias.insert(typeAliasPair); 308 } 309 310 void ModuleState::initialize(Operation *op) { 311 // Initialize the symbol aliases. 312 initializeSymbolAliases(); 313 314 // Visit each of the nested operations. 315 op->walk([&](Operation *op) { visitOperation(op); }); 316 } 317 318 //===----------------------------------------------------------------------===// 319 // ModulePrinter 320 //===----------------------------------------------------------------------===// 321 322 namespace { 323 class ModulePrinter { 324 public: 325 ModulePrinter(raw_ostream &os, ModuleState *state = nullptr) 326 : os(os), state(state) {} 327 explicit ModulePrinter(ModulePrinter &printer) 328 : os(printer.os), state(printer.state) {} 329 330 template <typename Container, typename UnaryFunctor> 331 inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { 332 interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; }); 333 } 334 335 void print(ModuleOp module); 336 337 /// Print the given attribute. If 'mayElideType' is true, some attributes are 338 /// printed without the type when the type matches the default used in the 339 /// parser (for example i64 is the default for integer attributes). 340 void printAttribute(Attribute attr, bool mayElideType = false); 341 342 void printType(Type type); 343 void printLocation(LocationAttr loc); 344 345 void printAffineMap(AffineMap map); 346 void printAffineExpr( 347 AffineExpr expr, 348 llvm::function_ref<void(unsigned, bool)> printValueName = nullptr); 349 void printAffineConstraint(AffineExpr expr, bool isEq); 350 void printIntegerSet(IntegerSet set); 351 352 protected: 353 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 354 ArrayRef<StringRef> elidedAttrs = {}); 355 void printTrailingLocation(Location loc); 356 void printLocationInternal(LocationAttr loc, bool pretty = false); 357 void printDenseElementsAttr(DenseElementsAttr attr); 358 359 /// This enum is used to represent the binding stength of the enclosing 360 /// context that an AffineExprStorage is being printed in, so we can 361 /// intelligently produce parens. 362 enum class BindingStrength { 363 Weak, // + and - 364 Strong, // All other binary operators. 365 }; 366 void printAffineExprInternal( 367 AffineExpr expr, BindingStrength enclosingTightness, 368 llvm::function_ref<void(unsigned, bool)> printValueName = nullptr); 369 370 /// The output stream for the printer. 371 raw_ostream &os; 372 373 /// An optional printer state for the module. 374 ModuleState *state; 375 }; 376 } // end anonymous namespace 377 378 void ModulePrinter::printTrailingLocation(Location loc) { 379 // Check to see if we are printing debug information. 380 if (!shouldPrintDebugInfoOpt) 381 return; 382 383 os << " "; 384 printLocation(loc); 385 } 386 387 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) { 388 switch (loc.getKind()) { 389 case StandardAttributes::UnknownLocation: 390 if (pretty) 391 os << "[unknown]"; 392 else 393 os << "unknown"; 394 break; 395 case StandardAttributes::FileLineColLocation: { 396 auto fileLoc = loc.cast<FileLineColLoc>(); 397 auto mayQuote = pretty ? "" : "\""; 398 os << mayQuote << fileLoc.getFilename() << mayQuote << ':' 399 << fileLoc.getLine() << ':' << fileLoc.getColumn(); 400 break; 401 } 402 case StandardAttributes::NameLocation: { 403 auto nameLoc = loc.cast<NameLoc>(); 404 os << '\"' << nameLoc.getName() << '\"'; 405 406 // Print the child if it isn't unknown. 407 auto childLoc = nameLoc.getChildLoc(); 408 if (!childLoc.isa<UnknownLoc>()) { 409 os << '('; 410 printLocationInternal(childLoc, pretty); 411 os << ')'; 412 } 413 break; 414 } 415 case StandardAttributes::CallSiteLocation: { 416 auto callLocation = loc.cast<CallSiteLoc>(); 417 auto caller = callLocation.getCaller(); 418 auto callee = callLocation.getCallee(); 419 if (!pretty) 420 os << "callsite("; 421 printLocationInternal(callee, pretty); 422 if (pretty) { 423 if (callee.isa<NameLoc>()) { 424 if (caller.isa<FileLineColLoc>()) { 425 os << " at "; 426 } else { 427 os << "\n at "; 428 } 429 } else { 430 os << "\n at "; 431 } 432 } else { 433 os << " at "; 434 } 435 printLocationInternal(caller, pretty); 436 if (!pretty) 437 os << ")"; 438 break; 439 } 440 case StandardAttributes::FusedLocation: { 441 auto fusedLoc = loc.cast<FusedLoc>(); 442 if (!pretty) 443 os << "fused"; 444 if (auto metadata = fusedLoc.getMetadata()) 445 os << '<' << metadata << '>'; 446 os << '['; 447 interleave( 448 fusedLoc.getLocations(), 449 [&](Location loc) { printLocationInternal(loc, pretty); }, 450 [&]() { os << ", "; }); 451 os << ']'; 452 break; 453 } 454 } 455 } 456 457 /// Print a floating point value in a way that the parser will be able to 458 /// round-trip losslessly. 459 static void printFloatValue(const APFloat &apValue, raw_ostream &os) { 460 // We would like to output the FP constant value in exponential notation, 461 // but we cannot do this if doing so will lose precision. Check here to 462 // make sure that we only output it in exponential format if we can parse 463 // the value back and get the same value. 464 bool isInf = apValue.isInfinity(); 465 bool isNaN = apValue.isNaN(); 466 if (!isInf && !isNaN) { 467 SmallString<128> strValue; 468 apValue.toString(strValue, 6, 0, false); 469 470 // Check to make sure that the stringized number is not some string like 471 // "Inf" or NaN, that atof will accept, but the lexer will not. Check 472 // that the string matches the "[-+]?[0-9]" regex. 473 assert(((strValue[0] >= '0' && strValue[0] <= '9') || 474 ((strValue[0] == '-' || strValue[0] == '+') && 475 (strValue[1] >= '0' && strValue[1] <= '9'))) && 476 "[-+]?[0-9] regex does not match!"); 477 478 // Parse back the stringized version and check that the value is equal 479 // (i.e., there is no precision loss). If it is not, use the default format 480 // of APFloat instead of the exponential notation. 481 if (!APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) { 482 strValue.clear(); 483 apValue.toString(strValue); 484 } 485 os << strValue; 486 return; 487 } 488 489 // Print special values in hexadecimal format. The sign bit should be 490 // included in the literal. 491 SmallVector<char, 16> str; 492 APInt apInt = apValue.bitcastToAPInt(); 493 apInt.toString(str, /*Radix=*/16, /*Signed=*/false, 494 /*formatAsCLiteral=*/true); 495 os << str; 496 } 497 498 void ModulePrinter::printLocation(LocationAttr loc) { 499 if (printPrettyDebugInfo) { 500 printLocationInternal(loc, /*pretty=*/true); 501 } else { 502 os << "loc("; 503 printLocationInternal(loc); 504 os << ')'; 505 } 506 } 507 508 /// Returns if the given dialect symbol data is simple enough to print in the 509 /// pretty form, i.e. without the enclosing "". 510 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) { 511 // The name must start with an identifier. 512 if (symName.empty() || !isalpha(symName.front())) 513 return false; 514 515 // Ignore all the characters that are valid in an identifier in the symbol 516 // name. 517 symName = 518 symName.drop_while([](char c) { return llvm::isAlnum(c) || c == '.'; }); 519 if (symName.empty()) 520 return true; 521 522 // If we got to an unexpected character, then it must be a <>. Check those 523 // recursively. 524 if (symName.front() != '<' || symName.back() != '>') 525 return false; 526 527 SmallVector<char, 8> nestedPunctuation; 528 do { 529 // If we ran out of characters, then we had a punctuation mismatch. 530 if (symName.empty()) 531 return false; 532 533 auto c = symName.front(); 534 symName = symName.drop_front(); 535 536 switch (c) { 537 // We never allow null characters. This is an EOF indicator for the lexer 538 // which we could handle, but isn't important for any known dialect. 539 case '\0': 540 return false; 541 case '<': 542 case '[': 543 case '(': 544 case '{': 545 nestedPunctuation.push_back(c); 546 continue; 547 case '-': 548 // Treat `->` as a special token. 549 if (!symName.empty() && symName.front() == '>') { 550 symName = symName.drop_front(); 551 continue; 552 } 553 break; 554 // Reject types with mismatched brackets. 555 case '>': 556 if (nestedPunctuation.pop_back_val() != '<') 557 return false; 558 break; 559 case ']': 560 if (nestedPunctuation.pop_back_val() != '[') 561 return false; 562 break; 563 case ')': 564 if (nestedPunctuation.pop_back_val() != '(') 565 return false; 566 break; 567 case '}': 568 if (nestedPunctuation.pop_back_val() != '{') 569 return false; 570 break; 571 default: 572 continue; 573 } 574 575 // We're done when the punctuation is fully matched. 576 } while (!nestedPunctuation.empty()); 577 578 // If there were extra characters, then we failed. 579 return symName.empty(); 580 } 581 582 /// Print the given dialect symbol to the stream. 583 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, 584 StringRef dialectName, StringRef symString) { 585 os << symPrefix << dialectName; 586 587 // If this symbol name is simple enough, print it directly in pretty form, 588 // otherwise, we print it as an escaped string. 589 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) { 590 os << '.' << symString; 591 return; 592 } 593 594 // TODO: escape the symbol name, it could contain " characters. 595 os << "<\"" << symString << "\">"; 596 } 597 598 void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { 599 if (!attr) { 600 os << "<<NULL ATTRIBUTE>>"; 601 return; 602 } 603 604 // Check for an alias for this attribute. 605 if (state) { 606 Twine alias = state->getAttributeAlias(attr); 607 if (!alias.isTriviallyEmpty()) { 608 os << '#' << alias; 609 return; 610 } 611 } 612 613 switch (attr.getKind()) { 614 default: { 615 auto &dialect = attr.getDialect(); 616 617 // Ask the dialect to serialize the attribute to a string. 618 std::string attrName; 619 { 620 llvm::raw_string_ostream attrNameStr(attrName); 621 dialect.printAttribute(attr, attrNameStr); 622 } 623 624 printDialectSymbol(os, "#", dialect.getNamespace(), attrName); 625 break; 626 } 627 case StandardAttributes::Opaque: { 628 auto opaqueAttr = attr.cast<OpaqueAttr>(); 629 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), 630 opaqueAttr.getAttrData()); 631 break; 632 } 633 case StandardAttributes::Unit: 634 os << "unit"; 635 break; 636 case StandardAttributes::Bool: 637 os << (attr.cast<BoolAttr>().getValue() ? "true" : "false"); 638 639 // BoolAttr always elides the type. 640 return; 641 case StandardAttributes::Dictionary: 642 os << '{'; 643 interleaveComma(attr.cast<DictionaryAttr>().getValue(), 644 [&](NamedAttribute attr) { 645 os << attr.first << " = "; 646 printAttribute(attr.second); 647 }); 648 os << '}'; 649 break; 650 case StandardAttributes::Integer: { 651 auto intAttr = attr.cast<IntegerAttr>(); 652 // Print all integer attributes as signed unless i1. 653 bool isSigned = intAttr.getType().isIndex() || 654 intAttr.getType().getIntOrFloatBitWidth() != 1; 655 intAttr.getValue().print(os, isSigned); 656 657 // IntegerAttr elides the type if I64. 658 if (mayElideType && intAttr.getType().isInteger(64)) 659 return; 660 break; 661 } 662 case StandardAttributes::Float: { 663 auto floatAttr = attr.cast<FloatAttr>(); 664 printFloatValue(floatAttr.getValue(), os); 665 666 // FloatAttr elides the type if F64. 667 if (mayElideType && floatAttr.getType().isF64()) 668 return; 669 break; 670 } 671 case StandardAttributes::String: 672 os << '"'; 673 printEscapedString(attr.cast<StringAttr>().getValue(), os); 674 os << '"'; 675 break; 676 case StandardAttributes::Array: 677 os << '['; 678 interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) { 679 printAttribute(attr, /*mayElideType=*/true); 680 }); 681 os << ']'; 682 break; 683 case StandardAttributes::AffineMap: 684 attr.cast<AffineMapAttr>().getValue().print(os); 685 686 // AffineMap always elides the type. 687 return; 688 case StandardAttributes::IntegerSet: 689 attr.cast<IntegerSetAttr>().getValue().print(os); 690 break; 691 case StandardAttributes::Type: 692 printType(attr.cast<TypeAttr>().getValue()); 693 break; 694 case StandardAttributes::SymbolRef: 695 os << '@' << attr.cast<SymbolRefAttr>().getValue(); 696 break; 697 case StandardAttributes::OpaqueElements: { 698 auto eltsAttr = attr.cast<OpaqueElementsAttr>(); 699 os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", "; 700 os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">"; 701 break; 702 } 703 case StandardAttributes::DenseElements: { 704 auto eltsAttr = attr.cast<DenseElementsAttr>(); 705 os << "dense<"; 706 printDenseElementsAttr(eltsAttr); 707 os << '>'; 708 break; 709 } 710 case StandardAttributes::SparseElements: { 711 auto elementsAttr = attr.cast<SparseElementsAttr>(); 712 os << "sparse<"; 713 printDenseElementsAttr(elementsAttr.getIndices()); 714 os << ", "; 715 printDenseElementsAttr(elementsAttr.getValues()); 716 os << '>'; 717 break; 718 } 719 720 // Location attributes. 721 case StandardAttributes::CallSiteLocation: 722 case StandardAttributes::FileLineColLocation: 723 case StandardAttributes::FusedLocation: 724 case StandardAttributes::NameLocation: 725 case StandardAttributes::UnknownLocation: 726 printLocation(attr.cast<LocationAttr>()); 727 break; 728 } 729 730 // Print the type if it isn't a 'none' type. 731 auto attrType = attr.getType(); 732 if (!attrType.isa<NoneType>()) { 733 os << " : "; 734 printType(attrType); 735 } 736 } 737 738 /// Print the integer element of the given DenseElementsAttr at 'index'. 739 static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os, 740 unsigned index) { 741 APInt value = *std::next(attr.int_value_begin(), index); 742 if (value.getBitWidth() == 1) 743 os << (value.getBoolValue() ? "true" : "false"); 744 else 745 value.print(os, /*isSigned=*/true); 746 } 747 748 /// Print the float element of the given DenseElementsAttr at 'index'. 749 static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os, 750 unsigned index) { 751 APFloat value = *std::next(attr.float_value_begin(), index); 752 printFloatValue(value, os); 753 } 754 755 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) { 756 auto type = attr.getType(); 757 auto shape = type.getShape(); 758 auto rank = type.getRank(); 759 760 // The function used to print elements of this attribute. 761 auto printEltFn = type.getElementType().isa<IntegerType>() 762 ? printDenseIntElement 763 : printDenseFloatElement; 764 765 // Special case for 0-d and splat tensors. 766 if (attr.isSplat()) { 767 printEltFn(attr, os, 0); 768 return; 769 } 770 771 // Special case for degenerate tensors. 772 auto numElements = type.getNumElements(); 773 if (numElements == 0) { 774 for (int i = 0; i < rank; ++i) 775 os << '['; 776 for (int i = 0; i < rank; ++i) 777 os << ']'; 778 return; 779 } 780 781 // We use a mixed-radix counter to iterate through the shape. When we bump a 782 // non-least-significant digit, we emit a close bracket. When we next emit an 783 // element we re-open all closed brackets. 784 785 // The mixed-radix counter, with radices in 'shape'. 786 SmallVector<unsigned, 4> counter(rank, 0); 787 // The number of brackets that have been opened and not closed. 788 unsigned openBrackets = 0; 789 790 auto bumpCounter = [&]() { 791 // Bump the least significant digit. 792 ++counter[rank - 1]; 793 // Iterate backwards bubbling back the increment. 794 for (unsigned i = rank - 1; i > 0; --i) 795 if (counter[i] >= shape[i]) { 796 // Index 'i' is rolled over. Bump (i-1) and close a bracket. 797 counter[i] = 0; 798 ++counter[i - 1]; 799 --openBrackets; 800 os << ']'; 801 } 802 }; 803 804 for (unsigned idx = 0, e = numElements; idx != e; ++idx) { 805 if (idx != 0) 806 os << ", "; 807 while (openBrackets++ < rank) 808 os << '['; 809 openBrackets = rank; 810 printEltFn(attr, os, idx); 811 bumpCounter(); 812 } 813 while (openBrackets-- > 0) 814 os << ']'; 815 } 816 817 void ModulePrinter::printType(Type type) { 818 // Check for an alias for this type. 819 if (state) { 820 StringRef alias = state->getTypeAlias(type); 821 if (!alias.empty()) { 822 os << '!' << alias; 823 return; 824 } 825 } 826 827 switch (type.getKind()) { 828 default: { 829 auto &dialect = type.getDialect(); 830 831 // Ask the dialect to serialize the type to a string. 832 std::string typeName; 833 { 834 llvm::raw_string_ostream typeNameStr(typeName); 835 dialect.printType(type, typeNameStr); 836 } 837 838 printDialectSymbol(os, "!", dialect.getNamespace(), typeName); 839 return; 840 } 841 case Type::Kind::Opaque: { 842 auto opaqueTy = type.cast<OpaqueType>(); 843 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), 844 opaqueTy.getTypeData()); 845 return; 846 } 847 case StandardTypes::Index: 848 os << "index"; 849 return; 850 case StandardTypes::BF16: 851 os << "bf16"; 852 return; 853 case StandardTypes::F16: 854 os << "f16"; 855 return; 856 case StandardTypes::F32: 857 os << "f32"; 858 return; 859 case StandardTypes::F64: 860 os << "f64"; 861 return; 862 863 case StandardTypes::Integer: { 864 auto integer = type.cast<IntegerType>(); 865 os << 'i' << integer.getWidth(); 866 return; 867 } 868 case Type::Kind::Function: { 869 auto func = type.cast<FunctionType>(); 870 os << '('; 871 interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); 872 os << ") -> "; 873 auto results = func.getResults(); 874 if (results.size() == 1 && !results[0].isa<FunctionType>()) 875 os << results[0]; 876 else { 877 os << '('; 878 interleaveComma(results, [&](Type type) { printType(type); }); 879 os << ')'; 880 } 881 return; 882 } 883 case StandardTypes::Vector: { 884 auto v = type.cast<VectorType>(); 885 os << "vector<"; 886 for (auto dim : v.getShape()) 887 os << dim << 'x'; 888 os << v.getElementType() << '>'; 889 return; 890 } 891 case StandardTypes::RankedTensor: { 892 auto v = type.cast<RankedTensorType>(); 893 os << "tensor<"; 894 for (auto dim : v.getShape()) { 895 if (dim < 0) 896 os << '?'; 897 else 898 os << dim; 899 os << 'x'; 900 } 901 os << v.getElementType() << '>'; 902 return; 903 } 904 case StandardTypes::UnrankedTensor: { 905 auto v = type.cast<UnrankedTensorType>(); 906 os << "tensor<*x"; 907 printType(v.getElementType()); 908 os << '>'; 909 return; 910 } 911 case StandardTypes::MemRef: { 912 auto v = type.cast<MemRefType>(); 913 os << "memref<"; 914 for (auto dim : v.getShape()) { 915 if (dim < 0) 916 os << '?'; 917 else 918 os << dim; 919 os << 'x'; 920 } 921 printType(v.getElementType()); 922 for (auto map : v.getAffineMaps()) { 923 os << ", "; 924 printAttribute(AffineMapAttr::get(map)); 925 } 926 // Only print the memory space if it is the non-default one. 927 if (v.getMemorySpace()) 928 os << ", " << v.getMemorySpace(); 929 os << '>'; 930 return; 931 } 932 case StandardTypes::Complex: 933 os << "complex<"; 934 printType(type.cast<ComplexType>().getElementType()); 935 os << '>'; 936 return; 937 case StandardTypes::Tuple: { 938 auto tuple = type.cast<TupleType>(); 939 os << "tuple<"; 940 interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); 941 os << '>'; 942 return; 943 } 944 case StandardTypes::None: 945 os << "none"; 946 return; 947 } 948 } 949 950 //===----------------------------------------------------------------------===// 951 // Affine expressions and maps 952 //===----------------------------------------------------------------------===// 953 954 void ModulePrinter::printAffineExpr( 955 AffineExpr expr, llvm::function_ref<void(unsigned, bool)> printValueName) { 956 printAffineExprInternal(expr, BindingStrength::Weak, printValueName); 957 } 958 959 void ModulePrinter::printAffineExprInternal( 960 AffineExpr expr, BindingStrength enclosingTightness, 961 llvm::function_ref<void(unsigned, bool)> printValueName) { 962 const char *binopSpelling = nullptr; 963 switch (expr.getKind()) { 964 case AffineExprKind::SymbolId: { 965 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition(); 966 if (printValueName) 967 printValueName(pos, /*isSymbol=*/true); 968 else 969 os << 's' << pos; 970 return; 971 } 972 case AffineExprKind::DimId: { 973 unsigned pos = expr.cast<AffineDimExpr>().getPosition(); 974 if (printValueName) 975 printValueName(pos, /*isSymbol=*/false); 976 else 977 os << 'd' << pos; 978 return; 979 } 980 case AffineExprKind::Constant: 981 os << expr.cast<AffineConstantExpr>().getValue(); 982 return; 983 case AffineExprKind::Add: 984 binopSpelling = " + "; 985 break; 986 case AffineExprKind::Mul: 987 binopSpelling = " * "; 988 break; 989 case AffineExprKind::FloorDiv: 990 binopSpelling = " floordiv "; 991 break; 992 case AffineExprKind::CeilDiv: 993 binopSpelling = " ceildiv "; 994 break; 995 case AffineExprKind::Mod: 996 binopSpelling = " mod "; 997 break; 998 } 999 1000 auto binOp = expr.cast<AffineBinaryOpExpr>(); 1001 AffineExpr lhsExpr = binOp.getLHS(); 1002 AffineExpr rhsExpr = binOp.getRHS(); 1003 1004 // Handle tightly binding binary operators. 1005 if (binOp.getKind() != AffineExprKind::Add) { 1006 if (enclosingTightness == BindingStrength::Strong) 1007 os << '('; 1008 1009 // Pretty print multiplication with -1. 1010 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>(); 1011 if (rhsConst && rhsConst.getValue() == -1) { 1012 os << "-"; 1013 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 1014 return; 1015 } 1016 1017 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName); 1018 1019 os << binopSpelling; 1020 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName); 1021 1022 if (enclosingTightness == BindingStrength::Strong) 1023 os << ')'; 1024 return; 1025 } 1026 1027 // Print out special "pretty" forms for add. 1028 if (enclosingTightness == BindingStrength::Strong) 1029 os << '('; 1030 1031 // Pretty print addition to a product that has a negative operand as a 1032 // subtraction. 1033 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) { 1034 if (rhs.getKind() == AffineExprKind::Mul) { 1035 AffineExpr rrhsExpr = rhs.getRHS(); 1036 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) { 1037 if (rrhs.getValue() == -1) { 1038 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 1039 printValueName); 1040 os << " - "; 1041 if (rhs.getLHS().getKind() == AffineExprKind::Add) { 1042 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 1043 printValueName); 1044 } else { 1045 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, 1046 printValueName); 1047 } 1048 1049 if (enclosingTightness == BindingStrength::Strong) 1050 os << ')'; 1051 return; 1052 } 1053 1054 if (rrhs.getValue() < -1) { 1055 printAffineExprInternal(lhsExpr, BindingStrength::Weak, 1056 printValueName); 1057 os << " - "; 1058 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, 1059 printValueName); 1060 os << " * " << -rrhs.getValue(); 1061 if (enclosingTightness == BindingStrength::Strong) 1062 os << ')'; 1063 return; 1064 } 1065 } 1066 } 1067 } 1068 1069 // Pretty print addition to a negative number as a subtraction. 1070 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) { 1071 if (rhsConst.getValue() < 0) { 1072 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 1073 os << " - " << -rhsConst.getValue(); 1074 if (enclosingTightness == BindingStrength::Strong) 1075 os << ')'; 1076 return; 1077 } 1078 } 1079 1080 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName); 1081 1082 os << " + "; 1083 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName); 1084 1085 if (enclosingTightness == BindingStrength::Strong) 1086 os << ')'; 1087 } 1088 1089 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) { 1090 printAffineExprInternal(expr, BindingStrength::Weak); 1091 isEq ? os << " == 0" : os << " >= 0"; 1092 } 1093 1094 void ModulePrinter::printAffineMap(AffineMap map) { 1095 // Dimension identifiers. 1096 os << '('; 1097 for (int i = 0; i < (int)map.getNumDims() - 1; ++i) 1098 os << 'd' << i << ", "; 1099 if (map.getNumDims() >= 1) 1100 os << 'd' << map.getNumDims() - 1; 1101 os << ')'; 1102 1103 // Symbolic identifiers. 1104 if (map.getNumSymbols() != 0) { 1105 os << '['; 1106 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i) 1107 os << 's' << i << ", "; 1108 if (map.getNumSymbols() >= 1) 1109 os << 's' << map.getNumSymbols() - 1; 1110 os << ']'; 1111 } 1112 1113 // Result affine expressions. 1114 os << " -> ("; 1115 interleaveComma(map.getResults(), 1116 [&](AffineExpr expr) { printAffineExpr(expr); }); 1117 os << ')'; 1118 } 1119 1120 void ModulePrinter::printIntegerSet(IntegerSet set) { 1121 // Dimension identifiers. 1122 os << '('; 1123 for (unsigned i = 1; i < set.getNumDims(); ++i) 1124 os << 'd' << i - 1 << ", "; 1125 if (set.getNumDims() >= 1) 1126 os << 'd' << set.getNumDims() - 1; 1127 os << ')'; 1128 1129 // Symbolic identifiers. 1130 if (set.getNumSymbols() != 0) { 1131 os << '['; 1132 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i) 1133 os << 's' << i << ", "; 1134 if (set.getNumSymbols() >= 1) 1135 os << 's' << set.getNumSymbols() - 1; 1136 os << ']'; 1137 } 1138 1139 // Print constraints. 1140 os << " : ("; 1141 int numConstraints = set.getNumConstraints(); 1142 for (int i = 1; i < numConstraints; ++i) { 1143 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1)); 1144 os << ", "; 1145 } 1146 if (numConstraints >= 1) 1147 printAffineConstraint(set.getConstraint(numConstraints - 1), 1148 set.isEq(numConstraints - 1)); 1149 os << ')'; 1150 } 1151 1152 //===----------------------------------------------------------------------===// 1153 // Operation printing 1154 //===----------------------------------------------------------------------===// 1155 1156 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1157 ArrayRef<StringRef> elidedAttrs) { 1158 // If there are no attributes, then there is nothing to be done. 1159 if (attrs.empty()) 1160 return; 1161 1162 // Filter out any attributes that shouldn't be included. 1163 SmallVector<NamedAttribute, 8> filteredAttrs; 1164 for (auto attr : attrs) { 1165 // If the caller has requested that this attribute be ignored, then drop it. 1166 if (llvm::any_of(elidedAttrs, 1167 [&](StringRef elided) { return attr.first.is(elided); })) 1168 continue; 1169 1170 // Otherwise add it to our filteredAttrs list. 1171 filteredAttrs.push_back(attr); 1172 } 1173 1174 // If there are no attributes left to print after filtering, then we're done. 1175 if (filteredAttrs.empty()) 1176 return; 1177 1178 // Otherwise, print them all out in braces. 1179 os << " {"; 1180 interleaveComma(filteredAttrs, [&](NamedAttribute attr) { 1181 os << attr.first; 1182 1183 // Pretty printing elides the attribute value for unit attributes. 1184 if (attr.second.isa<UnitAttr>()) 1185 return; 1186 1187 os << " = "; 1188 printAttribute(attr.second); 1189 }); 1190 os << '}'; 1191 } 1192 1193 namespace { 1194 1195 // OperationPrinter contains common functionality for printing operations. 1196 class OperationPrinter : public ModulePrinter, private OpAsmPrinter { 1197 public: 1198 OperationPrinter(Operation *op, ModulePrinter &other); 1199 OperationPrinter(Region *region, ModulePrinter &other); 1200 1201 // Methods to print operations. 1202 void print(Operation *op); 1203 void print(Block *block, bool printBlockArgs = true, 1204 bool printBlockTerminator = true); 1205 1206 void printOperation(Operation *op); 1207 void printGenericOp(Operation *op) override; 1208 1209 // Implement OpAsmPrinter. 1210 raw_ostream &getStream() const override { return os; } 1211 void printType(Type type) override { ModulePrinter::printType(type); } 1212 void printAttribute(Attribute attr) override { 1213 ModulePrinter::printAttribute(attr); 1214 } 1215 void printOperand(Value *value) override { printValueID(value); } 1216 1217 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 1218 ArrayRef<StringRef> elidedAttrs = {}) override { 1219 return ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs); 1220 }; 1221 1222 enum { nameSentinel = ~0U }; 1223 1224 void printBlockName(Block *block) { 1225 auto id = getBlockID(block); 1226 if (id != ~0U) 1227 os << "^bb" << id; 1228 else 1229 os << "^INVALIDBLOCK"; 1230 } 1231 1232 unsigned getBlockID(Block *block) { 1233 auto it = blockIDs.find(block); 1234 return it != blockIDs.end() ? it->second : ~0U; 1235 } 1236 1237 void printSuccessorAndUseList(Operation *term, unsigned index) override; 1238 1239 /// Print a region. 1240 void printRegion(Region &blocks, bool printEntryBlockArgs, 1241 bool printBlockTerminators) override { 1242 os << " {\n"; 1243 if (!blocks.empty()) { 1244 auto *entryBlock = &blocks.front(); 1245 print(entryBlock, 1246 printEntryBlockArgs && entryBlock->getNumArguments() != 0, 1247 printBlockTerminators); 1248 for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) 1249 print(&b); 1250 } 1251 os.indent(currentIndent) << "}"; 1252 } 1253 1254 /// Renumber the arguments for the specified region to the same names as the 1255 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 1256 /// operations. If any entry in namesToUse is null, the corresponding 1257 /// argument name is left alone. 1258 void shadowRegionArgs(Region ®ion, ArrayRef<Value *> namesToUse) override; 1259 1260 void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 1261 ArrayRef<Value *> operands) override { 1262 AffineMap map = mapAttr.getValue(); 1263 unsigned numDims = map.getNumDims(); 1264 auto printValueName = [&](unsigned pos, bool isSymbol) { 1265 unsigned index = isSymbol ? numDims + pos : pos; 1266 assert(index < operands.size()); 1267 if (isSymbol) 1268 os << "symbol("; 1269 printValueID(operands[index]); 1270 if (isSymbol) 1271 os << ')'; 1272 }; 1273 1274 interleaveComma(map.getResults(), [&](AffineExpr expr) { 1275 printAffineExpr(expr, printValueName); 1276 }); 1277 } 1278 1279 // Number of spaces used for indenting nested operations. 1280 const static unsigned indentWidth = 2; 1281 1282 protected: 1283 void numberValueID(Value *value); 1284 void numberValuesInRegion(Region ®ion); 1285 void numberValuesInBlock(Block &block); 1286 void printValueID(Value *value, bool printResultNo = true) const { 1287 printValueIDImpl(value, printResultNo, os); 1288 } 1289 1290 private: 1291 void printValueIDImpl(Value *value, bool printResultNo, 1292 raw_ostream &stream) const; 1293 1294 /// Uniques the given value name within the printer. If the given name 1295 /// conflicts, it is automatically renamed. 1296 StringRef uniqueValueName(StringRef name); 1297 1298 /// This is the value ID for each SSA value. If this returns ~0, then the 1299 /// valueID has an entry in valueNames. 1300 DenseMap<Value *, unsigned> valueIDs; 1301 DenseMap<Value *, StringRef> valueNames; 1302 1303 /// This is the block ID for each block in the current. 1304 DenseMap<Block *, unsigned> blockIDs; 1305 1306 /// This keeps track of all of the non-numeric names that are in flight, 1307 /// allowing us to check for duplicates. 1308 /// Note: the value of the map is unused. 1309 llvm::ScopedHashTable<StringRef, char> usedNames; 1310 llvm::BumpPtrAllocator usedNameAllocator; 1311 1312 // This is the current indentation level for nested structures. 1313 unsigned currentIndent = 0; 1314 1315 /// This is the next value ID to assign in numbering. 1316 unsigned nextValueID = 0; 1317 /// This is the next ID to assign to a region entry block argument. 1318 unsigned nextArgumentID = 0; 1319 /// This is the next ID to assign when a name conflict is detected. 1320 unsigned nextConflictID = 0; 1321 }; 1322 } // end anonymous namespace 1323 1324 OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other) 1325 : ModulePrinter(other) { 1326 if (op->getNumResults() != 0) 1327 numberValueID(op->getResult(0)); 1328 for (auto ®ion : op->getRegions()) 1329 numberValuesInRegion(region); 1330 } 1331 1332 OperationPrinter::OperationPrinter(Region *region, ModulePrinter &other) 1333 : ModulePrinter(other) { 1334 numberValuesInRegion(*region); 1335 } 1336 1337 /// Number all of the SSA values in the specified region. 1338 void OperationPrinter::numberValuesInRegion(Region ®ion) { 1339 // Save the current value ids to allow for numbering values in sibling regions 1340 // the same. 1341 unsigned curValueID = nextValueID; 1342 unsigned curArgumentID = nextArgumentID; 1343 unsigned curConflictID = nextConflictID; 1344 1345 // Push a new used names scope. 1346 llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames); 1347 1348 // Number the values within this region in a breadth-first order. 1349 unsigned nextBlockID = 0; 1350 for (auto &block : region) { 1351 // Each block gets a unique ID, and all of the operations within it get 1352 // numbered as well. 1353 blockIDs[&block] = nextBlockID++; 1354 numberValuesInBlock(block); 1355 } 1356 1357 // After that we traverse the nested regions. 1358 // TODO: Rework this loop to not use recursion. 1359 for (auto &block : region) { 1360 for (auto &op : block) 1361 for (auto &nestedRegion : op.getRegions()) 1362 numberValuesInRegion(nestedRegion); 1363 } 1364 1365 // Restore the original value ids. 1366 nextValueID = curValueID; 1367 nextArgumentID = curArgumentID; 1368 nextConflictID = curConflictID; 1369 } 1370 1371 /// Number all of the SSA values in the specified block, without traversing 1372 /// nested regions. 1373 void OperationPrinter::numberValuesInBlock(Block &block) { 1374 // Number the block arguments. 1375 for (auto *arg : block.getArguments()) 1376 numberValueID(arg); 1377 1378 // We number operation that have results, and we only number the first result. 1379 for (auto &op : block) 1380 if (op.getNumResults() != 0) 1381 numberValueID(op.getResult(0)); 1382 } 1383 1384 void OperationPrinter::numberValueID(Value *value) { 1385 assert(!valueIDs.count(value) && "Value numbered multiple times"); 1386 1387 SmallString<32> specialNameBuffer; 1388 llvm::raw_svector_ostream specialName(specialNameBuffer); 1389 1390 // Check to see if this value requested a special name. 1391 auto *op = value->getDefiningOp(); 1392 if (state && op) { 1393 if (auto *interface = state->getOpAsmInterface(op->getDialect())) 1394 interface->getOpResultName(op, specialName); 1395 } 1396 1397 if (specialNameBuffer.empty()) { 1398 switch (value->getKind()) { 1399 case Value::Kind::BlockArgument: 1400 // If this is an argument to the entry block of a region, give it an 'arg' 1401 // name. 1402 if (auto *block = cast<BlockArgument>(value)->getOwner()) { 1403 auto *parentRegion = block->getParent(); 1404 if (parentRegion && block == &parentRegion->front()) { 1405 specialName << "arg" << nextArgumentID++; 1406 break; 1407 } 1408 } 1409 // Otherwise number it normally. 1410 valueIDs[value] = nextValueID++; 1411 return; 1412 case Value::Kind::OpResult: 1413 // This is an uninteresting result, give it a boring number and be 1414 // done with it. 1415 valueIDs[value] = nextValueID++; 1416 return; 1417 } 1418 } 1419 1420 // Ok, this value had an interesting name. Remember it with a sentinel. 1421 valueIDs[value] = nameSentinel; 1422 valueNames[value] = uniqueValueName(specialName.str()); 1423 } 1424 1425 /// Uniques the given value name within the printer. If the given name 1426 /// conflicts, it is automatically renamed. 1427 StringRef OperationPrinter::uniqueValueName(StringRef name) { 1428 // Check to see if this name is already unique. 1429 if (!usedNames.count(name)) { 1430 name = name.copy(usedNameAllocator); 1431 } else { 1432 // Otherwise, we had a conflict - probe until we find a unique name. This 1433 // is guaranteed to terminate (and usually in a single iteration) because it 1434 // generates new names by incrementing nextConflictID. 1435 SmallString<64> probeName(name); 1436 probeName.push_back('_'); 1437 while (1) { 1438 probeName.resize(name.size() + 1); 1439 probeName += llvm::utostr(nextConflictID++); 1440 if (!usedNames.count(probeName)) { 1441 name = StringRef(probeName).copy(usedNameAllocator); 1442 break; 1443 } 1444 } 1445 } 1446 1447 usedNames.insert(name, char()); 1448 return name; 1449 } 1450 1451 void OperationPrinter::print(Block *block, bool printBlockArgs, 1452 bool printBlockTerminator) { 1453 // Print the block label and argument list if requested. 1454 if (printBlockArgs) { 1455 os.indent(currentIndent); 1456 printBlockName(block); 1457 1458 // Print the argument list if non-empty. 1459 if (!block->args_empty()) { 1460 os << '('; 1461 interleaveComma(block->getArguments(), [&](BlockArgument *arg) { 1462 printValueID(arg); 1463 os << ": "; 1464 printType(arg->getType()); 1465 }); 1466 os << ')'; 1467 } 1468 os << ':'; 1469 1470 // Print out some context information about the predecessors of this block. 1471 if (!block->getParent()) { 1472 os << "\t// block is not in a region!"; 1473 } else if (block->hasNoPredecessors()) { 1474 os << "\t// no predecessors"; 1475 } else if (auto *pred = block->getSinglePredecessor()) { 1476 os << "\t// pred: "; 1477 printBlockName(pred); 1478 } else { 1479 // We want to print the predecessors in increasing numeric order, not in 1480 // whatever order the use-list is in, so gather and sort them. 1481 SmallVector<std::pair<unsigned, Block *>, 4> predIDs; 1482 for (auto *pred : block->getPredecessors()) 1483 predIDs.push_back({getBlockID(pred), pred}); 1484 llvm::array_pod_sort(predIDs.begin(), predIDs.end()); 1485 1486 os << "\t// " << predIDs.size() << " preds: "; 1487 1488 interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) { 1489 printBlockName(pred.second); 1490 }); 1491 } 1492 os << '\n'; 1493 } 1494 1495 currentIndent += indentWidth; 1496 auto range = llvm::make_range( 1497 block->getOperations().begin(), 1498 std::prev(block->getOperations().end(), printBlockTerminator ? 0 : 1)); 1499 for (auto &op : range) { 1500 print(&op); 1501 os << '\n'; 1502 } 1503 currentIndent -= indentWidth; 1504 } 1505 1506 void OperationPrinter::print(Operation *op) { 1507 os.indent(currentIndent); 1508 printOperation(op); 1509 printTrailingLocation(op->getLoc()); 1510 } 1511 1512 void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, 1513 raw_ostream &stream) const { 1514 if (!value) { 1515 stream << "<<NULL>>"; 1516 return; 1517 } 1518 1519 int resultNo = -1; 1520 auto lookupValue = value; 1521 1522 // If this is a reference to the result of a multi-result operation or 1523 // operation, print out the # identifier and make sure to map our lookup 1524 // to the first result of the operation. 1525 if (auto *result = dyn_cast<OpResult>(value)) { 1526 if (result->getOwner()->getNumResults() != 1) { 1527 resultNo = result->getResultNumber(); 1528 lookupValue = result->getOwner()->getResult(0); 1529 } 1530 } 1531 1532 auto it = valueIDs.find(lookupValue); 1533 if (it == valueIDs.end()) { 1534 stream << "<<INVALID SSA VALUE>>"; 1535 return; 1536 } 1537 1538 stream << '%'; 1539 if (it->second != nameSentinel) { 1540 stream << it->second; 1541 } else { 1542 auto nameIt = valueNames.find(lookupValue); 1543 assert(nameIt != valueNames.end() && "Didn't have a name entry?"); 1544 stream << nameIt->second; 1545 } 1546 1547 if (resultNo != -1 && printResultNo) 1548 stream << '#' << resultNo; 1549 } 1550 1551 /// Renumber the arguments for the specified region to the same names as the 1552 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 1553 /// operations. If any entry in namesToUse is null, the corresponding 1554 /// argument name is left alone. 1555 void OperationPrinter::shadowRegionArgs(Region ®ion, 1556 ArrayRef<Value *> namesToUse) { 1557 assert(!region.empty() && "cannot shadow arguments of an empty region"); 1558 assert(region.front().getNumArguments() == namesToUse.size() && 1559 "incorrect number of names passed in"); 1560 assert(region.getParentOp()->isKnownIsolatedFromAbove() && 1561 "only KnownIsolatedFromAbove ops can shadow names"); 1562 1563 SmallVector<char, 16> nameStr; 1564 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { 1565 auto *nameToUse = namesToUse[i]; 1566 if (nameToUse == nullptr) 1567 continue; 1568 1569 auto *nameToReplace = region.front().getArgument(i); 1570 1571 nameStr.clear(); 1572 llvm::raw_svector_ostream nameStream(nameStr); 1573 printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream); 1574 1575 // Entry block arguments should already have a pretty "arg" name. 1576 assert(valueIDs[nameToReplace] == nameSentinel); 1577 1578 // Use the name without the leading %. 1579 auto name = StringRef(nameStream.str()).drop_front(); 1580 1581 // Overwrite the name. 1582 valueNames[nameToReplace] = name.copy(usedNameAllocator); 1583 } 1584 } 1585 1586 void OperationPrinter::printOperation(Operation *op) { 1587 if (size_t numResults = op->getNumResults()) { 1588 printValueID(op->getResult(0), /*printResultNo=*/false); 1589 if (numResults > 1) 1590 os << ':' << numResults; 1591 os << " = "; 1592 } 1593 1594 // TODO(riverriddle): FuncOp cannot be round-tripped currently, as 1595 // FunctionType cannot be used in a TypeAttr. 1596 if (printGenericOpForm && !isa<FuncOp>(op)) 1597 return printGenericOp(op); 1598 1599 // Check to see if this is a known operation. If so, use the registered 1600 // custom printer hook. 1601 if (auto *opInfo = op->getAbstractOperation()) { 1602 opInfo->printAssembly(op, this); 1603 return; 1604 } 1605 1606 // Otherwise print with the generic assembly form. 1607 printGenericOp(op); 1608 } 1609 1610 void OperationPrinter::printGenericOp(Operation *op) { 1611 os << '"'; 1612 printEscapedString(op->getName().getStringRef(), os); 1613 os << "\"("; 1614 1615 // Get the list of operands that are not successor operands. 1616 unsigned totalNumSuccessorOperands = 0; 1617 unsigned numSuccessors = op->getNumSuccessors(); 1618 for (unsigned i = 0; i < numSuccessors; ++i) 1619 totalNumSuccessorOperands += op->getNumSuccessorOperands(i); 1620 unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands; 1621 SmallVector<Value *, 8> properOperands( 1622 op->operand_begin(), std::next(op->operand_begin(), numProperOperands)); 1623 1624 interleaveComma(properOperands, [&](Value *value) { printValueID(value); }); 1625 1626 os << ')'; 1627 1628 // For terminators, print the list of successors and their operands. 1629 if (numSuccessors != 0) { 1630 os << '['; 1631 for (unsigned i = 0; i < numSuccessors; ++i) { 1632 if (i != 0) 1633 os << ", "; 1634 printSuccessorAndUseList(op, i); 1635 } 1636 os << ']'; 1637 } 1638 1639 // Print regions. 1640 if (op->getNumRegions() != 0) { 1641 os << " ("; 1642 interleaveComma(op->getRegions(), [&](Region ®ion) { 1643 printRegion(region, /*printEntryBlockArgs=*/true, 1644 /*printBlockTerminators=*/true); 1645 }); 1646 os << ')'; 1647 } 1648 1649 auto attrs = op->getAttrs(); 1650 printOptionalAttrDict(attrs); 1651 1652 // Print the type signature of the operation. 1653 os << " : "; 1654 printFunctionalType(op); 1655 } 1656 1657 void OperationPrinter::printSuccessorAndUseList(Operation *term, 1658 unsigned index) { 1659 printBlockName(term->getSuccessor(index)); 1660 1661 auto succOperands = term->getSuccessorOperands(index); 1662 if (succOperands.begin() == succOperands.end()) 1663 return; 1664 1665 os << '('; 1666 interleaveComma(succOperands, 1667 [this](Value *operand) { printValueID(operand); }); 1668 os << " : "; 1669 interleaveComma(succOperands, 1670 [this](Value *operand) { printType(operand->getType()); }); 1671 os << ')'; 1672 } 1673 1674 void ModulePrinter::print(ModuleOp module) { 1675 // Output the aliases at the top level. 1676 if (state) { 1677 state->printAttributeAliases(os); 1678 state->printTypeAliases(os); 1679 } 1680 1681 // Print the module. 1682 OperationPrinter(module, *this).print(module); 1683 os << '\n'; 1684 } 1685 1686 //===----------------------------------------------------------------------===// 1687 // print and dump methods 1688 //===----------------------------------------------------------------------===// 1689 1690 void Attribute::print(raw_ostream &os) const { 1691 ModulePrinter(os).printAttribute(*this); 1692 } 1693 1694 void Attribute::dump() const { 1695 print(llvm::errs()); 1696 llvm::errs() << "\n"; 1697 } 1698 1699 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); } 1700 1701 void Type::dump() { print(llvm::errs()); } 1702 1703 void AffineMap::dump() const { 1704 print(llvm::errs()); 1705 llvm::errs() << "\n"; 1706 } 1707 1708 void IntegerSet::dump() const { 1709 print(llvm::errs()); 1710 llvm::errs() << "\n"; 1711 } 1712 1713 void AffineExpr::print(raw_ostream &os) const { 1714 if (expr == nullptr) { 1715 os << "null affine expr"; 1716 return; 1717 } 1718 ModulePrinter(os).printAffineExpr(*this); 1719 } 1720 1721 void AffineExpr::dump() const { 1722 print(llvm::errs()); 1723 llvm::errs() << "\n"; 1724 } 1725 1726 void AffineMap::print(raw_ostream &os) const { 1727 if (map == nullptr) { 1728 os << "null affine map"; 1729 return; 1730 } 1731 ModulePrinter(os).printAffineMap(*this); 1732 } 1733 1734 void IntegerSet::print(raw_ostream &os) const { 1735 ModulePrinter(os).printIntegerSet(*this); 1736 } 1737 1738 void Value::print(raw_ostream &os) { 1739 switch (getKind()) { 1740 case Value::Kind::BlockArgument: 1741 // TODO: Improve this. 1742 os << "<block argument>\n"; 1743 return; 1744 case Value::Kind::OpResult: 1745 return getDefiningOp()->print(os); 1746 } 1747 } 1748 1749 void Value::dump() { print(llvm::errs()); } 1750 1751 void Operation::print(raw_ostream &os) { 1752 // Handle top-level operations. 1753 if (!getParent()) { 1754 ModulePrinter modulePrinter(os); 1755 OperationPrinter(this, modulePrinter).print(this); 1756 return; 1757 } 1758 1759 auto region = getParentRegion(); 1760 if (!region) { 1761 os << "<<UNLINKED INSTRUCTION>>\n"; 1762 return; 1763 } 1764 1765 // Get the top-level region. 1766 while (auto *nextRegion = region->getParentRegion()) 1767 region = nextRegion; 1768 1769 ModuleState state(getContext()); 1770 ModulePrinter modulePrinter(os, &state); 1771 OperationPrinter(region, modulePrinter).print(this); 1772 } 1773 1774 void Operation::dump() { 1775 print(llvm::errs()); 1776 llvm::errs() << "\n"; 1777 } 1778 1779 void Block::print(raw_ostream &os) { 1780 auto region = getParent(); 1781 if (!region) { 1782 os << "<<UNLINKED BLOCK>>\n"; 1783 return; 1784 } 1785 1786 // Get the top-level region. 1787 while (auto *nextRegion = region->getParentRegion()) 1788 region = nextRegion; 1789 1790 ModuleState state(region->getContext()); 1791 ModulePrinter modulePrinter(os, &state); 1792 OperationPrinter(region, modulePrinter).print(this); 1793 } 1794 1795 void Block::dump() { print(llvm::errs()); } 1796 1797 /// Print out the name of the block without printing its body. 1798 void Block::printAsOperand(raw_ostream &os, bool printType) { 1799 auto region = getParent(); 1800 if (!region) { 1801 os << "<<UNLINKED BLOCK>>\n"; 1802 return; 1803 } 1804 1805 // Get the top-level region. 1806 while (auto *nextRegion = region->getParentRegion()) 1807 region = nextRegion; 1808 1809 ModulePrinter modulePrinter(os); 1810 OperationPrinter(region, modulePrinter).printBlockName(this); 1811 } 1812 1813 void ModuleOp::print(raw_ostream &os) { 1814 ModuleState state(getContext()); 1815 state.initialize(*this); 1816 ModulePrinter(os, &state).print(*this); 1817 } 1818 1819 void ModuleOp::dump() { print(llvm::errs()); }