github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/Builders.cpp (about) 1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/AffineExpr.h" 20 #include "mlir/IR/AffineMap.h" 21 #include "mlir/IR/Attributes.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/Module.h" 25 #include "mlir/IR/StandardTypes.h" 26 #include "mlir/Support/Functional.h" 27 using namespace mlir; 28 29 Builder::Builder(ModuleOp module) : context(module.getContext()) {} 30 31 Identifier Builder::getIdentifier(StringRef str) { 32 return Identifier::get(str, context); 33 } 34 35 //===----------------------------------------------------------------------===// 36 // Locations. 37 //===----------------------------------------------------------------------===// 38 39 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); } 40 41 Location Builder::getFileLineColLoc(Identifier filename, unsigned line, 42 unsigned column) { 43 return FileLineColLoc::get(filename, line, column, context); 44 } 45 46 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) { 47 return FusedLoc::get(locs, metadata, context); 48 } 49 50 //===----------------------------------------------------------------------===// 51 // Types. 52 //===----------------------------------------------------------------------===// 53 54 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); } 55 56 FloatType Builder::getF16Type() { return FloatType::getF16(context); } 57 58 FloatType Builder::getF32Type() { return FloatType::getF32(context); } 59 60 FloatType Builder::getF64Type() { return FloatType::getF64(context); } 61 62 IndexType Builder::getIndexType() { return IndexType::get(context); } 63 64 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); } 65 66 IntegerType Builder::getIntegerType(unsigned width) { 67 return IntegerType::get(width, context); 68 } 69 70 FunctionType Builder::getFunctionType(ArrayRef<Type> inputs, 71 ArrayRef<Type> results) { 72 return FunctionType::get(inputs, results, context); 73 } 74 75 MemRefType Builder::getMemRefType(ArrayRef<int64_t> shape, Type elementType, 76 ArrayRef<AffineMap> affineMapComposition, 77 unsigned memorySpace) { 78 return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); 79 } 80 81 VectorType Builder::getVectorType(ArrayRef<int64_t> shape, Type elementType) { 82 return VectorType::get(shape, elementType); 83 } 84 85 RankedTensorType Builder::getTensorType(ArrayRef<int64_t> shape, 86 Type elementType) { 87 return RankedTensorType::get(shape, elementType); 88 } 89 90 UnrankedTensorType Builder::getTensorType(Type elementType) { 91 return UnrankedTensorType::get(elementType); 92 } 93 94 TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) { 95 return TupleType::get(elementTypes, context); 96 } 97 98 NoneType Builder::getNoneType() { return NoneType::get(context); } 99 100 //===----------------------------------------------------------------------===// 101 // Attributes. 102 //===----------------------------------------------------------------------===// 103 104 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) { 105 return NamedAttribute(getIdentifier(name), val); 106 } 107 108 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } 109 110 BoolAttr Builder::getBoolAttr(bool value) { 111 return BoolAttr::get(value, context); 112 } 113 114 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) { 115 return DictionaryAttr::get(value, context); 116 } 117 118 IntegerAttr Builder::getI64IntegerAttr(int64_t value) { 119 return IntegerAttr::get(getIntegerType(64), APInt(64, value)); 120 } 121 122 IntegerAttr Builder::getI32IntegerAttr(int32_t value) { 123 return IntegerAttr::get(getIntegerType(32), APInt(32, value)); 124 } 125 126 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { 127 if (type.isIndex()) 128 return IntegerAttr::get(type, APInt(64, value)); 129 return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value)); 130 } 131 132 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) { 133 return IntegerAttr::get(type, value); 134 } 135 136 FloatAttr Builder::getF64FloatAttr(double value) { 137 return FloatAttr::get(getF64Type(), APFloat(value)); 138 } 139 140 FloatAttr Builder::getF32FloatAttr(float value) { 141 return FloatAttr::get(getF32Type(), APFloat(value)); 142 } 143 144 FloatAttr Builder::getF16FloatAttr(float value) { 145 return FloatAttr::get(getF16Type(), value); 146 } 147 148 FloatAttr Builder::getFloatAttr(Type type, double value) { 149 return FloatAttr::get(type, value); 150 } 151 152 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) { 153 return FloatAttr::get(type, value); 154 } 155 156 StringAttr Builder::getStringAttr(StringRef bytes) { 157 return StringAttr::get(bytes, context); 158 } 159 160 StringAttr Builder::getStringAttr(StringRef bytes, Type type) { 161 return StringAttr::get(bytes, type); 162 } 163 164 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) { 165 return ArrayAttr::get(value, context); 166 } 167 168 AffineMapAttr Builder::getAffineMapAttr(AffineMap map) { 169 return AffineMapAttr::get(map); 170 } 171 172 IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { 173 return IntegerSetAttr::get(set); 174 } 175 176 TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); } 177 178 SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { 179 auto symName = 180 value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); 181 assert(symName && "value does not have a valid symbol name"); 182 return getSymbolRefAttr(symName.getValue()); 183 } 184 SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { 185 return SymbolRefAttr::get(value, getContext()); 186 } 187 188 ElementsAttr Builder::getDenseElementsAttr(ShapedType type, 189 ArrayRef<Attribute> values) { 190 return DenseElementsAttr::get(type, values); 191 } 192 193 ElementsAttr Builder::getDenseIntElementsAttr(ShapedType type, 194 ArrayRef<int64_t> values) { 195 return DenseIntElementsAttr::get(type, values); 196 } 197 198 ElementsAttr Builder::getSparseElementsAttr(ShapedType type, 199 DenseIntElementsAttr indices, 200 DenseElementsAttr values) { 201 return SparseElementsAttr::get(type, indices, values); 202 } 203 204 ElementsAttr Builder::getOpaqueElementsAttr(Dialect *dialect, ShapedType type, 205 StringRef bytes) { 206 return OpaqueElementsAttr::get(dialect, type, bytes); 207 } 208 209 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) { 210 auto attrs = functional::map( 211 [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values); 212 return getArrayAttr(attrs); 213 } 214 215 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) { 216 auto attrs = functional::map( 217 [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values); 218 return getArrayAttr(attrs); 219 } 220 221 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) { 222 auto attrs = functional::map( 223 [this](int64_t v) -> Attribute { 224 return getIntegerAttr(IndexType::get(getContext()), v); 225 }, 226 values); 227 return getArrayAttr(attrs); 228 } 229 230 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) { 231 auto attrs = functional::map( 232 [this](float v) -> Attribute { return getF32FloatAttr(v); }, values); 233 return getArrayAttr(attrs); 234 } 235 236 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) { 237 auto attrs = functional::map( 238 [this](double v) -> Attribute { return getF64FloatAttr(v); }, values); 239 return getArrayAttr(attrs); 240 } 241 242 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) { 243 auto attrs = functional::map( 244 [this](StringRef v) -> Attribute { return getStringAttr(v); }, values); 245 return getArrayAttr(attrs); 246 } 247 248 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) { 249 auto attrs = functional::map( 250 [this](AffineMap v) -> Attribute { return getAffineMapAttr(v); }, values); 251 return getArrayAttr(attrs); 252 } 253 254 Attribute Builder::getZeroAttr(Type type) { 255 switch (type.getKind()) { 256 case StandardTypes::F16: 257 return getF16FloatAttr(0); 258 case StandardTypes::F32: 259 return getF32FloatAttr(0); 260 case StandardTypes::F64: 261 return getF64FloatAttr(0); 262 case StandardTypes::Integer: { 263 auto width = type.cast<IntegerType>().getWidth(); 264 if (width == 1) 265 return getBoolAttr(false); 266 return getIntegerAttr(type, APInt(width, 0)); 267 } 268 case StandardTypes::Vector: 269 case StandardTypes::RankedTensor: { 270 auto vtType = type.cast<ShapedType>(); 271 auto element = getZeroAttr(vtType.getElementType()); 272 if (!element) 273 return {}; 274 return getDenseElementsAttr(vtType, element); 275 } 276 default: 277 break; 278 } 279 return {}; 280 } 281 282 //===----------------------------------------------------------------------===// 283 // Affine Expressions, Affine Maps, and Integet Sets. 284 //===----------------------------------------------------------------------===// 285 286 AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount, 287 ArrayRef<AffineExpr> results) { 288 return AffineMap::get(dimCount, symbolCount, results); 289 } 290 291 AffineExpr Builder::getAffineDimExpr(unsigned position) { 292 return mlir::getAffineDimExpr(position, context); 293 } 294 295 AffineExpr Builder::getAffineSymbolExpr(unsigned position) { 296 return mlir::getAffineSymbolExpr(position, context); 297 } 298 299 AffineExpr Builder::getAffineConstantExpr(int64_t constant) { 300 return mlir::getAffineConstantExpr(constant, context); 301 } 302 303 IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount, 304 ArrayRef<AffineExpr> constraints, 305 ArrayRef<bool> isEq) { 306 return IntegerSet::get(dimCount, symbolCount, constraints, isEq); 307 } 308 309 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); } 310 311 AffineMap Builder::getConstantAffineMap(int64_t val) { 312 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0, 313 {getAffineConstantExpr(val)}); 314 } 315 316 AffineMap Builder::getDimIdentityMap() { 317 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, 318 {getAffineDimExpr(0)}); 319 } 320 321 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) { 322 SmallVector<AffineExpr, 4> dimExprs; 323 dimExprs.reserve(rank); 324 for (unsigned i = 0; i < rank; ++i) 325 dimExprs.push_back(getAffineDimExpr(i)); 326 return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs); 327 } 328 329 AffineMap Builder::getSymbolIdentityMap() { 330 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1, 331 {getAffineSymbolExpr(0)}); 332 } 333 334 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) { 335 // expr = d0 + shift. 336 auto expr = getAffineDimExpr(0) + shift; 337 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}); 338 } 339 340 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { 341 SmallVector<AffineExpr, 4> shiftedResults; 342 shiftedResults.reserve(map.getNumResults()); 343 for (auto resultExpr : map.getResults()) { 344 shiftedResults.push_back(resultExpr + shift); 345 } 346 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults); 347 } 348 349 //===----------------------------------------------------------------------===// 350 // OpBuilder. 351 //===----------------------------------------------------------------------===// 352 353 OpBuilder::~OpBuilder() {} 354 355 /// Add new block and set the insertion point to the end of it. The block is 356 /// inserted at the provided insertion point of 'parent'. 357 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) { 358 assert(parent && "expected valid parent region"); 359 if (insertPt == Region::iterator()) 360 insertPt = parent->end(); 361 362 Block *b = new Block(); 363 parent->getBlocks().insert(insertPt, b); 364 setInsertionPointToEnd(b); 365 return b; 366 } 367 368 /// Add new block and set the insertion point to the end of it. The block is 369 /// placed before 'insertBefore'. 370 Block *OpBuilder::createBlock(Block *insertBefore) { 371 assert(insertBefore && "expected valid insertion block"); 372 return createBlock(insertBefore->getParent(), Region::iterator(insertBefore)); 373 } 374 375 /// Create an operation given the fields represented as an OperationState. 376 Operation *OpBuilder::createOperation(const OperationState &state) { 377 assert(block && "createOperation() called without setting builder's block"); 378 auto *op = Operation::create(state); 379 insert(op); 380 return op; 381 } 382 383 /// Attempts to fold the given operation and places new results within 384 /// 'results'. 385 void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) { 386 results.reserve(op->getNumResults()); 387 SmallVector<OpFoldResult, 4> foldResults; 388 389 // Returns if the given fold result corresponds to a valid existing value. 390 auto isValidValue = [](OpFoldResult result) { 391 return result.dyn_cast<Value *>(); 392 }; 393 394 // Check if the fold failed, or did not result in only existing values. 395 SmallVector<Attribute, 4> constOperands(op->getNumOperands()); 396 if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() || 397 !llvm::all_of(foldResults, isValidValue)) { 398 // Simply return the existing operation results. 399 results.assign(op->result_begin(), op->result_end()); 400 return; 401 } 402 403 // Populate the results with the folded results and remove the original op. 404 llvm::transform(foldResults, std::back_inserter(results), 405 [](OpFoldResult result) { return result.get<Value *>(); }); 406 op->erase(); 407 } 408 409 /// Insert the given operation at the current insertion point. 410 void OpBuilder::insert(Operation *op) { 411 if (block) 412 block->getOperations().insert(insertPoint, op); 413 }