github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/Builders.cpp (about)

     1  //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/IR/Builders.h"
    19  #include "mlir/IR/AffineExpr.h"
    20  #include "mlir/IR/AffineMap.h"
    21  #include "mlir/IR/Attributes.h"
    22  #include "mlir/IR/IntegerSet.h"
    23  #include "mlir/IR/Location.h"
    24  #include "mlir/IR/Module.h"
    25  #include "mlir/IR/StandardTypes.h"
    26  #include "mlir/Support/Functional.h"
    27  using namespace mlir;
    28  
    29  Builder::Builder(ModuleOp module) : context(module.getContext()) {}
    30  
    31  Identifier Builder::getIdentifier(StringRef str) {
    32    return Identifier::get(str, context);
    33  }
    34  
    35  //===----------------------------------------------------------------------===//
    36  // Locations.
    37  //===----------------------------------------------------------------------===//
    38  
    39  Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
    40  
    41  Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
    42                                      unsigned column) {
    43    return FileLineColLoc::get(filename, line, column, context);
    44  }
    45  
    46  Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
    47    return FusedLoc::get(locs, metadata, context);
    48  }
    49  
    50  //===----------------------------------------------------------------------===//
    51  // Types.
    52  //===----------------------------------------------------------------------===//
    53  
    54  FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
    55  
    56  FloatType Builder::getF16Type() { return FloatType::getF16(context); }
    57  
    58  FloatType Builder::getF32Type() { return FloatType::getF32(context); }
    59  
    60  FloatType Builder::getF64Type() { return FloatType::getF64(context); }
    61  
    62  IndexType Builder::getIndexType() { return IndexType::get(context); }
    63  
    64  IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
    65  
    66  IntegerType Builder::getIntegerType(unsigned width) {
    67    return IntegerType::get(width, context);
    68  }
    69  
    70  FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
    71                                        ArrayRef<Type> results) {
    72    return FunctionType::get(inputs, results, context);
    73  }
    74  
    75  MemRefType Builder::getMemRefType(ArrayRef<int64_t> shape, Type elementType,
    76                                    ArrayRef<AffineMap> affineMapComposition,
    77                                    unsigned memorySpace) {
    78    return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
    79  }
    80  
    81  VectorType Builder::getVectorType(ArrayRef<int64_t> shape, Type elementType) {
    82    return VectorType::get(shape, elementType);
    83  }
    84  
    85  RankedTensorType Builder::getTensorType(ArrayRef<int64_t> shape,
    86                                          Type elementType) {
    87    return RankedTensorType::get(shape, elementType);
    88  }
    89  
    90  UnrankedTensorType Builder::getTensorType(Type elementType) {
    91    return UnrankedTensorType::get(elementType);
    92  }
    93  
    94  TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
    95    return TupleType::get(elementTypes, context);
    96  }
    97  
    98  NoneType Builder::getNoneType() { return NoneType::get(context); }
    99  
   100  //===----------------------------------------------------------------------===//
   101  // Attributes.
   102  //===----------------------------------------------------------------------===//
   103  
   104  NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
   105    return NamedAttribute(getIdentifier(name), val);
   106  }
   107  
   108  UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
   109  
   110  BoolAttr Builder::getBoolAttr(bool value) {
   111    return BoolAttr::get(value, context);
   112  }
   113  
   114  DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
   115    return DictionaryAttr::get(value, context);
   116  }
   117  
   118  IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
   119    return IntegerAttr::get(getIntegerType(64), APInt(64, value));
   120  }
   121  
   122  IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
   123    return IntegerAttr::get(getIntegerType(32), APInt(32, value));
   124  }
   125  
   126  IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
   127    if (type.isIndex())
   128      return IntegerAttr::get(type, APInt(64, value));
   129    return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value));
   130  }
   131  
   132  IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
   133    return IntegerAttr::get(type, value);
   134  }
   135  
   136  FloatAttr Builder::getF64FloatAttr(double value) {
   137    return FloatAttr::get(getF64Type(), APFloat(value));
   138  }
   139  
   140  FloatAttr Builder::getF32FloatAttr(float value) {
   141    return FloatAttr::get(getF32Type(), APFloat(value));
   142  }
   143  
   144  FloatAttr Builder::getF16FloatAttr(float value) {
   145    return FloatAttr::get(getF16Type(), value);
   146  }
   147  
   148  FloatAttr Builder::getFloatAttr(Type type, double value) {
   149    return FloatAttr::get(type, value);
   150  }
   151  
   152  FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
   153    return FloatAttr::get(type, value);
   154  }
   155  
   156  StringAttr Builder::getStringAttr(StringRef bytes) {
   157    return StringAttr::get(bytes, context);
   158  }
   159  
   160  StringAttr Builder::getStringAttr(StringRef bytes, Type type) {
   161    return StringAttr::get(bytes, type);
   162  }
   163  
   164  ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
   165    return ArrayAttr::get(value, context);
   166  }
   167  
   168  AffineMapAttr Builder::getAffineMapAttr(AffineMap map) {
   169    return AffineMapAttr::get(map);
   170  }
   171  
   172  IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
   173    return IntegerSetAttr::get(set);
   174  }
   175  
   176  TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
   177  
   178  SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
   179    auto symName =
   180        value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
   181    assert(symName && "value does not have a valid symbol name");
   182    return getSymbolRefAttr(symName.getValue());
   183  }
   184  SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
   185    return SymbolRefAttr::get(value, getContext());
   186  }
   187  
   188  ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
   189                                             ArrayRef<Attribute> values) {
   190    return DenseElementsAttr::get(type, values);
   191  }
   192  
   193  ElementsAttr Builder::getDenseIntElementsAttr(ShapedType type,
   194                                                ArrayRef<int64_t> values) {
   195    return DenseIntElementsAttr::get(type, values);
   196  }
   197  
   198  ElementsAttr Builder::getSparseElementsAttr(ShapedType type,
   199                                              DenseIntElementsAttr indices,
   200                                              DenseElementsAttr values) {
   201    return SparseElementsAttr::get(type, indices, values);
   202  }
   203  
   204  ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, ShapedType type,
   205                                              StringRef bytes) {
   206    return OpaqueElementsAttr::get(dialect, type, bytes);
   207  }
   208  
   209  ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
   210    auto attrs = functional::map(
   211        [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values);
   212    return getArrayAttr(attrs);
   213  }
   214  
   215  ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
   216    auto attrs = functional::map(
   217        [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values);
   218    return getArrayAttr(attrs);
   219  }
   220  
   221  ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
   222    auto attrs = functional::map(
   223        [this](int64_t v) -> Attribute {
   224          return getIntegerAttr(IndexType::get(getContext()), v);
   225        },
   226        values);
   227    return getArrayAttr(attrs);
   228  }
   229  
   230  ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
   231    auto attrs = functional::map(
   232        [this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
   233    return getArrayAttr(attrs);
   234  }
   235  
   236  ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
   237    auto attrs = functional::map(
   238        [this](double v) -> Attribute { return getF64FloatAttr(v); }, values);
   239    return getArrayAttr(attrs);
   240  }
   241  
   242  ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
   243    auto attrs = functional::map(
   244        [this](StringRef v) -> Attribute { return getStringAttr(v); }, values);
   245    return getArrayAttr(attrs);
   246  }
   247  
   248  ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
   249    auto attrs = functional::map(
   250        [this](AffineMap v) -> Attribute { return getAffineMapAttr(v); }, values);
   251    return getArrayAttr(attrs);
   252  }
   253  
   254  Attribute Builder::getZeroAttr(Type type) {
   255    switch (type.getKind()) {
   256    case StandardTypes::F16:
   257      return getF16FloatAttr(0);
   258    case StandardTypes::F32:
   259      return getF32FloatAttr(0);
   260    case StandardTypes::F64:
   261      return getF64FloatAttr(0);
   262    case StandardTypes::Integer: {
   263      auto width = type.cast<IntegerType>().getWidth();
   264      if (width == 1)
   265        return getBoolAttr(false);
   266      return getIntegerAttr(type, APInt(width, 0));
   267    }
   268    case StandardTypes::Vector:
   269    case StandardTypes::RankedTensor: {
   270      auto vtType = type.cast<ShapedType>();
   271      auto element = getZeroAttr(vtType.getElementType());
   272      if (!element)
   273        return {};
   274      return getDenseElementsAttr(vtType, element);
   275    }
   276    default:
   277      break;
   278    }
   279    return {};
   280  }
   281  
   282  //===----------------------------------------------------------------------===//
   283  // Affine Expressions, Affine Maps, and Integet Sets.
   284  //===----------------------------------------------------------------------===//
   285  
   286  AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
   287                                  ArrayRef<AffineExpr> results) {
   288    return AffineMap::get(dimCount, symbolCount, results);
   289  }
   290  
   291  AffineExpr Builder::getAffineDimExpr(unsigned position) {
   292    return mlir::getAffineDimExpr(position, context);
   293  }
   294  
   295  AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
   296    return mlir::getAffineSymbolExpr(position, context);
   297  }
   298  
   299  AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
   300    return mlir::getAffineConstantExpr(constant, context);
   301  }
   302  
   303  IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
   304                                    ArrayRef<AffineExpr> constraints,
   305                                    ArrayRef<bool> isEq) {
   306    return IntegerSet::get(dimCount, symbolCount, constraints, isEq);
   307  }
   308  
   309  AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
   310  
   311  AffineMap Builder::getConstantAffineMap(int64_t val) {
   312    return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
   313                          {getAffineConstantExpr(val)});
   314  }
   315  
   316  AffineMap Builder::getDimIdentityMap() {
   317    return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
   318                          {getAffineDimExpr(0)});
   319  }
   320  
   321  AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
   322    SmallVector<AffineExpr, 4> dimExprs;
   323    dimExprs.reserve(rank);
   324    for (unsigned i = 0; i < rank; ++i)
   325      dimExprs.push_back(getAffineDimExpr(i));
   326    return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs);
   327  }
   328  
   329  AffineMap Builder::getSymbolIdentityMap() {
   330    return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
   331                          {getAffineSymbolExpr(0)});
   332  }
   333  
   334  AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
   335    // expr = d0 + shift.
   336    auto expr = getAffineDimExpr(0) + shift;
   337    return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr});
   338  }
   339  
   340  AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
   341    SmallVector<AffineExpr, 4> shiftedResults;
   342    shiftedResults.reserve(map.getNumResults());
   343    for (auto resultExpr : map.getResults()) {
   344      shiftedResults.push_back(resultExpr + shift);
   345    }
   346    return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults);
   347  }
   348  
   349  //===----------------------------------------------------------------------===//
   350  // OpBuilder.
   351  //===----------------------------------------------------------------------===//
   352  
   353  OpBuilder::~OpBuilder() {}
   354  
   355  /// Add new block and set the insertion point to the end of it. The block is
   356  /// inserted at the provided insertion point of 'parent'.
   357  Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
   358    assert(parent && "expected valid parent region");
   359    if (insertPt == Region::iterator())
   360      insertPt = parent->end();
   361  
   362    Block *b = new Block();
   363    parent->getBlocks().insert(insertPt, b);
   364    setInsertionPointToEnd(b);
   365    return b;
   366  }
   367  
   368  /// Add new block and set the insertion point to the end of it.  The block is
   369  /// placed before 'insertBefore'.
   370  Block *OpBuilder::createBlock(Block *insertBefore) {
   371    assert(insertBefore && "expected valid insertion block");
   372    return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
   373  }
   374  
   375  /// Create an operation given the fields represented as an OperationState.
   376  Operation *OpBuilder::createOperation(const OperationState &state) {
   377    assert(block && "createOperation() called without setting builder's block");
   378    auto *op = Operation::create(state);
   379    insert(op);
   380    return op;
   381  }
   382  
   383  /// Attempts to fold the given operation and places new results within
   384  /// 'results'.
   385  void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) {
   386    results.reserve(op->getNumResults());
   387    SmallVector<OpFoldResult, 4> foldResults;
   388  
   389    // Returns if the given fold result corresponds to a valid existing value.
   390    auto isValidValue = [](OpFoldResult result) {
   391      return result.dyn_cast<Value *>();
   392    };
   393  
   394    // Check if the fold failed, or did not result in only existing values.
   395    SmallVector<Attribute, 4> constOperands(op->getNumOperands());
   396    if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() ||
   397        !llvm::all_of(foldResults, isValidValue)) {
   398      // Simply return the existing operation results.
   399      results.assign(op->result_begin(), op->result_end());
   400      return;
   401    }
   402  
   403    // Populate the results with the folded results and remove the original op.
   404    llvm::transform(foldResults, std::back_inserter(results),
   405                    [](OpFoldResult result) { return result.get<Value *>(); });
   406    op->erase();
   407  }
   408  
   409  /// Insert the given operation at the current insertion point.
   410  void OpBuilder::insert(Operation *op) {
   411    if (block)
   412      block->getOperations().insert(insertPoint, op);
   413  }