github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Transforms/Utils/FoldUtils.cpp (about) 1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines various operation fold utilities. These utilities are 19 // intended to be used by passes to unify and simply their logic. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Transforms/FoldUtils.h" 24 25 #include "mlir/Dialect/StandardOps/Ops.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/IR/Operation.h" 29 30 using namespace mlir; 31 32 /// Given an operation, find the parent region that folded constants should be 33 /// inserted into. 34 static Region *getInsertionRegion(Operation *op) { 35 while (Region *region = op->getParentRegion()) { 36 // Insert in this region for any of the following scenarios: 37 // * The parent is unregistered, or is known to be isolated from above. 38 // * The parent is a top-level operation. 39 auto *parentOp = region->getParentOp(); 40 if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() || 41 !parentOp->getBlock()) 42 return region; 43 // Traverse up the parent looking for an insertion region. 44 op = parentOp; 45 } 46 llvm_unreachable("expected valid insertion region"); 47 } 48 49 /// A utility function used to materialize a constant for a given attribute and 50 /// type. On success, a valid constant value is returned. Otherwise, null is 51 /// returned 52 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 53 Attribute value, Type type, 54 Location loc) { 55 auto insertPt = builder.getInsertionPoint(); 56 (void)insertPt; 57 58 // Ask the dialect to materialize a constant operation for this value. 59 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 60 assert(insertPt == builder.getInsertionPoint()); 61 assert(matchPattern(constOp, m_Constant(&value))); 62 return constOp; 63 } 64 65 // If the dialect is unable to materialize a constant, check to see if the 66 // standard constant can be used. 67 if (ConstantOp::isBuildableWith(value, type)) 68 return builder.create<ConstantOp>(loc, type, value); 69 return nullptr; 70 } 71 72 //===----------------------------------------------------------------------===// 73 // OperationFolder 74 //===----------------------------------------------------------------------===// 75 76 LogicalResult OperationFolder::tryToFold( 77 Operation *op, 78 llvm::function_ref<void(Operation *)> processGeneratedConstants, 79 llvm::function_ref<void(Operation *)> preReplaceAction) { 80 // If this is a unique'd constant, return failure as we know that it has 81 // already been folded. 82 if (referencedDialects.count(op)) 83 return failure(); 84 85 // Try to fold the operation. 86 SmallVector<Value *, 8> results; 87 if (failed(tryToFold(op, results, processGeneratedConstants))) 88 return failure(); 89 90 // Constant folding succeeded. We will start replacing this op's uses and 91 // eventually erase this op. Invoke the callback provided by the caller to 92 // perform any pre-replacement action. 93 if (preReplaceAction) 94 preReplaceAction(op); 95 96 // Check to see if the operation was just updated in place. 97 if (results.empty()) 98 return success(); 99 100 // Otherwise, replace all of the result values and erase the operation. 101 for (unsigned i = 0, e = results.size(); i != e; ++i) 102 op->getResult(i)->replaceAllUsesWith(results[i]); 103 op->erase(); 104 return success(); 105 } 106 107 /// Notifies that the given constant `op` should be remove from this 108 /// OperationFolder's internal bookkeeping. 109 void OperationFolder::notifyRemoval(Operation *op) { 110 // Check to see if this operation is uniqued within the folder. 111 auto it = referencedDialects.find(op); 112 if (it == referencedDialects.end()) 113 return; 114 115 // Get the constant value for this operation, this is the value that was used 116 // to unique the operation internally. 117 Attribute constValue; 118 matchPattern(op, m_Constant(&constValue)); 119 assert(constValue); 120 121 // Get the constant map that this operation was uniqued in. 122 auto &uniquedConstants = foldScopes[getInsertionRegion(op)]; 123 124 // Erase all of the references to this operation. 125 auto type = op->getResult(0)->getType(); 126 for (auto *dialect : it->second) 127 uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 128 referencedDialects.erase(it); 129 } 130 131 /// Tries to perform folding on the given `op`. If successful, populates 132 /// `results` with the results of the folding. 133 LogicalResult OperationFolder::tryToFold( 134 Operation *op, SmallVectorImpl<Value *> &results, 135 llvm::function_ref<void(Operation *)> processGeneratedConstants) { 136 SmallVector<Attribute, 8> operandConstants; 137 SmallVector<OpFoldResult, 8> foldResults; 138 139 // Check to see if any operands to the operation is constant and whether 140 // the operation knows how to constant fold itself. 141 operandConstants.assign(op->getNumOperands(), Attribute()); 142 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 143 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 144 145 // If this is a commutative binary operation with a constant on the left 146 // side move it to the right side. 147 if (operandConstants.size() == 2 && operandConstants[0] && 148 !operandConstants[1] && op->isCommutative()) { 149 std::swap(op->getOpOperand(0), op->getOpOperand(1)); 150 std::swap(operandConstants[0], operandConstants[1]); 151 } 152 153 // Attempt to constant fold the operation. 154 if (failed(op->fold(operandConstants, foldResults))) 155 return failure(); 156 157 // Check to see if the operation was just updated in place. 158 if (foldResults.empty()) 159 return success(); 160 assert(foldResults.size() == op->getNumResults()); 161 162 // Create a builder to insert new operations into the entry block of the 163 // insertion region. 164 auto *insertionRegion = getInsertionRegion(op); 165 auto &entry = insertionRegion->front(); 166 OpBuilder builder(&entry, entry.begin()); 167 168 // Get the constant map for the insertion region of this operation. 169 auto &uniquedConstants = foldScopes[insertionRegion]; 170 171 // Create the result constants and replace the results. 172 auto *dialect = op->getDialect(); 173 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 174 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 175 176 // Check if the result was an SSA value. 177 if (auto *repl = foldResults[i].dyn_cast<Value *>()) { 178 results.emplace_back(repl); 179 continue; 180 } 181 182 // Check to see if there is a canonicalized version of this constant. 183 auto *res = op->getResult(i); 184 Attribute attrRepl = foldResults[i].get<Attribute>(); 185 if (auto *constOp = 186 tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, 187 res->getType(), op->getLoc())) { 188 results.push_back(constOp->getResult(0)); 189 continue; 190 } 191 // If materialization fails, cleanup any operations generated for the 192 // previous results and return failure. 193 for (Operation &op : llvm::make_early_inc_range( 194 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { 195 notifyRemoval(&op); 196 op.erase(); 197 } 198 return failure(); 199 } 200 201 // Process any newly generated operations. 202 if (processGeneratedConstants) { 203 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) 204 processGeneratedConstants(&*i); 205 } 206 207 return success(); 208 } 209 210 /// Try to get or create a new constant entry. On success this returns the 211 /// constant operation value, nullptr otherwise. 212 Operation *OperationFolder::tryGetOrCreateConstant( 213 ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder, 214 Attribute value, Type type, Location loc) { 215 // Check if an existing mapping already exists. 216 auto constKey = std::make_tuple(dialect, value, type); 217 auto *&constInst = uniquedConstants[constKey]; 218 if (constInst) 219 return constInst; 220 221 // If one doesn't exist, try to materialize one. 222 if (!(constInst = materializeConstant(dialect, builder, value, type, loc))) 223 return nullptr; 224 225 // Check to see if the generated constant is in the expected dialect. 226 auto *newDialect = constInst->getDialect(); 227 if (newDialect == dialect) { 228 referencedDialects[constInst].push_back(dialect); 229 return constInst; 230 } 231 232 // If it isn't, then we also need to make sure that the mapping for the new 233 // dialect is valid. 234 auto newKey = std::make_tuple(newDialect, value, type); 235 236 // If an existing operation in the new dialect already exists, delete the 237 // materialized operation in favor of the existing one. 238 if (auto *existingOp = uniquedConstants.lookup(newKey)) { 239 constInst->erase(); 240 referencedDialects[existingOp].push_back(dialect); 241 return constInst = existingOp; 242 } 243 244 // Otherwise, update the new dialect to the materialized operation. 245 referencedDialects[constInst].assign({dialect, newDialect}); 246 auto newIt = uniquedConstants.insert({newKey, constInst}); 247 return newIt.first->second; 248 }