github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/FunctionSupport.cpp (about) 1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/FunctionSupport.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/OpImplementation.h" 21 22 using namespace mlir; 23 24 static ParseResult 25 parseArgumentList(OpAsmParser *parser, bool allowVariadic, 26 SmallVectorImpl<Type> &argTypes, 27 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 28 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, 29 bool &isVariadic) { 30 if (parser->parseLParen()) 31 return failure(); 32 33 // The argument list either has to consistently have ssa-id's followed by 34 // types, or just be a type list. It isn't ok to sometimes have SSA ID's and 35 // sometimes not. 36 auto parseArgument = [&]() -> ParseResult { 37 llvm::SMLoc loc = parser->getCurrentLocation(); 38 39 // Parse argument name if present. 40 OpAsmParser::OperandType argument; 41 Type argumentType; 42 if (succeeded(parser->parseOptionalRegionArgument(argument)) && 43 !argument.name.empty()) { 44 // Reject this if the preceding argument was missing a name. 45 if (argNames.empty() && !argTypes.empty()) 46 return parser->emitError(loc, 47 "expected type instead of SSA identifier"); 48 argNames.push_back(argument); 49 50 if (parser->parseColonType(argumentType)) 51 return failure(); 52 } else if (allowVariadic && succeeded(parser->parseOptionalEllipsis())) { 53 isVariadic = true; 54 return success(); 55 } else if (!argNames.empty()) { 56 // Reject this if the preceding argument had a name. 57 return parser->emitError(loc, "expected SSA identifier"); 58 } else if (parser->parseType(argumentType)) { 59 return failure(); 60 } 61 62 // Add the argument type. 63 argTypes.push_back(argumentType); 64 65 // Parse any argument attributes. 66 SmallVector<NamedAttribute, 2> attrs; 67 if (parser->parseOptionalAttributeDict(attrs)) 68 return failure(); 69 argAttrs.push_back(attrs); 70 return success(); 71 }; 72 73 // Parse the function arguments. 74 if (parser->parseOptionalRParen()) { 75 do { 76 unsigned numTypedArguments = argTypes.size(); 77 if (parseArgument()) 78 return failure(); 79 80 llvm::SMLoc loc = parser->getCurrentLocation(); 81 if (argTypes.size() == numTypedArguments && 82 succeeded(parser->parseOptionalComma())) 83 return parser->emitError( 84 loc, "variadic arguments must be in the end of the argument list"); 85 } while (succeeded(parser->parseOptionalComma())); 86 parser->parseRParen(); 87 } 88 89 return success(); 90 } 91 92 /// Parse a function signature, starting with a name and including the 93 /// parameter list. 94 static ParseResult parseFunctionSignature( 95 OpAsmParser *parser, bool allowVariadic, 96 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 97 SmallVectorImpl<Type> &argTypes, 98 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic, 99 SmallVectorImpl<Type> &results) { 100 if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, 101 isVariadic)) 102 return failure(); 103 // Parse the return types if present. 104 return parser->parseOptionalArrowTypeList(results); 105 } 106 107 /// Parser implementation for function-like operations. Uses `funcTypeBuilder` 108 /// to construct the custom function type given lists of input and output types. 109 ParseResult 110 mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result, 111 bool allowVariadic, 112 mlir::impl::FuncTypeBuilder funcTypeBuilder) { 113 SmallVector<OpAsmParser::OperandType, 4> entryArgs; 114 SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs; 115 SmallVector<Type, 4> argTypes; 116 SmallVector<Type, 4> results; 117 auto &builder = parser->getBuilder(); 118 119 // Parse the name as a symbol reference attribute. 120 SymbolRefAttr nameAttr; 121 if (parser->parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 122 result->attributes)) 123 return failure(); 124 // Convert the parsed function attr into a string attr. 125 result->attributes.back().second = builder.getStringAttr(nameAttr.getValue()); 126 127 // Parse the function signature. 128 auto signatureLocation = parser->getCurrentLocation(); 129 bool isVariadic = false; 130 if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, 131 argAttrs, isVariadic, results)) 132 return failure(); 133 134 std::string errorMessage; 135 if (auto type = funcTypeBuilder(builder, argTypes, results, 136 impl::VariadicFlag(isVariadic), errorMessage)) 137 result->addAttribute(getTypeAttrName(), builder.getTypeAttr(type)); 138 else 139 return parser->emitError(signatureLocation) 140 << "failed to construct function type" 141 << (errorMessage.empty() ? "" : ": ") << errorMessage; 142 143 // If function attributes are present, parse them. 144 if (succeeded(parser->parseOptionalKeyword("attributes"))) 145 if (parser->parseOptionalAttributeDict(result->attributes)) 146 return failure(); 147 148 // Add the attributes to the function arguments. 149 SmallString<8> argAttrName; 150 for (unsigned i = 0, e = argTypes.size(); i != e; ++i) 151 if (!argAttrs[i].empty()) 152 result->addAttribute(getArgAttrName(i, argAttrName), 153 builder.getDictionaryAttr(argAttrs[i])); 154 155 // Parse the optional function body. 156 auto *body = result->addRegion(); 157 if (parser->parseOptionalRegion(*body, entryArgs, 158 entryArgs.empty() ? llvm::ArrayRef<Type>() 159 : argTypes)) 160 return failure(); 161 162 return success(); 163 } 164 165 /// Print the signature of the function-like operation `op`. Assumes `op` has 166 /// the FunctionLike trait and passed the verification. 167 static void printSignature(OpAsmPrinter *p, Operation *op, 168 ArrayRef<Type> argTypes, bool isVariadic, 169 ArrayRef<Type> results) { 170 Region &body = op->getRegion(0); 171 bool isExternal = body.empty(); 172 173 *p << '('; 174 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { 175 if (i > 0) 176 *p << ", "; 177 178 if (!isExternal) { 179 p->printOperand(body.front().getArgument(i)); 180 *p << ": "; 181 } 182 183 p->printType(argTypes[i]); 184 p->printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); 185 } 186 187 if (isVariadic) { 188 if (!argTypes.empty()) 189 *p << ", "; 190 *p << "..."; 191 } 192 193 *p << ')'; 194 p->printOptionalArrowTypeList(results); 195 } 196 197 /// Printer implementation for function-like operations. Accepts lists of 198 /// argument and result types to use while printing. 199 void mlir::impl::printFunctionLikeOp(OpAsmPrinter *p, Operation *op, 200 ArrayRef<Type> argTypes, bool isVariadic, 201 ArrayRef<Type> results) { 202 // Print the operation and the function name. 203 auto funcName = 204 op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName()) 205 .getValue(); 206 *p << op->getName() << " @" << funcName; 207 208 // Print the signature. 209 printSignature(p, op, argTypes, isVariadic, results); 210 211 // Print out function attributes, if present. 212 SmallVector<StringRef, 2> ignoredAttrs = { 213 ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; 214 215 // Ignore any argument attributes. 216 std::vector<SmallString<8>> argAttrStorage; 217 SmallString<8> argAttrName; 218 for (unsigned i = 0, e = argTypes.size(); i != e; ++i) 219 if (op->getAttr(getArgAttrName(i, argAttrName))) 220 argAttrStorage.emplace_back(argAttrName); 221 ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); 222 223 auto attrs = op->getAttrs(); 224 if (attrs.size() > ignoredAttrs.size()) { 225 *p << "\n attributes "; 226 p->printOptionalAttrDict(attrs, ignoredAttrs); 227 } 228 229 // Print the body if this is not an external function. 230 Region &body = op->getRegion(0); 231 if (!body.empty()) 232 p->printRegion(body, /*printEntryBlockArgs=*/false, 233 /*printBlockTerminators=*/true); 234 }