github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/Attributes.cpp (about)

     1  //===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/IR/Attributes.h"
    19  #include "AttributeDetail.h"
    20  #include "mlir/IR/AffineMap.h"
    21  #include "mlir/IR/Diagnostics.h"
    22  #include "mlir/IR/Dialect.h"
    23  #include "mlir/IR/Function.h"
    24  #include "mlir/IR/IntegerSet.h"
    25  #include "mlir/IR/Types.h"
    26  #include "llvm/ADT/Sequence.h"
    27  #include "llvm/ADT/Twine.h"
    28  
    29  using namespace mlir;
    30  using namespace mlir::detail;
    31  
    32  //===----------------------------------------------------------------------===//
    33  // AttributeStorage
    34  //===----------------------------------------------------------------------===//
    35  
    36  AttributeStorage::AttributeStorage(Type type)
    37      : type(type.getAsOpaquePointer()) {}
    38  AttributeStorage::AttributeStorage() : type(nullptr) {}
    39  
    40  Type AttributeStorage::getType() const {
    41    return Type::getFromOpaquePointer(type);
    42  }
    43  void AttributeStorage::setType(Type newType) {
    44    type = newType.getAsOpaquePointer();
    45  }
    46  
    47  //===----------------------------------------------------------------------===//
    48  // Attribute
    49  //===----------------------------------------------------------------------===//
    50  
    51  /// Return the type of this attribute.
    52  Type Attribute::getType() const { return impl->getType(); }
    53  
    54  /// Return the context this attribute belongs to.
    55  MLIRContext *Attribute::getContext() const { return getType().getContext(); }
    56  
    57  /// Get the dialect this attribute is registered to.
    58  Dialect &Attribute::getDialect() const { return impl->getDialect(); }
    59  
    60  //===----------------------------------------------------------------------===//
    61  // AffineMapAttr
    62  //===----------------------------------------------------------------------===//
    63  
    64  AffineMapAttr AffineMapAttr::get(AffineMap value) {
    65    return Base::get(value.getContext(), StandardAttributes::AffineMap, value);
    66  }
    67  
    68  AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
    69  
    70  //===----------------------------------------------------------------------===//
    71  // ArrayAttr
    72  //===----------------------------------------------------------------------===//
    73  
    74  ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
    75    return Base::get(context, StandardAttributes::Array, value);
    76  }
    77  
    78  ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
    79  
    80  //===----------------------------------------------------------------------===//
    81  // BoolAttr
    82  //===----------------------------------------------------------------------===//
    83  
    84  bool BoolAttr::getValue() const { return getImpl()->value; }
    85  
    86  //===----------------------------------------------------------------------===//
    87  // DictionaryAttr
    88  //===----------------------------------------------------------------------===//
    89  
    90  /// Perform a three-way comparison between the names of the specified
    91  /// NamedAttributes.
    92  static int compareNamedAttributes(const NamedAttribute *lhs,
    93                                    const NamedAttribute *rhs) {
    94    return lhs->first.str().compare(rhs->first.str());
    95  }
    96  
    97  DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
    98                                     MLIRContext *context) {
    99    assert(llvm::all_of(value,
   100                        [](const NamedAttribute &attr) { return attr.second; }) &&
   101           "value cannot have null entries");
   102  
   103    // We need to sort the element list to canonicalize it, but we also don't want
   104    // to do a ton of work in the super common case where the element list is
   105    // already sorted.
   106    SmallVector<NamedAttribute, 8> storage;
   107    switch (value.size()) {
   108    case 0:
   109      break;
   110    case 1:
   111      // A single element is already sorted.
   112      break;
   113    case 2:
   114      assert(value[0].first != value[1].first &&
   115             "DictionaryAttr element names must be unique");
   116  
   117      // Don't invoke a general sort for two element case.
   118      if (value[0].first.strref() > value[1].first.strref()) {
   119        storage.push_back(value[1]);
   120        storage.push_back(value[0]);
   121        value = storage;
   122      }
   123      break;
   124    default:
   125      // Check to see they are sorted already.
   126      bool isSorted = true;
   127      for (unsigned i = 0, e = value.size() - 1; i != e; ++i) {
   128        if (value[i].first.strref() > value[i + 1].first.strref()) {
   129          isSorted = false;
   130          break;
   131        }
   132      }
   133      // If not, do a general sort.
   134      if (!isSorted) {
   135        storage.append(value.begin(), value.end());
   136        llvm::array_pod_sort(storage.begin(), storage.end(),
   137                             compareNamedAttributes);
   138        value = storage;
   139      }
   140  
   141      // Ensure that the attribute elements are unique.
   142      assert(std::adjacent_find(value.begin(), value.end(),
   143                                [](NamedAttribute l, NamedAttribute r) {
   144                                  return l.first == r.first;
   145                                }) == value.end() &&
   146             "DictionaryAttr element names must be unique");
   147    }
   148  
   149    return Base::get(context, StandardAttributes::Dictionary, value);
   150  }
   151  
   152  ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
   153    return getImpl()->getElements();
   154  }
   155  
   156  /// Return the specified attribute if present, null otherwise.
   157  Attribute DictionaryAttr::get(StringRef name) const {
   158    for (auto elt : getValue())
   159      if (elt.first.is(name))
   160        return elt.second;
   161    return nullptr;
   162  }
   163  Attribute DictionaryAttr::get(Identifier name) const {
   164    for (auto elt : getValue())
   165      if (elt.first == name)
   166        return elt.second;
   167    return nullptr;
   168  }
   169  
   170  DictionaryAttr::iterator DictionaryAttr::begin() const {
   171    return getValue().begin();
   172  }
   173  DictionaryAttr::iterator DictionaryAttr::end() const {
   174    return getValue().end();
   175  }
   176  size_t DictionaryAttr::size() const { return getValue().size(); }
   177  
   178  //===----------------------------------------------------------------------===//
   179  // FloatAttr
   180  //===----------------------------------------------------------------------===//
   181  
   182  FloatAttr FloatAttr::get(Type type, double value) {
   183    return Base::get(type.getContext(), StandardAttributes::Float, type, value);
   184  }
   185  
   186  FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
   187    return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
   188                            type, value);
   189  }
   190  
   191  FloatAttr FloatAttr::get(Type type, const APFloat &value) {
   192    return Base::get(type.getContext(), StandardAttributes::Float, type, value);
   193  }
   194  
   195  FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
   196    return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
   197                            type, value);
   198  }
   199  
   200  APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
   201  
   202  double FloatAttr::getValueAsDouble() const {
   203    return getValueAsDouble(getValue());
   204  }
   205  double FloatAttr::getValueAsDouble(APFloat value) {
   206    if (&value.getSemantics() != &APFloat::IEEEdouble()) {
   207      bool losesInfo = false;
   208      value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
   209                    &losesInfo);
   210    }
   211    return value.convertToDouble();
   212  }
   213  
   214  /// Verify construction invariants.
   215  static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
   216                                                 Type type) {
   217    if (!type.isa<FloatType>()) {
   218      if (loc)
   219        emitError(*loc, "expected floating point type");
   220      return failure();
   221    }
   222    return success();
   223  }
   224  
   225  LogicalResult FloatAttr::verifyConstructionInvariants(
   226      llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
   227    return verifyFloatTypeInvariants(loc, type);
   228  }
   229  
   230  LogicalResult
   231  FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
   232                                          MLIRContext *ctx, Type type,
   233                                          const APFloat &value) {
   234    // Verify that the type is correct.
   235    if (failed(verifyFloatTypeInvariants(loc, type)))
   236      return failure();
   237  
   238    // Verify that the type semantics match that of the value.
   239    if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
   240      if (loc)
   241        emitError(*loc,
   242                  "FloatAttr type doesn't match the type implied by its value");
   243      return failure();
   244    }
   245    return success();
   246  }
   247  
   248  //===----------------------------------------------------------------------===//
   249  // SymbolRefAttr
   250  //===----------------------------------------------------------------------===//
   251  
   252  SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
   253    return Base::get(ctx, StandardAttributes::SymbolRef, value,
   254                     NoneType::get(ctx));
   255  }
   256  
   257  StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
   258  
   259  //===----------------------------------------------------------------------===//
   260  // IntegerAttr
   261  //===----------------------------------------------------------------------===//
   262  
   263  IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
   264    return Base::get(type.getContext(), StandardAttributes::Integer, type, value);
   265  }
   266  
   267  IntegerAttr IntegerAttr::get(Type type, int64_t value) {
   268    // This uses 64 bit APInts by default for index type.
   269    if (type.isIndex())
   270      return get(type, APInt(64, value));
   271  
   272    auto intType = type.cast<IntegerType>();
   273    return get(type, APInt(intType.getWidth(), value));
   274  }
   275  
   276  APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
   277  
   278  int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
   279  
   280  //===----------------------------------------------------------------------===//
   281  // IntegerSetAttr
   282  //===----------------------------------------------------------------------===//
   283  
   284  IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
   285    return Base::get(value.getConstraint(0).getContext(),
   286                     StandardAttributes::IntegerSet, value);
   287  }
   288  
   289  IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
   290  
   291  //===----------------------------------------------------------------------===//
   292  // OpaqueAttr
   293  //===----------------------------------------------------------------------===//
   294  
   295  OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
   296                             MLIRContext *context) {
   297    return Base::get(context, StandardAttributes::Opaque, dialect, attrData,
   298                     type);
   299  }
   300  
   301  OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
   302                                    Type type, Location location) {
   303    return Base::getChecked(location, type.getContext(),
   304                            StandardAttributes::Opaque, dialect, attrData, type);
   305  }
   306  
   307  /// Returns the dialect namespace of the opaque attribute.
   308  Identifier OpaqueAttr::getDialectNamespace() const {
   309    return getImpl()->dialectNamespace;
   310  }
   311  
   312  /// Returns the raw attribute data of the opaque attribute.
   313  StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
   314  
   315  /// Verify the construction of an opaque attribute.
   316  LogicalResult OpaqueAttr::verifyConstructionInvariants(
   317      llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
   318      StringRef attrData, Type type) {
   319    if (!Dialect::isValidNamespace(dialect.strref())) {
   320      if (loc)
   321        emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
   322      return failure();
   323    }
   324    return success();
   325  }
   326  
   327  //===----------------------------------------------------------------------===//
   328  // StringAttr
   329  //===----------------------------------------------------------------------===//
   330  
   331  StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
   332    return get(bytes, NoneType::get(context));
   333  }
   334  
   335  /// Get an instance of a StringAttr with the given string and Type.
   336  StringAttr StringAttr::get(StringRef bytes, Type type) {
   337    return Base::get(type.getContext(), StandardAttributes::String, bytes, type);
   338  }
   339  
   340  StringRef StringAttr::getValue() const { return getImpl()->value; }
   341  
   342  //===----------------------------------------------------------------------===//
   343  // TypeAttr
   344  //===----------------------------------------------------------------------===//
   345  
   346  TypeAttr TypeAttr::get(Type value) {
   347    return Base::get(value.getContext(), StandardAttributes::Type, value);
   348  }
   349  
   350  Type TypeAttr::getValue() const { return getImpl()->value; }
   351  
   352  //===----------------------------------------------------------------------===//
   353  // ElementsAttr
   354  //===----------------------------------------------------------------------===//
   355  
   356  ShapedType ElementsAttr::getType() const {
   357    return Attribute::getType().cast<ShapedType>();
   358  }
   359  
   360  /// Returns the number of elements held by this attribute.
   361  int64_t ElementsAttr::getNumElements() const {
   362    return getType().getNumElements();
   363  }
   364  
   365  /// Return the value at the given index. If index does not refer to a valid
   366  /// element, then a null attribute is returned.
   367  Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   368    switch (getKind()) {
   369    case StandardAttributes::DenseElements:
   370      return cast<DenseElementsAttr>().getValue(index);
   371    case StandardAttributes::OpaqueElements:
   372      return cast<OpaqueElementsAttr>().getValue(index);
   373    case StandardAttributes::SparseElements:
   374      return cast<SparseElementsAttr>().getValue(index);
   375    default:
   376      llvm_unreachable("unknown ElementsAttr kind");
   377    }
   378  }
   379  
   380  /// Return if the given 'index' refers to a valid element in this attribute.
   381  bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
   382    auto type = getType();
   383  
   384    // Verify that the rank of the indices matches the held type.
   385    auto rank = type.getRank();
   386    if (rank != static_cast<int64_t>(index.size()))
   387      return false;
   388  
   389    // Verify that all of the indices are within the shape dimensions.
   390    auto shape = type.getShape();
   391    return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
   392      return static_cast<int64_t>(index[i]) < shape[i];
   393    });
   394  }
   395  
   396  ElementsAttr ElementsAttr::mapValues(
   397      Type newElementType,
   398      llvm::function_ref<APInt(const APInt &)> mapping) const {
   399    switch (getKind()) {
   400    case StandardAttributes::DenseElements:
   401      return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
   402    default:
   403      llvm_unreachable("unsupported ElementsAttr subtype");
   404    }
   405  }
   406  
   407  ElementsAttr ElementsAttr::mapValues(
   408      Type newElementType,
   409      llvm::function_ref<APInt(const APFloat &)> mapping) const {
   410    switch (getKind()) {
   411    case StandardAttributes::DenseElements:
   412      return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
   413    default:
   414      llvm_unreachable("unsupported ElementsAttr subtype");
   415    }
   416  }
   417  
   418  /// Returns the 1 dimenional flattened row-major index from the given
   419  /// multi-dimensional index.
   420  uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
   421    assert(isValidIndex(index) && "expected valid multi-dimensional index");
   422    auto type = getType();
   423  
   424    // Reduce the provided multidimensional index into a flattended 1D row-major
   425    // index.
   426    auto rank = type.getRank();
   427    auto shape = type.getShape();
   428    uint64_t valueIndex = 0;
   429    uint64_t dimMultiplier = 1;
   430    for (int i = rank - 1; i >= 0; --i) {
   431      valueIndex += index[i] * dimMultiplier;
   432      dimMultiplier *= shape[i];
   433    }
   434    return valueIndex;
   435  }
   436  
   437  //===----------------------------------------------------------------------===//
   438  // DenseElementAttr Utilities
   439  //===----------------------------------------------------------------------===//
   440  
   441  static size_t getDenseElementBitwidth(Type eltType) {
   442    // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
   443    // with double semantics.
   444    return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
   445  }
   446  
   447  /// Get the bitwidth of a dense element type within the buffer.
   448  /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
   449  static size_t getDenseElementStorageWidth(size_t origWidth) {
   450    return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
   451  }
   452  
   453  /// Set a bit to a specific value.
   454  static void setBit(char *rawData, size_t bitPos, bool value) {
   455    if (value)
   456      rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
   457    else
   458      rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
   459  }
   460  
   461  /// Return the value of the specified bit.
   462  static bool getBit(const char *rawData, size_t bitPos) {
   463    return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
   464  }
   465  
   466  /// Writes value to the bit position `bitPos` in array `rawData`.
   467  static void writeBits(char *rawData, size_t bitPos, APInt value) {
   468    size_t bitWidth = value.getBitWidth();
   469  
   470    // If the bitwidth is 1 we just toggle the specific bit.
   471    if (bitWidth == 1)
   472      return setBit(rawData, bitPos, value.isOneValue());
   473  
   474    // Otherwise, the bit position is guaranteed to be byte aligned.
   475    assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
   476    std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
   477                llvm::divideCeil(bitWidth, CHAR_BIT),
   478                rawData + (bitPos / CHAR_BIT));
   479  }
   480  
   481  /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
   482  /// `rawData`.
   483  static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
   484    // Handle a boolean bit position.
   485    if (bitWidth == 1)
   486      return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
   487  
   488    // Otherwise, the bit position must be 8-bit aligned.
   489    assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
   490    APInt result(bitWidth, 0);
   491    std::copy_n(
   492        rawData + (bitPos / CHAR_BIT), llvm::divideCeil(bitWidth, CHAR_BIT),
   493        const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
   494    return result;
   495  }
   496  
   497  /// Returns if 'values' corresponds to a splat, i.e. one element, or has the
   498  /// same element count as 'type'.
   499  template <typename Values>
   500  static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
   501    return (values.size() == 1) ||
   502           (type.getNumElements() == static_cast<int64_t>(values.size()));
   503  }
   504  
   505  //===----------------------------------------------------------------------===//
   506  // DenseElementAttr Iterators
   507  //===----------------------------------------------------------------------===//
   508  
   509  /// Constructs a new iterator.
   510  DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
   511      DenseElementsAttr attr, size_t index)
   512      : indexed_accessor_iterator<AttributeElementIterator, const void *,
   513                                  Attribute, Attribute, Attribute>(
   514            attr.getAsOpaquePointer(), index) {}
   515  
   516  /// Accesses the Attribute value at this iterator position.
   517  Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
   518    auto owner = getFromOpaquePointer(object).cast<DenseElementsAttr>();
   519    Type eltTy = owner.getType().getElementType();
   520    if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
   521      if (intEltTy.getWidth() == 1)
   522        return BoolAttr::get((*IntElementIterator(owner, index)).isOneValue(),
   523                             owner.getContext());
   524      return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
   525    }
   526    if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
   527      IntElementIterator intIt(owner, index);
   528      FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
   529      return FloatAttr::get(eltTy, *floatIt);
   530    }
   531    llvm_unreachable("unexpected element type");
   532  }
   533  
   534  /// Constructs a new iterator.
   535  DenseElementsAttr::BoolElementIterator::BoolElementIterator(
   536      DenseElementsAttr attr, size_t dataIndex)
   537      : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
   538            attr.getRawData().data(), attr.isSplat(), dataIndex) {}
   539  
   540  /// Accesses the bool value at this iterator position.
   541  bool DenseElementsAttr::BoolElementIterator::operator*() const {
   542    return getBit(getData(), getDataIndex());
   543  }
   544  
   545  /// Constructs a new iterator.
   546  DenseElementsAttr::IntElementIterator::IntElementIterator(
   547      DenseElementsAttr attr, size_t dataIndex)
   548      : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
   549            attr.getRawData().data(), attr.isSplat(), dataIndex),
   550        bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
   551  
   552  /// Accesses the raw APInt value at this iterator position.
   553  APInt DenseElementsAttr::IntElementIterator::operator*() const {
   554    return readBits(getData(),
   555                    getDataIndex() * getDenseElementStorageWidth(bitWidth),
   556                    bitWidth);
   557  }
   558  
   559  DenseElementsAttr::FloatElementIterator::FloatElementIterator(
   560      const llvm::fltSemantics &smt, IntElementIterator it)
   561      : llvm::mapped_iterator<IntElementIterator,
   562                              std::function<APFloat(const APInt &)>>(
   563            it, [&](const APInt &val) { return APFloat(smt, val); }) {}
   564  
   565  //===----------------------------------------------------------------------===//
   566  // DenseElementsAttr
   567  //===----------------------------------------------------------------------===//
   568  
   569  DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   570                                           ArrayRef<Attribute> values) {
   571    assert(type.getElementType().isIntOrFloat() &&
   572           "expected int or float element type");
   573    assert(hasSameElementsOrSplat(type, values));
   574  
   575    auto eltType = type.getElementType();
   576    size_t bitWidth = getDenseElementBitwidth(eltType);
   577    size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
   578  
   579    // Compress the attribute values into a character buffer.
   580    SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
   581                              values.size());
   582    APInt intVal;
   583    for (unsigned i = 0, e = values.size(); i < e; ++i) {
   584      assert(eltType == values[i].getType() &&
   585             "expected attribute value to have element type");
   586  
   587      switch (eltType.getKind()) {
   588      case StandardTypes::BF16:
   589      case StandardTypes::F16:
   590      case StandardTypes::F32:
   591      case StandardTypes::F64:
   592        intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
   593        break;
   594      case StandardTypes::Integer:
   595        intVal = values[i].isa<BoolAttr>()
   596                     ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
   597                     : values[i].cast<IntegerAttr>().getValue();
   598        break;
   599      default:
   600        llvm_unreachable("unexpected element type");
   601      }
   602      assert(intVal.getBitWidth() == bitWidth &&
   603             "expected value to have same bitwidth as element type");
   604      writeBits(data.data(), i * storageBitWidth, intVal);
   605    }
   606    return getRaw(type, data, /*isSplat=*/(values.size() == 1));
   607  }
   608  
   609  DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   610                                           ArrayRef<bool> values) {
   611    assert(hasSameElementsOrSplat(type, values));
   612    assert(type.getElementType().isInteger(1));
   613  
   614    std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
   615    for (int i = 0, e = values.size(); i != e; ++i)
   616      setBit(buff.data(), i, values[i]);
   617    return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
   618  }
   619  
   620  /// Constructs a dense integer elements attribute from an array of APInt
   621  /// values. Each APInt value is expected to have the same bitwidth as the
   622  /// element type of 'type'.
   623  DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   624                                           ArrayRef<APInt> values) {
   625    assert(type.getElementType().isa<IntegerType>());
   626    return getRaw(type, values);
   627  }
   628  
   629  // Constructs a dense float elements attribute from an array of APFloat
   630  // values. Each APFloat value is expected to have the same bitwidth as the
   631  // element type of 'type'.
   632  DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   633                                           ArrayRef<APFloat> values) {
   634    assert(type.getElementType().isa<FloatType>());
   635  
   636    // Convert the APFloat values to APInt and create a dense elements attribute.
   637    std::vector<APInt> intValues(values.size());
   638    for (unsigned i = 0, e = values.size(); i != e; ++i)
   639      intValues[i] = values[i].bitcastToAPInt();
   640    return getRaw(type, intValues);
   641  }
   642  
   643  // Constructs a dense elements attribute from an array of raw APInt values.
   644  // Each APInt value is expected to have the same bitwidth as the element type
   645  // of 'type'.
   646  DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
   647                                              ArrayRef<APInt> values) {
   648    assert(hasSameElementsOrSplat(type, values));
   649  
   650    size_t bitWidth = getDenseElementBitwidth(type.getElementType());
   651    size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
   652    std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
   653                                  values.size());
   654    for (unsigned i = 0, e = values.size(); i != e; ++i) {
   655      assert(values[i].getBitWidth() == bitWidth);
   656      writeBits(elementData.data(), i * storageBitWidth, values[i]);
   657    }
   658    return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
   659  }
   660  
   661  DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
   662                                              ArrayRef<char> data, bool isSplat) {
   663    assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
   664           "type must be ranked tensor or vector");
   665    assert(type.hasStaticShape() && "type must have static shape");
   666    return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
   667                     data, isSplat);
   668  }
   669  
   670  /// Check the information for a c++ data type, check if this type is valid for
   671  /// the current attribute. This method is used to verify specific type
   672  /// invariants that the templatized 'getValues' method cannot.
   673  static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize,
   674                                bool isInt) {
   675    // Make sure that the data element size is the same as the type element width.
   676    if ((dataEltSize * CHAR_BIT) != type.getElementTypeBitWidth())
   677      return false;
   678  
   679    // Check that the element type is valid.
   680    return isInt ? type.getElementType().isa<IntegerType>()
   681                 : type.getElementType().isa<FloatType>();
   682  }
   683  
   684  /// Overload of the 'getRaw' method that asserts that the given type is of
   685  /// integer type. This method is used to verify type invariants that the
   686  /// templatized 'get' method cannot.
   687  DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
   688                                                        ArrayRef<char> data,
   689                                                        int64_t dataEltSize,
   690                                                        bool isInt) {
   691    assert(::isValidIntOrFloat(type, dataEltSize, isInt));
   692  
   693    int64_t numElements = data.size() / dataEltSize;
   694    assert(numElements == 1 || numElements == type.getNumElements());
   695    return getRaw(type, data, /*isSplat=*/numElements == 1);
   696  }
   697  
   698  /// A method used to verify specific type invariants that the templatized 'get'
   699  /// method cannot.
   700  bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
   701                                            bool isInt) const {
   702    return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
   703  }
   704  
   705  /// Return the raw storage data held by this attribute.
   706  ArrayRef<char> DenseElementsAttr::getRawData() const {
   707    return static_cast<ImplType *>(impl)->data;
   708  }
   709  
   710  /// Returns if this attribute corresponds to a splat, i.e. if all element
   711  /// values are the same.
   712  bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
   713  
   714  /// Return the held element values as a range of Attributes.
   715  auto DenseElementsAttr::getAttributeValues() const
   716      -> llvm::iterator_range<AttributeElementIterator> {
   717    return {attr_value_begin(), attr_value_end()};
   718  }
   719  auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
   720    return AttributeElementIterator(*this, 0);
   721  }
   722  auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
   723    return AttributeElementIterator(*this, getNumElements());
   724  }
   725  
   726  /// Return the held element values as a range of bool. The element type of
   727  /// this attribute must be of integer type of bitwidth 1.
   728  auto DenseElementsAttr::getBoolValues() const
   729      -> llvm::iterator_range<BoolElementIterator> {
   730    auto eltType = getType().getElementType().dyn_cast<IntegerType>();
   731    assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
   732    (void)eltType;
   733    return {BoolElementIterator(*this, 0),
   734            BoolElementIterator(*this, getNumElements())};
   735  }
   736  
   737  /// Return the held element values as a range of APInts. The element type of
   738  /// this attribute must be of integer type.
   739  auto DenseElementsAttr::getIntValues() const
   740      -> llvm::iterator_range<IntElementIterator> {
   741    assert(getType().getElementType().isa<IntegerType>() &&
   742           "expected integer type");
   743    return {raw_int_begin(), raw_int_end()};
   744  }
   745  auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
   746    assert(getType().getElementType().isa<IntegerType>() &&
   747           "expected integer type");
   748    return raw_int_begin();
   749  }
   750  auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
   751    assert(getType().getElementType().isa<IntegerType>() &&
   752           "expected integer type");
   753    return raw_int_end();
   754  }
   755  
   756  /// Return the held element values as a range of APFloat. The element type of
   757  /// this attribute must be of float type.
   758  auto DenseElementsAttr::getFloatValues() const
   759      -> llvm::iterator_range<FloatElementIterator> {
   760    auto elementType = getType().getElementType().cast<FloatType>();
   761    assert(elementType.isa<FloatType>() && "expected float type");
   762    const auto &elementSemantics = elementType.getFloatSemantics();
   763    return {FloatElementIterator(elementSemantics, raw_int_begin()),
   764            FloatElementIterator(elementSemantics, raw_int_end())};
   765  }
   766  auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
   767    return getFloatValues().begin();
   768  }
   769  auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
   770    return getFloatValues().end();
   771  }
   772  
   773  /// Return a new DenseElementsAttr that has the same data as the current
   774  /// attribute, but has been reshaped to 'newType'. The new type must have the
   775  /// same total number of elements as well as element type.
   776  DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
   777    ShapedType curType = getType();
   778    if (curType == newType)
   779      return *this;
   780  
   781    (void)curType;
   782    assert(newType.getElementType() == curType.getElementType() &&
   783           "expected the same element type");
   784    assert(newType.getNumElements() == curType.getNumElements() &&
   785           "expected the same number of elements");
   786    return getRaw(newType, getRawData(), isSplat());
   787  }
   788  
   789  DenseElementsAttr DenseElementsAttr::mapValues(
   790      Type newElementType,
   791      llvm::function_ref<APInt(const APInt &)> mapping) const {
   792    return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
   793  }
   794  
   795  DenseElementsAttr DenseElementsAttr::mapValues(
   796      Type newElementType,
   797      llvm::function_ref<APInt(const APFloat &)> mapping) const {
   798    return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
   799  }
   800  
   801  //===----------------------------------------------------------------------===//
   802  // DenseFPElementsAttr
   803  //===----------------------------------------------------------------------===//
   804  
   805  template <typename Fn, typename Attr>
   806  static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
   807                                  Type newElementType,
   808                                  llvm::SmallVectorImpl<char> &data) {
   809    size_t bitWidth = getDenseElementBitwidth(newElementType);
   810    size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
   811  
   812    ShapedType newArrayType;
   813    if (inType.isa<RankedTensorType>())
   814      newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
   815    else if (inType.isa<UnrankedTensorType>())
   816      newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
   817    else if (inType.isa<VectorType>())
   818      newArrayType = VectorType::get(inType.getShape(), newElementType);
   819    else
   820      assert(newArrayType && "Unhandled tensor type");
   821  
   822    size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
   823    data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
   824  
   825    // Functor used to process a single element value of the attribute.
   826    auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
   827      auto newInt = mapping(value);
   828      assert(newInt.getBitWidth() == bitWidth);
   829      writeBits(data.data(), index * storageBitWidth, newInt);
   830    };
   831  
   832    // Check for the splat case.
   833    if (attr.isSplat()) {
   834      processElt(*attr.begin(), /*index=*/0);
   835      return newArrayType;
   836    }
   837  
   838    // Otherwise, process all of the element values.
   839    uint64_t elementIdx = 0;
   840    for (auto value : attr)
   841      processElt(value, elementIdx++);
   842    return newArrayType;
   843  }
   844  
   845  DenseElementsAttr DenseFPElementsAttr::mapValues(
   846      Type newElementType,
   847      llvm::function_ref<APInt(const APFloat &)> mapping) const {
   848    llvm::SmallVector<char, 8> elementData;
   849    auto newArrayType =
   850        mappingHelper(mapping, *this, getType(), newElementType, elementData);
   851  
   852    return getRaw(newArrayType, elementData, isSplat());
   853  }
   854  
   855  /// Method for supporting type inquiry through isa, cast and dyn_cast.
   856  bool DenseFPElementsAttr::classof(Attribute attr) {
   857    return attr.isa<DenseElementsAttr>() &&
   858           attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
   859  }
   860  
   861  //===----------------------------------------------------------------------===//
   862  // DenseIntElementsAttr
   863  //===----------------------------------------------------------------------===//
   864  
   865  DenseElementsAttr DenseIntElementsAttr::mapValues(
   866      Type newElementType,
   867      llvm::function_ref<APInt(const APInt &)> mapping) const {
   868    llvm::SmallVector<char, 8> elementData;
   869    auto newArrayType =
   870        mappingHelper(mapping, *this, getType(), newElementType, elementData);
   871  
   872    return getRaw(newArrayType, elementData, isSplat());
   873  }
   874  
   875  /// Method for supporting type inquiry through isa, cast and dyn_cast.
   876  bool DenseIntElementsAttr::classof(Attribute attr) {
   877    return attr.isa<DenseElementsAttr>() &&
   878           attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
   879  }
   880  
   881  //===----------------------------------------------------------------------===//
   882  // OpaqueElementsAttr
   883  //===----------------------------------------------------------------------===//
   884  
   885  OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
   886                                             StringRef bytes) {
   887    assert(TensorType::isValidElementType(type.getElementType()) &&
   888           "Input element type should be a valid tensor element type");
   889    return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type,
   890                     dialect, bytes);
   891  }
   892  
   893  StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
   894  
   895  /// Return the value at the given index. If index does not refer to a valid
   896  /// element, then a null attribute is returned.
   897  Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   898    assert(isValidIndex(index) && "expected valid multi-dimensional index");
   899    if (Dialect *dialect = getDialect())
   900      return dialect->extractElementHook(*this, index);
   901    return Attribute();
   902  }
   903  
   904  Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
   905  
   906  bool OpaqueElementsAttr::decode(ElementsAttr &result) {
   907    if (auto *d = getDialect())
   908      return d->decodeHook(*this, result);
   909    return true;
   910  }
   911  
   912  //===----------------------------------------------------------------------===//
   913  // SparseElementsAttr
   914  //===----------------------------------------------------------------------===//
   915  
   916  SparseElementsAttr SparseElementsAttr::get(ShapedType type,
   917                                             DenseElementsAttr indices,
   918                                             DenseElementsAttr values) {
   919    assert(indices.getType().getElementType().isInteger(64) &&
   920           "expected sparse indices to be 64-bit integer values");
   921    assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
   922           "type must be ranked tensor or vector");
   923    assert(type.hasStaticShape() && "type must have static shape");
   924    return Base::get(type.getContext(), StandardAttributes::SparseElements, type,
   925                     indices.cast<DenseIntElementsAttr>(), values);
   926  }
   927  
   928  DenseIntElementsAttr SparseElementsAttr::getIndices() const {
   929    return getImpl()->indices;
   930  }
   931  
   932  DenseElementsAttr SparseElementsAttr::getValues() const {
   933    return getImpl()->values;
   934  }
   935  
   936  /// Return the value of the element at the given index.
   937  Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   938    assert(isValidIndex(index) && "expected valid multi-dimensional index");
   939    auto type = getType();
   940  
   941    // The sparse indices are 64-bit integers, so we can reinterpret the raw data
   942    // as a 1-D index array.
   943    auto sparseIndices = getIndices();
   944    auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
   945  
   946    // Check to see if the indices are a splat.
   947    if (sparseIndices.isSplat()) {
   948      // If the index is also not a splat of the index value, we know that the
   949      // value is zero.
   950      auto splatIndex = *sparseIndexValues.begin();
   951      if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
   952        return getZeroAttr();
   953  
   954      // If the indices are a splat, we also expect the values to be a splat.
   955      assert(getValues().isSplat() && "expected splat values");
   956      return getValues().getSplatValue();
   957    }
   958  
   959    // Build a mapping between known indices and the offset of the stored element.
   960    llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
   961    auto numSparseIndices = sparseIndices.getType().getDimSize(0);
   962    size_t rank = type.getRank();
   963    for (size_t i = 0, e = numSparseIndices; i != e; ++i)
   964      mappedIndices.try_emplace(
   965          {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
   966  
   967    // Look for the provided index key within the mapped indices. If the provided
   968    // index is not found, then return a zero attribute.
   969    auto it = mappedIndices.find(index);
   970    if (it == mappedIndices.end())
   971      return getZeroAttr();
   972  
   973    // Otherwise, return the held sparse value element.
   974    return getValues().getValue(it->second);
   975  }
   976  
   977  /// Get a zero APFloat for the given sparse attribute.
   978  APFloat SparseElementsAttr::getZeroAPFloat() const {
   979    auto eltType = getType().getElementType().cast<FloatType>();
   980    return APFloat(eltType.getFloatSemantics());
   981  }
   982  
   983  /// Get a zero APInt for the given sparse attribute.
   984  APInt SparseElementsAttr::getZeroAPInt() const {
   985    auto eltType = getType().getElementType().cast<IntegerType>();
   986    return APInt::getNullValue(eltType.getWidth());
   987  }
   988  
   989  /// Get a zero attribute for the given attribute type.
   990  Attribute SparseElementsAttr::getZeroAttr() const {
   991    auto eltType = getType().getElementType();
   992  
   993    // Handle floating point elements.
   994    if (eltType.isa<FloatType>())
   995      return FloatAttr::get(eltType, 0);
   996  
   997    // Otherwise, this is an integer.
   998    auto intEltTy = eltType.cast<IntegerType>();
   999    if (intEltTy.getWidth() == 1)
  1000      return BoolAttr::get(false, eltType.getContext());
  1001    return IntegerAttr::get(eltType, 0);
  1002  }
  1003  
  1004  /// Flatten, and return, all of the sparse indices in this attribute in
  1005  /// row-major order.
  1006  std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
  1007    std::vector<ptrdiff_t> flatSparseIndices;
  1008  
  1009    // The sparse indices are 64-bit integers, so we can reinterpret the raw data
  1010    // as a 1-D index array.
  1011    auto sparseIndices = getIndices();
  1012    auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
  1013    if (sparseIndices.isSplat()) {
  1014      SmallVector<uint64_t, 8> indices(getType().getRank(),
  1015                                       *sparseIndexValues.begin());
  1016      flatSparseIndices.push_back(getFlattenedIndex(indices));
  1017      return flatSparseIndices;
  1018    }
  1019  
  1020    // Otherwise, reinterpret each index as an ArrayRef when flattening.
  1021    auto numSparseIndices = sparseIndices.getType().getDimSize(0);
  1022    size_t rank = getType().getRank();
  1023    for (size_t i = 0, e = numSparseIndices; i != e; ++i)
  1024      flatSparseIndices.push_back(getFlattenedIndex(
  1025          {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
  1026    return flatSparseIndices;
  1027  }
  1028  
  1029  //===----------------------------------------------------------------------===//
  1030  // NamedAttributeList
  1031  //===----------------------------------------------------------------------===//
  1032  
  1033  NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
  1034    setAttrs(attributes);
  1035  }
  1036  
  1037  ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
  1038    return attrs ? attrs.getValue() : llvm::None;
  1039  }
  1040  
  1041  /// Replace the held attributes with ones provided in 'newAttrs'.
  1042  void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
  1043    // Don't create an attribute list if there are no attributes.
  1044    if (attributes.empty())
  1045      attrs = nullptr;
  1046    else
  1047      attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
  1048  }
  1049  
  1050  /// Return the specified attribute if present, null otherwise.
  1051  Attribute NamedAttributeList::get(StringRef name) const {
  1052    return attrs ? attrs.get(name) : nullptr;
  1053  }
  1054  
  1055  /// Return the specified attribute if present, null otherwise.
  1056  Attribute NamedAttributeList::get(Identifier name) const {
  1057    return attrs ? attrs.get(name) : nullptr;
  1058  }
  1059  
  1060  /// If the an attribute exists with the specified name, change it to the new
  1061  /// value.  Otherwise, add a new attribute with the specified name/value.
  1062  void NamedAttributeList::set(Identifier name, Attribute value) {
  1063    assert(value && "attributes may never be null");
  1064  
  1065    // If we already have this attribute, replace it.
  1066    auto origAttrs = getAttrs();
  1067    SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
  1068    for (auto &elt : newAttrs)
  1069      if (elt.first == name) {
  1070        elt.second = value;
  1071        attrs = DictionaryAttr::get(newAttrs, value.getContext());
  1072        return;
  1073      }
  1074  
  1075    // Otherwise, add it.
  1076    newAttrs.push_back({name, value});
  1077    attrs = DictionaryAttr::get(newAttrs, value.getContext());
  1078  }
  1079  
  1080  /// Remove the attribute with the specified name if it exists.  The return
  1081  /// value indicates whether the attribute was present or not.
  1082  auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
  1083    auto origAttrs = getAttrs();
  1084    for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
  1085      if (origAttrs[i].first == name) {
  1086        // Handle the simple case of removing the only attribute in the list.
  1087        if (e == 1) {
  1088          attrs = nullptr;
  1089          return RemoveResult::Removed;
  1090        }
  1091  
  1092        SmallVector<NamedAttribute, 8> newAttrs;
  1093        newAttrs.reserve(origAttrs.size() - 1);
  1094        newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
  1095        newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
  1096        attrs = DictionaryAttr::get(newAttrs, newAttrs[0].second.getContext());
  1097        return RemoveResult::Removed;
  1098      }
  1099    }
  1100    return RemoveResult::NotFound;
  1101  }