github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/Attributes.cpp (about) 1 //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/Attributes.h" 19 #include "AttributeDetail.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/Diagnostics.h" 22 #include "mlir/IR/Dialect.h" 23 #include "mlir/IR/Function.h" 24 #include "mlir/IR/IntegerSet.h" 25 #include "mlir/IR/Types.h" 26 #include "llvm/ADT/Sequence.h" 27 #include "llvm/ADT/Twine.h" 28 29 using namespace mlir; 30 using namespace mlir::detail; 31 32 //===----------------------------------------------------------------------===// 33 // AttributeStorage 34 //===----------------------------------------------------------------------===// 35 36 AttributeStorage::AttributeStorage(Type type) 37 : type(type.getAsOpaquePointer()) {} 38 AttributeStorage::AttributeStorage() : type(nullptr) {} 39 40 Type AttributeStorage::getType() const { 41 return Type::getFromOpaquePointer(type); 42 } 43 void AttributeStorage::setType(Type newType) { 44 type = newType.getAsOpaquePointer(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 // Attribute 49 //===----------------------------------------------------------------------===// 50 51 /// Return the type of this attribute. 52 Type Attribute::getType() const { return impl->getType(); } 53 54 /// Return the context this attribute belongs to. 55 MLIRContext *Attribute::getContext() const { return getType().getContext(); } 56 57 /// Get the dialect this attribute is registered to. 58 Dialect &Attribute::getDialect() const { return impl->getDialect(); } 59 60 //===----------------------------------------------------------------------===// 61 // AffineMapAttr 62 //===----------------------------------------------------------------------===// 63 64 AffineMapAttr AffineMapAttr::get(AffineMap value) { 65 return Base::get(value.getContext(), StandardAttributes::AffineMap, value); 66 } 67 68 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } 69 70 //===----------------------------------------------------------------------===// 71 // ArrayAttr 72 //===----------------------------------------------------------------------===// 73 74 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) { 75 return Base::get(context, StandardAttributes::Array, value); 76 } 77 78 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; } 79 80 //===----------------------------------------------------------------------===// 81 // BoolAttr 82 //===----------------------------------------------------------------------===// 83 84 bool BoolAttr::getValue() const { return getImpl()->value; } 85 86 //===----------------------------------------------------------------------===// 87 // DictionaryAttr 88 //===----------------------------------------------------------------------===// 89 90 /// Perform a three-way comparison between the names of the specified 91 /// NamedAttributes. 92 static int compareNamedAttributes(const NamedAttribute *lhs, 93 const NamedAttribute *rhs) { 94 return lhs->first.str().compare(rhs->first.str()); 95 } 96 97 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value, 98 MLIRContext *context) { 99 assert(llvm::all_of(value, 100 [](const NamedAttribute &attr) { return attr.second; }) && 101 "value cannot have null entries"); 102 103 // We need to sort the element list to canonicalize it, but we also don't want 104 // to do a ton of work in the super common case where the element list is 105 // already sorted. 106 SmallVector<NamedAttribute, 8> storage; 107 switch (value.size()) { 108 case 0: 109 break; 110 case 1: 111 // A single element is already sorted. 112 break; 113 case 2: 114 assert(value[0].first != value[1].first && 115 "DictionaryAttr element names must be unique"); 116 117 // Don't invoke a general sort for two element case. 118 if (value[0].first.strref() > value[1].first.strref()) { 119 storage.push_back(value[1]); 120 storage.push_back(value[0]); 121 value = storage; 122 } 123 break; 124 default: 125 // Check to see they are sorted already. 126 bool isSorted = true; 127 for (unsigned i = 0, e = value.size() - 1; i != e; ++i) { 128 if (value[i].first.strref() > value[i + 1].first.strref()) { 129 isSorted = false; 130 break; 131 } 132 } 133 // If not, do a general sort. 134 if (!isSorted) { 135 storage.append(value.begin(), value.end()); 136 llvm::array_pod_sort(storage.begin(), storage.end(), 137 compareNamedAttributes); 138 value = storage; 139 } 140 141 // Ensure that the attribute elements are unique. 142 assert(std::adjacent_find(value.begin(), value.end(), 143 [](NamedAttribute l, NamedAttribute r) { 144 return l.first == r.first; 145 }) == value.end() && 146 "DictionaryAttr element names must be unique"); 147 } 148 149 return Base::get(context, StandardAttributes::Dictionary, value); 150 } 151 152 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const { 153 return getImpl()->getElements(); 154 } 155 156 /// Return the specified attribute if present, null otherwise. 157 Attribute DictionaryAttr::get(StringRef name) const { 158 for (auto elt : getValue()) 159 if (elt.first.is(name)) 160 return elt.second; 161 return nullptr; 162 } 163 Attribute DictionaryAttr::get(Identifier name) const { 164 for (auto elt : getValue()) 165 if (elt.first == name) 166 return elt.second; 167 return nullptr; 168 } 169 170 DictionaryAttr::iterator DictionaryAttr::begin() const { 171 return getValue().begin(); 172 } 173 DictionaryAttr::iterator DictionaryAttr::end() const { 174 return getValue().end(); 175 } 176 size_t DictionaryAttr::size() const { return getValue().size(); } 177 178 //===----------------------------------------------------------------------===// 179 // FloatAttr 180 //===----------------------------------------------------------------------===// 181 182 FloatAttr FloatAttr::get(Type type, double value) { 183 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 184 } 185 186 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { 187 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 188 type, value); 189 } 190 191 FloatAttr FloatAttr::get(Type type, const APFloat &value) { 192 return Base::get(type.getContext(), StandardAttributes::Float, type, value); 193 } 194 195 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { 196 return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, 197 type, value); 198 } 199 200 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } 201 202 double FloatAttr::getValueAsDouble() const { 203 return getValueAsDouble(getValue()); 204 } 205 double FloatAttr::getValueAsDouble(APFloat value) { 206 if (&value.getSemantics() != &APFloat::IEEEdouble()) { 207 bool losesInfo = false; 208 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 209 &losesInfo); 210 } 211 return value.convertToDouble(); 212 } 213 214 /// Verify construction invariants. 215 static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc, 216 Type type) { 217 if (!type.isa<FloatType>()) { 218 if (loc) 219 emitError(*loc, "expected floating point type"); 220 return failure(); 221 } 222 return success(); 223 } 224 225 LogicalResult FloatAttr::verifyConstructionInvariants( 226 llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) { 227 return verifyFloatTypeInvariants(loc, type); 228 } 229 230 LogicalResult 231 FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc, 232 MLIRContext *ctx, Type type, 233 const APFloat &value) { 234 // Verify that the type is correct. 235 if (failed(verifyFloatTypeInvariants(loc, type))) 236 return failure(); 237 238 // Verify that the type semantics match that of the value. 239 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { 240 if (loc) 241 emitError(*loc, 242 "FloatAttr type doesn't match the type implied by its value"); 243 return failure(); 244 } 245 return success(); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // SymbolRefAttr 250 //===----------------------------------------------------------------------===// 251 252 SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { 253 return Base::get(ctx, StandardAttributes::SymbolRef, value, 254 NoneType::get(ctx)); 255 } 256 257 StringRef SymbolRefAttr::getValue() const { return getImpl()->value; } 258 259 //===----------------------------------------------------------------------===// 260 // IntegerAttr 261 //===----------------------------------------------------------------------===// 262 263 IntegerAttr IntegerAttr::get(Type type, const APInt &value) { 264 return Base::get(type.getContext(), StandardAttributes::Integer, type, value); 265 } 266 267 IntegerAttr IntegerAttr::get(Type type, int64_t value) { 268 // This uses 64 bit APInts by default for index type. 269 if (type.isIndex()) 270 return get(type, APInt(64, value)); 271 272 auto intType = type.cast<IntegerType>(); 273 return get(type, APInt(intType.getWidth(), value)); 274 } 275 276 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } 277 278 int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } 279 280 //===----------------------------------------------------------------------===// 281 // IntegerSetAttr 282 //===----------------------------------------------------------------------===// 283 284 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { 285 return Base::get(value.getConstraint(0).getContext(), 286 StandardAttributes::IntegerSet, value); 287 } 288 289 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } 290 291 //===----------------------------------------------------------------------===// 292 // OpaqueAttr 293 //===----------------------------------------------------------------------===// 294 295 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, 296 MLIRContext *context) { 297 return Base::get(context, StandardAttributes::Opaque, dialect, attrData, 298 type); 299 } 300 301 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, 302 Type type, Location location) { 303 return Base::getChecked(location, type.getContext(), 304 StandardAttributes::Opaque, dialect, attrData, type); 305 } 306 307 /// Returns the dialect namespace of the opaque attribute. 308 Identifier OpaqueAttr::getDialectNamespace() const { 309 return getImpl()->dialectNamespace; 310 } 311 312 /// Returns the raw attribute data of the opaque attribute. 313 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } 314 315 /// Verify the construction of an opaque attribute. 316 LogicalResult OpaqueAttr::verifyConstructionInvariants( 317 llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect, 318 StringRef attrData, Type type) { 319 if (!Dialect::isValidNamespace(dialect.strref())) { 320 if (loc) 321 emitError(*loc) << "invalid dialect namespace '" << dialect << "'"; 322 return failure(); 323 } 324 return success(); 325 } 326 327 //===----------------------------------------------------------------------===// 328 // StringAttr 329 //===----------------------------------------------------------------------===// 330 331 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { 332 return get(bytes, NoneType::get(context)); 333 } 334 335 /// Get an instance of a StringAttr with the given string and Type. 336 StringAttr StringAttr::get(StringRef bytes, Type type) { 337 return Base::get(type.getContext(), StandardAttributes::String, bytes, type); 338 } 339 340 StringRef StringAttr::getValue() const { return getImpl()->value; } 341 342 //===----------------------------------------------------------------------===// 343 // TypeAttr 344 //===----------------------------------------------------------------------===// 345 346 TypeAttr TypeAttr::get(Type value) { 347 return Base::get(value.getContext(), StandardAttributes::Type, value); 348 } 349 350 Type TypeAttr::getValue() const { return getImpl()->value; } 351 352 //===----------------------------------------------------------------------===// 353 // ElementsAttr 354 //===----------------------------------------------------------------------===// 355 356 ShapedType ElementsAttr::getType() const { 357 return Attribute::getType().cast<ShapedType>(); 358 } 359 360 /// Returns the number of elements held by this attribute. 361 int64_t ElementsAttr::getNumElements() const { 362 return getType().getNumElements(); 363 } 364 365 /// Return the value at the given index. If index does not refer to a valid 366 /// element, then a null attribute is returned. 367 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const { 368 switch (getKind()) { 369 case StandardAttributes::DenseElements: 370 return cast<DenseElementsAttr>().getValue(index); 371 case StandardAttributes::OpaqueElements: 372 return cast<OpaqueElementsAttr>().getValue(index); 373 case StandardAttributes::SparseElements: 374 return cast<SparseElementsAttr>().getValue(index); 375 default: 376 llvm_unreachable("unknown ElementsAttr kind"); 377 } 378 } 379 380 /// Return if the given 'index' refers to a valid element in this attribute. 381 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const { 382 auto type = getType(); 383 384 // Verify that the rank of the indices matches the held type. 385 auto rank = type.getRank(); 386 if (rank != static_cast<int64_t>(index.size())) 387 return false; 388 389 // Verify that all of the indices are within the shape dimensions. 390 auto shape = type.getShape(); 391 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 392 return static_cast<int64_t>(index[i]) < shape[i]; 393 }); 394 } 395 396 ElementsAttr ElementsAttr::mapValues( 397 Type newElementType, 398 llvm::function_ref<APInt(const APInt &)> mapping) const { 399 switch (getKind()) { 400 case StandardAttributes::DenseElements: 401 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 402 default: 403 llvm_unreachable("unsupported ElementsAttr subtype"); 404 } 405 } 406 407 ElementsAttr ElementsAttr::mapValues( 408 Type newElementType, 409 llvm::function_ref<APInt(const APFloat &)> mapping) const { 410 switch (getKind()) { 411 case StandardAttributes::DenseElements: 412 return cast<DenseElementsAttr>().mapValues(newElementType, mapping); 413 default: 414 llvm_unreachable("unsupported ElementsAttr subtype"); 415 } 416 } 417 418 /// Returns the 1 dimenional flattened row-major index from the given 419 /// multi-dimensional index. 420 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const { 421 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 422 auto type = getType(); 423 424 // Reduce the provided multidimensional index into a flattended 1D row-major 425 // index. 426 auto rank = type.getRank(); 427 auto shape = type.getShape(); 428 uint64_t valueIndex = 0; 429 uint64_t dimMultiplier = 1; 430 for (int i = rank - 1; i >= 0; --i) { 431 valueIndex += index[i] * dimMultiplier; 432 dimMultiplier *= shape[i]; 433 } 434 return valueIndex; 435 } 436 437 //===----------------------------------------------------------------------===// 438 // DenseElementAttr Utilities 439 //===----------------------------------------------------------------------===// 440 441 static size_t getDenseElementBitwidth(Type eltType) { 442 // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored 443 // with double semantics. 444 return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); 445 } 446 447 /// Get the bitwidth of a dense element type within the buffer. 448 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. 449 static size_t getDenseElementStorageWidth(size_t origWidth) { 450 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); 451 } 452 453 /// Set a bit to a specific value. 454 static void setBit(char *rawData, size_t bitPos, bool value) { 455 if (value) 456 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT)); 457 else 458 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT)); 459 } 460 461 /// Return the value of the specified bit. 462 static bool getBit(const char *rawData, size_t bitPos) { 463 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0; 464 } 465 466 /// Writes value to the bit position `bitPos` in array `rawData`. 467 static void writeBits(char *rawData, size_t bitPos, APInt value) { 468 size_t bitWidth = value.getBitWidth(); 469 470 // If the bitwidth is 1 we just toggle the specific bit. 471 if (bitWidth == 1) 472 return setBit(rawData, bitPos, value.isOneValue()); 473 474 // Otherwise, the bit position is guaranteed to be byte aligned. 475 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 476 std::copy_n(reinterpret_cast<const char *>(value.getRawData()), 477 llvm::divideCeil(bitWidth, CHAR_BIT), 478 rawData + (bitPos / CHAR_BIT)); 479 } 480 481 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array 482 /// `rawData`. 483 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { 484 // Handle a boolean bit position. 485 if (bitWidth == 1) 486 return APInt(1, getBit(rawData, bitPos) ? 1 : 0); 487 488 // Otherwise, the bit position must be 8-bit aligned. 489 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); 490 APInt result(bitWidth, 0); 491 std::copy_n( 492 rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT), 493 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); 494 return result; 495 } 496 497 /// Returns if 'values' corresponds to a splat, i.e. one element, or has the 498 /// same element count as 'type'. 499 template <typename Values> 500 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { 501 return (values.size() == 1) || 502 (type.getNumElements() == static_cast<int64_t>(values.size())); 503 } 504 505 //===----------------------------------------------------------------------===// 506 // DenseElementAttr Iterators 507 //===----------------------------------------------------------------------===// 508 509 /// Constructs a new iterator. 510 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( 511 DenseElementsAttr attr, size_t index) 512 : indexed_accessor_iterator<AttributeElementIterator, const void *, 513 Attribute, Attribute, Attribute>( 514 attr.getAsOpaquePointer(), index) {} 515 516 /// Accesses the Attribute value at this iterator position. 517 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { 518 auto owner = getFromOpaquePointer(object).cast<DenseElementsAttr>(); 519 Type eltTy = owner.getType().getElementType(); 520 if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) { 521 if (intEltTy.getWidth() == 1) 522 return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(), 523 owner.getContext()); 524 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); 525 } 526 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { 527 IntElementIterator intIt(owner, index); 528 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); 529 return FloatAttr::get(eltTy, *floatIt); 530 } 531 llvm_unreachable("unexpected element type"); 532 } 533 534 /// Constructs a new iterator. 535 DenseElementsAttr::BoolElementIterator::BoolElementIterator( 536 DenseElementsAttr attr, size_t dataIndex) 537 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( 538 attr.getRawData().data(), attr.isSplat(), dataIndex) {} 539 540 /// Accesses the bool value at this iterator position. 541 bool DenseElementsAttr::BoolElementIterator::operator*() const { 542 return getBit(getData(), getDataIndex()); 543 } 544 545 /// Constructs a new iterator. 546 DenseElementsAttr::IntElementIterator::IntElementIterator( 547 DenseElementsAttr attr, size_t dataIndex) 548 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( 549 attr.getRawData().data(), attr.isSplat(), dataIndex), 550 bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {} 551 552 /// Accesses the raw APInt value at this iterator position. 553 APInt DenseElementsAttr::IntElementIterator::operator*() const { 554 return readBits(getData(), 555 getDataIndex() * getDenseElementStorageWidth(bitWidth), 556 bitWidth); 557 } 558 559 DenseElementsAttr::FloatElementIterator::FloatElementIterator( 560 const llvm::fltSemantics &smt, IntElementIterator it) 561 : llvm::mapped_iterator<IntElementIterator, 562 std::function<APFloat(const APInt &)>>( 563 it, [&](const APInt &val) { return APFloat(smt, val); }) {} 564 565 //===----------------------------------------------------------------------===// 566 // DenseElementsAttr 567 //===----------------------------------------------------------------------===// 568 569 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 570 ArrayRef<Attribute> values) { 571 assert(type.getElementType().isIntOrFloat() && 572 "expected int or float element type"); 573 assert(hasSameElementsOrSplat(type, values)); 574 575 auto eltType = type.getElementType(); 576 size_t bitWidth = getDenseElementBitwidth(eltType); 577 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 578 579 // Compress the attribute values into a character buffer. 580 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 581 values.size()); 582 APInt intVal; 583 for (unsigned i = 0, e = values.size(); i < e; ++i) { 584 assert(eltType == values[i].getType() && 585 "expected attribute value to have element type"); 586 587 switch (eltType.getKind()) { 588 case StandardTypes::BF16: 589 case StandardTypes::F16: 590 case StandardTypes::F32: 591 case StandardTypes::F64: 592 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt(); 593 break; 594 case StandardTypes::Integer: 595 intVal = values[i].isa<BoolAttr>() 596 ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0) 597 : values[i].cast<IntegerAttr>().getValue(); 598 break; 599 default: 600 llvm_unreachable("unexpected element type"); 601 } 602 assert(intVal.getBitWidth() == bitWidth && 603 "expected value to have same bitwidth as element type"); 604 writeBits(data.data(), i * storageBitWidth, intVal); 605 } 606 return getRaw(type, data, /*isSplat=*/(values.size() == 1)); 607 } 608 609 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 610 ArrayRef<bool> values) { 611 assert(hasSameElementsOrSplat(type, values)); 612 assert(type.getElementType().isInteger(1)); 613 614 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); 615 for (int i = 0, e = values.size(); i != e; ++i) 616 setBit(buff.data(), i, values[i]); 617 return getRaw(type, buff, /*isSplat=*/(values.size() == 1)); 618 } 619 620 /// Constructs a dense integer elements attribute from an array of APInt 621 /// values. Each APInt value is expected to have the same bitwidth as the 622 /// element type of 'type'. 623 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 624 ArrayRef<APInt> values) { 625 assert(type.getElementType().isa<IntegerType>()); 626 return getRaw(type, values); 627 } 628 629 // Constructs a dense float elements attribute from an array of APFloat 630 // values. Each APFloat value is expected to have the same bitwidth as the 631 // element type of 'type'. 632 DenseElementsAttr DenseElementsAttr::get(ShapedType type, 633 ArrayRef<APFloat> values) { 634 assert(type.getElementType().isa<FloatType>()); 635 636 // Convert the APFloat values to APInt and create a dense elements attribute. 637 std::vector<APInt> intValues(values.size()); 638 for (unsigned i = 0, e = values.size(); i != e; ++i) 639 intValues[i] = values[i].bitcastToAPInt(); 640 return getRaw(type, intValues); 641 } 642 643 // Constructs a dense elements attribute from an array of raw APInt values. 644 // Each APInt value is expected to have the same bitwidth as the element type 645 // of 'type'. 646 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 647 ArrayRef<APInt> values) { 648 assert(hasSameElementsOrSplat(type, values)); 649 650 size_t bitWidth = getDenseElementBitwidth(type.getElementType()); 651 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 652 std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * 653 values.size()); 654 for (unsigned i = 0, e = values.size(); i != e; ++i) { 655 assert(values[i].getBitWidth() == bitWidth); 656 writeBits(elementData.data(), i * storageBitWidth, values[i]); 657 } 658 return getRaw(type, elementData, /*isSplat=*/(values.size() == 1)); 659 } 660 661 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type, 662 ArrayRef<char> data, bool isSplat) { 663 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 664 "type must be ranked tensor or vector"); 665 assert(type.hasStaticShape() && "type must have static shape"); 666 return Base::get(type.getContext(), StandardAttributes::DenseElements, type, 667 data, isSplat); 668 } 669 670 /// Check the information for a c++ data type, check if this type is valid for 671 /// the current attribute. This method is used to verify specific type 672 /// invariants that the templatized 'getValues' method cannot. 673 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, 674 bool isInt) { 675 // Make sure that the data element size is the same as the type element width. 676 if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth()) 677 return false; 678 679 // Check that the element type is valid. 680 return isInt ? type.getElementType().isa<IntegerType>() 681 : type.getElementType().isa<FloatType>(); 682 } 683 684 /// Overload of the 'getRaw' method that asserts that the given type is of 685 /// integer type. This method is used to verify type invariants that the 686 /// templatized 'get' method cannot. 687 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, 688 ArrayRef<char> data, 689 int64_t dataEltSize, 690 bool isInt) { 691 assert(::isValidIntOrFloat(type, dataEltSize, isInt)); 692 693 int64_t numElements = data.size() / dataEltSize; 694 assert(numElements == 1 || numElements == type.getNumElements()); 695 return getRaw(type, data, /*isSplat=*/numElements == 1); 696 } 697 698 /// A method used to verify specific type invariants that the templatized 'get' 699 /// method cannot. 700 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, 701 bool isInt) const { 702 return ::isValidIntOrFloat(getType(), dataEltSize, isInt); 703 } 704 705 /// Return the raw storage data held by this attribute. 706 ArrayRef<char> DenseElementsAttr::getRawData() const { 707 return static_cast<ImplType *>(impl)->data; 708 } 709 710 /// Returns if this attribute corresponds to a splat, i.e. if all element 711 /// values are the same. 712 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; } 713 714 /// Return the held element values as a range of Attributes. 715 auto DenseElementsAttr::getAttributeValues() const 716 -> llvm::iterator_range<AttributeElementIterator> { 717 return {attr_value_begin(), attr_value_end()}; 718 } 719 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator { 720 return AttributeElementIterator(*this, 0); 721 } 722 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator { 723 return AttributeElementIterator(*this, getNumElements()); 724 } 725 726 /// Return the held element values as a range of bool. The element type of 727 /// this attribute must be of integer type of bitwidth 1. 728 auto DenseElementsAttr::getBoolValues() const 729 -> llvm::iterator_range<BoolElementIterator> { 730 auto eltType = getType().getElementType().dyn_cast<IntegerType>(); 731 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type"); 732 (void)eltType; 733 return {BoolElementIterator(*this, 0), 734 BoolElementIterator(*this, getNumElements())}; 735 } 736 737 /// Return the held element values as a range of APInts. The element type of 738 /// this attribute must be of integer type. 739 auto DenseElementsAttr::getIntValues() const 740 -> llvm::iterator_range<IntElementIterator> { 741 assert(getType().getElementType().isa<IntegerType>() && 742 "expected integer type"); 743 return {raw_int_begin(), raw_int_end()}; 744 } 745 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { 746 assert(getType().getElementType().isa<IntegerType>() && 747 "expected integer type"); 748 return raw_int_begin(); 749 } 750 auto DenseElementsAttr::int_value_end() const -> IntElementIterator { 751 assert(getType().getElementType().isa<IntegerType>() && 752 "expected integer type"); 753 return raw_int_end(); 754 } 755 756 /// Return the held element values as a range of APFloat. The element type of 757 /// this attribute must be of float type. 758 auto DenseElementsAttr::getFloatValues() const 759 -> llvm::iterator_range<FloatElementIterator> { 760 auto elementType = getType().getElementType().cast<FloatType>(); 761 assert(elementType.isa<FloatType>() && "expected float type"); 762 const auto &elementSemantics = elementType.getFloatSemantics(); 763 return {FloatElementIterator(elementSemantics, raw_int_begin()), 764 FloatElementIterator(elementSemantics, raw_int_end())}; 765 } 766 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { 767 return getFloatValues().begin(); 768 } 769 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { 770 return getFloatValues().end(); 771 } 772 773 /// Return a new DenseElementsAttr that has the same data as the current 774 /// attribute, but has been reshaped to 'newType'. The new type must have the 775 /// same total number of elements as well as element type. 776 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { 777 ShapedType curType = getType(); 778 if (curType == newType) 779 return *this; 780 781 (void)curType; 782 assert(newType.getElementType() == curType.getElementType() && 783 "expected the same element type"); 784 assert(newType.getNumElements() == curType.getNumElements() && 785 "expected the same number of elements"); 786 return getRaw(newType, getRawData(), isSplat()); 787 } 788 789 DenseElementsAttr DenseElementsAttr::mapValues( 790 Type newElementType, 791 llvm::function_ref<APInt(const APInt &)> mapping) const { 792 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); 793 } 794 795 DenseElementsAttr DenseElementsAttr::mapValues( 796 Type newElementType, 797 llvm::function_ref<APInt(const APFloat &)> mapping) const { 798 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); 799 } 800 801 //===----------------------------------------------------------------------===// 802 // DenseFPElementsAttr 803 //===----------------------------------------------------------------------===// 804 805 template <typename Fn, typename Attr> 806 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, 807 Type newElementType, 808 llvm::SmallVectorImpl<char> &data) { 809 size_t bitWidth = getDenseElementBitwidth(newElementType); 810 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); 811 812 ShapedType newArrayType; 813 if (inType.isa<RankedTensorType>()) 814 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 815 else if (inType.isa<UnrankedTensorType>()) 816 newArrayType = RankedTensorType::get(inType.getShape(), newElementType); 817 else if (inType.isa<VectorType>()) 818 newArrayType = VectorType::get(inType.getShape(), newElementType); 819 else 820 assert(newArrayType && "Unhandled tensor type"); 821 822 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); 823 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements); 824 825 // Functor used to process a single element value of the attribute. 826 auto processElt = [&](decltype(*attr.begin()) value, size_t index) { 827 auto newInt = mapping(value); 828 assert(newInt.getBitWidth() == bitWidth); 829 writeBits(data.data(), index * storageBitWidth, newInt); 830 }; 831 832 // Check for the splat case. 833 if (attr.isSplat()) { 834 processElt(*attr.begin(), /*index=*/0); 835 return newArrayType; 836 } 837 838 // Otherwise, process all of the element values. 839 uint64_t elementIdx = 0; 840 for (auto value : attr) 841 processElt(value, elementIdx++); 842 return newArrayType; 843 } 844 845 DenseElementsAttr DenseFPElementsAttr::mapValues( 846 Type newElementType, 847 llvm::function_ref<APInt(const APFloat &)> mapping) const { 848 llvm::SmallVector<char, 8> elementData; 849 auto newArrayType = 850 mappingHelper(mapping, *this, getType(), newElementType, elementData); 851 852 return getRaw(newArrayType, elementData, isSplat()); 853 } 854 855 /// Method for supporting type inquiry through isa, cast and dyn_cast. 856 bool DenseFPElementsAttr::classof(Attribute attr) { 857 return attr.isa<DenseElementsAttr>() && 858 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>(); 859 } 860 861 //===----------------------------------------------------------------------===// 862 // DenseIntElementsAttr 863 //===----------------------------------------------------------------------===// 864 865 DenseElementsAttr DenseIntElementsAttr::mapValues( 866 Type newElementType, 867 llvm::function_ref<APInt(const APInt &)> mapping) const { 868 llvm::SmallVector<char, 8> elementData; 869 auto newArrayType = 870 mappingHelper(mapping, *this, getType(), newElementType, elementData); 871 872 return getRaw(newArrayType, elementData, isSplat()); 873 } 874 875 /// Method for supporting type inquiry through isa, cast and dyn_cast. 876 bool DenseIntElementsAttr::classof(Attribute attr) { 877 return attr.isa<DenseElementsAttr>() && 878 attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>(); 879 } 880 881 //===----------------------------------------------------------------------===// 882 // OpaqueElementsAttr 883 //===----------------------------------------------------------------------===// 884 885 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, 886 StringRef bytes) { 887 assert(TensorType::isValidElementType(type.getElementType()) && 888 "Input element type should be a valid tensor element type"); 889 return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, 890 dialect, bytes); 891 } 892 893 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } 894 895 /// Return the value at the given index. If index does not refer to a valid 896 /// element, then a null attribute is returned. 897 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const { 898 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 899 if (Dialect *dialect = getDialect()) 900 return dialect->extractElementHook(*this, index); 901 return Attribute(); 902 } 903 904 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } 905 906 bool OpaqueElementsAttr::decode(ElementsAttr &result) { 907 if (auto *d = getDialect()) 908 return d->decodeHook(*this, result); 909 return true; 910 } 911 912 //===----------------------------------------------------------------------===// 913 // SparseElementsAttr 914 //===----------------------------------------------------------------------===// 915 916 SparseElementsAttr SparseElementsAttr::get(ShapedType type, 917 DenseElementsAttr indices, 918 DenseElementsAttr values) { 919 assert(indices.getType().getElementType().isInteger(64) && 920 "expected sparse indices to be 64-bit integer values"); 921 assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) && 922 "type must be ranked tensor or vector"); 923 assert(type.hasStaticShape() && "type must have static shape"); 924 return Base::get(type.getContext(), StandardAttributes::SparseElements, type, 925 indices.cast<DenseIntElementsAttr>(), values); 926 } 927 928 DenseIntElementsAttr SparseElementsAttr::getIndices() const { 929 return getImpl()->indices; 930 } 931 932 DenseElementsAttr SparseElementsAttr::getValues() const { 933 return getImpl()->values; 934 } 935 936 /// Return the value of the element at the given index. 937 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const { 938 assert(isValidIndex(index) && "expected valid multi-dimensional index"); 939 auto type = getType(); 940 941 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 942 // as a 1-D index array. 943 auto sparseIndices = getIndices(); 944 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 945 946 // Check to see if the indices are a splat. 947 if (sparseIndices.isSplat()) { 948 // If the index is also not a splat of the index value, we know that the 949 // value is zero. 950 auto splatIndex = *sparseIndexValues.begin(); 951 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) 952 return getZeroAttr(); 953 954 // If the indices are a splat, we also expect the values to be a splat. 955 assert(getValues().isSplat() && "expected splat values"); 956 return getValues().getSplatValue(); 957 } 958 959 // Build a mapping between known indices and the offset of the stored element. 960 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices; 961 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 962 size_t rank = type.getRank(); 963 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 964 mappedIndices.try_emplace( 965 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); 966 967 // Look for the provided index key within the mapped indices. If the provided 968 // index is not found, then return a zero attribute. 969 auto it = mappedIndices.find(index); 970 if (it == mappedIndices.end()) 971 return getZeroAttr(); 972 973 // Otherwise, return the held sparse value element. 974 return getValues().getValue(it->second); 975 } 976 977 /// Get a zero APFloat for the given sparse attribute. 978 APFloat SparseElementsAttr::getZeroAPFloat() const { 979 auto eltType = getType().getElementType().cast<FloatType>(); 980 return APFloat(eltType.getFloatSemantics()); 981 } 982 983 /// Get a zero APInt for the given sparse attribute. 984 APInt SparseElementsAttr::getZeroAPInt() const { 985 auto eltType = getType().getElementType().cast<IntegerType>(); 986 return APInt::getNullValue(eltType.getWidth()); 987 } 988 989 /// Get a zero attribute for the given attribute type. 990 Attribute SparseElementsAttr::getZeroAttr() const { 991 auto eltType = getType().getElementType(); 992 993 // Handle floating point elements. 994 if (eltType.isa<FloatType>()) 995 return FloatAttr::get(eltType, 0); 996 997 // Otherwise, this is an integer. 998 auto intEltTy = eltType.cast<IntegerType>(); 999 if (intEltTy.getWidth() == 1) 1000 return BoolAttr::get(false, eltType.getContext()); 1001 return IntegerAttr::get(eltType, 0); 1002 } 1003 1004 /// Flatten, and return, all of the sparse indices in this attribute in 1005 /// row-major order. 1006 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { 1007 std::vector<ptrdiff_t> flatSparseIndices; 1008 1009 // The sparse indices are 64-bit integers, so we can reinterpret the raw data 1010 // as a 1-D index array. 1011 auto sparseIndices = getIndices(); 1012 auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); 1013 if (sparseIndices.isSplat()) { 1014 SmallVector<uint64_t, 8> indices(getType().getRank(), 1015 *sparseIndexValues.begin()); 1016 flatSparseIndices.push_back(getFlattenedIndex(indices)); 1017 return flatSparseIndices; 1018 } 1019 1020 // Otherwise, reinterpret each index as an ArrayRef when flattening. 1021 auto numSparseIndices = sparseIndices.getType().getDimSize(0); 1022 size_t rank = getType().getRank(); 1023 for (size_t i = 0, e = numSparseIndices; i != e; ++i) 1024 flatSparseIndices.push_back(getFlattenedIndex( 1025 {&*std::next(sparseIndexValues.begin(), i * rank), rank})); 1026 return flatSparseIndices; 1027 } 1028 1029 //===----------------------------------------------------------------------===// 1030 // NamedAttributeList 1031 //===----------------------------------------------------------------------===// 1032 1033 NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) { 1034 setAttrs(attributes); 1035 } 1036 1037 ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const { 1038 return attrs ? attrs.getValue() : llvm::None; 1039 } 1040 1041 /// Replace the held attributes with ones provided in 'newAttrs'. 1042 void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) { 1043 // Don't create an attribute list if there are no attributes. 1044 if (attributes.empty()) 1045 attrs = nullptr; 1046 else 1047 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext()); 1048 } 1049 1050 /// Return the specified attribute if present, null otherwise. 1051 Attribute NamedAttributeList::get(StringRef name) const { 1052 return attrs ? attrs.get(name) : nullptr; 1053 } 1054 1055 /// Return the specified attribute if present, null otherwise. 1056 Attribute NamedAttributeList::get(Identifier name) const { 1057 return attrs ? attrs.get(name) : nullptr; 1058 } 1059 1060 /// If the an attribute exists with the specified name, change it to the new 1061 /// value. Otherwise, add a new attribute with the specified name/value. 1062 void NamedAttributeList::set(Identifier name, Attribute value) { 1063 assert(value && "attributes may never be null"); 1064 1065 // If we already have this attribute, replace it. 1066 auto origAttrs = getAttrs(); 1067 SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); 1068 for (auto &elt : newAttrs) 1069 if (elt.first == name) { 1070 elt.second = value; 1071 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1072 return; 1073 } 1074 1075 // Otherwise, add it. 1076 newAttrs.push_back({name, value}); 1077 attrs = DictionaryAttr::get(newAttrs, value.getContext()); 1078 } 1079 1080 /// Remove the attribute with the specified name if it exists. The return 1081 /// value indicates whether the attribute was present or not. 1082 auto NamedAttributeList::remove(Identifier name) -> RemoveResult { 1083 auto origAttrs = getAttrs(); 1084 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 1085 if (origAttrs[i].first == name) { 1086 // Handle the simple case of removing the only attribute in the list. 1087 if (e == 1) { 1088 attrs = nullptr; 1089 return RemoveResult::Removed; 1090 } 1091 1092 SmallVector<NamedAttribute, 8> newAttrs; 1093 newAttrs.reserve(origAttrs.size() - 1); 1094 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 1095 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 1096 attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext()); 1097 return RemoveResult::Removed; 1098 } 1099 } 1100 return RemoveResult::NotFound; 1101 }