github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/EDSC/Builders.cpp (about)

     1  //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/EDSC/Builders.h"
    19  #include "mlir/Dialect/StandardOps/Ops.h"
    20  #include "mlir/IR/AffineExpr.h"
    21  
    22  #include "llvm/ADT/Optional.h"
    23  
    24  using namespace mlir;
    25  using namespace mlir::edsc;
    26  
    27  mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, Location location)
    28      : builder(builder), location(location),
    29        enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
    30        nestedBuilder(nullptr) {
    31    getCurrentScopedContext() = this;
    32  }
    33  
    34  /// Sets the insertion point of the builder to 'newInsertPt' for the duration
    35  /// of the scope. The existing insertion point of the builder is restored on
    36  /// destruction.
    37  mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder,
    38                                           OpBuilder::InsertPoint newInsertPt,
    39                                           Location location)
    40      : builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()),
    41        location(location),
    42        enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
    43        nestedBuilder(nullptr) {
    44    getCurrentScopedContext() = this;
    45    builder.restoreInsertionPoint(newInsertPt);
    46  }
    47  
    48  mlir::edsc::ScopedContext::~ScopedContext() {
    49    assert(!nestedBuilder &&
    50           "Active NestedBuilder must have been exited at this point!");
    51    if (prevBuilderInsertPoint)
    52      builder.restoreInsertionPoint(*prevBuilderInsertPoint);
    53    getCurrentScopedContext() = enclosingScopedContext;
    54  }
    55  
    56  ScopedContext *&mlir::edsc::ScopedContext::getCurrentScopedContext() {
    57    thread_local ScopedContext *context = nullptr;
    58    return context;
    59  }
    60  
    61  OpBuilder &mlir::edsc::ScopedContext::getBuilder() {
    62    assert(ScopedContext::getCurrentScopedContext() &&
    63           "Unexpected Null ScopedContext");
    64    return ScopedContext::getCurrentScopedContext()->builder;
    65  }
    66  
    67  Location mlir::edsc::ScopedContext::getLocation() {
    68    assert(ScopedContext::getCurrentScopedContext() &&
    69           "Unexpected Null ScopedContext");
    70    return ScopedContext::getCurrentScopedContext()->location;
    71  }
    72  
    73  MLIRContext *mlir::edsc::ScopedContext::getContext() {
    74    return getBuilder().getContext();
    75  }
    76  
    77  mlir::edsc::ValueHandle::ValueHandle(index_t cst) {
    78    auto &b = ScopedContext::getBuilder();
    79    auto loc = ScopedContext::getLocation();
    80    v = b.create<ConstantIndexOp>(loc, cst.v).getResult();
    81    t = v->getType();
    82  }
    83  
    84  ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) {
    85    assert(t == other.t && "Wrong type capture");
    86    assert(!v && "ValueHandle has already been captured, use a new name!");
    87    v = other.v;
    88    return *this;
    89  }
    90  
    91  ValueHandle
    92  mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map,
    93                                                     ArrayRef<Value *> operands) {
    94    Operation *op =
    95        makeComposedAffineApply(ScopedContext::getBuilder(),
    96                                ScopedContext::getLocation(), map, operands)
    97            .getOperation();
    98    assert(op->getNumResults() == 1 && "Not a single result AffineApply");
    99    return ValueHandle(op->getResult(0));
   100  }
   101  
   102  ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands,
   103                                  ArrayRef<Type> resultTypes,
   104                                  ArrayRef<NamedAttribute> attributes) {
   105    Operation *op =
   106        OperationHandle::create(name, operands, resultTypes, attributes);
   107    if (op->getNumResults() == 1) {
   108      return ValueHandle(op->getResult(0));
   109    }
   110    if (auto f = dyn_cast<AffineForOp>(op)) {
   111      return ValueHandle(f.getInductionVar());
   112    }
   113    llvm_unreachable("unsupported operation, use an OperationHandle instead");
   114  }
   115  
   116  OperationHandle OperationHandle::create(StringRef name,
   117                                          ArrayRef<ValueHandle> operands,
   118                                          ArrayRef<Type> resultTypes,
   119                                          ArrayRef<NamedAttribute> attributes) {
   120    OperationState state(ScopedContext::getLocation(), name);
   121    SmallVector<Value *, 4> ops(operands.begin(), operands.end());
   122    state.addOperands(ops);
   123    state.addTypes(resultTypes);
   124    for (const auto &attr : attributes) {
   125      state.addAttribute(attr.first, attr.second);
   126    }
   127    return OperationHandle(ScopedContext::getBuilder().createOperation(state));
   128  }
   129  
   130  BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
   131    auto &currentB = ScopedContext::getBuilder();
   132    auto *ib = currentB.getInsertionBlock();
   133    auto ip = currentB.getInsertionPoint();
   134    BlockHandle res;
   135    res.block = ScopedContext::getBuilder().createBlock(ib->getParent());
   136    // createBlock sets the insertion point inside the block.
   137    // We do not want this behavior when using declarative builders with nesting.
   138    currentB.setInsertionPoint(ib, ip);
   139    for (auto t : argTypes) {
   140      res.block->addArgument(t);
   141    }
   142    return res;
   143  }
   144  
   145  static llvm::Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
   146                                                   ArrayRef<ValueHandle> ubs,
   147                                                   int64_t step) {
   148    if (lbs.size() != 1 || ubs.size() != 1)
   149      return llvm::Optional<ValueHandle>();
   150  
   151    auto *lbDef = lbs.front().getValue()->getDefiningOp();
   152    auto *ubDef = ubs.front().getValue()->getDefiningOp();
   153    if (!lbDef || !ubDef)
   154      return llvm::Optional<ValueHandle>();
   155  
   156    auto lbConst = dyn_cast<ConstantIndexOp>(lbDef);
   157    auto ubConst = dyn_cast<ConstantIndexOp>(ubDef);
   158    if (!lbConst || !ubConst)
   159      return llvm::Optional<ValueHandle>();
   160  
   161    return ValueHandle::create<AffineForOp>(lbConst.getValue(),
   162                                            ubConst.getValue(), step);
   163  }
   164  
   165  mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv,
   166                                       ArrayRef<ValueHandle> lbHandles,
   167                                       ArrayRef<ValueHandle> ubHandles,
   168                                       int64_t step) {
   169    if (auto res = emitStaticFor(lbHandles, ubHandles, step)) {
   170      *iv = res.getValue();
   171    } else {
   172      SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end());
   173      SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end());
   174      *iv = ValueHandle::create<AffineForOp>(
   175          lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()),
   176          ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()),
   177          step);
   178    }
   179    auto *body = getForInductionVarOwner(iv->getValue()).getBody();
   180    enter(body, /*prev=*/1);
   181  }
   182  
   183  ValueHandle
   184  mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) {
   185    // Call to `exit` must be explicit and asymmetric (cannot happen in the
   186    // destructor) because of ordering wrt comma operator.
   187    /// The particular use case concerns nested blocks:
   188    ///
   189    /// ```c++
   190    ///    For (&i, lb, ub, 1)({
   191    ///      /--- destructor for this `For` is not always called before ...
   192    ///      V
   193    ///      For (&j1, lb, ub, 1)({
   194    ///        some_op_1,
   195    ///      }),
   196    ///      /--- ... this scope is entered, resulting in improperly nested IR.
   197    ///      V
   198    ///      For (&j2, lb, ub, 1)({
   199    ///        some_op_2,
   200    ///      }),
   201    ///    });
   202    /// ```
   203    if (fun)
   204      fun();
   205    exit();
   206    return ValueHandle::null();
   207  }
   208  
   209  mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
   210                                               ArrayRef<ValueHandle> lbs,
   211                                               ArrayRef<ValueHandle> ubs,
   212                                               ArrayRef<int64_t> steps) {
   213    assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
   214    assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
   215    assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
   216    for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
   217      loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it),
   218                         std::get<3>(it));
   219    }
   220  }
   221  
   222  ValueHandle
   223  mlir::edsc::LoopNestBuilder::operator()(llvm::function_ref<void(void)> fun) {
   224    if (fun)
   225      fun();
   226    // Iterate on the calling operator() on all the loops in the nest.
   227    // The iteration order is from innermost to outermost because enter/exit needs
   228    // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
   229    // occurs on calling operator()). The asymmetry is required for properly
   230    // nesting imperfectly nested regions (see LoopBuilder::operator()).
   231    for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) {
   232      (*lit)();
   233    }
   234    return ValueHandle::null();
   235  }
   236  
   237  mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
   238    assert(bh && "Expected already captured BlockHandle");
   239    enter(bh.getBlock());
   240  }
   241  
   242  mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
   243                                         ArrayRef<ValueHandle *> args) {
   244    assert(!*bh && "BlockHandle already captures a block, use "
   245                   "the explicit BockBuilder(bh, Append())({}) syntax instead.");
   246    llvm::SmallVector<Type, 8> types;
   247    for (auto *a : args) {
   248      assert(!a->hasValue() &&
   249             "Expected delayed ValueHandle that has not yet captured.");
   250      types.push_back(a->getType());
   251    }
   252    *bh = BlockHandle::create(types);
   253    for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
   254      *(std::get<0>(it)) = ValueHandle(std::get<1>(it));
   255    }
   256    enter(bh->getBlock());
   257  }
   258  
   259  /// Only serves as an ordering point between entering nested block and creating
   260  /// stmts.
   261  void mlir::edsc::BlockBuilder::operator()(llvm::function_ref<void(void)> fun) {
   262    // Call to `exit` must be explicit and asymmetric (cannot happen in the
   263    // destructor) because of ordering wrt comma operator.
   264    if (fun)
   265      fun();
   266    exit();
   267  }
   268  
   269  template <typename Op>
   270  static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) {
   271    return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue());
   272  }
   273  
   274  static std::pair<AffineExpr, Value *>
   275  categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims,
   276                              unsigned &numSymbols) {
   277    AffineExpr d;
   278    Value *resultVal = nullptr;
   279    if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val->getDefiningOp())) {
   280      d = getAffineConstantExpr(constant.getValue(), context);
   281    } else if (isValidSymbol(val) && !isValidDim(val)) {
   282      d = getAffineSymbolExpr(numSymbols++, context);
   283      resultVal = val;
   284    } else {
   285      assert(isValidDim(val) && "Must be a valid Dim");
   286      d = getAffineDimExpr(numDims++, context);
   287      resultVal = val;
   288    }
   289    return std::make_pair(d, resultVal);
   290  }
   291  
   292  static ValueHandle createBinaryIndexHandle(
   293      ValueHandle lhs, ValueHandle rhs,
   294      llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
   295    MLIRContext *context = ScopedContext::getContext();
   296    unsigned numDims = 0, numSymbols = 0;
   297    AffineExpr d0, d1;
   298    Value *v0, *v1;
   299    std::tie(d0, v0) =
   300        categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols);
   301    std::tie(d1, v1) =
   302        categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols);
   303    SmallVector<Value *, 2> operands;
   304    if (v0) {
   305      operands.push_back(v0);
   306    }
   307    if (v1) {
   308      operands.push_back(v1);
   309    }
   310    auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)});
   311    // TODO: createOrFold when available.
   312    return ValueHandle::createComposedAffineApply(map, operands);
   313  }
   314  
   315  template <typename IOp, typename FOp>
   316  static ValueHandle createBinaryHandle(
   317      ValueHandle lhs, ValueHandle rhs,
   318      llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
   319    auto thisType = lhs.getValue()->getType();
   320    auto thatType = rhs.getValue()->getType();
   321    assert(thisType == thatType && "cannot mix types in operators");
   322    (void)thisType;
   323    (void)thatType;
   324    if (thisType.isIndex()) {
   325      return createBinaryIndexHandle(lhs, rhs, affCombiner);
   326    } else if (thisType.isa<IntegerType>()) {
   327      return createBinaryHandle<IOp>(lhs, rhs);
   328    } else if (thisType.isa<FloatType>()) {
   329      return createBinaryHandle<FOp>(lhs, rhs);
   330    } else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
   331      auto aggregateType = thisType.cast<ShapedType>();
   332      if (aggregateType.getElementType().isa<IntegerType>())
   333        return createBinaryHandle<IOp>(lhs, rhs);
   334      else if (aggregateType.getElementType().isa<FloatType>())
   335        return createBinaryHandle<FOp>(lhs, rhs);
   336    }
   337    llvm_unreachable("failed to create a ValueHandle");
   338  }
   339  
   340  ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) {
   341    return createBinaryHandle<AddIOp, AddFOp>(
   342        lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
   343  }
   344  
   345  ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) {
   346    return createBinaryHandle<SubIOp, SubFOp>(
   347        lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
   348  }
   349  
   350  ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) {
   351    return createBinaryHandle<MulIOp, MulFOp>(
   352        lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
   353  }
   354  
   355  ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) {
   356    return createBinaryHandle<DivISOp, DivFOp>(
   357        lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
   358          llvm_unreachable("only exprs of non-index type support operator/");
   359        });
   360  }
   361  
   362  ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) {
   363    return createBinaryHandle<RemISOp, RemFOp>(
   364        lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
   365  }
   366  
   367  ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) {
   368    return createBinaryIndexHandle(
   369        lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
   370  }
   371  
   372  ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) {
   373    return createBinaryIndexHandle(
   374        lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
   375  }
   376  
   377  ValueHandle mlir::edsc::op::operator!(ValueHandle value) {
   378    assert(value.getType().isInteger(1) && "expected boolean expression");
   379    return ValueHandle::create<ConstantIntOp>(1, 1) - value;
   380  }
   381  
   382  ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) {
   383    assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
   384    assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
   385    return lhs * rhs;
   386  }
   387  
   388  ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) {
   389    return !(!lhs && !rhs);
   390  }
   391  
   392  static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
   393                                           ValueHandle lhs, ValueHandle rhs) {
   394    auto lhsType = lhs.getType();
   395    auto rhsType = rhs.getType();
   396    (void)lhsType;
   397    (void)rhsType;
   398    assert(lhsType == rhsType && "cannot mix types in operators");
   399    assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
   400           "only integer comparisons are supported");
   401  
   402    auto op = ScopedContext::getBuilder().create<CmpIOp>(
   403        ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
   404    return ValueHandle(op.getResult());
   405  }
   406  
   407  static ValueHandle createFComparisonExpr(CmpFPredicate predicate,
   408                                           ValueHandle lhs, ValueHandle rhs) {
   409    auto lhsType = lhs.getType();
   410    auto rhsType = rhs.getType();
   411    (void)lhsType;
   412    (void)rhsType;
   413    assert(lhsType == rhsType && "cannot mix types in operators");
   414    assert(lhsType.isa<FloatType>() && "only float comparisons are supported");
   415  
   416    auto op = ScopedContext::getBuilder().create<CmpFOp>(
   417        ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
   418    return ValueHandle(op.getResult());
   419  }
   420  
   421  // All floating point comparison are ordered through EDSL
   422  ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
   423    auto type = lhs.getType();
   424    return type.isa<FloatType>()
   425               ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
   426               : createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs);
   427  }
   428  ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
   429    auto type = lhs.getType();
   430    return type.isa<FloatType>()
   431               ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
   432               : createIComparisonExpr(CmpIPredicate::NE, lhs, rhs);
   433  }
   434  ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
   435    auto type = lhs.getType();
   436    return type.isa<FloatType>()
   437               ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
   438               :
   439               // TODO(ntv,zinenko): signed by default, how about unsigned?
   440               createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs);
   441  }
   442  ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
   443    auto type = lhs.getType();
   444    return type.isa<FloatType>()
   445               ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
   446               : createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs);
   447  }
   448  ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
   449    auto type = lhs.getType();
   450    return type.isa<FloatType>()
   451               ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
   452               : createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs);
   453  }
   454  ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
   455    auto type = lhs.getType();
   456    return type.isa<FloatType>()
   457               ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
   458               : createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs);
   459  }