github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Analysis/AffineAnalysis.cpp (about) 1 //===- AffineAnalysis.cpp - Affine structures analysis routines -----------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements miscellaneous analysis routines for affine structures 19 // (expressions, maps, sets), and other utilities relying on such analysis. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Analysis/AffineAnalysis.h" 24 #include "mlir/Analysis/AffineStructures.h" 25 #include "mlir/Analysis/Utils.h" 26 #include "mlir/Dialect/AffineOps/AffineOps.h" 27 #include "mlir/Dialect/StandardOps/Ops.h" 28 #include "mlir/IR/AffineExprVisitor.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/Function.h" 31 #include "mlir/IR/IntegerSet.h" 32 #include "mlir/IR/Operation.h" 33 #include "mlir/Support/MathExtras.h" 34 #include "mlir/Support/STLExtras.h" 35 #include "llvm/ADT/DenseMap.h" 36 #include "llvm/Support/Debug.h" 37 #include "llvm/Support/raw_ostream.h" 38 39 #define DEBUG_TYPE "affine-analysis" 40 41 using namespace mlir; 42 43 using llvm::dbgs; 44 45 /// Returns the sequence of AffineApplyOp Operations operation in 46 /// 'affineApplyOps', which are reachable via a search starting from 'operands', 47 /// and ending at operands which are not defined by AffineApplyOps. 48 // TODO(andydavis) Add a method to AffineApplyOp which forward substitutes 49 // the AffineApplyOp into any user AffineApplyOps. 50 void mlir::getReachableAffineApplyOps( 51 ArrayRef<Value *> operands, SmallVectorImpl<Operation *> &affineApplyOps) { 52 struct State { 53 // The ssa value for this node in the DFS traversal. 54 Value *value; 55 // The operand index of 'value' to explore next during DFS traversal. 56 unsigned operandIndex; 57 }; 58 SmallVector<State, 4> worklist; 59 for (auto *operand : operands) { 60 worklist.push_back({operand, 0}); 61 } 62 63 while (!worklist.empty()) { 64 State &state = worklist.back(); 65 auto *opInst = state.value->getDefiningOp(); 66 // Note: getDefiningOp will return nullptr if the operand is not an 67 // Operation (i.e. block argument), which is a terminator for the search. 68 if (!isa_and_nonnull<AffineApplyOp>(opInst)) { 69 worklist.pop_back(); 70 continue; 71 } 72 73 if (state.operandIndex == 0) { 74 // Pre-Visit: Add 'opInst' to reachable sequence. 75 affineApplyOps.push_back(opInst); 76 } 77 if (state.operandIndex < opInst->getNumOperands()) { 78 // Visit: Add next 'affineApplyOp' operand to worklist. 79 // Get next operand to visit at 'operandIndex'. 80 auto *nextOperand = opInst->getOperand(state.operandIndex); 81 // Increment 'operandIndex' in 'state'. 82 ++state.operandIndex; 83 // Add 'nextOperand' to worklist. 84 worklist.push_back({nextOperand, 0}); 85 } else { 86 // Post-visit: done visiting operands AffineApplyOp, pop off stack. 87 worklist.pop_back(); 88 } 89 } 90 } 91 92 // Builds a system of constraints with dimensional identifiers corresponding to 93 // the loop IVs of the forOps appearing in that order. Any symbols founds in 94 // the bound operands are added as symbols in the system. Returns failure for 95 // the yet unimplemented cases. 96 // TODO(andydavis,bondhugula) Handle non-unit steps through local variables or 97 // stride information in FlatAffineConstraints. (For eg., by using iv - lb % 98 // step = 0 and/or by introducing a method in FlatAffineConstraints 99 // setExprStride(ArrayRef<int64_t> expr, int64_t stride) 100 LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps, 101 FlatAffineConstraints *domain) { 102 SmallVector<Value *, 4> indices; 103 extractForInductionVars(forOps, &indices); 104 // Reset while associated Values in 'indices' to the domain. 105 domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); 106 for (auto forOp : forOps) { 107 // Add constraints from forOp's bounds. 108 if (failed(domain->addAffineForOpDomain(forOp))) 109 return failure(); 110 } 111 return success(); 112 } 113 114 // Computes the iteration domain for 'opInst' and populates 'indexSet', which 115 // encapsulates the constraints involving loops surrounding 'opInst' and 116 // potentially involving any Function symbols. The dimensional identifiers in 117 // 'indexSet' correspond to the loops surounding 'op' from outermost to 118 // innermost. 119 // TODO(andydavis) Add support to handle IfInsts surrounding 'op'. 120 static LogicalResult getInstIndexSet(Operation *op, 121 FlatAffineConstraints *indexSet) { 122 // TODO(andydavis) Extend this to gather enclosing IfInsts and consider 123 // factoring it out into a utility function. 124 SmallVector<AffineForOp, 4> loops; 125 getLoopIVs(*op, &loops); 126 return getIndexSet(loops, indexSet); 127 } 128 129 // ValuePositionMap manages the mapping from Values which represent dimension 130 // and symbol identifiers from 'src' and 'dst' access functions to positions 131 // in new space where some Values are kept separate (using addSrc/DstValue) 132 // and some Values are merged (addSymbolValue). 133 // Position lookups return the absolute position in the new space which 134 // has the following format: 135 // 136 // [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifers] 137 // 138 // Note: access function non-IV dimension identifiers (that have 'dimension' 139 // positions in the access function position space) are assigned as symbols 140 // in the output position space. Convienience access functions which lookup 141 // an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle 142 // the common case of resolving positions for all access function operands. 143 // 144 // TODO(andydavis) Generalize this: could take a template parameter for 145 // the number of maps (3 in the current case), and lookups could take indices 146 // of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". 147 class ValuePositionMap { 148 public: 149 void addSrcValue(Value *value) { 150 if (addValueAt(value, &srcDimPosMap, numSrcDims)) 151 ++numSrcDims; 152 } 153 void addDstValue(Value *value) { 154 if (addValueAt(value, &dstDimPosMap, numDstDims)) 155 ++numDstDims; 156 } 157 void addSymbolValue(Value *value) { 158 if (addValueAt(value, &symbolPosMap, numSymbols)) 159 ++numSymbols; 160 } 161 unsigned getSrcDimOrSymPos(Value *value) const { 162 return getDimOrSymPos(value, srcDimPosMap, 0); 163 } 164 unsigned getDstDimOrSymPos(Value *value) const { 165 return getDimOrSymPos(value, dstDimPosMap, numSrcDims); 166 } 167 unsigned getSymPos(Value *value) const { 168 auto it = symbolPosMap.find(value); 169 assert(it != symbolPosMap.end()); 170 return numSrcDims + numDstDims + it->second; 171 } 172 173 unsigned getNumSrcDims() const { return numSrcDims; } 174 unsigned getNumDstDims() const { return numDstDims; } 175 unsigned getNumDims() const { return numSrcDims + numDstDims; } 176 unsigned getNumSymbols() const { return numSymbols; } 177 178 private: 179 bool addValueAt(Value *value, DenseMap<Value *, unsigned> *posMap, 180 unsigned position) { 181 auto it = posMap->find(value); 182 if (it == posMap->end()) { 183 (*posMap)[value] = position; 184 return true; 185 } 186 return false; 187 } 188 unsigned getDimOrSymPos(Value *value, 189 const DenseMap<Value *, unsigned> &dimPosMap, 190 unsigned dimPosOffset) const { 191 auto it = dimPosMap.find(value); 192 if (it != dimPosMap.end()) { 193 return dimPosOffset + it->second; 194 } 195 it = symbolPosMap.find(value); 196 assert(it != symbolPosMap.end()); 197 return numSrcDims + numDstDims + it->second; 198 } 199 200 unsigned numSrcDims = 0; 201 unsigned numDstDims = 0; 202 unsigned numSymbols = 0; 203 DenseMap<Value *, unsigned> srcDimPosMap; 204 DenseMap<Value *, unsigned> dstDimPosMap; 205 DenseMap<Value *, unsigned> symbolPosMap; 206 }; 207 208 // Builds a map from Value to identifier position in a new merged identifier 209 // list, which is the result of merging dim/symbol lists from src/dst 210 // iteration domains, the format of which is as follows: 211 // 212 // [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term] 213 // 214 // This method populates 'valuePosMap' with mappings from operand Values in 215 // 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') 216 // to the position of these values in the merged list. 217 static void buildDimAndSymbolPositionMaps( 218 const FlatAffineConstraints &srcDomain, 219 const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, 220 const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap, 221 FlatAffineConstraints *dependenceConstraints) { 222 auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) { 223 for (unsigned i = 0, e = values.size(); i < e; ++i) { 224 auto *value = values[i]; 225 if (!isForInductionVar(values[i])) { 226 assert(isValidSymbol(values[i]) && 227 "access operand has to be either a loop IV or a symbol"); 228 valuePosMap->addSymbolValue(value); 229 } else if (isSrc) { 230 valuePosMap->addSrcValue(value); 231 } else { 232 valuePosMap->addDstValue(value); 233 } 234 } 235 }; 236 237 SmallVector<Value *, 4> srcValues, destValues; 238 srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues); 239 dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues); 240 // Update value position map with identifiers from src iteration domain. 241 updateValuePosMap(srcValues, /*isSrc=*/true); 242 // Update value position map with identifiers from dst iteration domain. 243 updateValuePosMap(destValues, /*isSrc=*/false); 244 // Update value position map with identifiers from src access function. 245 updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true); 246 // Update value position map with identifiers from dst access function. 247 updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); 248 } 249 250 // Sets up dependence constraints columns appropriately, in the format: 251 // [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term] 252 void initDependenceConstraints(const FlatAffineConstraints &srcDomain, 253 const FlatAffineConstraints &dstDomain, 254 const AffineValueMap &srcAccessMap, 255 const AffineValueMap &dstAccessMap, 256 const ValuePositionMap &valuePosMap, 257 FlatAffineConstraints *dependenceConstraints) { 258 // Calculate number of equalities/inequalities and columns required to 259 // initialize FlatAffineConstraints for 'dependenceDomain'. 260 unsigned numIneq = 261 srcDomain.getNumInequalities() + dstDomain.getNumInequalities(); 262 AffineMap srcMap = srcAccessMap.getAffineMap(); 263 assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults()); 264 unsigned numEq = srcMap.getNumResults(); 265 unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds(); 266 unsigned numSymbols = valuePosMap.getNumSymbols(); 267 unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds(); 268 unsigned numIds = numDims + numSymbols + numLocals; 269 unsigned numCols = numIds + 1; 270 271 // Set flat affine constraints sizes and reserving space for constraints. 272 dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols, 273 numLocals); 274 275 // Set values corresponding to dependence constraint identifiers. 276 SmallVector<Value *, 4> srcLoopIVs, dstLoopIVs; 277 srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs); 278 dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs); 279 280 dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs); 281 dependenceConstraints->setIdValues( 282 srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); 283 284 // Set values for the symbolic identifier dimensions. 285 auto setSymbolIds = [&](ArrayRef<Value *> values) { 286 for (auto *value : values) { 287 if (!isForInductionVar(value)) { 288 assert(isValidSymbol(value) && "expected symbol"); 289 dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value); 290 } 291 } 292 }; 293 294 setSymbolIds(srcAccessMap.getOperands()); 295 setSymbolIds(dstAccessMap.getOperands()); 296 297 SmallVector<Value *, 8> srcSymbolValues, dstSymbolValues; 298 srcDomain.getIdValues(srcDomain.getNumDimIds(), 299 srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); 300 dstDomain.getIdValues(dstDomain.getNumDimIds(), 301 dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); 302 setSymbolIds(srcSymbolValues); 303 setSymbolIds(dstSymbolValues); 304 305 for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds(); 306 i < e; i++) 307 assert(dependenceConstraints->getIds()[i].hasValue()); 308 } 309 310 // Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into 311 // 'dependenceDomain'. 312 // Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a 313 // srcDomain/dstDomain Value maps. 314 static void addDomainConstraints(const FlatAffineConstraints &srcDomain, 315 const FlatAffineConstraints &dstDomain, 316 const ValuePositionMap &valuePosMap, 317 FlatAffineConstraints *dependenceDomain) { 318 unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds(); 319 320 SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols()); 321 322 auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) { 323 const FlatAffineConstraints &domain = isSrc ? srcDomain : dstDomain; 324 unsigned numCsts = 325 isEq ? domain.getNumEqualities() : domain.getNumInequalities(); 326 unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds(); 327 auto at = [&](unsigned i, unsigned j) -> int64_t { 328 return isEq ? domain.atEq(i, j) : domain.atIneq(i, j); 329 }; 330 auto map = [&](unsigned i) -> int64_t { 331 return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getIdValue(i)) 332 : valuePosMap.getDstDimOrSymPos(domain.getIdValue(i)); 333 }; 334 335 for (unsigned i = 0; i < numCsts; ++i) { 336 // Zero fill. 337 std::fill(cst.begin(), cst.end(), 0); 338 // Set coefficients for identifiers corresponding to domain. 339 for (unsigned j = 0; j < numDimAndSymbolIds; ++j) 340 cst[map(j)] = at(i, j); 341 // Local terms. 342 for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++) 343 cst[depNumDimsAndSymbolIds + localOffset + j] = 344 at(i, numDimAndSymbolIds + j); 345 // Set constant term. 346 cst[cst.size() - 1] = at(i, domain.getNumCols() - 1); 347 // Add constraint. 348 if (isEq) 349 dependenceDomain->addEquality(cst); 350 else 351 dependenceDomain->addInequality(cst); 352 } 353 }; 354 355 // Add equalities from src domain. 356 addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0); 357 // Add inequalities from src domain. 358 addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0); 359 // Add equalities from dst domain. 360 addDomain(/*isSrc=*/false, /*isEq=*/true, 361 /*localOffset=*/srcDomain.getNumLocalIds()); 362 // Add inequalities from dst domain. 363 addDomain(/*isSrc=*/false, /*isEq=*/false, 364 /*localOffset=*/srcDomain.getNumLocalIds()); 365 } 366 367 // Adds equality constraints that equate src and dst access functions 368 // represented by 'srcAccessMap' and 'dstAccessMap' for each result. 369 // Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count. 370 // For example, given the following two accesses functions to a 2D memref: 371 // 372 // Source access function: 373 // (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2) 374 // 375 // Destination acceses function: 376 // (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2) 377 // 378 // This method constructs the following equality constraints in 379 // 'dependenceDomain', by equating the access functions for each result 380 // (i.e. each memref dim). Notice that 'd0' for the destination access function 381 // is mapped into 'd0' in the equality constraint: 382 // 383 // d0 d1 s0 c 384 // -- -- -- -- 385 // a0 -c0 (a1 - c1) (a1 - c2) = 0 386 // b0 -f0 (b1 - f1) (b1 - f2) = 0 387 // 388 // Returns failure if any AffineExpr cannot be flattened (due to it being 389 // semi-affine). Returns success otherwise. 390 static LogicalResult 391 addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, 392 const AffineValueMap &dstAccessMap, 393 const ValuePositionMap &valuePosMap, 394 FlatAffineConstraints *dependenceDomain) { 395 AffineMap srcMap = srcAccessMap.getAffineMap(); 396 AffineMap dstMap = dstAccessMap.getAffineMap(); 397 assert(srcMap.getNumResults() == dstMap.getNumResults()); 398 unsigned numResults = srcMap.getNumResults(); 399 400 unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); 401 ArrayRef<Value *> srcOperands = srcAccessMap.getOperands(); 402 403 unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); 404 ArrayRef<Value *> dstOperands = dstAccessMap.getOperands(); 405 406 std::vector<SmallVector<int64_t, 8>> srcFlatExprs; 407 std::vector<SmallVector<int64_t, 8>> destFlatExprs; 408 FlatAffineConstraints srcLocalVarCst, destLocalVarCst; 409 // Get flattened expressions for the source destination maps. 410 if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) || 411 failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))) 412 return failure(); 413 414 unsigned domNumLocalIds = dependenceDomain->getNumLocalIds(); 415 unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds(); 416 unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds(); 417 unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds; 418 for (unsigned i = 0; i < numLocalIdsToAdd; i++) { 419 dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds()); 420 } 421 422 unsigned numDims = dependenceDomain->getNumDimIds(); 423 unsigned numSymbols = dependenceDomain->getNumSymbolIds(); 424 unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds(); 425 unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds; 426 427 // Equality to add. 428 SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols()); 429 for (unsigned i = 0; i < numResults; ++i) { 430 // Zero fill. 431 std::fill(eq.begin(), eq.end(), 0); 432 433 // Flattened AffineExpr for src result 'i'. 434 const auto &srcFlatExpr = srcFlatExprs[i]; 435 // Set identifier coefficients from src access function. 436 for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) 437 eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j]; 438 // Local terms. 439 for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) 440 eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j]; 441 // Set constant term. 442 eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1]; 443 444 // Flattened AffineExpr for dest result 'i'. 445 const auto &destFlatExpr = destFlatExprs[i]; 446 // Set identifier coefficients from dst access function. 447 for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) 448 eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j]; 449 // Local terms. 450 for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) 451 eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j]; 452 // Set constant term. 453 eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1]; 454 455 // Add equality constraint. 456 dependenceDomain->addEquality(eq); 457 } 458 459 // Add equality constraints for any operands that are defined by constant ops. 460 auto addEqForConstOperands = [&](ArrayRef<Value *> operands) { 461 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 462 if (isForInductionVar(operands[i])) 463 continue; 464 auto *symbol = operands[i]; 465 assert(isValidSymbol(symbol)); 466 // Check if the symbol is a constant. 467 if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol->getDefiningOp())) 468 dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), 469 cOp.getValue()); 470 } 471 }; 472 473 // Add equality constraints for any src symbols defined by constant ops. 474 addEqForConstOperands(srcOperands); 475 // Add equality constraints for any dst symbols defined by constant ops. 476 addEqForConstOperands(dstOperands); 477 478 // By construction (see flattener), local var constraints will not have any 479 // equalities. 480 assert(srcLocalVarCst.getNumEqualities() == 0 && 481 destLocalVarCst.getNumEqualities() == 0); 482 // Add inequalities from srcLocalVarCst and destLocalVarCst into the 483 // dependence domain. 484 SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols()); 485 for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) { 486 std::fill(ineq.begin(), ineq.end(), 0); 487 488 // Set identifier coefficients from src local var constraints. 489 for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) 490 ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = 491 srcLocalVarCst.atIneq(r, j); 492 // Local terms. 493 for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) 494 ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j); 495 // Set constant term. 496 ineq[ineq.size() - 1] = 497 srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1); 498 dependenceDomain->addInequality(ineq); 499 } 500 501 for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) { 502 std::fill(ineq.begin(), ineq.end(), 0); 503 // Set identifier coefficients from dest local var constraints. 504 for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) 505 ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] = 506 destLocalVarCst.atIneq(r, j); 507 // Local terms. 508 for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) 509 ineq[newLocalIdOffset + numSrcLocalIds + j] = 510 destLocalVarCst.atIneq(r, dstNumIds + j); 511 // Set constant term. 512 ineq[ineq.size() - 1] = 513 destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1); 514 515 dependenceDomain->addInequality(ineq); 516 } 517 return success(); 518 } 519 520 // Returns the number of outer loop common to 'src/dstDomain'. 521 // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null. 522 static unsigned 523 getNumCommonLoops(const FlatAffineConstraints &srcDomain, 524 const FlatAffineConstraints &dstDomain, 525 SmallVectorImpl<AffineForOp> *commonLoops = nullptr) { 526 // Find the number of common loops shared by src and dst accesses. 527 unsigned minNumLoops = 528 std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds()); 529 unsigned numCommonLoops = 0; 530 for (unsigned i = 0; i < minNumLoops; ++i) { 531 if (!isForInductionVar(srcDomain.getIdValue(i)) || 532 !isForInductionVar(dstDomain.getIdValue(i)) || 533 srcDomain.getIdValue(i) != dstDomain.getIdValue(i)) 534 break; 535 if (commonLoops != nullptr) 536 commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i))); 537 ++numCommonLoops; 538 } 539 if (commonLoops != nullptr) 540 assert(commonLoops->size() == numCommonLoops); 541 return numCommonLoops; 542 } 543 544 // Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. 545 static Block *getCommonBlock(const MemRefAccess &srcAccess, 546 const MemRefAccess &dstAccess, 547 const FlatAffineConstraints &srcDomain, 548 unsigned numCommonLoops) { 549 if (numCommonLoops == 0) { 550 auto *block = srcAccess.opInst->getBlock(); 551 while (!llvm::isa<FuncOp>(block->getParentOp())) { 552 block = block->getParentOp()->getBlock(); 553 } 554 return block; 555 } 556 auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); 557 auto forOp = getForInductionVarOwner(commonForValue); 558 assert(forOp && "commonForValue was not an induction variable"); 559 return forOp.getBody(); 560 } 561 562 // Returns true if the ancestor operation of 'srcAccess' appears before the 563 // ancestor operation of 'dstAccess' in the common ancestral block. Returns 564 // false otherwise. 565 // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals, 566 // the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that 567 // 'numCommonLoops' is the number of contiguous surrounding outer loops. 568 static bool srcAppearsBeforeDstInAncestralBlock( 569 const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, 570 const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { 571 // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. 572 auto *commonBlock = 573 getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); 574 // Check the dominance relationship between the respective ancestors of the 575 // src and dst in the Block of the innermost among the common loops. 576 auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst); 577 assert(srcInst != nullptr); 578 auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst); 579 assert(dstInst != nullptr); 580 581 // Determine whether dstInst comes after srcInst. 582 return srcInst->isBeforeInBlock(dstInst); 583 } 584 585 // Adds ordering constraints to 'dependenceDomain' based on number of loops 586 // common to 'src/dstDomain' and requested 'loopDepth'. 587 // Note that 'loopDepth' cannot exceed the number of common loops plus one. 588 // EX: Given a loop nest of depth 2 with IVs 'i' and 'j': 589 // *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1 590 // *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1 591 // *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j' 592 static void addOrderingConstraints(const FlatAffineConstraints &srcDomain, 593 const FlatAffineConstraints &dstDomain, 594 unsigned loopDepth, 595 FlatAffineConstraints *dependenceDomain) { 596 unsigned numCols = dependenceDomain->getNumCols(); 597 SmallVector<int64_t, 4> eq(numCols); 598 unsigned numSrcDims = srcDomain.getNumDimIds(); 599 unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); 600 unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth); 601 for (unsigned i = 0; i < numCommonLoopConstraints; ++i) { 602 std::fill(eq.begin(), eq.end(), 0); 603 eq[i] = -1; 604 eq[i + numSrcDims] = 1; 605 if (i == loopDepth - 1) { 606 eq[numCols - 1] = -1; 607 dependenceDomain->addInequality(eq); 608 } else { 609 dependenceDomain->addEquality(eq); 610 } 611 } 612 } 613 614 // Computes distance and direction vectors in 'dependences', by adding 615 // variables to 'dependenceDomain' which represent the difference of the IVs, 616 // eliminating all other variables, and reading off distance vectors from 617 // equality constraints (if possible), and direction vectors from inequalities. 618 static void computeDirectionVector( 619 const FlatAffineConstraints &srcDomain, 620 const FlatAffineConstraints &dstDomain, unsigned loopDepth, 621 FlatAffineConstraints *dependenceDomain, 622 llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) { 623 // Find the number of common loops shared by src and dst accesses. 624 SmallVector<AffineForOp, 4> commonLoops; 625 unsigned numCommonLoops = 626 getNumCommonLoops(srcDomain, dstDomain, &commonLoops); 627 if (numCommonLoops == 0) 628 return; 629 // Compute direction vectors for requested loop depth. 630 unsigned numIdsToEliminate = dependenceDomain->getNumIds(); 631 // Add new variables to 'dependenceDomain' to represent the direction 632 // constraints for each shared loop. 633 for (unsigned j = 0; j < numCommonLoops; ++j) { 634 dependenceDomain->addDimId(j); 635 } 636 637 // Add equality contraints for each common loop, setting newly introduced 638 // variable at column 'j' to the 'dst' IV minus the 'src IV. 639 SmallVector<int64_t, 4> eq; 640 eq.resize(dependenceDomain->getNumCols()); 641 unsigned numSrcDims = srcDomain.getNumDimIds(); 642 // Constraint variables format: 643 // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant] 644 for (unsigned j = 0; j < numCommonLoops; ++j) { 645 std::fill(eq.begin(), eq.end(), 0); 646 eq[j] = 1; 647 eq[j + numCommonLoops] = 1; 648 eq[j + numCommonLoops + numSrcDims] = -1; 649 dependenceDomain->addEquality(eq); 650 } 651 652 // Eliminate all variables other than the direction variables just added. 653 dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate); 654 655 // Scan each common loop variable column and set direction vectors based 656 // on eliminated constraint system. 657 dependenceComponents->resize(numCommonLoops); 658 for (unsigned j = 0; j < numCommonLoops; ++j) { 659 (*dependenceComponents)[j].op = commonLoops[j].getOperation(); 660 auto lbConst = dependenceDomain->getConstantLowerBound(j); 661 (*dependenceComponents)[j].lb = 662 lbConst.getValueOr(std::numeric_limits<int64_t>::min()); 663 auto ubConst = dependenceDomain->getConstantUpperBound(j); 664 (*dependenceComponents)[j].ub = 665 ubConst.getValueOr(std::numeric_limits<int64_t>::max()); 666 } 667 } 668 669 // Populates 'accessMap' with composition of AffineApplyOps reachable from 670 // indices of MemRefAccess. 671 void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { 672 // Get affine map from AffineLoad/Store. 673 AffineMap map; 674 if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) 675 map = loadOp.getAffineMap(); 676 else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) 677 map = storeOp.getAffineMap(); 678 SmallVector<Value *, 8> operands(indices.begin(), indices.end()); 679 fullyComposeAffineMapAndOperands(&map, &operands); 680 map = simplifyAffineMap(map); 681 canonicalizeMapAndOperands(&map, &operands); 682 accessMap->reset(map, operands); 683 } 684 685 // Builds a flat affine constraint system to check if there exists a dependence 686 // between memref accesses 'srcAccess' and 'dstAccess'. 687 // Returns 'NoDependence' if the accesses can be definitively shown not to 688 // access the same element. 689 // Returns 'HasDependence' if the accesses do access the same element. 690 // Returns 'Failure' if an error or unsupported case was encountered. 691 // If a dependence exists, returns in 'dependenceComponents' a direction 692 // vector for the dependence, with a component for each loop IV in loops 693 // common to both accesses (see Dependence in AffineAnalysis.h for details). 694 // 695 // The memref access dependence check is comprised of the following steps: 696 // *) Compute access functions for each access. Access functions are computed 697 // using AffineValueMaps initialized with the indices from an access, then 698 // composed with AffineApplyOps reachable from operands of that access, 699 // until operands of the AffineValueMap are loop IVs or symbols. 700 // *) Build iteration domain constraints for each access. Iteration domain 701 // constraints are pairs of inequality contraints representing the 702 // upper/lower loop bounds for each AffineForOp in the loop nest associated 703 // with each access. 704 // *) Build dimension and symbol position maps for each access, which map 705 // Values from access functions and iteration domains to their position 706 // in the merged constraint system built by this method. 707 // 708 // This method builds a constraint system with the following column format: 709 // 710 // [src-dim-identifiers, dst-dim-identifiers, symbols, constant] 711 // 712 // For example, given the following MLIR code with with "source" and 713 // "destination" accesses to the same memref labled, and symbols %M, %N, %K: 714 // 715 // affine.for %i0 = 0 to 100 { 716 // affine.for %i1 = 0 to 50 { 717 // %a0 = affine.apply 718 // (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N] 719 // // Source memref access. 720 // store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32> 721 // } 722 // } 723 // 724 // affine.for %i2 = 0 to 100 { 725 // affine.for %i3 = 0 to 50 { 726 // %a1 = affine.apply 727 // (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M] 728 // // Destination memref access. 729 // %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32> 730 // } 731 // } 732 // 733 // The access functions would be the following: 734 // 735 // src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M) 736 // dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K) 737 // 738 // The iteration domains for the src/dst accesses would be the following: 739 // 740 // src: 0 <= %i0 <= 100, 0 <= %i1 <= 50 741 // dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50 742 // 743 // The symbols by both accesses would be assigned to a canonical position order 744 // which will be used in the dependence constraint system: 745 // 746 // symbol name: %M %N %K 747 // symbol pos: 0 1 2 748 // 749 // Equality constraints are built by equating each result of src/destination 750 // access functions. For this example, the following two equality constraints 751 // will be added to the dependence constraint system: 752 // 753 // [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] 754 // 2 -4 -7 -9 1 1 0 0 = 0 755 // 0 3 0 -11 -1 0 1 0 = 0 756 // 757 // Inequality constraints from the iteration domain will be meged into 758 // the dependence constraint system 759 // 760 // [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] 761 // 1 0 0 0 0 0 0 0 >= 0 762 // -1 0 0 0 0 0 0 100 >= 0 763 // 0 1 0 0 0 0 0 0 >= 0 764 // 0 -1 0 0 0 0 0 50 >= 0 765 // 0 0 1 0 0 0 0 0 >= 0 766 // 0 0 -1 0 0 0 0 100 >= 0 767 // 0 0 0 1 0 0 0 0 >= 0 768 // 0 0 0 -1 0 0 0 50 >= 0 769 // 770 // 771 // TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv. 772 DependenceResult mlir::checkMemrefAccessDependence( 773 const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, 774 unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, 775 llvm::SmallVector<DependenceComponent, 2> *dependenceComponents, 776 bool allowRAR) { 777 LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: " 778 << Twine(loopDepth) << " between:\n";); 779 LLVM_DEBUG(srcAccess.opInst->dump();); 780 LLVM_DEBUG(dstAccess.opInst->dump();); 781 782 // Return 'NoDependence' if these accesses do not access the same memref. 783 if (srcAccess.memref != dstAccess.memref) 784 return DependenceResult::NoDependence; 785 786 // Return 'NoDependence' if one of these accesses is not an AffineStoreOp. 787 if (!allowRAR && !isa<AffineStoreOp>(srcAccess.opInst) && 788 !isa<AffineStoreOp>(dstAccess.opInst)) 789 return DependenceResult::NoDependence; 790 791 // Get composed access function for 'srcAccess'. 792 AffineValueMap srcAccessMap; 793 srcAccess.getAccessMap(&srcAccessMap); 794 795 // Get composed access function for 'dstAccess'. 796 AffineValueMap dstAccessMap; 797 dstAccess.getAccessMap(&dstAccessMap); 798 799 // Get iteration domain for the 'srcAccess' operation. 800 FlatAffineConstraints srcDomain; 801 if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain))) 802 return DependenceResult::Failure; 803 804 // Get iteration domain for 'dstAccess' operation. 805 FlatAffineConstraints dstDomain; 806 if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain))) 807 return DependenceResult::Failure; 808 809 // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor 810 // operation of 'srcAccess' does not properly dominate the ancestor 811 // operation of 'dstAccess' in the same common operation block. 812 // Note: this check is skipped if 'allowRAR' is true, because because RAR 813 // deps can exist irrespective of lexicographic ordering b/w src and dst. 814 unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); 815 assert(loopDepth <= numCommonLoops + 1); 816 if (!allowRAR && loopDepth > numCommonLoops && 817 !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain, 818 numCommonLoops)) { 819 return DependenceResult::NoDependence; 820 } 821 // Build dim and symbol position maps for each access from access operand 822 // Value to position in merged contstraint system. 823 ValuePositionMap valuePosMap; 824 buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, 825 dstAccessMap, &valuePosMap, 826 dependenceConstraints); 827 828 initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap, 829 valuePosMap, dependenceConstraints); 830 831 assert(valuePosMap.getNumDims() == 832 srcDomain.getNumDimIds() + dstDomain.getNumDimIds()); 833 834 // Create memref access constraint by equating src/dst access functions. 835 // Note that this check is conservative, and will fail in the future when 836 // local variables for mod/div exprs are supported. 837 if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, 838 dependenceConstraints))) 839 return DependenceResult::Failure; 840 841 // Add 'src' happens before 'dst' ordering constraints. 842 addOrderingConstraints(srcDomain, dstDomain, loopDepth, 843 dependenceConstraints); 844 // Add src and dst domain constraints. 845 addDomainConstraints(srcDomain, dstDomain, valuePosMap, 846 dependenceConstraints); 847 848 // Return 'NoDependence' if the solution space is empty: no dependence. 849 if (dependenceConstraints->isEmpty()) { 850 return DependenceResult::NoDependence; 851 } 852 853 // Compute dependence direction vector and return true. 854 if (dependenceComponents != nullptr) { 855 computeDirectionVector(srcDomain, dstDomain, loopDepth, 856 dependenceConstraints, dependenceComponents); 857 } 858 859 LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n"); 860 LLVM_DEBUG(dependenceConstraints->dump()); 861 return DependenceResult::HasDependence; 862 } 863 864 /// Gathers dependence components for dependences between all ops in loop nest 865 /// rooted at 'forOp' at loop depths in range [1, maxLoopDepth]. 866 void mlir::getDependenceComponents( 867 AffineForOp forOp, unsigned maxLoopDepth, 868 std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec) { 869 // Collect all load and store ops in loop nest rooted at 'forOp'. 870 SmallVector<Operation *, 8> loadAndStoreOpInsts; 871 forOp.getOperation()->walk([&](Operation *opInst) { 872 if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst)) 873 loadAndStoreOpInsts.push_back(opInst); 874 }); 875 876 unsigned numOps = loadAndStoreOpInsts.size(); 877 for (unsigned d = 1; d <= maxLoopDepth; ++d) { 878 for (unsigned i = 0; i < numOps; ++i) { 879 auto *srcOpInst = loadAndStoreOpInsts[i]; 880 MemRefAccess srcAccess(srcOpInst); 881 for (unsigned j = 0; j < numOps; ++j) { 882 auto *dstOpInst = loadAndStoreOpInsts[j]; 883 MemRefAccess dstAccess(dstOpInst); 884 885 FlatAffineConstraints dependenceConstraints; 886 llvm::SmallVector<DependenceComponent, 2> depComps; 887 // TODO(andydavis,bondhugula) Explore whether it would be profitable 888 // to pre-compute and store deps instead of repeatedly checking. 889 DependenceResult result = checkMemrefAccessDependence( 890 srcAccess, dstAccess, d, &dependenceConstraints, &depComps); 891 if (hasDependence(result)) 892 depCompsVec->push_back(depComps); 893 } 894 } 895 } 896 }