github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/IR/FunctionSupport.cpp (about)

     1  //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/IR/FunctionSupport.h"
    19  #include "mlir/IR/Builders.h"
    20  #include "mlir/IR/OpImplementation.h"
    21  
    22  using namespace mlir;
    23  
    24  static ParseResult
    25  parseArgumentList(OpAsmParser *parser, bool allowVariadic,
    26                    SmallVectorImpl<Type> &argTypes,
    27                    SmallVectorImpl<OpAsmParser::OperandType> &argNames,
    28                    SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs,
    29                    bool &isVariadic) {
    30    if (parser->parseLParen())
    31      return failure();
    32  
    33    // The argument list either has to consistently have ssa-id's followed by
    34    // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
    35    // sometimes not.
    36    auto parseArgument = [&]() -> ParseResult {
    37      llvm::SMLoc loc = parser->getCurrentLocation();
    38  
    39      // Parse argument name if present.
    40      OpAsmParser::OperandType argument;
    41      Type argumentType;
    42      if (succeeded(parser->parseOptionalRegionArgument(argument)) &&
    43          !argument.name.empty()) {
    44        // Reject this if the preceding argument was missing a name.
    45        if (argNames.empty() && !argTypes.empty())
    46          return parser->emitError(loc,
    47                                   "expected type instead of SSA identifier");
    48        argNames.push_back(argument);
    49  
    50        if (parser->parseColonType(argumentType))
    51          return failure();
    52      } else if (allowVariadic && succeeded(parser->parseOptionalEllipsis())) {
    53        isVariadic = true;
    54        return success();
    55      } else if (!argNames.empty()) {
    56        // Reject this if the preceding argument had a name.
    57        return parser->emitError(loc, "expected SSA identifier");
    58      } else if (parser->parseType(argumentType)) {
    59        return failure();
    60      }
    61  
    62      // Add the argument type.
    63      argTypes.push_back(argumentType);
    64  
    65      // Parse any argument attributes.
    66      SmallVector<NamedAttribute, 2> attrs;
    67      if (parser->parseOptionalAttributeDict(attrs))
    68        return failure();
    69      argAttrs.push_back(attrs);
    70      return success();
    71    };
    72  
    73    // Parse the function arguments.
    74    if (parser->parseOptionalRParen()) {
    75      do {
    76        unsigned numTypedArguments = argTypes.size();
    77        if (parseArgument())
    78          return failure();
    79  
    80        llvm::SMLoc loc = parser->getCurrentLocation();
    81        if (argTypes.size() == numTypedArguments &&
    82            succeeded(parser->parseOptionalComma()))
    83          return parser->emitError(
    84              loc, "variadic arguments must be in the end of the argument list");
    85      } while (succeeded(parser->parseOptionalComma()));
    86      parser->parseRParen();
    87    }
    88  
    89    return success();
    90  }
    91  
    92  /// Parse a function signature, starting with a name and including the
    93  /// parameter list.
    94  static ParseResult parseFunctionSignature(
    95      OpAsmParser *parser, bool allowVariadic,
    96      SmallVectorImpl<OpAsmParser::OperandType> &argNames,
    97      SmallVectorImpl<Type> &argTypes,
    98      SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
    99      SmallVectorImpl<Type> &results) {
   100    if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs,
   101                          isVariadic))
   102      return failure();
   103    // Parse the return types if present.
   104    return parser->parseOptionalArrowTypeList(results);
   105  }
   106  
   107  /// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
   108  /// to construct the custom function type given lists of input and output types.
   109  ParseResult
   110  mlir::impl::parseFunctionLikeOp(OpAsmParser *parser, OperationState *result,
   111                                  bool allowVariadic,
   112                                  mlir::impl::FuncTypeBuilder funcTypeBuilder) {
   113    SmallVector<OpAsmParser::OperandType, 4> entryArgs;
   114    SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
   115    SmallVector<Type, 4> argTypes;
   116    SmallVector<Type, 4> results;
   117    auto &builder = parser->getBuilder();
   118  
   119    // Parse the name as a symbol reference attribute.
   120    SymbolRefAttr nameAttr;
   121    if (parser->parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
   122                               result->attributes))
   123      return failure();
   124    // Convert the parsed function attr into a string attr.
   125    result->attributes.back().second = builder.getStringAttr(nameAttr.getValue());
   126  
   127    // Parse the function signature.
   128    auto signatureLocation = parser->getCurrentLocation();
   129    bool isVariadic = false;
   130    if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
   131                               argAttrs, isVariadic, results))
   132      return failure();
   133  
   134    std::string errorMessage;
   135    if (auto type = funcTypeBuilder(builder, argTypes, results,
   136                                    impl::VariadicFlag(isVariadic), errorMessage))
   137      result->addAttribute(getTypeAttrName(), builder.getTypeAttr(type));
   138    else
   139      return parser->emitError(signatureLocation)
   140             << "failed to construct function type"
   141             << (errorMessage.empty() ? "" : ": ") << errorMessage;
   142  
   143    // If function attributes are present, parse them.
   144    if (succeeded(parser->parseOptionalKeyword("attributes")))
   145      if (parser->parseOptionalAttributeDict(result->attributes))
   146        return failure();
   147  
   148    // Add the attributes to the function arguments.
   149    SmallString<8> argAttrName;
   150    for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
   151      if (!argAttrs[i].empty())
   152        result->addAttribute(getArgAttrName(i, argAttrName),
   153                             builder.getDictionaryAttr(argAttrs[i]));
   154  
   155    // Parse the optional function body.
   156    auto *body = result->addRegion();
   157    if (parser->parseOptionalRegion(*body, entryArgs,
   158                                    entryArgs.empty() ? llvm::ArrayRef<Type>()
   159                                                      : argTypes))
   160      return failure();
   161  
   162    return success();
   163  }
   164  
   165  /// Print the signature of the function-like operation `op`.  Assumes `op` has
   166  /// the FunctionLike trait and passed the verification.
   167  static void printSignature(OpAsmPrinter *p, Operation *op,
   168                             ArrayRef<Type> argTypes, bool isVariadic,
   169                             ArrayRef<Type> results) {
   170    Region &body = op->getRegion(0);
   171    bool isExternal = body.empty();
   172  
   173    *p << '(';
   174    for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
   175      if (i > 0)
   176        *p << ", ";
   177  
   178      if (!isExternal) {
   179        p->printOperand(body.front().getArgument(i));
   180        *p << ": ";
   181      }
   182  
   183      p->printType(argTypes[i]);
   184      p->printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
   185    }
   186  
   187    if (isVariadic) {
   188      if (!argTypes.empty())
   189        *p << ", ";
   190      *p << "...";
   191    }
   192  
   193    *p << ')';
   194    p->printOptionalArrowTypeList(results);
   195  }
   196  
   197  /// Printer implementation for function-like operations.  Accepts lists of
   198  /// argument and result types to use while printing.
   199  void mlir::impl::printFunctionLikeOp(OpAsmPrinter *p, Operation *op,
   200                                       ArrayRef<Type> argTypes, bool isVariadic,
   201                                       ArrayRef<Type> results) {
   202    // Print the operation and the function name.
   203    auto funcName =
   204        op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
   205            .getValue();
   206    *p << op->getName() << " @" << funcName;
   207  
   208    // Print the signature.
   209    printSignature(p, op, argTypes, isVariadic, results);
   210  
   211    // Print out function attributes, if present.
   212    SmallVector<StringRef, 2> ignoredAttrs = {
   213        ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
   214  
   215    // Ignore any argument attributes.
   216    std::vector<SmallString<8>> argAttrStorage;
   217    SmallString<8> argAttrName;
   218    for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
   219      if (op->getAttr(getArgAttrName(i, argAttrName)))
   220        argAttrStorage.emplace_back(argAttrName);
   221    ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
   222  
   223    auto attrs = op->getAttrs();
   224    if (attrs.size() > ignoredAttrs.size()) {
   225      *p << "\n  attributes ";
   226      p->printOptionalAttrDict(attrs, ignoredAttrs);
   227    }
   228  
   229    // Print the body if this is not an external function.
   230    Region &body = op->getRegion(0);
   231    if (!body.empty())
   232      p->printRegion(body, /*printEntryBlockArgs=*/false,
   233                     /*printBlockTerminators=*/true);
   234  }