github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/Utils/Utils.cpp (about) 1 //===- Utils.cpp ---- Misc utilities for code and data transformation -----===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements miscellaneous transformation routines for non-loop IR 19 // structures. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Transforms/Utils.h" 24 25 #include "mlir/Analysis/AffineAnalysis.h" 26 #include "mlir/Analysis/AffineStructures.h" 27 #include "mlir/Analysis/Dominance.h" 28 #include "mlir/Analysis/Utils.h" 29 #include "mlir/Dialect/AffineOps/AffineOps.h" 30 #include "mlir/Dialect/StandardOps/Ops.h" 31 #include "mlir/IR/Builders.h" 32 #include "mlir/IR/Function.h" 33 #include "mlir/IR/Module.h" 34 #include "mlir/Support/MathExtras.h" 35 #include "llvm/ADT/DenseMap.h" 36 using namespace mlir; 37 38 /// Return true if this operation dereferences one or more memref's. 39 // Temporary utility: will be replaced when this is modeled through 40 // side-effects/op traits. TODO(b/117228571) 41 static bool isMemRefDereferencingOp(Operation &op) { 42 if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op) || 43 isa<AffineDmaStartOp>(op) || isa<AffineDmaWaitOp>(op)) 44 return true; 45 return false; 46 } 47 48 /// Return the AffineMapAttr associated with memory 'op' on 'memref'. 49 static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value *memref) { 50 if (auto loadOp = dyn_cast<AffineLoadOp>(op)) 51 return loadOp.getAffineMapAttrForMemRef(memref); 52 else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) 53 return storeOp.getAffineMapAttrForMemRef(memref); 54 else if (auto dmaStart = dyn_cast<AffineDmaStartOp>(op)) 55 return dmaStart.getAffineMapAttrForMemRef(memref); 56 assert(isa<AffineDmaWaitOp>(op)); 57 return cast<AffineDmaWaitOp>(op).getAffineMapAttrForMemRef(memref); 58 } 59 60 // Perform the replacement in `op`. 61 LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, 62 Operation *op, 63 ArrayRef<Value *> extraIndices, 64 AffineMap indexRemap, 65 ArrayRef<Value *> extraOperands) { 66 unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); 67 (void)newMemRefRank; // unused in opt mode 68 unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); 69 (void)oldMemRefRank; 70 if (indexRemap) { 71 assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); 72 assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); 73 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 74 } else { 75 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 76 } 77 78 // Assert same elemental type. 79 assert(oldMemRef->getType().cast<MemRefType>().getElementType() == 80 newMemRef->getType().cast<MemRefType>().getElementType()); 81 82 if (!isMemRefDereferencingOp(*op)) 83 // Failure: memref used in a non-dereferencing context (potentially 84 // escapes); no replacement in these cases. 85 return failure(); 86 87 SmallVector<unsigned, 2> usePositions; 88 for (const auto &opEntry : llvm::enumerate(op->getOperands())) { 89 if (opEntry.value() == oldMemRef) 90 usePositions.push_back(opEntry.index()); 91 } 92 93 // If memref doesn't appear, nothing to do. 94 if (usePositions.empty()) 95 return success(); 96 97 if (usePositions.size() > 1) { 98 // TODO(mlir-team): extend it for this case when needed (rare). 99 assert(false && "multiple dereferencing uses in a single op not supported"); 100 return failure(); 101 } 102 103 unsigned memRefOperandPos = usePositions.front(); 104 105 OpBuilder builder(op); 106 NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); 107 AffineMap oldMap = oldMapAttrPair.second.cast<AffineMapAttr>().getValue(); 108 unsigned oldMapNumInputs = oldMap.getNumInputs(); 109 SmallVector<Value *, 4> oldMapOperands( 110 op->operand_begin() + memRefOperandPos + 1, 111 op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); 112 113 // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. 114 SmallVector<Value *, 4> oldMemRefOperands; 115 SmallVector<Value *, 4> affineApplyOps; 116 oldMemRefOperands.reserve(oldMemRefRank); 117 if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { 118 for (auto resultExpr : oldMap.getResults()) { 119 auto singleResMap = builder.getAffineMap( 120 oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); 121 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 122 oldMapOperands); 123 oldMemRefOperands.push_back(afOp); 124 affineApplyOps.push_back(afOp); 125 } 126 } else { 127 oldMemRefOperands.append(oldMapOperands.begin(), oldMapOperands.end()); 128 } 129 130 // Construct new indices as a remap of the old ones if a remapping has been 131 // provided. The indices of a memref come right after it, i.e., 132 // at position memRefOperandPos + 1. 133 SmallVector<Value *, 4> remapOperands; 134 remapOperands.reserve(extraOperands.size() + oldMemRefRank); 135 remapOperands.append(extraOperands.begin(), extraOperands.end()); 136 remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); 137 138 SmallVector<Value *, 4> remapOutputs; 139 remapOutputs.reserve(oldMemRefRank); 140 141 if (indexRemap && 142 indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { 143 // Remapped indices. 144 for (auto resultExpr : indexRemap.getResults()) { 145 auto singleResMap = builder.getAffineMap( 146 indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); 147 auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, 148 remapOperands); 149 remapOutputs.push_back(afOp); 150 affineApplyOps.push_back(afOp); 151 } 152 } else { 153 // No remapping specified. 154 remapOutputs.append(remapOperands.begin(), remapOperands.end()); 155 } 156 157 SmallVector<Value *, 4> newMapOperands; 158 newMapOperands.reserve(newMemRefRank); 159 160 // Prepend 'extraIndices' in 'newMapOperands'. 161 for (auto *extraIndex : extraIndices) { 162 assert(extraIndex->getDefiningOp()->getNumResults() == 1 && 163 "single result op's expected to generate these indices"); 164 assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && 165 "invalid memory op index"); 166 newMapOperands.push_back(extraIndex); 167 } 168 169 // Append 'remapOutputs' to 'newMapOperands'. 170 newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); 171 172 // Create new fully composed AffineMap for new op to be created. 173 assert(newMapOperands.size() == newMemRefRank); 174 auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); 175 // TODO(b/136262594) Avoid creating/deleting temporary AffineApplyOps here. 176 fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); 177 newMap = simplifyAffineMap(newMap); 178 canonicalizeMapAndOperands(&newMap, &newMapOperands); 179 // Remove any affine.apply's that became dead as a result of composition. 180 for (auto *value : affineApplyOps) 181 if (value->use_empty()) 182 value->getDefiningOp()->erase(); 183 184 // Construct the new operation using this memref. 185 OperationState state(op->getLoc(), op->getName()); 186 state.setOperandListToResizable(op->hasResizableOperandsList()); 187 state.operands.reserve(op->getNumOperands() + extraIndices.size()); 188 // Insert the non-memref operands. 189 state.operands.append(op->operand_begin(), 190 op->operand_begin() + memRefOperandPos); 191 // Insert the new memref value. 192 state.operands.push_back(newMemRef); 193 194 // Insert the new memref map operands. 195 state.operands.append(newMapOperands.begin(), newMapOperands.end()); 196 197 // Insert the remaining operands unmodified. 198 state.operands.append(op->operand_begin() + memRefOperandPos + 1 + 199 oldMapNumInputs, 200 op->operand_end()); 201 202 // Result types don't change. Both memref's are of the same elemental type. 203 state.types.reserve(op->getNumResults()); 204 for (auto *result : op->getResults()) 205 state.types.push_back(result->getType()); 206 207 // Add attribute for 'newMap', other Attributes do not change. 208 auto newMapAttr = builder.getAffineMapAttr(newMap); 209 for (auto namedAttr : op->getAttrs()) { 210 if (namedAttr.first == oldMapAttrPair.first) { 211 state.attributes.push_back({namedAttr.first, newMapAttr}); 212 } else { 213 state.attributes.push_back(namedAttr); 214 } 215 } 216 217 // Create the new operation. 218 auto *repOp = builder.createOperation(state); 219 op->replaceAllUsesWith(repOp); 220 op->erase(); 221 222 return success(); 223 } 224 225 LogicalResult mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef, 226 ArrayRef<Value *> extraIndices, 227 AffineMap indexRemap, 228 ArrayRef<Value *> extraOperands, 229 Operation *domInstFilter, 230 Operation *postDomInstFilter) { 231 unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank(); 232 (void)newMemRefRank; // unused in opt mode 233 unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank(); 234 (void)oldMemRefRank; 235 if (indexRemap) { 236 assert(indexRemap.getNumSymbols() == 0 && "pure dimensional map expected"); 237 assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank); 238 assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); 239 } else { 240 assert(oldMemRefRank + extraIndices.size() == newMemRefRank); 241 } 242 243 // Assert same elemental type. 244 assert(oldMemRef->getType().cast<MemRefType>().getElementType() == 245 newMemRef->getType().cast<MemRefType>().getElementType()); 246 247 std::unique_ptr<DominanceInfo> domInfo; 248 std::unique_ptr<PostDominanceInfo> postDomInfo; 249 if (domInstFilter) 250 domInfo = std::make_unique<DominanceInfo>( 251 domInstFilter->getParentOfType<FuncOp>()); 252 253 if (postDomInstFilter) 254 postDomInfo = std::make_unique<PostDominanceInfo>( 255 postDomInstFilter->getParentOfType<FuncOp>()); 256 257 // Walk all uses of old memref; collect ops to perform replacement. We use a 258 // DenseSet since an operation could potentially have multiple uses of a 259 // memref (although rare), and the replacement later is going to erase ops. 260 DenseSet<Operation *> opsToReplace; 261 for (auto *op : oldMemRef->getUsers()) { 262 // Skip this use if it's not dominated by domInstFilter. 263 if (domInstFilter && !domInfo->dominates(domInstFilter, op)) 264 continue; 265 266 // Skip this use if it's not post-dominated by postDomInstFilter. 267 if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op)) 268 continue; 269 270 // Skip dealloc's - no replacement is necessary, and a memref replacement 271 // at other uses doesn't hurt these dealloc's. 272 if (isa<DeallocOp>(op)) 273 continue; 274 275 // Check if the memref was used in a non-dereferencing context. It is fine 276 // for the memref to be used in a non-dereferencing way outside of the 277 // region where this replacement is happening. 278 if (!isMemRefDereferencingOp(*op)) 279 // Failure: memref used in a non-dereferencing op (potentially escapes); 280 // no replacement in these cases. 281 return failure(); 282 283 // We'll first collect and then replace --- since replacement erases the op 284 // that has the use, and that op could be postDomFilter or domFilter itself! 285 opsToReplace.insert(op); 286 } 287 288 for (auto *op : opsToReplace) { 289 if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, op, extraIndices, 290 indexRemap, extraOperands))) 291 assert(false && "memref replacement guaranteed to succeed here"); 292 } 293 294 return success(); 295 } 296 297 /// Given an operation, inserts one or more single result affine 298 /// apply operations, results of which are exclusively used by this operation 299 /// operation. The operands of these newly created affine apply ops are 300 /// guaranteed to be loop iterators or terminal symbols of a function. 301 /// 302 /// Before 303 /// 304 /// affine.for %i = 0 to #map(%N) 305 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 306 /// "send"(%idx, %A, ...) 307 /// "compute"(%idx) 308 /// 309 /// After 310 /// 311 /// affine.for %i = 0 to #map(%N) 312 /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) 313 /// "send"(%idx, %A, ...) 314 /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) 315 /// "compute"(%idx_) 316 /// 317 /// This allows applying different transformations on send and compute (for eg. 318 /// different shifts/delays). 319 /// 320 /// Returns nullptr either if none of opInst's operands were the result of an 321 /// affine.apply and thus there was no affine computation slice to create, or if 322 /// all the affine.apply op's supplying operands to this opInst did not have any 323 /// uses besides this opInst; otherwise returns the list of affine.apply 324 /// operations created in output argument `sliceOps`. 325 void mlir::createAffineComputationSlice( 326 Operation *opInst, SmallVectorImpl<AffineApplyOp> *sliceOps) { 327 // Collect all operands that are results of affine apply ops. 328 SmallVector<Value *, 4> subOperands; 329 subOperands.reserve(opInst->getNumOperands()); 330 for (auto *operand : opInst->getOperands()) 331 if (isa_and_nonnull<AffineApplyOp>(operand->getDefiningOp())) 332 subOperands.push_back(operand); 333 334 // Gather sequence of AffineApplyOps reachable from 'subOperands'. 335 SmallVector<Operation *, 4> affineApplyOps; 336 getReachableAffineApplyOps(subOperands, affineApplyOps); 337 // Skip transforming if there are no affine maps to compose. 338 if (affineApplyOps.empty()) 339 return; 340 341 // Check if all uses of the affine apply op's lie only in this op op, in 342 // which case there would be nothing to do. 343 bool localized = true; 344 for (auto *op : affineApplyOps) { 345 for (auto *result : op->getResults()) { 346 for (auto *user : result->getUsers()) { 347 if (user != opInst) { 348 localized = false; 349 break; 350 } 351 } 352 } 353 } 354 if (localized) 355 return; 356 357 OpBuilder builder(opInst); 358 SmallVector<Value *, 4> composedOpOperands(subOperands); 359 auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); 360 fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); 361 362 // Create an affine.apply for each of the map results. 363 sliceOps->reserve(composedMap.getNumResults()); 364 for (auto resultExpr : composedMap.getResults()) { 365 auto singleResMap = builder.getAffineMap( 366 composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); 367 sliceOps->push_back(builder.create<AffineApplyOp>( 368 opInst->getLoc(), singleResMap, composedOpOperands)); 369 } 370 371 // Construct the new operands that include the results from the composed 372 // affine apply op above instead of existing ones (subOperands). So, they 373 // differ from opInst's operands only for those operands in 'subOperands', for 374 // which they will be replaced by the corresponding one from 'sliceOps'. 375 SmallVector<Value *, 4> newOperands(opInst->getOperands()); 376 for (unsigned i = 0, e = newOperands.size(); i < e; i++) { 377 // Replace the subOperands from among the new operands. 378 unsigned j, f; 379 for (j = 0, f = subOperands.size(); j < f; j++) { 380 if (newOperands[i] == subOperands[j]) 381 break; 382 } 383 if (j < subOperands.size()) { 384 newOperands[i] = (*sliceOps)[j]; 385 } 386 } 387 for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { 388 opInst->setOperand(idx, newOperands[idx]); 389 } 390 }