github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/TableGen/Operator.cpp (about) 1 //===- Operator.cpp - Operator class --------------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // Operator wrapper to simplify using TableGen Record defining a MLIR Op. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Operator.h" 23 #include "mlir/TableGen/OpTrait.h" 24 #include "mlir/TableGen/Predicate.h" 25 #include "mlir/TableGen/Type.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/TableGen/Error.h" 28 #include "llvm/TableGen/Record.h" 29 30 using namespace mlir; 31 32 using llvm::DagInit; 33 using llvm::DefInit; 34 using llvm::Record; 35 36 tblgen::Operator::Operator(const llvm::Record &def) 37 : dialect(def.getValueAsDef("opDialect")), def(def) { 38 // The first `_` in the op's TableGen def name is treated as separating the 39 // dialect prefix and the op class name. The dialect prefix will be ignored if 40 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered 41 // as part of the class name. 42 StringRef prefix; 43 std::tie(prefix, cppClassName) = def.getName().split('_'); 44 if (prefix.empty()) { 45 // Class name with a leading underscore and without dialect prefix 46 cppClassName = def.getName(); 47 } else if (cppClassName.empty()) { 48 // Class name without dialect prefix 49 cppClassName = prefix; 50 } 51 52 populateOpStructure(); 53 } 54 55 std::string tblgen::Operator::getOperationName() const { 56 auto prefix = dialect.getName(); 57 auto opName = def.getValueAsString("opName"); 58 if (prefix.empty()) 59 return opName; 60 return llvm::formatv("{0}.{1}", prefix, opName); 61 } 62 63 StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } 64 65 StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } 66 67 std::string tblgen::Operator::getQualCppClassName() const { 68 auto prefix = dialect.getCppNamespace(); 69 if (prefix.empty()) 70 return cppClassName; 71 return llvm::formatv("{0}::{1}", prefix, cppClassName); 72 } 73 74 int tblgen::Operator::getNumResults() const { 75 DagInit *results = def.getValueAsDag("results"); 76 return results->getNumArgs(); 77 } 78 79 StringRef tblgen::Operator::getExtraClassDeclaration() const { 80 constexpr auto attr = "extraClassDeclaration"; 81 if (def.isValueUnset(attr)) 82 return {}; 83 return def.getValueAsString(attr); 84 } 85 86 const llvm::Record &tblgen::Operator::getDef() const { return def; } 87 88 bool tblgen::Operator::isVariadic() const { 89 return getNumVariadicOperands() != 0 || getNumVariadicResults() != 0; 90 } 91 92 bool tblgen::Operator::skipDefaultBuilders() const { 93 return def.getValueAsBit("skipDefaultBuilders"); 94 } 95 96 auto tblgen::Operator::result_begin() -> value_iterator { 97 return results.begin(); 98 } 99 100 auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } 101 102 auto tblgen::Operator::getResults() -> value_range { 103 return {result_begin(), result_end()}; 104 } 105 106 tblgen::TypeConstraint 107 tblgen::Operator::getResultTypeConstraint(int index) const { 108 DagInit *results = def.getValueAsDag("results"); 109 return TypeConstraint(cast<DefInit>(results->getArg(index))); 110 } 111 112 StringRef tblgen::Operator::getResultName(int index) const { 113 DagInit *results = def.getValueAsDag("results"); 114 return results->getArgNameStr(index); 115 } 116 117 unsigned tblgen::Operator::getNumVariadicResults() const { 118 return std::count_if( 119 results.begin(), results.end(), 120 [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); 121 } 122 123 unsigned tblgen::Operator::getNumVariadicOperands() const { 124 return std::count_if( 125 operands.begin(), operands.end(), 126 [](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); }); 127 } 128 129 StringRef tblgen::Operator::getArgName(int index) const { 130 DagInit *argumentValues = def.getValueAsDag("arguments"); 131 return argumentValues->getArgName(index)->getValue(); 132 } 133 134 bool tblgen::Operator::hasTrait(StringRef trait) const { 135 for (auto t : getTraits()) { 136 if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) { 137 if (opTrait->getTrait() == trait) 138 return true; 139 } else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) { 140 if (opTrait->getTrait() == trait) 141 return true; 142 } 143 } 144 return false; 145 } 146 147 unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } 148 149 const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const { 150 return regions[index]; 151 } 152 153 auto tblgen::Operator::trait_begin() const -> const_trait_iterator { 154 return traits.begin(); 155 } 156 auto tblgen::Operator::trait_end() const -> const_trait_iterator { 157 return traits.end(); 158 } 159 auto tblgen::Operator::getTraits() const 160 -> llvm::iterator_range<const_trait_iterator> { 161 return {trait_begin(), trait_end()}; 162 } 163 164 auto tblgen::Operator::attribute_begin() const -> attribute_iterator { 165 return attributes.begin(); 166 } 167 auto tblgen::Operator::attribute_end() const -> attribute_iterator { 168 return attributes.end(); 169 } 170 auto tblgen::Operator::getAttributes() const 171 -> llvm::iterator_range<attribute_iterator> { 172 return {attribute_begin(), attribute_end()}; 173 } 174 175 auto tblgen::Operator::operand_begin() -> value_iterator { 176 return operands.begin(); 177 } 178 auto tblgen::Operator::operand_end() -> value_iterator { 179 return operands.end(); 180 } 181 auto tblgen::Operator::getOperands() -> value_range { 182 return {operand_begin(), operand_end()}; 183 } 184 185 auto tblgen::Operator::getArg(int index) const -> Argument { 186 return arguments[index]; 187 } 188 189 void tblgen::Operator::populateOpStructure() { 190 auto &recordKeeper = def.getRecords(); 191 auto typeConstraintClass = recordKeeper.getClass("TypeConstraint"); 192 auto attrClass = recordKeeper.getClass("Attr"); 193 auto derivedAttrClass = recordKeeper.getClass("DerivedAttr"); 194 numNativeAttributes = 0; 195 196 // The argument ordering is operands, native attributes, derived 197 // attributes. 198 DagInit *argumentValues = def.getValueAsDag("arguments"); 199 unsigned i = 0; 200 // Handle operands and native attributes. 201 for (unsigned e = argumentValues->getNumArgs(); i != e; ++i) { 202 auto arg = argumentValues->getArg(i); 203 auto givenName = argumentValues->getArgNameStr(i); 204 auto argDefInit = dyn_cast<DefInit>(arg); 205 if (!argDefInit) 206 PrintFatalError(def.getLoc(), 207 Twine("undefined type for argument #") + Twine(i)); 208 Record *argDef = argDefInit->getDef(); 209 210 if (argDef->isSubClassOf(typeConstraintClass)) { 211 operands.push_back( 212 NamedTypeConstraint{givenName, TypeConstraint(argDefInit)}); 213 arguments.emplace_back(&operands.back()); 214 } else if (argDef->isSubClassOf(attrClass)) { 215 if (givenName.empty()) 216 PrintFatalError(argDef->getLoc(), "attributes must be named"); 217 if (argDef->isSubClassOf(derivedAttrClass)) 218 PrintFatalError(argDef->getLoc(), 219 "derived attributes not allowed in argument list"); 220 attributes.push_back({givenName, Attribute(argDef)}); 221 arguments.emplace_back(&attributes.back()); 222 ++numNativeAttributes; 223 } else { 224 PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving " 225 "from TypeConstraint or Attr are allowed"); 226 } 227 } 228 229 // Handle derived attributes. 230 for (const auto &val : def.getValues()) { 231 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) { 232 if (!record->isSubClassOf(attrClass)) 233 continue; 234 if (!record->isSubClassOf(derivedAttrClass)) 235 PrintFatalError(def.getLoc(), 236 "unexpected Attr where only DerivedAttr is allowed"); 237 238 if (record->getClasses().size() != 1) { 239 PrintFatalError( 240 def.getLoc(), 241 "unsupported attribute modelling, only single class expected"); 242 } 243 attributes.push_back( 244 {cast<llvm::StringInit>(val.getNameInit())->getValue(), 245 Attribute(cast<DefInit>(val.getValue()))}); 246 } 247 } 248 249 auto *resultsDag = def.getValueAsDag("results"); 250 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); 251 if (!outsOp || outsOp->getDef()->getName() != "outs") { 252 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive"); 253 } 254 255 // Handle results. 256 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) { 257 auto name = resultsDag->getArgNameStr(i); 258 auto *resultDef = dyn_cast<DefInit>(resultsDag->getArg(i)); 259 if (!resultDef) { 260 PrintFatalError(def.getLoc(), 261 Twine("undefined type for result #") + Twine(i)); 262 } 263 results.push_back({name, TypeConstraint(resultDef)}); 264 } 265 266 auto traitListInit = def.getValueAsListInit("traits"); 267 if (!traitListInit) 268 return; 269 traits.reserve(traitListInit->size()); 270 for (auto traitInit : *traitListInit) 271 traits.push_back(OpTrait::create(traitInit)); 272 273 // Handle regions 274 auto *regionsDag = def.getValueAsDag("regions"); 275 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator()); 276 if (!regionsOp || regionsOp->getDef()->getName() != "region") { 277 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); 278 } 279 280 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { 281 auto name = regionsDag->getArgNameStr(i); 282 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i)); 283 if (!regionInit) { 284 PrintFatalError(def.getLoc(), 285 Twine("undefined kind for region #") + Twine(i)); 286 } 287 regions.push_back({name, Region(regionInit->getDef())}); 288 } 289 } 290 291 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); } 292 293 bool tblgen::Operator::hasDescription() const { 294 return def.getValue("description") != nullptr; 295 } 296 297 StringRef tblgen::Operator::getDescription() const { 298 return def.getValueAsString("description"); 299 } 300 301 bool tblgen::Operator::hasSummary() const { 302 return def.getValue("summary") != nullptr; 303 } 304 305 StringRef tblgen::Operator::getSummary() const { 306 return def.getValueAsString("summary"); 307 }