github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/EDSC/Builders.cpp (about) 1 //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/EDSC/Builders.h" 19 #include "mlir/Dialect/StandardOps/Ops.h" 20 #include "mlir/IR/AffineExpr.h" 21 22 #include "llvm/ADT/Optional.h" 23 24 using namespace mlir; 25 using namespace mlir::edsc; 26 27 mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, Location location) 28 : builder(builder), location(location), 29 enclosingScopedContext(ScopedContext::getCurrentScopedContext()), 30 nestedBuilder(nullptr) { 31 getCurrentScopedContext() = this; 32 } 33 34 /// Sets the insertion point of the builder to 'newInsertPt' for the duration 35 /// of the scope. The existing insertion point of the builder is restored on 36 /// destruction. 37 mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, 38 OpBuilder::InsertPoint newInsertPt, 39 Location location) 40 : builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()), 41 location(location), 42 enclosingScopedContext(ScopedContext::getCurrentScopedContext()), 43 nestedBuilder(nullptr) { 44 getCurrentScopedContext() = this; 45 builder.restoreInsertionPoint(newInsertPt); 46 } 47 48 mlir::edsc::ScopedContext::~ScopedContext() { 49 assert(!nestedBuilder && 50 "Active NestedBuilder must have been exited at this point!"); 51 if (prevBuilderInsertPoint) 52 builder.restoreInsertionPoint(*prevBuilderInsertPoint); 53 getCurrentScopedContext() = enclosingScopedContext; 54 } 55 56 ScopedContext *&mlir::edsc::ScopedContext::getCurrentScopedContext() { 57 thread_local ScopedContext *context = nullptr; 58 return context; 59 } 60 61 OpBuilder &mlir::edsc::ScopedContext::getBuilder() { 62 assert(ScopedContext::getCurrentScopedContext() && 63 "Unexpected Null ScopedContext"); 64 return ScopedContext::getCurrentScopedContext()->builder; 65 } 66 67 Location mlir::edsc::ScopedContext::getLocation() { 68 assert(ScopedContext::getCurrentScopedContext() && 69 "Unexpected Null ScopedContext"); 70 return ScopedContext::getCurrentScopedContext()->location; 71 } 72 73 MLIRContext *mlir::edsc::ScopedContext::getContext() { 74 return getBuilder().getContext(); 75 } 76 77 mlir::edsc::ValueHandle::ValueHandle(index_t cst) { 78 auto &b = ScopedContext::getBuilder(); 79 auto loc = ScopedContext::getLocation(); 80 v = b.create<ConstantIndexOp>(loc, cst.v).getResult(); 81 t = v->getType(); 82 } 83 84 ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) { 85 assert(t == other.t && "Wrong type capture"); 86 assert(!v && "ValueHandle has already been captured, use a new name!"); 87 v = other.v; 88 return *this; 89 } 90 91 ValueHandle 92 mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map, 93 ArrayRef<Value *> operands) { 94 Operation *op = 95 makeComposedAffineApply(ScopedContext::getBuilder(), 96 ScopedContext::getLocation(), map, operands) 97 .getOperation(); 98 assert(op->getNumResults() == 1 && "Not a single result AffineApply"); 99 return ValueHandle(op->getResult(0)); 100 } 101 102 ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands, 103 ArrayRef<Type> resultTypes, 104 ArrayRef<NamedAttribute> attributes) { 105 Operation *op = 106 OperationHandle::create(name, operands, resultTypes, attributes); 107 if (op->getNumResults() == 1) { 108 return ValueHandle(op->getResult(0)); 109 } 110 if (auto f = dyn_cast<AffineForOp>(op)) { 111 return ValueHandle(f.getInductionVar()); 112 } 113 llvm_unreachable("unsupported operation, use an OperationHandle instead"); 114 } 115 116 OperationHandle OperationHandle::create(StringRef name, 117 ArrayRef<ValueHandle> operands, 118 ArrayRef<Type> resultTypes, 119 ArrayRef<NamedAttribute> attributes) { 120 OperationState state(ScopedContext::getLocation(), name); 121 SmallVector<Value *, 4> ops(operands.begin(), operands.end()); 122 state.addOperands(ops); 123 state.addTypes(resultTypes); 124 for (const auto &attr : attributes) { 125 state.addAttribute(attr.first, attr.second); 126 } 127 return OperationHandle(ScopedContext::getBuilder().createOperation(state)); 128 } 129 130 BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) { 131 auto ¤tB = ScopedContext::getBuilder(); 132 auto *ib = currentB.getInsertionBlock(); 133 auto ip = currentB.getInsertionPoint(); 134 BlockHandle res; 135 res.block = ScopedContext::getBuilder().createBlock(ib->getParent()); 136 // createBlock sets the insertion point inside the block. 137 // We do not want this behavior when using declarative builders with nesting. 138 currentB.setInsertionPoint(ib, ip); 139 for (auto t : argTypes) { 140 res.block->addArgument(t); 141 } 142 return res; 143 } 144 145 static llvm::Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs, 146 ArrayRef<ValueHandle> ubs, 147 int64_t step) { 148 if (lbs.size() != 1 || ubs.size() != 1) 149 return llvm::Optional<ValueHandle>(); 150 151 auto *lbDef = lbs.front().getValue()->getDefiningOp(); 152 auto *ubDef = ubs.front().getValue()->getDefiningOp(); 153 if (!lbDef || !ubDef) 154 return llvm::Optional<ValueHandle>(); 155 156 auto lbConst = dyn_cast<ConstantIndexOp>(lbDef); 157 auto ubConst = dyn_cast<ConstantIndexOp>(ubDef); 158 if (!lbConst || !ubConst) 159 return llvm::Optional<ValueHandle>(); 160 161 return ValueHandle::create<AffineForOp>(lbConst.getValue(), 162 ubConst.getValue(), step); 163 } 164 165 mlir::edsc::LoopBuilder::LoopBuilder(ValueHandle *iv, 166 ArrayRef<ValueHandle> lbHandles, 167 ArrayRef<ValueHandle> ubHandles, 168 int64_t step) { 169 if (auto res = emitStaticFor(lbHandles, ubHandles, step)) { 170 *iv = res.getValue(); 171 } else { 172 SmallVector<Value *, 4> lbs(lbHandles.begin(), lbHandles.end()); 173 SmallVector<Value *, 4> ubs(ubHandles.begin(), ubHandles.end()); 174 *iv = ValueHandle::create<AffineForOp>( 175 lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()), 176 ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()), 177 step); 178 } 179 auto *body = getForInductionVarOwner(iv->getValue()).getBody(); 180 enter(body, /*prev=*/1); 181 } 182 183 ValueHandle 184 mlir::edsc::LoopBuilder::operator()(llvm::function_ref<void(void)> fun) { 185 // Call to `exit` must be explicit and asymmetric (cannot happen in the 186 // destructor) because of ordering wrt comma operator. 187 /// The particular use case concerns nested blocks: 188 /// 189 /// ```c++ 190 /// For (&i, lb, ub, 1)({ 191 /// /--- destructor for this `For` is not always called before ... 192 /// V 193 /// For (&j1, lb, ub, 1)({ 194 /// some_op_1, 195 /// }), 196 /// /--- ... this scope is entered, resulting in improperly nested IR. 197 /// V 198 /// For (&j2, lb, ub, 1)({ 199 /// some_op_2, 200 /// }), 201 /// }); 202 /// ``` 203 if (fun) 204 fun(); 205 exit(); 206 return ValueHandle::null(); 207 } 208 209 mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs, 210 ArrayRef<ValueHandle> lbs, 211 ArrayRef<ValueHandle> ubs, 212 ArrayRef<int64_t> steps) { 213 assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); 214 assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); 215 assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); 216 for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { 217 loops.emplace_back(std::get<0>(it), std::get<1>(it), std::get<2>(it), 218 std::get<3>(it)); 219 } 220 } 221 222 ValueHandle 223 mlir::edsc::LoopNestBuilder::operator()(llvm::function_ref<void(void)> fun) { 224 if (fun) 225 fun(); 226 // Iterate on the calling operator() on all the loops in the nest. 227 // The iteration order is from innermost to outermost because enter/exit needs 228 // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() 229 // occurs on calling operator()). The asymmetry is required for properly 230 // nesting imperfectly nested regions (see LoopBuilder::operator()). 231 for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) { 232 (*lit)(); 233 } 234 return ValueHandle::null(); 235 } 236 237 mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) { 238 assert(bh && "Expected already captured BlockHandle"); 239 enter(bh.getBlock()); 240 } 241 242 mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, 243 ArrayRef<ValueHandle *> args) { 244 assert(!*bh && "BlockHandle already captures a block, use " 245 "the explicit BockBuilder(bh, Append())({}) syntax instead."); 246 llvm::SmallVector<Type, 8> types; 247 for (auto *a : args) { 248 assert(!a->hasValue() && 249 "Expected delayed ValueHandle that has not yet captured."); 250 types.push_back(a->getType()); 251 } 252 *bh = BlockHandle::create(types); 253 for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) { 254 *(std::get<0>(it)) = ValueHandle(std::get<1>(it)); 255 } 256 enter(bh->getBlock()); 257 } 258 259 /// Only serves as an ordering point between entering nested block and creating 260 /// stmts. 261 void mlir::edsc::BlockBuilder::operator()(llvm::function_ref<void(void)> fun) { 262 // Call to `exit` must be explicit and asymmetric (cannot happen in the 263 // destructor) because of ordering wrt comma operator. 264 if (fun) 265 fun(); 266 exit(); 267 } 268 269 template <typename Op> 270 static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) { 271 return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue()); 272 } 273 274 static std::pair<AffineExpr, Value *> 275 categorizeValueByAffineType(MLIRContext *context, Value *val, unsigned &numDims, 276 unsigned &numSymbols) { 277 AffineExpr d; 278 Value *resultVal = nullptr; 279 if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val->getDefiningOp())) { 280 d = getAffineConstantExpr(constant.getValue(), context); 281 } else if (isValidSymbol(val) && !isValidDim(val)) { 282 d = getAffineSymbolExpr(numSymbols++, context); 283 resultVal = val; 284 } else { 285 assert(isValidDim(val) && "Must be a valid Dim"); 286 d = getAffineDimExpr(numDims++, context); 287 resultVal = val; 288 } 289 return std::make_pair(d, resultVal); 290 } 291 292 static ValueHandle createBinaryIndexHandle( 293 ValueHandle lhs, ValueHandle rhs, 294 llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { 295 MLIRContext *context = ScopedContext::getContext(); 296 unsigned numDims = 0, numSymbols = 0; 297 AffineExpr d0, d1; 298 Value *v0, *v1; 299 std::tie(d0, v0) = 300 categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols); 301 std::tie(d1, v1) = 302 categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols); 303 SmallVector<Value *, 2> operands; 304 if (v0) { 305 operands.push_back(v0); 306 } 307 if (v1) { 308 operands.push_back(v1); 309 } 310 auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)}); 311 // TODO: createOrFold when available. 312 return ValueHandle::createComposedAffineApply(map, operands); 313 } 314 315 template <typename IOp, typename FOp> 316 static ValueHandle createBinaryHandle( 317 ValueHandle lhs, ValueHandle rhs, 318 llvm::function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { 319 auto thisType = lhs.getValue()->getType(); 320 auto thatType = rhs.getValue()->getType(); 321 assert(thisType == thatType && "cannot mix types in operators"); 322 (void)thisType; 323 (void)thatType; 324 if (thisType.isIndex()) { 325 return createBinaryIndexHandle(lhs, rhs, affCombiner); 326 } else if (thisType.isa<IntegerType>()) { 327 return createBinaryHandle<IOp>(lhs, rhs); 328 } else if (thisType.isa<FloatType>()) { 329 return createBinaryHandle<FOp>(lhs, rhs); 330 } else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) { 331 auto aggregateType = thisType.cast<ShapedType>(); 332 if (aggregateType.getElementType().isa<IntegerType>()) 333 return createBinaryHandle<IOp>(lhs, rhs); 334 else if (aggregateType.getElementType().isa<FloatType>()) 335 return createBinaryHandle<FOp>(lhs, rhs); 336 } 337 llvm_unreachable("failed to create a ValueHandle"); 338 } 339 340 ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) { 341 return createBinaryHandle<AddIOp, AddFOp>( 342 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; }); 343 } 344 345 ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) { 346 return createBinaryHandle<SubIOp, SubFOp>( 347 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; }); 348 } 349 350 ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) { 351 return createBinaryHandle<MulIOp, MulFOp>( 352 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; }); 353 } 354 355 ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) { 356 return createBinaryHandle<DivISOp, DivFOp>( 357 lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr { 358 llvm_unreachable("only exprs of non-index type support operator/"); 359 }); 360 } 361 362 ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) { 363 return createBinaryHandle<RemISOp, RemFOp>( 364 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; }); 365 } 366 367 ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) { 368 return createBinaryIndexHandle( 369 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); }); 370 } 371 372 ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) { 373 return createBinaryIndexHandle( 374 lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); }); 375 } 376 377 ValueHandle mlir::edsc::op::operator!(ValueHandle value) { 378 assert(value.getType().isInteger(1) && "expected boolean expression"); 379 return ValueHandle::create<ConstantIntOp>(1, 1) - value; 380 } 381 382 ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) { 383 assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); 384 assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); 385 return lhs * rhs; 386 } 387 388 ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) { 389 return !(!lhs && !rhs); 390 } 391 392 static ValueHandle createIComparisonExpr(CmpIPredicate predicate, 393 ValueHandle lhs, ValueHandle rhs) { 394 auto lhsType = lhs.getType(); 395 auto rhsType = rhs.getType(); 396 (void)lhsType; 397 (void)rhsType; 398 assert(lhsType == rhsType && "cannot mix types in operators"); 399 assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) && 400 "only integer comparisons are supported"); 401 402 auto op = ScopedContext::getBuilder().create<CmpIOp>( 403 ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); 404 return ValueHandle(op.getResult()); 405 } 406 407 static ValueHandle createFComparisonExpr(CmpFPredicate predicate, 408 ValueHandle lhs, ValueHandle rhs) { 409 auto lhsType = lhs.getType(); 410 auto rhsType = rhs.getType(); 411 (void)lhsType; 412 (void)rhsType; 413 assert(lhsType == rhsType && "cannot mix types in operators"); 414 assert(lhsType.isa<FloatType>() && "only float comparisons are supported"); 415 416 auto op = ScopedContext::getBuilder().create<CmpFOp>( 417 ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); 418 return ValueHandle(op.getResult()); 419 } 420 421 // All floating point comparison are ordered through EDSL 422 ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { 423 auto type = lhs.getType(); 424 return type.isa<FloatType>() 425 ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs) 426 : createIComparisonExpr(CmpIPredicate::EQ, lhs, rhs); 427 } 428 ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) { 429 auto type = lhs.getType(); 430 return type.isa<FloatType>() 431 ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) 432 : createIComparisonExpr(CmpIPredicate::NE, lhs, rhs); 433 } 434 ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { 435 auto type = lhs.getType(); 436 return type.isa<FloatType>() 437 ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) 438 : 439 // TODO(ntv,zinenko): signed by default, how about unsigned? 440 createIComparisonExpr(CmpIPredicate::SLT, lhs, rhs); 441 } 442 ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) { 443 auto type = lhs.getType(); 444 return type.isa<FloatType>() 445 ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) 446 : createIComparisonExpr(CmpIPredicate::SLE, lhs, rhs); 447 } 448 ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) { 449 auto type = lhs.getType(); 450 return type.isa<FloatType>() 451 ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) 452 : createIComparisonExpr(CmpIPredicate::SGT, lhs, rhs); 453 } 454 ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) { 455 auto type = lhs.getType(); 456 return type.isa<FloatType>() 457 ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) 458 : createIComparisonExpr(CmpIPredicate::SGE, lhs, rhs); 459 }