github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Parser/Parser.cpp (about)

     1  //===- Parser.cpp - MLIR Parser Implementation ----------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements the parser for the MLIR textual form.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Parser.h"
    23  #include "Lexer.h"
    24  #include "mlir/Analysis/Verifier.h"
    25  #include "mlir/IR/AffineExpr.h"
    26  #include "mlir/IR/AffineMap.h"
    27  #include "mlir/IR/Attributes.h"
    28  #include "mlir/IR/Builders.h"
    29  #include "mlir/IR/Dialect.h"
    30  #include "mlir/IR/IntegerSet.h"
    31  #include "mlir/IR/Location.h"
    32  #include "mlir/IR/MLIRContext.h"
    33  #include "mlir/IR/Module.h"
    34  #include "mlir/IR/OpImplementation.h"
    35  #include "mlir/IR/StandardTypes.h"
    36  #include "mlir/Support/STLExtras.h"
    37  #include "llvm/ADT/APInt.h"
    38  #include "llvm/ADT/DenseMap.h"
    39  #include "llvm/ADT/StringSet.h"
    40  #include "llvm/ADT/bit.h"
    41  #include "llvm/Support/MemoryBuffer.h"
    42  #include "llvm/Support/PrettyStackTrace.h"
    43  #include "llvm/Support/SMLoc.h"
    44  #include "llvm/Support/SourceMgr.h"
    45  #include <algorithm>
    46  using namespace mlir;
    47  using llvm::MemoryBuffer;
    48  using llvm::SMLoc;
    49  using llvm::SourceMgr;
    50  
    51  namespace {
    52  class Parser;
    53  
    54  //===----------------------------------------------------------------------===//
    55  // ParserState
    56  //===----------------------------------------------------------------------===//
    57  
    58  /// This class refers to all of the state maintained globally by the parser,
    59  /// such as the current lexer position etc. The Parser base class provides
    60  /// methods to access this.
    61  class ParserState {
    62  public:
    63    ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
    64        : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {}
    65  
    66    // A map from attribute alias identifier to Attribute.
    67    llvm::StringMap<Attribute> attributeAliasDefinitions;
    68  
    69    // A map from type alias identifier to Type.
    70    llvm::StringMap<Type> typeAliasDefinitions;
    71  
    72  private:
    73    ParserState(const ParserState &) = delete;
    74    void operator=(const ParserState &) = delete;
    75  
    76    friend class Parser;
    77  
    78    // The context we're parsing into.
    79    MLIRContext *const context;
    80  
    81    // The lexer for the source file we're parsing.
    82    Lexer lex;
    83  
    84    // This is the next token that hasn't been consumed yet.
    85    Token curToken;
    86  };
    87  
    88  //===----------------------------------------------------------------------===//
    89  // Parser
    90  //===----------------------------------------------------------------------===//
    91  
    92  /// This class implement support for parsing global entities like types and
    93  /// shared entities like SSA names.  It is intended to be subclassed by
    94  /// specialized subparsers that include state, e.g. when a local symbol table.
    95  class Parser {
    96  public:
    97    Builder builder;
    98  
    99    Parser(ParserState &state) : builder(state.context), state(state) {}
   100  
   101    // Helper methods to get stuff from the parser-global state.
   102    ParserState &getState() const { return state; }
   103    MLIRContext *getContext() const { return state.context; }
   104    const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
   105  
   106    /// Parse a comma-separated list of elements up until the specified end token.
   107    ParseResult
   108    parseCommaSeparatedListUntil(Token::Kind rightToken,
   109                                 const std::function<ParseResult()> &parseElement,
   110                                 bool allowEmptyList = true);
   111  
   112    /// Parse a comma separated list of elements that must have at least one entry
   113    /// in it.
   114    ParseResult
   115    parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
   116  
   117    ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
   118  
   119    // We have two forms of parsing methods - those that return a non-null
   120    // pointer on success, and those that return a ParseResult to indicate whether
   121    // they returned a failure.  The second class fills in by-reference arguments
   122    // as the results of their action.
   123  
   124    //===--------------------------------------------------------------------===//
   125    // Error Handling
   126    //===--------------------------------------------------------------------===//
   127  
   128    /// Emit an error and return failure.
   129    InFlightDiagnostic emitError(const Twine &message = {}) {
   130      return emitError(state.curToken.getLoc(), message);
   131    }
   132    InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {});
   133  
   134    /// Encode the specified source location information into an attribute for
   135    /// attachment to the IR.
   136    Location getEncodedSourceLocation(llvm::SMLoc loc) {
   137      return state.lex.getEncodedSourceLocation(loc);
   138    }
   139  
   140    //===--------------------------------------------------------------------===//
   141    // Token Parsing
   142    //===--------------------------------------------------------------------===//
   143  
   144    /// Return the current token the parser is inspecting.
   145    const Token &getToken() const { return state.curToken; }
   146    StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
   147  
   148    /// If the current token has the specified kind, consume it and return true.
   149    /// If not, return false.
   150    bool consumeIf(Token::Kind kind) {
   151      if (state.curToken.isNot(kind))
   152        return false;
   153      consumeToken(kind);
   154      return true;
   155    }
   156  
   157    /// Advance the current lexer onto the next token.
   158    void consumeToken() {
   159      assert(state.curToken.isNot(Token::eof, Token::error) &&
   160             "shouldn't advance past EOF or errors");
   161      state.curToken = state.lex.lexToken();
   162    }
   163  
   164    /// Advance the current lexer onto the next token, asserting what the expected
   165    /// current token is.  This is preferred to the above method because it leads
   166    /// to more self-documenting code with better checking.
   167    void consumeToken(Token::Kind kind) {
   168      assert(state.curToken.is(kind) && "consumed an unexpected token");
   169      consumeToken();
   170    }
   171  
   172    /// Consume the specified token if present and return success.  On failure,
   173    /// output a diagnostic and return failure.
   174    ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
   175  
   176    //===--------------------------------------------------------------------===//
   177    // Type Parsing
   178    //===--------------------------------------------------------------------===//
   179  
   180    ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
   181    ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
   182    ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
   183  
   184    /// Parse an arbitrary type.
   185    Type parseType();
   186  
   187    /// Parse a complex type.
   188    Type parseComplexType();
   189  
   190    /// Parse an extended type.
   191    Type parseExtendedType();
   192  
   193    /// Parse a function type.
   194    Type parseFunctionType();
   195  
   196    /// Parse a memref type.
   197    Type parseMemRefType();
   198  
   199    /// Parse a non function type.
   200    Type parseNonFunctionType();
   201  
   202    /// Parse a tensor type.
   203    Type parseTensorType();
   204  
   205    /// Parse a tuple type.
   206    Type parseTupleType();
   207  
   208    /// Parse a vector type.
   209    VectorType parseVectorType();
   210    ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
   211                                         bool allowDynamic = true);
   212    ParseResult parseXInDimensionList();
   213  
   214    //===--------------------------------------------------------------------===//
   215    // Attribute Parsing
   216    //===--------------------------------------------------------------------===//
   217  
   218    /// Parse an arbitrary attribute with an optional type.
   219    Attribute parseAttribute(Type type = {});
   220  
   221    /// Parse an attribute dictionary.
   222    ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
   223  
   224    /// Parse an extended attribute.
   225    Attribute parseExtendedAttr(Type type);
   226  
   227    /// Parse a float attribute.
   228    Attribute parseFloatAttr(Type type, bool isNegative);
   229  
   230    /// Parse a decimal or a hexadecimal literal, which can be either an integer
   231    /// or a float attribute.
   232    Attribute parseDecOrHexAttr(Type type, bool isNegative);
   233  
   234    /// Parse an opaque elements attribute.
   235    Attribute parseOpaqueElementsAttr();
   236  
   237    /// Parse a dense elements attribute.
   238    Attribute parseDenseElementsAttr();
   239    ShapedType parseElementsLiteralType();
   240  
   241    /// Parse a sparse elements attribute.
   242    Attribute parseSparseElementsAttr();
   243  
   244    //===--------------------------------------------------------------------===//
   245    // Location Parsing
   246    //===--------------------------------------------------------------------===//
   247  
   248    /// Parse an inline location.
   249    ParseResult parseLocation(LocationAttr &loc);
   250  
   251    /// Parse a raw location instance.
   252    ParseResult parseLocationInstance(LocationAttr &loc);
   253  
   254    /// Parse a callsite location instance.
   255    ParseResult parseCallSiteLocation(LocationAttr &loc);
   256  
   257    /// Parse a fused location instance.
   258    ParseResult parseFusedLocation(LocationAttr &loc);
   259  
   260    /// Parse a name or FileLineCol location instance.
   261    ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
   262  
   263    /// Parse an optional trailing location.
   264    ///
   265    ///   trailing-location     ::= location?
   266    ///
   267    template <typename Owner>
   268    ParseResult parseOptionalTrailingLocation(Owner *owner) {
   269      // If there is a 'loc' we parse a trailing location.
   270      if (!getToken().is(Token::kw_loc))
   271        return success();
   272  
   273      // Parse the location.
   274      LocationAttr directLoc;
   275      if (parseLocation(directLoc))
   276        return failure();
   277      owner->setLoc(directLoc);
   278      return success();
   279    }
   280  
   281    //===--------------------------------------------------------------------===//
   282    // Affine Parsing
   283    //===--------------------------------------------------------------------===//
   284  
   285    ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
   286                                                    IntegerSet &set);
   287  
   288    /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
   289    ParseResult
   290    parseAffineMapOfSSAIds(AffineMap &map,
   291                           llvm::function_ref<ParseResult(bool)> parseElement);
   292  
   293  private:
   294    /// The Parser is subclassed and reinstantiated.  Do not add additional
   295    /// non-trivial state here, add it to the ParserState class.
   296    ParserState &state;
   297  };
   298  } // end anonymous namespace
   299  
   300  //===----------------------------------------------------------------------===//
   301  // Helper methods.
   302  //===----------------------------------------------------------------------===//
   303  
   304  /// Parse a comma separated list of elements that must have at least one entry
   305  /// in it.
   306  ParseResult Parser::parseCommaSeparatedList(
   307      const std::function<ParseResult()> &parseElement) {
   308    // Non-empty case starts with an element.
   309    if (parseElement())
   310      return failure();
   311  
   312    // Otherwise we have a list of comma separated elements.
   313    while (consumeIf(Token::comma)) {
   314      if (parseElement())
   315        return failure();
   316    }
   317    return success();
   318  }
   319  
   320  /// Parse a comma-separated list of elements, terminated with an arbitrary
   321  /// token.  This allows empty lists if allowEmptyList is true.
   322  ///
   323  ///   abstract-list ::= rightToken                  // if allowEmptyList == true
   324  ///   abstract-list ::= element (',' element)* rightToken
   325  ///
   326  ParseResult Parser::parseCommaSeparatedListUntil(
   327      Token::Kind rightToken, const std::function<ParseResult()> &parseElement,
   328      bool allowEmptyList) {
   329    // Handle the empty case.
   330    if (getToken().is(rightToken)) {
   331      if (!allowEmptyList)
   332        return emitError("expected list element");
   333      consumeToken(rightToken);
   334      return success();
   335    }
   336  
   337    if (parseCommaSeparatedList(parseElement) ||
   338        parseToken(rightToken, "expected ',' or '" +
   339                                   Token::getTokenSpelling(rightToken) + "'"))
   340      return failure();
   341  
   342    return success();
   343  }
   344  
   345  /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
   346  /// and may be recursive.  Return with the 'prettyName' StringRef encompasing
   347  /// the entire pretty name.
   348  ///
   349  ///   pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
   350  ///   pretty-dialect-sym-contents ::= pretty-dialect-sym-body
   351  ///                                  | '(' pretty-dialect-sym-contents+ ')'
   352  ///                                  | '[' pretty-dialect-sym-contents+ ']'
   353  ///                                  | '{' pretty-dialect-sym-contents+ '}'
   354  ///                                  | '[^[<({>\])}\0]+'
   355  ///
   356  ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
   357    // Pretty symbol names are a relatively unstructured format that contains a
   358    // series of properly nested punctuation, with anything else in the middle.
   359    // Scan ahead to find it and consume it if successful, otherwise emit an
   360    // error.
   361    auto *curPtr = getTokenSpelling().data();
   362  
   363    SmallVector<char, 8> nestedPunctuation;
   364  
   365    // Scan over the nested punctuation, bailing out on error and consuming until
   366    // we find the end.  We know that we're currently looking at the '<', so we
   367    // can go until we find the matching '>' character.
   368    assert(*curPtr == '<');
   369    do {
   370      char c = *curPtr++;
   371      switch (c) {
   372      case '\0':
   373        // This also handles the EOF case.
   374        return emitError("unexpected nul or EOF in pretty dialect name");
   375      case '<':
   376      case '[':
   377      case '(':
   378      case '{':
   379        nestedPunctuation.push_back(c);
   380        continue;
   381  
   382      case '-':
   383        // The sequence `->` is treated as special token.
   384        if (*curPtr == '>')
   385          ++curPtr;
   386        continue;
   387  
   388      case '>':
   389        if (nestedPunctuation.pop_back_val() != '<')
   390          return emitError("unbalanced '>' character in pretty dialect name");
   391        break;
   392      case ']':
   393        if (nestedPunctuation.pop_back_val() != '[')
   394          return emitError("unbalanced ']' character in pretty dialect name");
   395        break;
   396      case ')':
   397        if (nestedPunctuation.pop_back_val() != '(')
   398          return emitError("unbalanced ')' character in pretty dialect name");
   399        break;
   400      case '}':
   401        if (nestedPunctuation.pop_back_val() != '{')
   402          return emitError("unbalanced '}' character in pretty dialect name");
   403        break;
   404  
   405      default:
   406        continue;
   407      }
   408    } while (!nestedPunctuation.empty());
   409  
   410    // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
   411    // consuming all this stuff, and return.
   412    state.lex.resetPointer(curPtr);
   413  
   414    unsigned length = curPtr - prettyName.begin();
   415    prettyName = StringRef(prettyName.begin(), length);
   416    consumeToken();
   417    return success();
   418  }
   419  
   420  /// Parse an extended dialect symbol.
   421  template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
   422  static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
   423                                    SymbolAliasMap &aliases,
   424                                    CreateFn &&createSymbol) {
   425    // Parse the dialect namespace.
   426    StringRef identifier = p.getTokenSpelling().drop_front();
   427    auto loc = p.getToken().getLoc();
   428    p.consumeToken(identifierTok);
   429  
   430    // If there is no '<' token following this, and if the typename contains no
   431    // dot, then we are parsing a symbol alias.
   432    if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
   433      // Check for an alias for this type.
   434      auto aliasIt = aliases.find(identifier);
   435      if (aliasIt == aliases.end())
   436        return (p.emitError("undefined symbol alias id '" + identifier + "'"),
   437                nullptr);
   438      return aliasIt->second;
   439    }
   440  
   441    // Otherwise, we are parsing a dialect-specific symbol.  If the name contains
   442    // a dot, then this is the "pretty" form.  If not, it is the verbose form that
   443    // looks like <"...">.
   444    std::string symbolData;
   445    auto dialectName = identifier;
   446  
   447    // Handle the verbose form, where "identifier" is a simple dialect name.
   448    if (!identifier.contains('.')) {
   449      // Consume the '<'.
   450      if (p.parseToken(Token::less, "expected '<' in dialect type"))
   451        return nullptr;
   452  
   453      // Parse the symbol specific data.
   454      if (p.getToken().isNot(Token::string))
   455        return (p.emitError("expected string literal data in dialect symbol"),
   456                nullptr);
   457      symbolData = p.getToken().getStringValue();
   458      loc = p.getToken().getLoc();
   459      p.consumeToken(Token::string);
   460  
   461      // Consume the '>'.
   462      if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
   463        return nullptr;
   464    } else {
   465      // Ok, the dialect name is the part of the identifier before the dot, the
   466      // part after the dot is the dialect's symbol, or the start thereof.
   467      auto dotHalves = identifier.split('.');
   468      dialectName = dotHalves.first;
   469      auto prettyName = dotHalves.second;
   470  
   471      // If the dialect's symbol is followed immediately by a <, then lex the body
   472      // of it into prettyName.
   473      if (p.getToken().is(Token::less) &&
   474          prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
   475        if (p.parsePrettyDialectSymbolName(prettyName))
   476          return nullptr;
   477      }
   478  
   479      symbolData = prettyName.str();
   480    }
   481  
   482    // Call into the provided symbol construction function.
   483    auto encodedLoc = p.getEncodedSourceLocation(loc);
   484    return createSymbol(dialectName, symbolData, encodedLoc);
   485  }
   486  
   487  //===----------------------------------------------------------------------===//
   488  // Error Handling
   489  //===----------------------------------------------------------------------===//
   490  
   491  InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
   492    auto diag = mlir::emitError(getEncodedSourceLocation(loc), message);
   493  
   494    // If we hit a parse error in response to a lexer error, then the lexer
   495    // already reported the error.
   496    if (getToken().is(Token::error))
   497      diag.abandon();
   498    return diag;
   499  }
   500  
   501  //===----------------------------------------------------------------------===//
   502  // Token Parsing
   503  //===----------------------------------------------------------------------===//
   504  
   505  /// Consume the specified token if present and return success.  On failure,
   506  /// output a diagnostic and return failure.
   507  ParseResult Parser::parseToken(Token::Kind expectedToken,
   508                                 const Twine &message) {
   509    if (consumeIf(expectedToken))
   510      return success();
   511    return emitError(message);
   512  }
   513  
   514  //===----------------------------------------------------------------------===//
   515  // Type Parsing
   516  //===----------------------------------------------------------------------===//
   517  
   518  /// Parse an arbitrary type.
   519  ///
   520  ///   type ::= function-type
   521  ///          | non-function-type
   522  ///
   523  Type Parser::parseType() {
   524    if (getToken().is(Token::l_paren))
   525      return parseFunctionType();
   526    return parseNonFunctionType();
   527  }
   528  
   529  /// Parse a function result type.
   530  ///
   531  ///   function-result-type ::= type-list-parens
   532  ///                          | non-function-type
   533  ///
   534  ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
   535    if (getToken().is(Token::l_paren))
   536      return parseTypeListParens(elements);
   537  
   538    Type t = parseNonFunctionType();
   539    if (!t)
   540      return failure();
   541    elements.push_back(t);
   542    return success();
   543  }
   544  
   545  /// Parse a list of types without an enclosing parenthesis.  The list must have
   546  /// at least one member.
   547  ///
   548  ///   type-list-no-parens ::=  type (`,` type)*
   549  ///
   550  ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
   551    auto parseElt = [&]() -> ParseResult {
   552      auto elt = parseType();
   553      elements.push_back(elt);
   554      return elt ? success() : failure();
   555    };
   556  
   557    return parseCommaSeparatedList(parseElt);
   558  }
   559  
   560  /// Parse a parenthesized list of types.
   561  ///
   562  ///   type-list-parens ::= `(` `)`
   563  ///                      | `(` type-list-no-parens `)`
   564  ///
   565  ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
   566    if (parseToken(Token::l_paren, "expected '('"))
   567      return failure();
   568  
   569    // Handle empty lists.
   570    if (getToken().is(Token::r_paren))
   571      return consumeToken(), success();
   572  
   573    if (parseTypeListNoParens(elements) ||
   574        parseToken(Token::r_paren, "expected ')'"))
   575      return failure();
   576    return success();
   577  }
   578  
   579  /// Parse a complex type.
   580  ///
   581  ///   complex-type ::= `complex` `<` type `>`
   582  ///
   583  Type Parser::parseComplexType() {
   584    consumeToken(Token::kw_complex);
   585  
   586    // Parse the '<'.
   587    if (parseToken(Token::less, "expected '<' in complex type"))
   588      return nullptr;
   589  
   590    auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
   591    auto elementType = parseType();
   592    if (!elementType ||
   593        parseToken(Token::greater, "expected '>' in complex type"))
   594      return nullptr;
   595  
   596    return ComplexType::getChecked(elementType, typeLocation);
   597  }
   598  
   599  /// Parse an extended type.
   600  ///
   601  ///   extended-type ::= (dialect-type | type-alias)
   602  ///   dialect-type  ::= `!` dialect-namespace `<` `"` type-data `"` `>`
   603  ///   dialect-type  ::= `!` alias-name pretty-dialect-attribute-body?
   604  ///   type-alias    ::= `!` alias-name
   605  ///
   606  Type Parser::parseExtendedType() {
   607    return parseExtendedSymbol<Type>(
   608        *this, Token::exclamation_identifier, state.typeAliasDefinitions,
   609        [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type {
   610          // If we found a registered dialect, then ask it to parse the type.
   611          if (auto *dialect = state.context->getRegisteredDialect(dialectName))
   612            return dialect->parseType(symbolData, loc);
   613  
   614          // Otherwise, form a new opaque type.
   615          return OpaqueType::getChecked(
   616              Identifier::get(dialectName, state.context), symbolData,
   617              state.context, loc);
   618        });
   619  }
   620  
   621  /// Parse a function type.
   622  ///
   623  ///   function-type ::= type-list-parens `->` function-result-type
   624  ///
   625  Type Parser::parseFunctionType() {
   626    assert(getToken().is(Token::l_paren));
   627  
   628    SmallVector<Type, 4> arguments, results;
   629    if (parseTypeListParens(arguments) ||
   630        parseToken(Token::arrow, "expected '->' in function type") ||
   631        parseFunctionResultTypes(results))
   632      return nullptr;
   633  
   634    return builder.getFunctionType(arguments, results);
   635  }
   636  
   637  /// Parse a memref type.
   638  ///
   639  ///   memref-type ::= `memref` `<` dimension-list-ranked type
   640  ///                   (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
   641  ///
   642  ///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
   643  ///   memory-space ::= integer-literal /* | TODO: address-space-id */
   644  ///
   645  Type Parser::parseMemRefType() {
   646    consumeToken(Token::kw_memref);
   647  
   648    if (parseToken(Token::less, "expected '<' in memref type"))
   649      return nullptr;
   650  
   651    SmallVector<int64_t, 4> dimensions;
   652    if (parseDimensionListRanked(dimensions))
   653      return nullptr;
   654  
   655    // Parse the element type.
   656    auto typeLoc = getToken().getLoc();
   657    auto elementType = parseType();
   658    if (!elementType)
   659      return nullptr;
   660  
   661    // Parse semi-affine-map-composition.
   662    SmallVector<AffineMap, 2> affineMapComposition;
   663    unsigned memorySpace = 0;
   664    bool parsedMemorySpace = false;
   665  
   666    auto parseElt = [&]() -> ParseResult {
   667      if (getToken().is(Token::integer)) {
   668        // Parse memory space.
   669        if (parsedMemorySpace)
   670          return emitError("multiple memory spaces specified in memref type");
   671        auto v = getToken().getUnsignedIntegerValue();
   672        if (!v.hasValue())
   673          return emitError("invalid memory space in memref type");
   674        memorySpace = v.getValue();
   675        consumeToken(Token::integer);
   676        parsedMemorySpace = true;
   677      } else {
   678        // Parse affine map.
   679        if (parsedMemorySpace)
   680          return emitError("affine map after memory space in memref type");
   681        auto affineMap = parseAttribute();
   682        if (!affineMap)
   683          return failure();
   684  
   685        // Verify that the parsed attribute is an affine map.
   686        if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
   687          affineMapComposition.push_back(affineMapAttr.getValue());
   688        else
   689          return emitError("expected affine map in memref type");
   690      }
   691      return success();
   692    };
   693  
   694    // Parse a list of mappings and address space if present.
   695    if (consumeIf(Token::comma)) {
   696      // Parse comma separated list of affine maps, followed by memory space.
   697      if (parseCommaSeparatedListUntil(Token::greater, parseElt,
   698                                       /*allowEmptyList=*/false)) {
   699        return nullptr;
   700      }
   701    } else {
   702      if (parseToken(Token::greater, "expected ',' or '>' in memref type"))
   703        return nullptr;
   704    }
   705  
   706    return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
   707                                  memorySpace, getEncodedSourceLocation(typeLoc));
   708  }
   709  
   710  /// Parse any type except the function type.
   711  ///
   712  ///   non-function-type ::= integer-type
   713  ///                       | index-type
   714  ///                       | float-type
   715  ///                       | extended-type
   716  ///                       | vector-type
   717  ///                       | tensor-type
   718  ///                       | memref-type
   719  ///                       | complex-type
   720  ///                       | tuple-type
   721  ///                       | none-type
   722  ///
   723  ///   index-type ::= `index`
   724  ///   float-type ::= `f16` | `bf16` | `f32` | `f64`
   725  ///   none-type ::= `none`
   726  ///
   727  Type Parser::parseNonFunctionType() {
   728    switch (getToken().getKind()) {
   729    default:
   730      return (emitError("expected non-function type"), nullptr);
   731    case Token::kw_memref:
   732      return parseMemRefType();
   733    case Token::kw_tensor:
   734      return parseTensorType();
   735    case Token::kw_complex:
   736      return parseComplexType();
   737    case Token::kw_tuple:
   738      return parseTupleType();
   739    case Token::kw_vector:
   740      return parseVectorType();
   741    // integer-type
   742    case Token::inttype: {
   743      auto width = getToken().getIntTypeBitwidth();
   744      if (!width.hasValue())
   745        return (emitError("invalid integer width"), nullptr);
   746      auto loc = getEncodedSourceLocation(getToken().getLoc());
   747      consumeToken(Token::inttype);
   748      return IntegerType::getChecked(width.getValue(), builder.getContext(), loc);
   749    }
   750  
   751    // float-type
   752    case Token::kw_bf16:
   753      consumeToken(Token::kw_bf16);
   754      return builder.getBF16Type();
   755    case Token::kw_f16:
   756      consumeToken(Token::kw_f16);
   757      return builder.getF16Type();
   758    case Token::kw_f32:
   759      consumeToken(Token::kw_f32);
   760      return builder.getF32Type();
   761    case Token::kw_f64:
   762      consumeToken(Token::kw_f64);
   763      return builder.getF64Type();
   764  
   765    // index-type
   766    case Token::kw_index:
   767      consumeToken(Token::kw_index);
   768      return builder.getIndexType();
   769  
   770    // none-type
   771    case Token::kw_none:
   772      consumeToken(Token::kw_none);
   773      return builder.getNoneType();
   774  
   775    // extended type
   776    case Token::exclamation_identifier:
   777      return parseExtendedType();
   778    }
   779  }
   780  
   781  /// Parse a tensor type.
   782  ///
   783  ///   tensor-type ::= `tensor` `<` dimension-list type `>`
   784  ///   dimension-list ::= dimension-list-ranked | `*x`
   785  ///
   786  Type Parser::parseTensorType() {
   787    consumeToken(Token::kw_tensor);
   788  
   789    if (parseToken(Token::less, "expected '<' in tensor type"))
   790      return nullptr;
   791  
   792    bool isUnranked;
   793    SmallVector<int64_t, 4> dimensions;
   794  
   795    if (consumeIf(Token::star)) {
   796      // This is an unranked tensor type.
   797      isUnranked = true;
   798  
   799      if (parseXInDimensionList())
   800        return nullptr;
   801  
   802    } else {
   803      isUnranked = false;
   804      if (parseDimensionListRanked(dimensions))
   805        return nullptr;
   806    }
   807  
   808    // Parse the element type.
   809    auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
   810    auto elementType = parseType();
   811    if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
   812      return nullptr;
   813  
   814    if (isUnranked)
   815      return UnrankedTensorType::getChecked(elementType, typeLocation);
   816    return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
   817  }
   818  
   819  /// Parse a tuple type.
   820  ///
   821  ///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
   822  ///
   823  Type Parser::parseTupleType() {
   824    consumeToken(Token::kw_tuple);
   825  
   826    // Parse the '<'.
   827    if (parseToken(Token::less, "expected '<' in tuple type"))
   828      return nullptr;
   829  
   830    // Check for an empty tuple by directly parsing '>'.
   831    if (consumeIf(Token::greater))
   832      return TupleType::get(getContext());
   833  
   834    // Parse the element types and the '>'.
   835    SmallVector<Type, 4> types;
   836    if (parseTypeListNoParens(types) ||
   837        parseToken(Token::greater, "expected '>' in tuple type"))
   838      return nullptr;
   839  
   840    return TupleType::get(types, getContext());
   841  }
   842  
   843  /// Parse a vector type.
   844  ///
   845  ///   vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
   846  ///   non-empty-static-dimension-list ::= decimal-literal `x`
   847  ///                                       static-dimension-list
   848  ///   static-dimension-list ::= (decimal-literal `x`)*
   849  ///
   850  VectorType Parser::parseVectorType() {
   851    consumeToken(Token::kw_vector);
   852  
   853    if (parseToken(Token::less, "expected '<' in vector type"))
   854      return nullptr;
   855  
   856    SmallVector<int64_t, 4> dimensions;
   857    if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
   858      return nullptr;
   859    if (dimensions.empty())
   860      return (emitError("expected dimension size in vector type"), nullptr);
   861  
   862    // Parse the element type.
   863    auto typeLoc = getToken().getLoc();
   864    auto elementType = parseType();
   865    if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
   866      return nullptr;
   867  
   868    return VectorType::getChecked(dimensions, elementType,
   869                                  getEncodedSourceLocation(typeLoc));
   870  }
   871  
   872  /// Parse a dimension list of a tensor or memref type.  This populates the
   873  /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
   874  /// errors out on `?` otherwise.
   875  ///
   876  ///   dimension-list-ranked ::= (dimension `x`)*
   877  ///   dimension ::= `?` | decimal-literal
   878  ///
   879  /// When `allowDynamic` is not set, this is used to parse:
   880  ///
   881  ///   static-dimension-list ::= (decimal-literal `x`)*
   882  ParseResult
   883  Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
   884                                   bool allowDynamic) {
   885    while (getToken().isAny(Token::integer, Token::question)) {
   886      if (consumeIf(Token::question)) {
   887        if (!allowDynamic)
   888          return emitError("expected static shape");
   889        dimensions.push_back(-1);
   890      } else {
   891        // Hexadecimal integer literals (starting with `0x`) are not allowed in
   892        // aggregate type declarations.  Therefore, `0xf32` should be processed as
   893        // a sequence of separate elements `0`, `x`, `f32`.
   894        if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
   895          // We can get here only if the token is an integer literal.  Hexadecimal
   896          // integer literals can only start with `0x` (`1x` wouldn't lex as a
   897          // literal, just `1` would, at which point we don't get into this
   898          // branch).
   899          assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
   900          dimensions.push_back(0);
   901          state.lex.resetPointer(getTokenSpelling().data() + 1);
   902          consumeToken();
   903        } else {
   904          // Make sure this integer value is in bound and valid.
   905          auto dimension = getToken().getUnsignedIntegerValue();
   906          if (!dimension.hasValue())
   907            return emitError("invalid dimension");
   908          dimensions.push_back((int64_t)dimension.getValue());
   909          consumeToken(Token::integer);
   910        }
   911      }
   912  
   913      // Make sure we have an 'x' or something like 'xbf32'.
   914      if (parseXInDimensionList())
   915        return failure();
   916    }
   917  
   918    return success();
   919  }
   920  
   921  /// Parse an 'x' token in a dimension list, handling the case where the x is
   922  /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
   923  /// token.
   924  ParseResult Parser::parseXInDimensionList() {
   925    if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
   926      return emitError("expected 'x' in dimension list");
   927  
   928    // If we had a prefix of 'x', lex the next token immediately after the 'x'.
   929    if (getTokenSpelling().size() != 1)
   930      state.lex.resetPointer(getTokenSpelling().data() + 1);
   931  
   932    // Consume the 'x'.
   933    consumeToken(Token::bare_identifier);
   934  
   935    return success();
   936  }
   937  
   938  //===----------------------------------------------------------------------===//
   939  // Attribute parsing.
   940  //===----------------------------------------------------------------------===//
   941  
   942  /// Parse an arbitrary attribute.
   943  ///
   944  ///  attribute-value ::= `unit`
   945  ///                    | bool-literal
   946  ///                    | integer-literal (`:` (index-type | integer-type))?
   947  ///                    | float-literal (`:` float-type)?
   948  ///                    | string-literal (`:` type)?
   949  ///                    | type
   950  ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
   951  ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
   952  ///                    | symbol-ref-id
   953  ///                    | `dense` `<` attribute-value `>` `:`
   954  ///                      (tensor-type | vector-type)
   955  ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
   956  ///                      `:` (tensor-type | vector-type)
   957  ///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
   958  ///                      `>` `:` (tensor-type | vector-type)
   959  ///                    | extended-attribute
   960  ///
   961  Attribute Parser::parseAttribute(Type type) {
   962    switch (getToken().getKind()) {
   963    // Parse an AffineMap or IntegerSet attribute.
   964    case Token::l_paren: {
   965      // Try to parse an affine map or an integer set reference.
   966      AffineMap map;
   967      IntegerSet set;
   968      if (parseAffineMapOrIntegerSetReference(map, set))
   969        return nullptr;
   970      if (map)
   971        return builder.getAffineMapAttr(map);
   972      assert(set);
   973      return builder.getIntegerSetAttr(set);
   974    }
   975  
   976    // Parse an array attribute.
   977    case Token::l_square: {
   978      consumeToken(Token::l_square);
   979  
   980      SmallVector<Attribute, 4> elements;
   981      auto parseElt = [&]() -> ParseResult {
   982        elements.push_back(parseAttribute());
   983        return elements.back() ? success() : failure();
   984      };
   985  
   986      if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
   987        return nullptr;
   988      return builder.getArrayAttr(elements);
   989    }
   990  
   991    // Parse a boolean attribute.
   992    case Token::kw_false:
   993      consumeToken(Token::kw_false);
   994      return builder.getBoolAttr(false);
   995    case Token::kw_true:
   996      consumeToken(Token::kw_true);
   997      return builder.getBoolAttr(true);
   998  
   999    // Parse a dense elements attribute.
  1000    case Token::kw_dense:
  1001      return parseDenseElementsAttr();
  1002  
  1003    // Parse a dictionary attribute.
  1004    case Token::l_brace: {
  1005      SmallVector<NamedAttribute, 4> elements;
  1006      if (parseAttributeDict(elements))
  1007        return nullptr;
  1008      return builder.getDictionaryAttr(elements);
  1009    }
  1010  
  1011    // Parse an extended attribute, i.e. alias or dialect attribute.
  1012    case Token::hash_identifier:
  1013      return parseExtendedAttr(type);
  1014  
  1015    // Parse floating point and integer attributes.
  1016    case Token::floatliteral:
  1017      return parseFloatAttr(type, /*isNegative=*/false);
  1018    case Token::integer:
  1019      return parseDecOrHexAttr(type, /*isNegative=*/false);
  1020    case Token::minus: {
  1021      consumeToken(Token::minus);
  1022      if (getToken().is(Token::integer))
  1023        return parseDecOrHexAttr(type, /*isNegative=*/true);
  1024      if (getToken().is(Token::floatliteral))
  1025        return parseFloatAttr(type, /*isNegative=*/true);
  1026  
  1027      return (emitError("expected constant integer or floating point value"),
  1028              nullptr);
  1029    }
  1030  
  1031    // Parse a location attribute.
  1032    case Token::kw_loc: {
  1033      LocationAttr attr;
  1034      return failed(parseLocation(attr)) ? Attribute() : attr;
  1035    }
  1036  
  1037    // Parse an opaque elements attribute.
  1038    case Token::kw_opaque:
  1039      return parseOpaqueElementsAttr();
  1040  
  1041    // Parse a sparse elements attribute.
  1042    case Token::kw_sparse:
  1043      return parseSparseElementsAttr();
  1044  
  1045    // Parse a string attribute.
  1046    case Token::string: {
  1047      auto val = getToken().getStringValue();
  1048      consumeToken(Token::string);
  1049      // Parse the optional trailing colon type if one wasn't explicitly provided.
  1050      if (!type && consumeIf(Token::colon) && !(type = parseType()))
  1051        return Attribute();
  1052  
  1053      return type ? StringAttr::get(val, type)
  1054                  : StringAttr::get(val, getContext());
  1055    }
  1056  
  1057    // Parse a symbol reference attribute.
  1058    case Token::at_identifier: {
  1059      auto nameStr = getTokenSpelling();
  1060      consumeToken(Token::at_identifier);
  1061      return builder.getSymbolRefAttr(nameStr.drop_front());
  1062    }
  1063  
  1064    // Parse a 'unit' attribute.
  1065    case Token::kw_unit:
  1066      consumeToken(Token::kw_unit);
  1067      return builder.getUnitAttr();
  1068  
  1069    default:
  1070      // Parse a type attribute.
  1071      if (Type type = parseType())
  1072        return builder.getTypeAttr(type);
  1073      return nullptr;
  1074    }
  1075  }
  1076  
  1077  /// Attribute dictionary.
  1078  ///
  1079  ///   attribute-dict ::= `{` `}`
  1080  ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
  1081  ///   attribute-entry ::= bare-id `=` attribute-value
  1082  ///
  1083  ParseResult
  1084  Parser::parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes) {
  1085    if (!consumeIf(Token::l_brace))
  1086      return failure();
  1087  
  1088    auto parseElt = [&]() -> ParseResult {
  1089      // We allow keywords as attribute names.
  1090      if (getToken().isNot(Token::bare_identifier, Token::inttype) &&
  1091          !getToken().isKeyword())
  1092        return emitError("expected attribute name");
  1093      Identifier nameId = builder.getIdentifier(getTokenSpelling());
  1094      consumeToken();
  1095  
  1096      // Try to parse the '=' for the attribute value.
  1097      if (!consumeIf(Token::equal)) {
  1098        // If there is no '=', we treat this as a unit attribute.
  1099        attributes.push_back({nameId, builder.getUnitAttr()});
  1100        return success();
  1101      }
  1102  
  1103      auto attr = parseAttribute();
  1104      if (!attr)
  1105        return failure();
  1106  
  1107      attributes.push_back({nameId, attr});
  1108      return success();
  1109    };
  1110  
  1111    if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
  1112      return failure();
  1113  
  1114    return success();
  1115  }
  1116  
  1117  /// Parse an extended attribute.
  1118  ///
  1119  ///   extended-attribute ::= (dialect-attribute | attribute-alias)
  1120  ///   dialect-attribute  ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
  1121  ///   dialect-attribute  ::= `#` alias-name pretty-dialect-sym-body?
  1122  ///   attribute-alias    ::= `#` alias-name
  1123  ///
  1124  Attribute Parser::parseExtendedAttr(Type type) {
  1125    Attribute attr = parseExtendedSymbol<Attribute>(
  1126        *this, Token::hash_identifier, state.attributeAliasDefinitions,
  1127        [&](StringRef dialectName, StringRef symbolData,
  1128            Location loc) -> Attribute {
  1129          // Parse an optional trailing colon type.
  1130          Type attrType = type;
  1131          if (consumeIf(Token::colon) && !(attrType = parseType()))
  1132            return Attribute();
  1133  
  1134          // If we found a registered dialect, then ask it to parse the attribute.
  1135          if (auto *dialect = state.context->getRegisteredDialect(dialectName))
  1136            return dialect->parseAttribute(symbolData, attrType, loc);
  1137  
  1138          // Otherwise, form a new opaque attribute.
  1139          return OpaqueAttr::getChecked(
  1140              Identifier::get(dialectName, state.context), symbolData,
  1141              attrType ? attrType : NoneType::get(state.context), loc);
  1142        });
  1143  
  1144    // Ensure that the attribute has the same type as requested.
  1145    if (attr && type && attr.getType() != type) {
  1146      emitError("attribute type different than expected: expected ")
  1147          << type << ", but got " << attr.getType();
  1148      return nullptr;
  1149    }
  1150    return attr;
  1151  }
  1152  
  1153  /// Parse a float attribute.
  1154  Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
  1155    auto val = getToken().getFloatingPointValue();
  1156    if (!val.hasValue())
  1157      return (emitError("floating point value too large for attribute"), nullptr);
  1158    consumeToken(Token::floatliteral);
  1159    if (!type) {
  1160      // Default to F64 when no type is specified.
  1161      if (!consumeIf(Token::colon))
  1162        type = builder.getF64Type();
  1163      else if (!(type = parseType()))
  1164        return nullptr;
  1165    }
  1166    if (!type.isa<FloatType>())
  1167      return (emitError("floating point value not valid for specified type"),
  1168              nullptr);
  1169    return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
  1170  }
  1171  
  1172  /// Construct a float attribute bitwise equivalent to the integer literal.
  1173  static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type,
  1174                                                uint64_t value) {
  1175    int width = type.getIntOrFloatBitWidth();
  1176    APInt apInt(width, value);
  1177    if (apInt != value) {
  1178      p->emitError("hexadecimal float constant out of range for type");
  1179      return nullptr;
  1180    }
  1181    APFloat apFloat(type.getFloatSemantics(), apInt);
  1182    return p->builder.getFloatAttr(type, apFloat);
  1183  }
  1184  
  1185  /// Parse a decimal or a hexadecimal literal, which can be either an integer
  1186  /// or a float attribute.
  1187  Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
  1188    auto val = getToken().getUInt64IntegerValue();
  1189    if (!val.hasValue())
  1190      return (emitError("integer constant out of range for attribute"), nullptr);
  1191  
  1192    // Remember if the literal is hexadecimal.
  1193    StringRef spelling = getToken().getSpelling();
  1194    bool isHex = spelling.size() > 1 && spelling[1] == 'x';
  1195  
  1196    consumeToken(Token::integer);
  1197    if (!type) {
  1198      // Default to i64 if not type is specified.
  1199      if (!consumeIf(Token::colon))
  1200        type = builder.getIntegerType(64);
  1201      else if (!(type = parseType()))
  1202        return nullptr;
  1203    }
  1204  
  1205    // Hexadecimal representation of float literals is not supported for bfloat16.
  1206    // When supported, the literal should be unsigned.
  1207    auto floatType = type.dyn_cast<FloatType>();
  1208    if (floatType && !type.isBF16()) {
  1209      if (isNegative) {
  1210        emitError("hexadecimal float literal should not have a leading minus");
  1211        return nullptr;
  1212      }
  1213      if (!isHex) {
  1214        emitError("unexpected decimal integer literal for a float attribute")
  1215                .attachNote()
  1216            << "add a trailing dot to make the literal a float";
  1217        return nullptr;
  1218      }
  1219  
  1220      // Construct a float attribute bitwise equivalent to the integer literal.
  1221      return buildHexadecimalFloatLiteral(this, floatType, *val);
  1222    }
  1223  
  1224    if (!type.isIntOrIndex())
  1225      return (emitError("integer literal not valid for specified type"), nullptr);
  1226  
  1227    // Parse the integer literal.
  1228    int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
  1229    APInt apInt(width, *val, isNegative);
  1230    if (apInt != *val)
  1231      return (emitError("integer constant out of range for attribute"), nullptr);
  1232  
  1233    // Otherwise construct an integer attribute.
  1234    if (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)
  1235      return (emitError("integer constant out of range for attribute"), nullptr);
  1236  
  1237    return builder.getIntegerAttr(type, isNegative ? -apInt : apInt);
  1238  }
  1239  
  1240  /// Parse an opaque elements attribute.
  1241  Attribute Parser::parseOpaqueElementsAttr() {
  1242    consumeToken(Token::kw_opaque);
  1243    if (parseToken(Token::less, "expected '<' after 'opaque'"))
  1244      return nullptr;
  1245  
  1246    if (getToken().isNot(Token::string))
  1247      return (emitError("expected dialect namespace"), nullptr);
  1248  
  1249    auto name = getToken().getStringValue();
  1250    auto *dialect = builder.getContext()->getRegisteredDialect(name);
  1251    // TODO(shpeisman): Allow for having an unknown dialect on an opaque
  1252    // attribute. Otherwise, it can't be roundtripped without having the dialect
  1253    // registered.
  1254    if (!dialect)
  1255      return (emitError("no registered dialect with namespace '" + name + "'"),
  1256              nullptr);
  1257  
  1258    consumeToken(Token::string);
  1259    if (parseToken(Token::comma, "expected ','"))
  1260      return nullptr;
  1261  
  1262    if (getToken().getKind() != Token::string)
  1263      return (emitError("opaque string should start with '0x'"), nullptr);
  1264  
  1265    auto val = getToken().getStringValue();
  1266    if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
  1267      return (emitError("opaque string should start with '0x'"), nullptr);
  1268  
  1269    val = val.substr(2);
  1270    if (!llvm::all_of(val, llvm::isHexDigit))
  1271      return (emitError("opaque string only contains hex digits"), nullptr);
  1272  
  1273    consumeToken(Token::string);
  1274    if (parseToken(Token::greater, "expected '>'") ||
  1275        parseToken(Token::colon, "expected ':'"))
  1276      return nullptr;
  1277  
  1278    auto type = parseElementsLiteralType();
  1279    if (!type)
  1280      return nullptr;
  1281  
  1282    return builder.getOpaqueElementsAttr(dialect, type, llvm::fromHex(val));
  1283  }
  1284  
  1285  namespace {
  1286  class TensorLiteralParser {
  1287  public:
  1288    TensorLiteralParser(Parser &p) : p(p) {}
  1289  
  1290    ParseResult parse() {
  1291      if (p.getToken().is(Token::l_square))
  1292        return parseList(shape);
  1293      return parseElement();
  1294    }
  1295  
  1296    /// Build a dense attribute instance with the parsed elements and the given
  1297    /// shaped type.
  1298    DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
  1299  
  1300    ArrayRef<int64_t> getShape() const { return shape; }
  1301  
  1302  private:
  1303    enum class ElementKind { Boolean, Integer, Float };
  1304  
  1305    /// Return a string to represent the given element kind.
  1306    const char *getElementKindStr(ElementKind kind) {
  1307      switch (kind) {
  1308      case ElementKind::Boolean:
  1309        return "'boolean'";
  1310      case ElementKind::Integer:
  1311        return "'integer'";
  1312      case ElementKind::Float:
  1313        return "'float'";
  1314      }
  1315      llvm_unreachable("unknown element kind");
  1316    }
  1317  
  1318    /// Build a Dense Integer attribute for the given type.
  1319    DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type,
  1320                                 IntegerType eltTy);
  1321  
  1322    /// Build a Dense Float attribute for the given type.
  1323    DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
  1324                                   FloatType eltTy);
  1325  
  1326    /// Parse a single element, returning failure if it isn't a valid element
  1327    /// literal. For example:
  1328    /// parseElement(1) -> Success, 1
  1329    /// parseElement([1]) -> Failure
  1330    ParseResult parseElement();
  1331  
  1332    /// Parse a list of either lists or elements, returning the dimensions of the
  1333    /// parsed sub-tensors in dims. For example:
  1334    ///   parseList([1, 2, 3]) -> Success, [3]
  1335    ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
  1336    ///   parseList([[1, 2], 3]) -> Failure
  1337    ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
  1338    ParseResult parseList(llvm::SmallVectorImpl<int64_t> &dims);
  1339  
  1340    Parser &p;
  1341  
  1342    /// The shape inferred from the parsed elements.
  1343    SmallVector<int64_t, 4> shape;
  1344  
  1345    /// Storage used when parsing elements, this is a pair of <is_negated, token>.
  1346    std::vector<std::pair<bool, Token>> storage;
  1347  
  1348    /// A flag that indicates the type of elements that have been parsed.
  1349    llvm::Optional<ElementKind> knownEltKind;
  1350  };
  1351  } // namespace
  1352  
  1353  /// Build a dense attribute instance with the parsed elements and the given
  1354  /// shaped type.
  1355  DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
  1356                                                 ShapedType type) {
  1357    // Check that the parsed storage size has the same number of elements to the
  1358    // type, or is a known splat.
  1359    if (!shape.empty() && getShape() != type.getShape()) {
  1360      p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
  1361                       << "]) does not match type ([" << type.getShape() << "])";
  1362      return nullptr;
  1363    }
  1364  
  1365    // If the type is an integer, build a set of APInt values from the storage
  1366    // with the correct bitwidth.
  1367    if (auto intTy = type.getElementType().dyn_cast<IntegerType>())
  1368      return getIntAttr(loc, type, intTy);
  1369  
  1370    // Otherwise, this must be a floating point type.
  1371    auto floatTy = type.getElementType().dyn_cast<FloatType>();
  1372    if (!floatTy) {
  1373      p.emitError(loc) << "expected floating-point or integer element type, got "
  1374                       << type.getElementType();
  1375      return nullptr;
  1376    }
  1377    return getFloatAttr(loc, type, floatTy);
  1378  }
  1379  
  1380  /// Build a Dense Integer attribute for the given type.
  1381  DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
  1382                                                    ShapedType type,
  1383                                                    IntegerType eltTy) {
  1384    std::vector<APInt> intElements;
  1385    intElements.reserve(storage.size());
  1386    for (const auto &signAndToken : storage) {
  1387      bool isNegative = signAndToken.first;
  1388      const Token &token = signAndToken.second;
  1389  
  1390      // Check to see if floating point values were parsed.
  1391      if (token.is(Token::floatliteral)) {
  1392        p.emitError() << "expected integer elements, but parsed floating-point";
  1393        return nullptr;
  1394      }
  1395  
  1396      assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
  1397             "unexpected token type");
  1398      if (token.isAny(Token::kw_true, Token::kw_false)) {
  1399        if (!eltTy.isInteger(1))
  1400          p.emitError() << "expected i1 type for 'true' or 'false' values";
  1401        APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
  1402                    /*isSigned=*/false);
  1403        intElements.push_back(apInt);
  1404        continue;
  1405      }
  1406  
  1407      // Create APInt values for each element with the correct bitwidth.
  1408      auto val = token.getUInt64IntegerValue();
  1409      if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0
  1410                                         : (int64_t)val.getValue() < 0)) {
  1411        p.emitError(token.getLoc(),
  1412                    "integer constant out of range for attribute");
  1413        return nullptr;
  1414      }
  1415      APInt apInt(eltTy.getWidth(), val.getValue(), isNegative);
  1416      if (apInt != val.getValue())
  1417        return (p.emitError("integer constant out of range for type"), nullptr);
  1418      intElements.push_back(isNegative ? -apInt : apInt);
  1419    }
  1420  
  1421    return DenseElementsAttr::get(type, intElements);
  1422  }
  1423  
  1424  /// Build a Dense Float attribute for the given type.
  1425  DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
  1426                                                      ShapedType type,
  1427                                                      FloatType eltTy) {
  1428    std::vector<Attribute> floatValues;
  1429    floatValues.reserve(storage.size());
  1430    for (const auto &signAndToken : storage) {
  1431      bool isNegative = signAndToken.first;
  1432      const Token &token = signAndToken.second;
  1433  
  1434      // Handle hexadecimal float literals.
  1435      if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
  1436        if (isNegative) {
  1437          p.emitError(token.getLoc())
  1438              << "hexadecimal float literal should not have a leading minus";
  1439          return nullptr;
  1440        }
  1441        auto val = token.getUInt64IntegerValue();
  1442        if (!val.hasValue()) {
  1443          p.emitError("hexadecimal float constant out of range for attribute");
  1444          return nullptr;
  1445        }
  1446        FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val);
  1447        if (!attr)
  1448          return nullptr;
  1449        floatValues.push_back(attr);
  1450        continue;
  1451      }
  1452  
  1453      // Check to see if any decimal integers or booleans were parsed.
  1454      if (!token.is(Token::floatliteral)) {
  1455        p.emitError() << "expected floating-point elements, but parsed integer";
  1456        return nullptr;
  1457      }
  1458  
  1459      // Build the float values from tokens.
  1460      auto val = token.getFloatingPointValue();
  1461      if (!val.hasValue()) {
  1462        p.emitError("floating point value too large for attribute");
  1463        return nullptr;
  1464      }
  1465      floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val));
  1466    }
  1467  
  1468    return DenseElementsAttr::get(type, floatValues);
  1469  }
  1470  
  1471  ParseResult TensorLiteralParser::parseElement() {
  1472    switch (p.getToken().getKind()) {
  1473    // Parse a boolean element.
  1474    case Token::kw_true:
  1475    case Token::kw_false:
  1476    case Token::floatliteral:
  1477    case Token::integer:
  1478      storage.emplace_back(/*isNegative=*/false, p.getToken());
  1479      p.consumeToken();
  1480      break;
  1481  
  1482    // Parse a signed integer or a negative floating-point element.
  1483    case Token::minus:
  1484      p.consumeToken(Token::minus);
  1485      if (!p.getToken().isAny(Token::floatliteral, Token::integer))
  1486        return p.emitError("expected integer or floating point literal");
  1487      storage.emplace_back(/*isNegative=*/true, p.getToken());
  1488      p.consumeToken();
  1489      break;
  1490  
  1491    default:
  1492      return p.emitError("expected element literal of primitive type");
  1493    }
  1494  
  1495    return success();
  1496  }
  1497  
  1498  /// Parse a list of either lists or elements, returning the dimensions of the
  1499  /// parsed sub-tensors in dims. For example:
  1500  ///   parseList([1, 2, 3]) -> Success, [3]
  1501  ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
  1502  ///   parseList([[1, 2], 3]) -> Failure
  1503  ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
  1504  ParseResult
  1505  TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
  1506    p.consumeToken(Token::l_square);
  1507  
  1508    auto checkDims =
  1509        [&](const llvm::SmallVectorImpl<int64_t> &prevDims,
  1510            const llvm::SmallVectorImpl<int64_t> &newDims) -> ParseResult {
  1511      if (prevDims == newDims)
  1512        return success();
  1513      return p.emitError("tensor literal is invalid; ranks are not consistent "
  1514                         "between elements");
  1515    };
  1516  
  1517    bool first = true;
  1518    llvm::SmallVector<int64_t, 4> newDims;
  1519    unsigned size = 0;
  1520    auto parseCommaSeparatedList = [&]() -> ParseResult {
  1521      llvm::SmallVector<int64_t, 4> thisDims;
  1522      if (p.getToken().getKind() == Token::l_square) {
  1523        if (parseList(thisDims))
  1524          return failure();
  1525      } else if (parseElement()) {
  1526        return failure();
  1527      }
  1528      ++size;
  1529      if (!first)
  1530        return checkDims(newDims, thisDims);
  1531      newDims = thisDims;
  1532      first = false;
  1533      return success();
  1534    };
  1535    if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
  1536      return failure();
  1537  
  1538    // Return the sublists' dimensions with 'size' prepended.
  1539    dims.clear();
  1540    dims.push_back(size);
  1541    dims.append(newDims.begin(), newDims.end());
  1542    return success();
  1543  }
  1544  
  1545  /// Parse a dense elements attribute.
  1546  Attribute Parser::parseDenseElementsAttr() {
  1547    consumeToken(Token::kw_dense);
  1548    if (parseToken(Token::less, "expected '<' after 'dense'"))
  1549      return nullptr;
  1550  
  1551    // Parse the literal data.
  1552    TensorLiteralParser literalParser(*this);
  1553    if (literalParser.parse())
  1554      return nullptr;
  1555  
  1556    if (parseToken(Token::greater, "expected '>'") ||
  1557        parseToken(Token::colon, "expected ':'"))
  1558      return nullptr;
  1559  
  1560    auto typeLoc = getToken().getLoc();
  1561    auto type = parseElementsLiteralType();
  1562    if (!type)
  1563      return nullptr;
  1564    return literalParser.getAttr(typeLoc, type);
  1565  }
  1566  
  1567  /// Shaped type for elements attribute.
  1568  ///
  1569  ///   elements-literal-type ::= vector-type | ranked-tensor-type
  1570  ///
  1571  /// This method also checks the type has static shape.
  1572  ShapedType Parser::parseElementsLiteralType() {
  1573    auto type = parseType();
  1574    if (!type)
  1575      return nullptr;
  1576  
  1577    if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
  1578      emitError("elements literal must be a ranked tensor or vector type");
  1579      return nullptr;
  1580    }
  1581  
  1582    auto sType = type.cast<ShapedType>();
  1583    if (!sType.hasStaticShape())
  1584      return (emitError("elements literal type must have static shape"), nullptr);
  1585  
  1586    return sType;
  1587  }
  1588  
  1589  /// Parse a sparse elements attribute.
  1590  Attribute Parser::parseSparseElementsAttr() {
  1591    consumeToken(Token::kw_sparse);
  1592    if (parseToken(Token::less, "Expected '<' after 'sparse'"))
  1593      return nullptr;
  1594  
  1595    /// Parse indices
  1596    auto indicesLoc = getToken().getLoc();
  1597    TensorLiteralParser indiceParser(*this);
  1598    if (indiceParser.parse())
  1599      return nullptr;
  1600  
  1601    if (parseToken(Token::comma, "expected ','"))
  1602      return nullptr;
  1603  
  1604    /// Parse values.
  1605    auto valuesLoc = getToken().getLoc();
  1606    TensorLiteralParser valuesParser(*this);
  1607    if (valuesParser.parse())
  1608      return nullptr;
  1609  
  1610    if (parseToken(Token::greater, "expected '>'") ||
  1611        parseToken(Token::colon, "expected ':'"))
  1612      return nullptr;
  1613  
  1614    auto type = parseElementsLiteralType();
  1615    if (!type)
  1616      return nullptr;
  1617  
  1618    // If the indices are a splat, i.e. the literal parser parsed an element and
  1619    // not a list, we set the shape explicitly. The indices are represented by a
  1620    // 2-dimensional shape where the second dimension is the rank of the type.
  1621    // Given that the parsed indices is a splat, we know that we only have one
  1622    // indice and thus one for the first dimension.
  1623    auto indiceEltType = builder.getIntegerType(64);
  1624    ShapedType indicesType;
  1625    if (indiceParser.getShape().empty()) {
  1626      indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
  1627    } else {
  1628      // Otherwise, set the shape to the one parsed by the literal parser.
  1629      indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
  1630    }
  1631    auto indices = indiceParser.getAttr(indicesLoc, indicesType);
  1632  
  1633    // If the values are a splat, set the shape explicitly based on the number of
  1634    // indices. The number of indices is encoded in the first dimension of the
  1635    // indice shape type.
  1636    auto valuesEltType = type.getElementType();
  1637    ShapedType valuesType =
  1638        valuesParser.getShape().empty()
  1639            ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
  1640            : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
  1641    auto values = valuesParser.getAttr(valuesLoc, valuesType);
  1642  
  1643    /// Sanity check.
  1644    if (valuesType.getRank() != 1)
  1645      return (emitError("expected 1-d tensor for values"), nullptr);
  1646  
  1647    auto sameShape = (indicesType.getRank() == 1) ||
  1648                     (type.getRank() == indicesType.getDimSize(1));
  1649    auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
  1650    if (!sameShape || !sameElementNum) {
  1651      emitError() << "expected shape ([" << type.getShape()
  1652                  << "]); inferred shape of indices literal (["
  1653                  << indicesType.getShape()
  1654                  << "]); inferred shape of values literal (["
  1655                  << valuesType.getShape() << "])";
  1656      return nullptr;
  1657    }
  1658  
  1659    // Build the sparse elements attribute by the indices and values.
  1660    return SparseElementsAttr::get(type, indices, values);
  1661  }
  1662  
  1663  //===----------------------------------------------------------------------===//
  1664  // Location parsing.
  1665  //===----------------------------------------------------------------------===//
  1666  
  1667  /// Parse a location.
  1668  ///
  1669  ///   location           ::= `loc` inline-location
  1670  ///   inline-location    ::= '(' location-inst ')'
  1671  ///
  1672  ParseResult Parser::parseLocation(LocationAttr &loc) {
  1673    // Check for 'loc' identifier.
  1674    if (parseToken(Token::kw_loc, "expected 'loc' keyword"))
  1675      return emitError();
  1676  
  1677    // Parse the inline-location.
  1678    if (parseToken(Token::l_paren, "expected '(' in inline location") ||
  1679        parseLocationInstance(loc) ||
  1680        parseToken(Token::r_paren, "expected ')' in inline location"))
  1681      return failure();
  1682    return success();
  1683  }
  1684  
  1685  /// Specific location instances.
  1686  ///
  1687  /// location-inst ::= filelinecol-location |
  1688  ///                   name-location |
  1689  ///                   callsite-location |
  1690  ///                   fused-location |
  1691  ///                   unknown-location
  1692  /// filelinecol-location ::= string-literal ':' integer-literal
  1693  ///                                         ':' integer-literal
  1694  /// name-location ::= string-literal
  1695  /// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')'
  1696  /// fused-location ::= fused ('<' attribute-value '>')?
  1697  ///                    '[' location-inst (location-inst ',')* ']'
  1698  /// unknown-location ::= 'unknown'
  1699  ///
  1700  ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
  1701    consumeToken(Token::bare_identifier);
  1702  
  1703    // Parse the '('.
  1704    if (parseToken(Token::l_paren, "expected '(' in callsite location"))
  1705      return failure();
  1706  
  1707    // Parse the callee location.
  1708    LocationAttr calleeLoc;
  1709    if (parseLocationInstance(calleeLoc))
  1710      return failure();
  1711  
  1712    // Parse the 'at'.
  1713    if (getToken().isNot(Token::bare_identifier) ||
  1714        getToken().getSpelling() != "at")
  1715      return emitError("expected 'at' in callsite location");
  1716    consumeToken(Token::bare_identifier);
  1717  
  1718    // Parse the caller location.
  1719    LocationAttr callerLoc;
  1720    if (parseLocationInstance(callerLoc))
  1721      return failure();
  1722  
  1723    // Parse the ')'.
  1724    if (parseToken(Token::r_paren, "expected ')' in callsite location"))
  1725      return failure();
  1726  
  1727    // Return the callsite location.
  1728    loc = CallSiteLoc::get(calleeLoc, callerLoc);
  1729    return success();
  1730  }
  1731  
  1732  ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
  1733    consumeToken(Token::bare_identifier);
  1734  
  1735    // Try to parse the optional metadata.
  1736    Attribute metadata;
  1737    if (consumeIf(Token::less)) {
  1738      metadata = parseAttribute();
  1739      if (!metadata)
  1740        return emitError("expected valid attribute metadata");
  1741      // Parse the '>' token.
  1742      if (parseToken(Token::greater,
  1743                     "expected '>' after fused location metadata"))
  1744        return failure();
  1745    }
  1746  
  1747    llvm::SmallVector<Location, 4> locations;
  1748    auto parseElt = [&] {
  1749      LocationAttr newLoc;
  1750      if (parseLocationInstance(newLoc))
  1751        return failure();
  1752      locations.push_back(newLoc);
  1753      return success();
  1754    };
  1755  
  1756    if (parseToken(Token::l_square, "expected '[' in fused location") ||
  1757        parseCommaSeparatedList(parseElt) ||
  1758        parseToken(Token::r_square, "expected ']' in fused location"))
  1759      return failure();
  1760  
  1761    // Return the fused location.
  1762    loc = FusedLoc::get(locations, metadata, getContext());
  1763    return success();
  1764  }
  1765  
  1766  ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
  1767    auto *ctx = getContext();
  1768    auto str = getToken().getStringValue();
  1769    consumeToken(Token::string);
  1770  
  1771    // If the next token is ':' this is a filelinecol location.
  1772    if (consumeIf(Token::colon)) {
  1773      // Parse the line number.
  1774      if (getToken().isNot(Token::integer))
  1775        return emitError("expected integer line number in FileLineColLoc");
  1776      auto line = getToken().getUnsignedIntegerValue();
  1777      if (!line.hasValue())
  1778        return emitError("expected integer line number in FileLineColLoc");
  1779      consumeToken(Token::integer);
  1780  
  1781      // Parse the ':'.
  1782      if (parseToken(Token::colon, "expected ':' in FileLineColLoc"))
  1783        return failure();
  1784  
  1785      // Parse the column number.
  1786      if (getToken().isNot(Token::integer))
  1787        return emitError("expected integer column number in FileLineColLoc");
  1788      auto column = getToken().getUnsignedIntegerValue();
  1789      if (!column.hasValue())
  1790        return emitError("expected integer column number in FileLineColLoc");
  1791      consumeToken(Token::integer);
  1792  
  1793      loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx);
  1794      return success();
  1795    }
  1796  
  1797    // Otherwise, this is a NameLoc.
  1798  
  1799    // Check for a child location.
  1800    if (consumeIf(Token::l_paren)) {
  1801      auto childSourceLoc = getToken().getLoc();
  1802  
  1803      // Parse the child location.
  1804      LocationAttr childLoc;
  1805      if (parseLocationInstance(childLoc))
  1806        return failure();
  1807  
  1808      // The child must not be another NameLoc.
  1809      if (childLoc.isa<NameLoc>())
  1810        return emitError(childSourceLoc,
  1811                         "child of NameLoc cannot be another NameLoc");
  1812      loc = NameLoc::get(Identifier::get(str, ctx), childLoc);
  1813  
  1814      // Parse the closing ')'.
  1815      if (parseToken(Token::r_paren,
  1816                     "expected ')' after child location of NameLoc"))
  1817        return failure();
  1818    } else {
  1819      loc = NameLoc::get(Identifier::get(str, ctx), ctx);
  1820    }
  1821  
  1822    return success();
  1823  }
  1824  
  1825  ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
  1826    // Handle either name or filelinecol locations.
  1827    if (getToken().is(Token::string))
  1828      return parseNameOrFileLineColLocation(loc);
  1829  
  1830    // Bare tokens required for other cases.
  1831    if (!getToken().is(Token::bare_identifier))
  1832      return emitError("expected location instance");
  1833  
  1834    // Check for the 'callsite' signifying a callsite location.
  1835    if (getToken().getSpelling() == "callsite")
  1836      return parseCallSiteLocation(loc);
  1837  
  1838    // If the token is 'fused', then this is a fused location.
  1839    if (getToken().getSpelling() == "fused")
  1840      return parseFusedLocation(loc);
  1841  
  1842    // Check for a 'unknown' for an unknown location.
  1843    if (getToken().getSpelling() == "unknown") {
  1844      consumeToken(Token::bare_identifier);
  1845      loc = UnknownLoc::get(getContext());
  1846      return success();
  1847    }
  1848  
  1849    return emitError("expected location instance");
  1850  }
  1851  
  1852  //===----------------------------------------------------------------------===//
  1853  // Affine parsing.
  1854  //===----------------------------------------------------------------------===//
  1855  
  1856  /// Lower precedence ops (all at the same precedence level). LNoOp is false in
  1857  /// the boolean sense.
  1858  enum AffineLowPrecOp {
  1859    /// Null value.
  1860    LNoOp,
  1861    Add,
  1862    Sub
  1863  };
  1864  
  1865  /// Higher precedence ops - all at the same precedence level. HNoOp is false
  1866  /// in the boolean sense.
  1867  enum AffineHighPrecOp {
  1868    /// Null value.
  1869    HNoOp,
  1870    Mul,
  1871    FloorDiv,
  1872    CeilDiv,
  1873    Mod
  1874  };
  1875  
  1876  namespace {
  1877  /// This is a specialized parser for affine structures (affine maps, affine
  1878  /// expressions, and integer sets), maintaining the state transient to their
  1879  /// bodies.
  1880  class AffineParser : public Parser {
  1881  public:
  1882    AffineParser(ParserState &state, bool allowParsingSSAIds = false,
  1883                 llvm::function_ref<ParseResult(bool)> parseElement = nullptr)
  1884        : Parser(state), allowParsingSSAIds(allowParsingSSAIds),
  1885          parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {}
  1886  
  1887    AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols);
  1888    ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
  1889    IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
  1890    ParseResult parseAffineMapOfSSAIds(AffineMap &map);
  1891    void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
  1892                                unsigned &numDims);
  1893  
  1894  private:
  1895    // Binary affine op parsing.
  1896    AffineLowPrecOp consumeIfLowPrecOp();
  1897    AffineHighPrecOp consumeIfHighPrecOp();
  1898  
  1899    // Identifier lists for polyhedral structures.
  1900    ParseResult parseDimIdList(unsigned &numDims);
  1901    ParseResult parseSymbolIdList(unsigned &numSymbols);
  1902    ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
  1903                                                unsigned &numSymbols);
  1904    ParseResult parseIdentifierDefinition(AffineExpr idExpr);
  1905  
  1906    AffineExpr parseAffineExpr();
  1907    AffineExpr parseParentheticalExpr();
  1908    AffineExpr parseNegateExpression(AffineExpr lhs);
  1909    AffineExpr parseIntegerExpr();
  1910    AffineExpr parseBareIdExpr();
  1911    AffineExpr parseSSAIdExpr(bool isSymbol);
  1912    AffineExpr parseSymbolSSAIdExpr();
  1913  
  1914    AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
  1915                                     AffineExpr rhs, SMLoc opLoc);
  1916    AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
  1917                                     AffineExpr rhs);
  1918    AffineExpr parseAffineOperandExpr(AffineExpr lhs);
  1919    AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
  1920    AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
  1921                                         SMLoc llhsOpLoc);
  1922    AffineExpr parseAffineConstraint(bool *isEq);
  1923  
  1924  private:
  1925    bool allowParsingSSAIds;
  1926    llvm::function_ref<ParseResult(bool)> parseElement;
  1927    unsigned numDimOperands;
  1928    unsigned numSymbolOperands;
  1929    SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
  1930  };
  1931  } // end anonymous namespace
  1932  
  1933  /// Create an affine binary high precedence op expression (mul's, div's, mod).
  1934  /// opLoc is the location of the op token to be used to report errors
  1935  /// for non-conforming expressions.
  1936  AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
  1937                                                 AffineExpr lhs, AffineExpr rhs,
  1938                                                 SMLoc opLoc) {
  1939    // TODO: make the error location info accurate.
  1940    switch (op) {
  1941    case Mul:
  1942      if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
  1943        emitError(opLoc, "non-affine expression: at least one of the multiply "
  1944                         "operands has to be either a constant or symbolic");
  1945        return nullptr;
  1946      }
  1947      return lhs * rhs;
  1948    case FloorDiv:
  1949      if (!rhs.isSymbolicOrConstant()) {
  1950        emitError(opLoc, "non-affine expression: right operand of floordiv "
  1951                         "has to be either a constant or symbolic");
  1952        return nullptr;
  1953      }
  1954      return lhs.floorDiv(rhs);
  1955    case CeilDiv:
  1956      if (!rhs.isSymbolicOrConstant()) {
  1957        emitError(opLoc, "non-affine expression: right operand of ceildiv "
  1958                         "has to be either a constant or symbolic");
  1959        return nullptr;
  1960      }
  1961      return lhs.ceilDiv(rhs);
  1962    case Mod:
  1963      if (!rhs.isSymbolicOrConstant()) {
  1964        emitError(opLoc, "non-affine expression: right operand of mod "
  1965                         "has to be either a constant or symbolic");
  1966        return nullptr;
  1967      }
  1968      return lhs % rhs;
  1969    case HNoOp:
  1970      llvm_unreachable("can't create affine expression for null high prec op");
  1971      return nullptr;
  1972    }
  1973    llvm_unreachable("Unknown AffineHighPrecOp");
  1974  }
  1975  
  1976  /// Create an affine binary low precedence op expression (add, sub).
  1977  AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
  1978                                                 AffineExpr lhs, AffineExpr rhs) {
  1979    switch (op) {
  1980    case AffineLowPrecOp::Add:
  1981      return lhs + rhs;
  1982    case AffineLowPrecOp::Sub:
  1983      return lhs - rhs;
  1984    case AffineLowPrecOp::LNoOp:
  1985      llvm_unreachable("can't create affine expression for null low prec op");
  1986      return nullptr;
  1987    }
  1988    llvm_unreachable("Unknown AffineLowPrecOp");
  1989  }
  1990  
  1991  /// Consume this token if it is a lower precedence affine op (there are only
  1992  /// two precedence levels).
  1993  AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
  1994    switch (getToken().getKind()) {
  1995    case Token::plus:
  1996      consumeToken(Token::plus);
  1997      return AffineLowPrecOp::Add;
  1998    case Token::minus:
  1999      consumeToken(Token::minus);
  2000      return AffineLowPrecOp::Sub;
  2001    default:
  2002      return AffineLowPrecOp::LNoOp;
  2003    }
  2004  }
  2005  
  2006  /// Consume this token if it is a higher precedence affine op (there are only
  2007  /// two precedence levels)
  2008  AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
  2009    switch (getToken().getKind()) {
  2010    case Token::star:
  2011      consumeToken(Token::star);
  2012      return Mul;
  2013    case Token::kw_floordiv:
  2014      consumeToken(Token::kw_floordiv);
  2015      return FloorDiv;
  2016    case Token::kw_ceildiv:
  2017      consumeToken(Token::kw_ceildiv);
  2018      return CeilDiv;
  2019    case Token::kw_mod:
  2020      consumeToken(Token::kw_mod);
  2021      return Mod;
  2022    default:
  2023      return HNoOp;
  2024    }
  2025  }
  2026  
  2027  /// Parse a high precedence op expression list: mul, div, and mod are high
  2028  /// precedence binary ops, i.e., parse a
  2029  ///   expr_1 op_1 expr_2 op_2 ... expr_n
  2030  /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
  2031  /// All affine binary ops are left associative.
  2032  /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
  2033  /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
  2034  /// null. llhsOpLoc is the location of the llhsOp token that will be used to
  2035  /// report an error for non-conforming expressions.
  2036  AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
  2037                                                     AffineHighPrecOp llhsOp,
  2038                                                     SMLoc llhsOpLoc) {
  2039    AffineExpr lhs = parseAffineOperandExpr(llhs);
  2040    if (!lhs)
  2041      return nullptr;
  2042  
  2043    // Found an LHS. Parse the remaining expression.
  2044    auto opLoc = getToken().getLoc();
  2045    if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
  2046      if (llhs) {
  2047        AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
  2048        if (!expr)
  2049          return nullptr;
  2050        return parseAffineHighPrecOpExpr(expr, op, opLoc);
  2051      }
  2052      // No LLHS, get RHS
  2053      return parseAffineHighPrecOpExpr(lhs, op, opLoc);
  2054    }
  2055  
  2056    // This is the last operand in this expression.
  2057    if (llhs)
  2058      return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
  2059  
  2060    // No llhs, 'lhs' itself is the expression.
  2061    return lhs;
  2062  }
  2063  
  2064  /// Parse an affine expression inside parentheses.
  2065  ///
  2066  ///   affine-expr ::= `(` affine-expr `)`
  2067  AffineExpr AffineParser::parseParentheticalExpr() {
  2068    if (parseToken(Token::l_paren, "expected '('"))
  2069      return nullptr;
  2070    if (getToken().is(Token::r_paren))
  2071      return (emitError("no expression inside parentheses"), nullptr);
  2072  
  2073    auto expr = parseAffineExpr();
  2074    if (!expr)
  2075      return nullptr;
  2076    if (parseToken(Token::r_paren, "expected ')'"))
  2077      return nullptr;
  2078  
  2079    return expr;
  2080  }
  2081  
  2082  /// Parse the negation expression.
  2083  ///
  2084  ///   affine-expr ::= `-` affine-expr
  2085  AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
  2086    if (parseToken(Token::minus, "expected '-'"))
  2087      return nullptr;
  2088  
  2089    AffineExpr operand = parseAffineOperandExpr(lhs);
  2090    // Since negation has the highest precedence of all ops (including high
  2091    // precedence ops) but lower than parentheses, we are only going to use
  2092    // parseAffineOperandExpr instead of parseAffineExpr here.
  2093    if (!operand)
  2094      // Extra error message although parseAffineOperandExpr would have
  2095      // complained. Leads to a better diagnostic.
  2096      return (emitError("missing operand of negation"), nullptr);
  2097    return (-1) * operand;
  2098  }
  2099  
  2100  /// Parse a bare id that may appear in an affine expression.
  2101  ///
  2102  ///   affine-expr ::= bare-id
  2103  AffineExpr AffineParser::parseBareIdExpr() {
  2104    if (getToken().isNot(Token::bare_identifier))
  2105      return (emitError("expected bare identifier"), nullptr);
  2106  
  2107    StringRef sRef = getTokenSpelling();
  2108    for (auto entry : dimsAndSymbols) {
  2109      if (entry.first == sRef) {
  2110        consumeToken(Token::bare_identifier);
  2111        return entry.second;
  2112      }
  2113    }
  2114  
  2115    return (emitError("use of undeclared identifier"), nullptr);
  2116  }
  2117  
  2118  /// Parse an SSA id which may appear in an affine expression.
  2119  AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
  2120    if (!allowParsingSSAIds)
  2121      return (emitError("unexpected ssa identifier"), nullptr);
  2122    if (getToken().isNot(Token::percent_identifier))
  2123      return (emitError("expected ssa identifier"), nullptr);
  2124    auto name = getTokenSpelling();
  2125    // Check if we already parsed this SSA id.
  2126    for (auto entry : dimsAndSymbols) {
  2127      if (entry.first == name) {
  2128        consumeToken(Token::percent_identifier);
  2129        return entry.second;
  2130      }
  2131    }
  2132    // Parse the SSA id and add an AffineDim/SymbolExpr to represent it.
  2133    if (parseElement(isSymbol))
  2134      return (emitError("failed to parse ssa identifier"), nullptr);
  2135    auto idExpr = isSymbol
  2136                      ? getAffineSymbolExpr(numSymbolOperands++, getContext())
  2137                      : getAffineDimExpr(numDimOperands++, getContext());
  2138    dimsAndSymbols.push_back({name, idExpr});
  2139    return idExpr;
  2140  }
  2141  
  2142  AffineExpr AffineParser::parseSymbolSSAIdExpr() {
  2143    if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
  2144        parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
  2145      return nullptr;
  2146    AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true);
  2147    if (!symbolExpr)
  2148      return nullptr;
  2149    if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
  2150      return nullptr;
  2151    return symbolExpr;
  2152  }
  2153  
  2154  /// Parse a positive integral constant appearing in an affine expression.
  2155  ///
  2156  ///   affine-expr ::= integer-literal
  2157  AffineExpr AffineParser::parseIntegerExpr() {
  2158    auto val = getToken().getUInt64IntegerValue();
  2159    if (!val.hasValue() || (int64_t)val.getValue() < 0)
  2160      return (emitError("constant too large for index"), nullptr);
  2161  
  2162    consumeToken(Token::integer);
  2163    return builder.getAffineConstantExpr((int64_t)val.getValue());
  2164  }
  2165  
  2166  /// Parses an expression that can be a valid operand of an affine expression.
  2167  /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
  2168  /// operator, the rhs of which is being parsed. This is used to determine
  2169  /// whether an error should be emitted for a missing right operand.
  2170  //  Eg: for an expression without parentheses (like i + j + k + l), each
  2171  //  of the four identifiers is an operand. For i + j*k + l, j*k is not an
  2172  //  operand expression, it's an op expression and will be parsed via
  2173  //  parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
  2174  //  -l are valid operands that will be parsed by this function.
  2175  AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
  2176    switch (getToken().getKind()) {
  2177    case Token::bare_identifier:
  2178      return parseBareIdExpr();
  2179    case Token::kw_symbol:
  2180      return parseSymbolSSAIdExpr();
  2181    case Token::percent_identifier:
  2182      return parseSSAIdExpr(/*isSymbol=*/false);
  2183    case Token::integer:
  2184      return parseIntegerExpr();
  2185    case Token::l_paren:
  2186      return parseParentheticalExpr();
  2187    case Token::minus:
  2188      return parseNegateExpression(lhs);
  2189    case Token::kw_ceildiv:
  2190    case Token::kw_floordiv:
  2191    case Token::kw_mod:
  2192    case Token::plus:
  2193    case Token::star:
  2194      if (lhs)
  2195        emitError("missing right operand of binary operator");
  2196      else
  2197        emitError("missing left operand of binary operator");
  2198      return nullptr;
  2199    default:
  2200      if (lhs)
  2201        emitError("missing right operand of binary operator");
  2202      else
  2203        emitError("expected affine expression");
  2204      return nullptr;
  2205    }
  2206  }
  2207  
  2208  /// Parse affine expressions that are bare-id's, integer constants,
  2209  /// parenthetical affine expressions, and affine op expressions that are a
  2210  /// composition of those.
  2211  ///
  2212  /// All binary op's associate from left to right.
  2213  ///
  2214  /// {add, sub} have lower precedence than {mul, div, and mod}.
  2215  ///
  2216  /// Add, sub'are themselves at the same precedence level. Mul, floordiv,
  2217  /// ceildiv, and mod are at the same higher precedence level. Negation has
  2218  /// higher precedence than any binary op.
  2219  ///
  2220  /// llhs: the affine expression appearing on the left of the one being parsed.
  2221  /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
  2222  /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
  2223  /// if llhs is non-null; otherwise lhs is returned. This is to deal with left
  2224  /// associativity.
  2225  ///
  2226  /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
  2227  /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
  2228  /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
  2229  AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
  2230                                                    AffineLowPrecOp llhsOp) {
  2231    AffineExpr lhs;
  2232    if (!(lhs = parseAffineOperandExpr(llhs)))
  2233      return nullptr;
  2234  
  2235    // Found an LHS. Deal with the ops.
  2236    if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
  2237      if (llhs) {
  2238        AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
  2239        return parseAffineLowPrecOpExpr(sum, lOp);
  2240      }
  2241      // No LLHS, get RHS and form the expression.
  2242      return parseAffineLowPrecOpExpr(lhs, lOp);
  2243    }
  2244    auto opLoc = getToken().getLoc();
  2245    if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
  2246      // We have a higher precedence op here. Get the rhs operand for the llhs
  2247      // through parseAffineHighPrecOpExpr.
  2248      AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
  2249      if (!highRes)
  2250        return nullptr;
  2251  
  2252      // If llhs is null, the product forms the first operand of the yet to be
  2253      // found expression. If non-null, the op to associate with llhs is llhsOp.
  2254      AffineExpr expr =
  2255          llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
  2256  
  2257      // Recurse for subsequent low prec op's after the affine high prec op
  2258      // expression.
  2259      if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
  2260        return parseAffineLowPrecOpExpr(expr, nextOp);
  2261      return expr;
  2262    }
  2263    // Last operand in the expression list.
  2264    if (llhs)
  2265      return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
  2266    // No llhs, 'lhs' itself is the expression.
  2267    return lhs;
  2268  }
  2269  
  2270  /// Parse an affine expression.
  2271  ///  affine-expr ::= `(` affine-expr `)`
  2272  ///                | `-` affine-expr
  2273  ///                | affine-expr `+` affine-expr
  2274  ///                | affine-expr `-` affine-expr
  2275  ///                | affine-expr `*` affine-expr
  2276  ///                | affine-expr `floordiv` affine-expr
  2277  ///                | affine-expr `ceildiv` affine-expr
  2278  ///                | affine-expr `mod` affine-expr
  2279  ///                | bare-id
  2280  ///                | integer-literal
  2281  ///
  2282  /// Additional conditions are checked depending on the production. For eg.,
  2283  /// one of the operands for `*` has to be either constant/symbolic; the second
  2284  /// operand for floordiv, ceildiv, and mod has to be a positive integer.
  2285  AffineExpr AffineParser::parseAffineExpr() {
  2286    return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
  2287  }
  2288  
  2289  /// Parse a dim or symbol from the lists appearing before the actual
  2290  /// expressions of the affine map. Update our state to store the
  2291  /// dimensional/symbolic identifier.
  2292  ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
  2293    if (getToken().isNot(Token::bare_identifier))
  2294      return emitError("expected bare identifier");
  2295  
  2296    auto name = getTokenSpelling();
  2297    for (auto entry : dimsAndSymbols) {
  2298      if (entry.first == name)
  2299        return emitError("redefinition of identifier '" + name + "'");
  2300    }
  2301    consumeToken(Token::bare_identifier);
  2302  
  2303    dimsAndSymbols.push_back({name, idExpr});
  2304    return success();
  2305  }
  2306  
  2307  /// Parse the list of dimensional identifiers to an affine map.
  2308  ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
  2309    if (parseToken(Token::l_paren,
  2310                   "expected '(' at start of dimensional identifiers list")) {
  2311      return failure();
  2312    }
  2313  
  2314    auto parseElt = [&]() -> ParseResult {
  2315      auto dimension = getAffineDimExpr(numDims++, getContext());
  2316      return parseIdentifierDefinition(dimension);
  2317    };
  2318    return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
  2319  }
  2320  
  2321  /// Parse the list of symbolic identifiers to an affine map.
  2322  ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
  2323    consumeToken(Token::l_square);
  2324    auto parseElt = [&]() -> ParseResult {
  2325      auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
  2326      return parseIdentifierDefinition(symbol);
  2327    };
  2328    return parseCommaSeparatedListUntil(Token::r_square, parseElt);
  2329  }
  2330  
  2331  /// Parse the list of symbolic identifiers to an affine map.
  2332  ParseResult
  2333  AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims,
  2334                                                unsigned &numSymbols) {
  2335    if (parseDimIdList(numDims)) {
  2336      return failure();
  2337    }
  2338    if (!getToken().is(Token::l_square)) {
  2339      numSymbols = 0;
  2340      return success();
  2341    }
  2342    return parseSymbolIdList(numSymbols);
  2343  }
  2344  
  2345  /// Parses an ambiguous affine map or integer set definition inline.
  2346  ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
  2347                                                             IntegerSet &set) {
  2348    unsigned numDims = 0, numSymbols = 0;
  2349  
  2350    // List of dimensional and optional symbol identifiers.
  2351    if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) {
  2352      return failure();
  2353    }
  2354  
  2355    // This is needed for parsing attributes as we wouldn't know whether we would
  2356    // be parsing an integer set attribute or an affine map attribute.
  2357    bool isArrow = getToken().is(Token::arrow);
  2358    bool isColon = getToken().is(Token::colon);
  2359    if (!isArrow && !isColon) {
  2360      return emitError("expected '->' or ':'");
  2361    } else if (isArrow) {
  2362      parseToken(Token::arrow, "expected '->' or '['");
  2363      map = parseAffineMapRange(numDims, numSymbols);
  2364      return map ? success() : failure();
  2365    } else if (parseToken(Token::colon, "expected ':' or '['")) {
  2366      return failure();
  2367    }
  2368  
  2369    if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
  2370      return success();
  2371  
  2372    return failure();
  2373  }
  2374  
  2375  /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
  2376  ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) {
  2377    if (parseToken(Token::l_square, "expected '['"))
  2378      return failure();
  2379  
  2380    SmallVector<AffineExpr, 4> exprs;
  2381    auto parseElt = [&]() -> ParseResult {
  2382      auto elt = parseAffineExpr();
  2383      exprs.push_back(elt);
  2384      return elt ? success() : failure();
  2385    };
  2386  
  2387    // Parse a multi-dimensional affine expression (a comma-separated list of
  2388    // 1-d affine expressions); the list cannot be empty. Grammar:
  2389    // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
  2390    if (parseCommaSeparatedListUntil(Token::r_square, parseElt,
  2391                                     /*allowEmptyList=*/true))
  2392      return failure();
  2393    // Parsed a valid affine map.
  2394    if (exprs.empty())
  2395      map = AffineMap();
  2396    else
  2397      map = builder.getAffineMap(numDimOperands,
  2398                                 dimsAndSymbols.size() - numDimOperands, exprs);
  2399    return success();
  2400  }
  2401  
  2402  /// Parse the range and sizes affine map definition inline.
  2403  ///
  2404  ///  affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
  2405  ///
  2406  ///  multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
  2407  AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
  2408                                              unsigned numSymbols) {
  2409    parseToken(Token::l_paren, "expected '(' at start of affine map range");
  2410  
  2411    SmallVector<AffineExpr, 4> exprs;
  2412    auto parseElt = [&]() -> ParseResult {
  2413      auto elt = parseAffineExpr();
  2414      ParseResult res = elt ? success() : failure();
  2415      exprs.push_back(elt);
  2416      return res;
  2417    };
  2418  
  2419    // Parse a multi-dimensional affine expression (a comma-separated list of
  2420    // 1-d affine expressions); the list cannot be empty. Grammar:
  2421    // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
  2422    if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false))
  2423      return AffineMap();
  2424  
  2425    // Parsed a valid affine map.
  2426    return builder.getAffineMap(numDims, numSymbols, exprs);
  2427  }
  2428  
  2429  /// Parse an affine constraint.
  2430  ///  affine-constraint ::= affine-expr `>=` `0`
  2431  ///                      | affine-expr `==` `0`
  2432  ///
  2433  /// isEq is set to true if the parsed constraint is an equality, false if it
  2434  /// is an inequality (greater than or equal).
  2435  ///
  2436  AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
  2437    AffineExpr expr = parseAffineExpr();
  2438    if (!expr)
  2439      return nullptr;
  2440  
  2441    if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
  2442        getToken().is(Token::integer)) {
  2443      auto dim = getToken().getUnsignedIntegerValue();
  2444      if (dim.hasValue() && dim.getValue() == 0) {
  2445        consumeToken(Token::integer);
  2446        *isEq = false;
  2447        return expr;
  2448      }
  2449      return (emitError("expected '0' after '>='"), nullptr);
  2450    }
  2451  
  2452    if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
  2453        getToken().is(Token::integer)) {
  2454      auto dim = getToken().getUnsignedIntegerValue();
  2455      if (dim.hasValue() && dim.getValue() == 0) {
  2456        consumeToken(Token::integer);
  2457        *isEq = true;
  2458        return expr;
  2459      }
  2460      return (emitError("expected '0' after '=='"), nullptr);
  2461    }
  2462  
  2463    return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
  2464            nullptr);
  2465  }
  2466  
  2467  /// Parse the constraints that are part of an integer set definition.
  2468  ///  integer-set-inline
  2469  ///                ::= dim-and-symbol-id-lists `:`
  2470  ///                '(' affine-constraint-conjunction? ')'
  2471  ///  affine-constraint-conjunction ::= affine-constraint (`,`
  2472  ///                                       affine-constraint)*
  2473  ///
  2474  IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
  2475                                                      unsigned numSymbols) {
  2476    if (parseToken(Token::l_paren,
  2477                   "expected '(' at start of integer set constraint list"))
  2478      return IntegerSet();
  2479  
  2480    SmallVector<AffineExpr, 4> constraints;
  2481    SmallVector<bool, 4> isEqs;
  2482    auto parseElt = [&]() -> ParseResult {
  2483      bool isEq;
  2484      auto elt = parseAffineConstraint(&isEq);
  2485      ParseResult res = elt ? success() : failure();
  2486      if (elt) {
  2487        constraints.push_back(elt);
  2488        isEqs.push_back(isEq);
  2489      }
  2490      return res;
  2491    };
  2492  
  2493    // Parse a list of affine constraints (comma-separated).
  2494    if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
  2495      return IntegerSet();
  2496  
  2497    // If no constraints were parsed, then treat this as a degenerate 'true' case.
  2498    if (constraints.empty()) {
  2499      /* 0 == 0 */
  2500      auto zero = getAffineConstantExpr(0, getContext());
  2501      return builder.getIntegerSet(numDims, numSymbols, zero, true);
  2502    }
  2503  
  2504    // Parsed a valid integer set.
  2505    return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs);
  2506  }
  2507  
  2508  /// Parse an ambiguous reference to either and affine map or an integer set.
  2509  ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
  2510                                                          IntegerSet &set) {
  2511    return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
  2512  }
  2513  
  2514  /// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to
  2515  /// parse SSA value uses encountered while parsing affine expressions.
  2516  ParseResult Parser::parseAffineMapOfSSAIds(
  2517      AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) {
  2518    return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
  2519        .parseAffineMapOfSSAIds(map);
  2520  }
  2521  
  2522  //===----------------------------------------------------------------------===//
  2523  // OperationParser
  2524  //===----------------------------------------------------------------------===//
  2525  
  2526  namespace {
  2527  /// This class provides support for parsing operations and regions of
  2528  /// operations.
  2529  class OperationParser : public Parser {
  2530  public:
  2531    OperationParser(ParserState &state, ModuleOp moduleOp)
  2532        : Parser(state), opBuilder(moduleOp.getBodyRegion()), moduleOp(moduleOp) {
  2533    }
  2534  
  2535    ~OperationParser();
  2536  
  2537    /// After parsing is finished, this function must be called to see if there
  2538    /// are any remaining issues.
  2539    ParseResult finalize();
  2540  
  2541    //===--------------------------------------------------------------------===//
  2542    // SSA Value Handling
  2543    //===--------------------------------------------------------------------===//
  2544  
  2545    /// This represents a use of an SSA value in the program.  The first two
  2546    /// entries in the tuple are the name and result number of a reference.  The
  2547    /// third is the location of the reference, which is used in case this ends
  2548    /// up being a use of an undefined value.
  2549    struct SSAUseInfo {
  2550      StringRef name;  // Value name, e.g. %42 or %abc
  2551      unsigned number; // Number, specified with #12
  2552      SMLoc loc;       // Location of first definition or use.
  2553    };
  2554  
  2555    /// Push a new SSA name scope to the parser.
  2556    void pushSSANameScope(bool isIsolated);
  2557  
  2558    /// Pop the last SSA name scope from the parser.
  2559    ParseResult popSSANameScope();
  2560  
  2561    /// Register a definition of a value with the symbol table.
  2562    ParseResult addDefinition(SSAUseInfo useInfo, Value *value);
  2563  
  2564    /// Parse an optional list of SSA uses into 'results'.
  2565    ParseResult parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results);
  2566  
  2567    /// Parse a single SSA use into 'result'.
  2568    ParseResult parseSSAUse(SSAUseInfo &result);
  2569  
  2570    /// Given a reference to an SSA value and its type, return a reference. This
  2571    /// returns null on failure.
  2572    Value *resolveSSAUse(SSAUseInfo useInfo, Type type);
  2573  
  2574    ParseResult parseSSADefOrUseAndType(
  2575        const std::function<ParseResult(SSAUseInfo, Type)> &action);
  2576  
  2577    ParseResult parseOptionalSSAUseAndTypeList(SmallVectorImpl<Value *> &results);
  2578  
  2579    /// Return the location of the value identified by its name and number if it
  2580    /// has been already reference.
  2581    llvm::Optional<SMLoc> getReferenceLoc(StringRef name, unsigned number) {
  2582      auto &values = isolatedNameScopes.back().values;
  2583      if (!values.count(name) || number >= values[name].size())
  2584        return {};
  2585      if (values[name][number].first)
  2586        return values[name][number].second;
  2587      return {};
  2588    }
  2589  
  2590    //===--------------------------------------------------------------------===//
  2591    // Operation Parsing
  2592    //===--------------------------------------------------------------------===//
  2593  
  2594    /// Parse an operation instance.
  2595    ParseResult parseOperation();
  2596  
  2597    /// Parse a single operation successor and its operand list.
  2598    ParseResult parseSuccessorAndUseList(Block *&dest,
  2599                                         SmallVectorImpl<Value *> &operands);
  2600  
  2601    /// Parse a comma-separated list of operation successors in brackets.
  2602    ParseResult
  2603    parseSuccessors(SmallVectorImpl<Block *> &destinations,
  2604                    SmallVectorImpl<SmallVector<Value *, 4>> &operands);
  2605  
  2606    /// Parse an operation instance that is in the generic form.
  2607    Operation *parseGenericOperation();
  2608  
  2609    /// Parse an operation instance that is in the op-defined custom form.
  2610    Operation *parseCustomOperation();
  2611  
  2612    //===--------------------------------------------------------------------===//
  2613    // Region Parsing
  2614    //===--------------------------------------------------------------------===//
  2615  
  2616    /// Parse a region into 'region' with the provided entry block arguments.
  2617    /// 'isIsolatedNameScope' indicates if the naming scope of this region is
  2618    /// isolated from those above.
  2619    ParseResult parseRegion(Region &region,
  2620                            ArrayRef<std::pair<SSAUseInfo, Type>> entryArguments,
  2621                            bool isIsolatedNameScope = false);
  2622  
  2623    /// Parse a region body into 'region'.
  2624    ParseResult parseRegionBody(Region &region);
  2625  
  2626    //===--------------------------------------------------------------------===//
  2627    // Block Parsing
  2628    //===--------------------------------------------------------------------===//
  2629  
  2630    /// Parse a new block into 'block'.
  2631    ParseResult parseBlock(Block *&block);
  2632  
  2633    /// Parse a list of operations into 'block'.
  2634    ParseResult parseBlockBody(Block *block);
  2635  
  2636    /// Parse a (possibly empty) list of block arguments.
  2637    ParseResult
  2638    parseOptionalBlockArgList(SmallVectorImpl<BlockArgument *> &results,
  2639                              Block *owner);
  2640  
  2641    /// Get the block with the specified name, creating it if it doesn't
  2642    /// already exist.  The location specified is the point of use, which allows
  2643    /// us to diagnose references to blocks that are not defined precisely.
  2644    Block *getBlockNamed(StringRef name, SMLoc loc);
  2645  
  2646    /// Define the block with the specified name. Returns the Block* or nullptr in
  2647    /// the case of redefinition.
  2648    Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing);
  2649  
  2650  private:
  2651    /// Returns the info for a block at the current scope for the given name.
  2652    std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) {
  2653      return blocksByName.back()[name];
  2654    }
  2655  
  2656    /// Insert a new forward reference to the given block.
  2657    void insertForwardRef(Block *block, SMLoc loc) {
  2658      forwardRef.back().try_emplace(block, loc);
  2659    }
  2660  
  2661    /// Erase any forward reference to the given block.
  2662    bool eraseForwardRef(Block *block) { return forwardRef.back().erase(block); }
  2663  
  2664    /// Record that a definition was added at the current scope.
  2665    void recordDefinition(StringRef def);
  2666  
  2667    /// Get the value entry for the given SSA name.
  2668    SmallVectorImpl<std::pair<Value *, SMLoc>> &getSSAValueEntry(StringRef name);
  2669  
  2670    /// Create a forward reference placeholder value with the given location and
  2671    /// result type.
  2672    Value *createForwardRefPlaceholder(SMLoc loc, Type type);
  2673  
  2674    /// Return true if this is a forward reference.
  2675    bool isForwardRefPlaceholder(Value *value) {
  2676      return forwardRefPlaceholders.count(value);
  2677    }
  2678  
  2679    /// This struct represents an isolated SSA name scope. This scope may contain
  2680    /// other nested non-isolated scopes. These scopes are used for operations
  2681    /// that are known to be isolated to allow for reusing names within their
  2682    /// regions, even if those names are used above.
  2683    struct IsolatedSSANameScope {
  2684      /// Record that a definition was added at the current scope.
  2685      void recordDefinition(StringRef def) {
  2686        definitionsPerScope.back().insert(def);
  2687      }
  2688  
  2689      /// Push a nested name scope.
  2690      void pushSSANameScope() { definitionsPerScope.push_back({}); }
  2691  
  2692      /// Pop a nested name scope.
  2693      void popSSANameScope() {
  2694        for (auto &def : definitionsPerScope.pop_back_val())
  2695          values.erase(def.getKey());
  2696      }
  2697  
  2698      /// This keeps track of all of the SSA values we are tracking for each name
  2699      /// scope, indexed by their name. This has one entry per result number.
  2700      llvm::StringMap<SmallVector<std::pair<Value *, SMLoc>, 1>> values;
  2701  
  2702      /// This keeps track of all of the values defined by a specific name scope.
  2703      SmallVector<llvm::StringSet<>, 2> definitionsPerScope;
  2704    };
  2705  
  2706    /// A list of isolated name scopes.
  2707    SmallVector<IsolatedSSANameScope, 2> isolatedNameScopes;
  2708  
  2709    /// This keeps track of the block names as well as the location of the first
  2710    /// reference for each nested name scope. This is used to diagnose invalid
  2711    /// block references and memoize them.
  2712    SmallVector<DenseMap<StringRef, std::pair<Block *, SMLoc>>, 2> blocksByName;
  2713    SmallVector<DenseMap<Block *, SMLoc>, 2> forwardRef;
  2714  
  2715    /// These are all of the placeholders we've made along with the location of
  2716    /// their first reference, to allow checking for use of undefined values.
  2717    DenseMap<Value *, SMLoc> forwardRefPlaceholders;
  2718  
  2719    /// The builder used when creating parsed operation instances.
  2720    OpBuilder opBuilder;
  2721  
  2722    /// The top level module operation.
  2723    ModuleOp moduleOp;
  2724  };
  2725  } // end anonymous namespace
  2726  
  2727  OperationParser::~OperationParser() {
  2728    for (auto &fwd : forwardRefPlaceholders) {
  2729      // Drop all uses of undefined forward declared reference and destroy
  2730      // defining operation.
  2731      fwd.first->dropAllUses();
  2732      fwd.first->getDefiningOp()->destroy();
  2733    }
  2734  }
  2735  
  2736  /// After parsing is finished, this function must be called to see if there are
  2737  /// any remaining issues.
  2738  ParseResult OperationParser::finalize() {
  2739    // Check for any forward references that are left.  If we find any, error
  2740    // out.
  2741    if (!forwardRefPlaceholders.empty()) {
  2742      SmallVector<std::pair<const char *, Value *>, 4> errors;
  2743      // Iteration over the map isn't deterministic, so sort by source location.
  2744      for (auto entry : forwardRefPlaceholders)
  2745        errors.push_back({entry.second.getPointer(), entry.first});
  2746      llvm::array_pod_sort(errors.begin(), errors.end());
  2747  
  2748      for (auto entry : errors) {
  2749        auto loc = SMLoc::getFromPointer(entry.first);
  2750        emitError(loc, "use of undeclared SSA value name");
  2751      }
  2752      return failure();
  2753    }
  2754  
  2755    return success();
  2756  }
  2757  
  2758  //===----------------------------------------------------------------------===//
  2759  // SSA Value Handling
  2760  //===----------------------------------------------------------------------===//
  2761  
  2762  void OperationParser::pushSSANameScope(bool isIsolated) {
  2763    blocksByName.push_back(DenseMap<StringRef, std::pair<Block *, SMLoc>>());
  2764    forwardRef.push_back(DenseMap<Block *, SMLoc>());
  2765  
  2766    // Push back a new name definition scope.
  2767    if (isIsolated)
  2768      isolatedNameScopes.push_back({});
  2769    isolatedNameScopes.back().pushSSANameScope();
  2770  }
  2771  
  2772  ParseResult OperationParser::popSSANameScope() {
  2773    auto forwardRefInCurrentScope = forwardRef.pop_back_val();
  2774  
  2775    // Verify that all referenced blocks were defined.
  2776    if (!forwardRefInCurrentScope.empty()) {
  2777      SmallVector<std::pair<const char *, Block *>, 4> errors;
  2778      // Iteration over the map isn't deterministic, so sort by source location.
  2779      for (auto entry : forwardRefInCurrentScope) {
  2780        errors.push_back({entry.second.getPointer(), entry.first});
  2781        // Add this block to the top-level region to allow for automatic cleanup.
  2782        moduleOp.getOperation()->getRegion(0).push_back(entry.first);
  2783      }
  2784      llvm::array_pod_sort(errors.begin(), errors.end());
  2785  
  2786      for (auto entry : errors) {
  2787        auto loc = SMLoc::getFromPointer(entry.first);
  2788        emitError(loc, "reference to an undefined block");
  2789      }
  2790      return failure();
  2791    }
  2792  
  2793    // Pop the next nested namescope. If there is only one internal namescope,
  2794    // just pop the isolated scope.
  2795    auto &currentNameScope = isolatedNameScopes.back();
  2796    if (currentNameScope.definitionsPerScope.size() == 1)
  2797      isolatedNameScopes.pop_back();
  2798    else
  2799      currentNameScope.popSSANameScope();
  2800  
  2801    blocksByName.pop_back();
  2802    return success();
  2803  }
  2804  
  2805  /// Register a definition of a value with the symbol table.
  2806  ParseResult OperationParser::addDefinition(SSAUseInfo useInfo, Value *value) {
  2807    auto &entries = getSSAValueEntry(useInfo.name);
  2808  
  2809    // Make sure there is a slot for this value.
  2810    if (entries.size() <= useInfo.number)
  2811      entries.resize(useInfo.number + 1);
  2812  
  2813    // If we already have an entry for this, check to see if it was a definition
  2814    // or a forward reference.
  2815    if (auto *existing = entries[useInfo.number].first) {
  2816      if (!isForwardRefPlaceholder(existing)) {
  2817        return emitError(useInfo.loc)
  2818            .append("redefinition of SSA value '", useInfo.name, "'")
  2819            .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
  2820            .append("previously defined here");
  2821      }
  2822  
  2823      // If it was a forward reference, update everything that used it to use
  2824      // the actual definition instead, delete the forward ref, and remove it
  2825      // from our set of forward references we track.
  2826      existing->replaceAllUsesWith(value);
  2827      existing->getDefiningOp()->destroy();
  2828      forwardRefPlaceholders.erase(existing);
  2829    }
  2830  
  2831    /// Record this definition for the current scope.
  2832    entries[useInfo.number] = {value, useInfo.loc};
  2833    recordDefinition(useInfo.name);
  2834    return success();
  2835  }
  2836  
  2837  /// Parse a (possibly empty) list of SSA operands.
  2838  ///
  2839  ///   ssa-use-list ::= ssa-use (`,` ssa-use)*
  2840  ///   ssa-use-list-opt ::= ssa-use-list?
  2841  ///
  2842  ParseResult
  2843  OperationParser::parseOptionalSSAUseList(SmallVectorImpl<SSAUseInfo> &results) {
  2844    if (getToken().isNot(Token::percent_identifier))
  2845      return success();
  2846    return parseCommaSeparatedList([&]() -> ParseResult {
  2847      SSAUseInfo result;
  2848      if (parseSSAUse(result))
  2849        return failure();
  2850      results.push_back(result);
  2851      return success();
  2852    });
  2853  }
  2854  
  2855  /// Parse a SSA operand for an operation.
  2856  ///
  2857  ///   ssa-use ::= ssa-id
  2858  ///
  2859  ParseResult OperationParser::parseSSAUse(SSAUseInfo &result) {
  2860    result.name = getTokenSpelling();
  2861    result.number = 0;
  2862    result.loc = getToken().getLoc();
  2863    if (parseToken(Token::percent_identifier, "expected SSA operand"))
  2864      return failure();
  2865  
  2866    // If we have an attribute ID, it is a result number.
  2867    if (getToken().is(Token::hash_identifier)) {
  2868      if (auto value = getToken().getHashIdentifierNumber())
  2869        result.number = value.getValue();
  2870      else
  2871        return emitError("invalid SSA value result number");
  2872      consumeToken(Token::hash_identifier);
  2873    }
  2874  
  2875    return success();
  2876  }
  2877  
  2878  /// Given an unbound reference to an SSA value and its type, return the value
  2879  /// it specifies.  This returns null on failure.
  2880  Value *OperationParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
  2881    auto &entries = getSSAValueEntry(useInfo.name);
  2882  
  2883    // If we have already seen a value of this name, return it.
  2884    if (useInfo.number < entries.size() && entries[useInfo.number].first) {
  2885      auto *result = entries[useInfo.number].first;
  2886      // Check that the type matches the other uses.
  2887      if (result->getType() == type)
  2888        return result;
  2889  
  2890      emitError(useInfo.loc, "use of value '")
  2891          .append(useInfo.name,
  2892                  "' expects different type than prior uses: ", type, " vs ",
  2893                  result->getType())
  2894          .attachNote(getEncodedSourceLocation(entries[useInfo.number].second))
  2895          .append("prior use here");
  2896      return nullptr;
  2897    }
  2898  
  2899    // Make sure we have enough slots for this.
  2900    if (entries.size() <= useInfo.number)
  2901      entries.resize(useInfo.number + 1);
  2902  
  2903    // If the value has already been defined and this is an overly large result
  2904    // number, diagnose that.
  2905    if (entries[0].first && !isForwardRefPlaceholder(entries[0].first))
  2906      return (emitError(useInfo.loc, "reference to invalid result number"),
  2907              nullptr);
  2908  
  2909    // Otherwise, this is a forward reference.  Create a placeholder and remember
  2910    // that we did so.
  2911    auto *result = createForwardRefPlaceholder(useInfo.loc, type);
  2912    entries[useInfo.number].first = result;
  2913    entries[useInfo.number].second = useInfo.loc;
  2914    return result;
  2915  }
  2916  
  2917  /// Parse an SSA use with an associated type.
  2918  ///
  2919  ///   ssa-use-and-type ::= ssa-use `:` type
  2920  ParseResult OperationParser::parseSSADefOrUseAndType(
  2921      const std::function<ParseResult(SSAUseInfo, Type)> &action) {
  2922    SSAUseInfo useInfo;
  2923    if (parseSSAUse(useInfo) ||
  2924        parseToken(Token::colon, "expected ':' and type for SSA operand"))
  2925      return failure();
  2926  
  2927    auto type = parseType();
  2928    if (!type)
  2929      return failure();
  2930  
  2931    return action(useInfo, type);
  2932  }
  2933  
  2934  /// Parse a (possibly empty) list of SSA operands, followed by a colon, then
  2935  /// followed by a type list.
  2936  ///
  2937  ///   ssa-use-and-type-list
  2938  ///     ::= ssa-use-list ':' type-list-no-parens
  2939  ///
  2940  ParseResult OperationParser::parseOptionalSSAUseAndTypeList(
  2941      SmallVectorImpl<Value *> &results) {
  2942    SmallVector<SSAUseInfo, 4> valueIDs;
  2943    if (parseOptionalSSAUseList(valueIDs))
  2944      return failure();
  2945  
  2946    // If there were no operands, then there is no colon or type lists.
  2947    if (valueIDs.empty())
  2948      return success();
  2949  
  2950    SmallVector<Type, 4> types;
  2951    if (parseToken(Token::colon, "expected ':' in operand list") ||
  2952        parseTypeListNoParens(types))
  2953      return failure();
  2954  
  2955    if (valueIDs.size() != types.size())
  2956      return emitError("expected ")
  2957             << valueIDs.size() << " types to match operand list";
  2958  
  2959    results.reserve(valueIDs.size());
  2960    for (unsigned i = 0, e = valueIDs.size(); i != e; ++i) {
  2961      if (auto *value = resolveSSAUse(valueIDs[i], types[i]))
  2962        results.push_back(value);
  2963      else
  2964        return failure();
  2965    }
  2966  
  2967    return success();
  2968  }
  2969  
  2970  /// Record that a definition was added at the current scope.
  2971  void OperationParser::recordDefinition(StringRef def) {
  2972    isolatedNameScopes.back().recordDefinition(def);
  2973  }
  2974  
  2975  /// Get the value entry for the given SSA name.
  2976  SmallVectorImpl<std::pair<Value *, SMLoc>> &
  2977  OperationParser::getSSAValueEntry(StringRef name) {
  2978    return isolatedNameScopes.back().values[name];
  2979  }
  2980  
  2981  /// Create and remember a new placeholder for a forward reference.
  2982  Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
  2983    // Forward references are always created as operations, because we just need
  2984    // something with a def/use chain.
  2985    //
  2986    // We create these placeholders as having an empty name, which we know
  2987    // cannot be created through normal user input, allowing us to distinguish
  2988    // them.
  2989    auto name = OperationName("placeholder", getContext());
  2990    auto *op = Operation::create(
  2991        getEncodedSourceLocation(loc), name, /*operands=*/{}, type,
  2992        /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
  2993        /*resizableOperandList=*/false);
  2994    forwardRefPlaceholders[op->getResult(0)] = loc;
  2995    return op->getResult(0);
  2996  }
  2997  
  2998  //===----------------------------------------------------------------------===//
  2999  // Operation Parsing
  3000  //===----------------------------------------------------------------------===//
  3001  
  3002  /// Parse an operation.
  3003  ///
  3004  ///  operation ::=
  3005  ///    operation-result? string '(' ssa-use-list? ')' attribute-dict?
  3006  ///    `:` function-type trailing-location?
  3007  ///  operation-result ::= ssa-id ((`:` integer-literal) | (`,` ssa-id)*) `=`
  3008  ///
  3009  ParseResult OperationParser::parseOperation() {
  3010    auto loc = getToken().getLoc();
  3011    SmallVector<std::pair<StringRef, SMLoc>, 1> resultIDs;
  3012    size_t numExpectedResults;
  3013    if (getToken().is(Token::percent_identifier)) {
  3014      // Parse the first result id.
  3015      resultIDs.emplace_back(getTokenSpelling(), loc);
  3016      consumeToken(Token::percent_identifier);
  3017  
  3018      // If the next token is a ':', we parse the expected result count.
  3019      if (consumeIf(Token::colon)) {
  3020        // Check that the next token is an integer.
  3021        if (!getToken().is(Token::integer))
  3022          return emitError("expected integer number of results");
  3023  
  3024        // Check that number of results is > 0.
  3025        auto val = getToken().getUInt64IntegerValue();
  3026        if (!val.hasValue() || val.getValue() < 1)
  3027          return emitError("expected named operation to have atleast 1 result");
  3028        consumeToken(Token::integer);
  3029        numExpectedResults = *val;
  3030      } else {
  3031        // Otherwise, this is a comma separated list of result ids.
  3032        if (consumeIf(Token::comma)) {
  3033          auto parseNextResult = [&]() -> ParseResult {
  3034            // Parse the next result id.
  3035            if (!getToken().is(Token::percent_identifier))
  3036              return emitError("expected valid ssa identifier");
  3037  
  3038            resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc());
  3039            consumeToken(Token::percent_identifier);
  3040            return success();
  3041          };
  3042  
  3043          if (parseCommaSeparatedList(parseNextResult))
  3044            return failure();
  3045        }
  3046        numExpectedResults = resultIDs.size();
  3047      }
  3048  
  3049      if (parseToken(Token::equal, "expected '=' after SSA name"))
  3050        return failure();
  3051    }
  3052  
  3053    Operation *op;
  3054    if (getToken().is(Token::bare_identifier) || getToken().isKeyword())
  3055      op = parseCustomOperation();
  3056    else if (getToken().is(Token::string))
  3057      op = parseGenericOperation();
  3058    else
  3059      return emitError("expected operation name in quotes");
  3060  
  3061    // If parsing of the basic operation failed, then this whole thing fails.
  3062    if (!op)
  3063      return failure();
  3064  
  3065    // If the operation had a name, register it.
  3066    if (!resultIDs.empty()) {
  3067      if (op->getNumResults() == 0)
  3068        return emitError(loc, "cannot name an operation with no results");
  3069      if (numExpectedResults != op->getNumResults())
  3070        return emitError(loc, "operation defines ")
  3071               << op->getNumResults() << " results but was provided "
  3072               << numExpectedResults << " to bind";
  3073  
  3074      // If the number of result names matches the number of operation results, we
  3075      // can directly use the provided names.
  3076      if (resultIDs.size() == op->getNumResults()) {
  3077        for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
  3078          if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second},
  3079                            op->getResult(i)))
  3080            return failure();
  3081      } else {
  3082        // Otherwise, we use the same name for all results.
  3083        StringRef name = resultIDs.front().first;
  3084        for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
  3085          if (addDefinition({name, i, loc}, op->getResult(i)))
  3086            return failure();
  3087      }
  3088    }
  3089  
  3090    // Try to parse the optional trailing location.
  3091    if (parseOptionalTrailingLocation(op))
  3092      return failure();
  3093  
  3094    return success();
  3095  }
  3096  
  3097  /// Parse a single operation successor and its operand list.
  3098  ///
  3099  ///   successor ::= block-id branch-use-list?
  3100  ///   branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
  3101  ///
  3102  ParseResult
  3103  OperationParser::parseSuccessorAndUseList(Block *&dest,
  3104                                            SmallVectorImpl<Value *> &operands) {
  3105    // Verify branch is identifier and get the matching block.
  3106    if (!getToken().is(Token::caret_identifier))
  3107      return emitError("expected block name");
  3108    dest = getBlockNamed(getTokenSpelling(), getToken().getLoc());
  3109    consumeToken();
  3110  
  3111    // Handle optional arguments.
  3112    if (consumeIf(Token::l_paren) &&
  3113        (parseOptionalSSAUseAndTypeList(operands) ||
  3114         parseToken(Token::r_paren, "expected ')' to close argument list"))) {
  3115      return failure();
  3116    }
  3117  
  3118    return success();
  3119  }
  3120  
  3121  /// Parse a comma-separated list of operation successors in brackets.
  3122  ///
  3123  ///   successor-list ::= `[` successor (`,` successor )* `]`
  3124  ///
  3125  ParseResult OperationParser::parseSuccessors(
  3126      SmallVectorImpl<Block *> &destinations,
  3127      SmallVectorImpl<SmallVector<Value *, 4>> &operands) {
  3128    if (parseToken(Token::l_square, "expected '['"))
  3129      return failure();
  3130  
  3131    auto parseElt = [this, &destinations, &operands]() {
  3132      Block *dest;
  3133      SmallVector<Value *, 4> destOperands;
  3134      auto res = parseSuccessorAndUseList(dest, destOperands);
  3135      destinations.push_back(dest);
  3136      operands.push_back(destOperands);
  3137      return res;
  3138    };
  3139    return parseCommaSeparatedListUntil(Token::r_square, parseElt,
  3140                                        /*allowEmptyList=*/false);
  3141  }
  3142  
  3143  namespace {
  3144  // RAII-style guard for cleaning up the regions in the operation state before
  3145  // deleting them.  Within the parser, regions may get deleted if parsing failed,
  3146  // and other errors may be present, in praticular undominated uses.  This makes
  3147  // sure such uses are deleted.
  3148  struct CleanupOpStateRegions {
  3149    ~CleanupOpStateRegions() {
  3150      SmallVector<Region *, 4> regionsToClean;
  3151      regionsToClean.reserve(state.regions.size());
  3152      for (auto &region : state.regions)
  3153        if (region)
  3154          for (auto &block : *region)
  3155            block.dropAllDefinedValueUses();
  3156    }
  3157    OperationState &state;
  3158  };
  3159  } // namespace
  3160  
  3161  Operation *OperationParser::parseGenericOperation() {
  3162    // Get location information for the operation.
  3163    auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
  3164  
  3165    auto name = getToken().getStringValue();
  3166    if (name.empty())
  3167      return (emitError("empty operation name is invalid"), nullptr);
  3168    if (name.find('\0') != StringRef::npos)
  3169      return (emitError("null character not allowed in operation name"), nullptr);
  3170  
  3171    consumeToken(Token::string);
  3172  
  3173    OperationState result(srcLocation, name);
  3174  
  3175    // Generic operations have a resizable operation list.
  3176    result.setOperandListToResizable();
  3177  
  3178    // Parse the operand list.
  3179    SmallVector<SSAUseInfo, 8> operandInfos;
  3180  
  3181    if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
  3182        parseOptionalSSAUseList(operandInfos) ||
  3183        parseToken(Token::r_paren, "expected ')' to end operand list")) {
  3184      return nullptr;
  3185    }
  3186  
  3187    // Parse the successor list but don't add successors to the result yet to
  3188    // avoid messing up with the argument order.
  3189    SmallVector<Block *, 2> successors;
  3190    SmallVector<SmallVector<Value *, 4>, 2> successorOperands;
  3191    if (getToken().is(Token::l_square)) {
  3192      // Check if the operation is a known terminator.
  3193      const AbstractOperation *abstractOp = result.name.getAbstractOperation();
  3194      if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator))
  3195        return emitError("successors in non-terminator"), nullptr;
  3196      if (parseSuccessors(successors, successorOperands))
  3197        return nullptr;
  3198    }
  3199  
  3200    // Parse the region list.
  3201    CleanupOpStateRegions guard{result};
  3202    if (consumeIf(Token::l_paren)) {
  3203      do {
  3204        // Create temporary regions with the top level region as parent.
  3205        result.regions.emplace_back(new Region(moduleOp));
  3206        if (parseRegion(*result.regions.back(), /*entryArguments=*/{}))
  3207          return nullptr;
  3208      } while (consumeIf(Token::comma));
  3209      if (parseToken(Token::r_paren, "expected ')' to end region list"))
  3210        return nullptr;
  3211    }
  3212  
  3213    if (getToken().is(Token::l_brace)) {
  3214      if (parseAttributeDict(result.attributes))
  3215        return nullptr;
  3216    }
  3217  
  3218    if (parseToken(Token::colon, "expected ':' followed by operation type"))
  3219      return nullptr;
  3220  
  3221    auto typeLoc = getToken().getLoc();
  3222    auto type = parseType();
  3223    if (!type)
  3224      return nullptr;
  3225    auto fnType = type.dyn_cast<FunctionType>();
  3226    if (!fnType)
  3227      return (emitError(typeLoc, "expected function type"), nullptr);
  3228  
  3229    result.addTypes(fnType.getResults());
  3230  
  3231    // Check that we have the right number of types for the operands.
  3232    auto operandTypes = fnType.getInputs();
  3233    if (operandTypes.size() != operandInfos.size()) {
  3234      auto plural = "s"[operandInfos.size() == 1];
  3235      return (emitError(typeLoc, "expected ")
  3236                  << operandInfos.size() << " operand type" << plural
  3237                  << " but had " << operandTypes.size(),
  3238              nullptr);
  3239    }
  3240  
  3241    // Resolve all of the operands.
  3242    for (unsigned i = 0, e = operandInfos.size(); i != e; ++i) {
  3243      result.operands.push_back(resolveSSAUse(operandInfos[i], operandTypes[i]));
  3244      if (!result.operands.back())
  3245        return nullptr;
  3246    }
  3247  
  3248    // Add the sucessors, and their operands after the proper operands.
  3249    for (const auto &succ : llvm::zip(successors, successorOperands)) {
  3250      Block *successor = std::get<0>(succ);
  3251      const SmallVector<Value *, 4> &operands = std::get<1>(succ);
  3252      result.addSuccessor(successor, operands);
  3253    }
  3254  
  3255    return opBuilder.createOperation(result);
  3256  }
  3257  
  3258  namespace {
  3259  class CustomOpAsmParser : public OpAsmParser {
  3260  public:
  3261    CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition,
  3262                      OperationParser &parser)
  3263        : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {}
  3264  
  3265    /// Parse an instance of the operation described by 'opDefinition' into the
  3266    /// provided operation state.
  3267    ParseResult parseOperation(OperationState *opState) {
  3268      if (opDefinition->parseAssembly(this, opState))
  3269        return failure();
  3270      return success();
  3271    }
  3272  
  3273    //===--------------------------------------------------------------------===//
  3274    // Utilities
  3275    //===--------------------------------------------------------------------===//
  3276  
  3277    /// Return if any errors were emitted during parsing.
  3278    bool didEmitError() const { return emittedError; }
  3279  
  3280    /// Emit a diagnostic at the specified location and return failure.
  3281    InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
  3282      emittedError = true;
  3283      return parser.emitError(loc, "custom op '" + opDefinition->name + "' " +
  3284                                       message);
  3285    }
  3286  
  3287    llvm::SMLoc getCurrentLocation() override {
  3288      return parser.getToken().getLoc();
  3289    }
  3290  
  3291    Builder &getBuilder() const override { return parser.builder; }
  3292  
  3293    llvm::SMLoc getNameLoc() const override { return nameLoc; }
  3294  
  3295    //===--------------------------------------------------------------------===//
  3296    // Token Parsing
  3297    //===--------------------------------------------------------------------===//
  3298  
  3299    /// Parse a `->` token.
  3300    ParseResult parseArrow() override {
  3301      return parser.parseToken(Token::arrow, "expected '->'");
  3302    }
  3303  
  3304    /// Parses a `->` if present.
  3305    ParseResult parseOptionalArrow() override {
  3306      return success(parser.consumeIf(Token::arrow));
  3307    }
  3308  
  3309    /// Parse a `:` token.
  3310    ParseResult parseColon() override {
  3311      return parser.parseToken(Token::colon, "expected ':'");
  3312    }
  3313  
  3314    /// Parse a `:` token if present.
  3315    ParseResult parseOptionalColon() override {
  3316      return success(parser.consumeIf(Token::colon));
  3317    }
  3318  
  3319    /// Parse a `,` token.
  3320    ParseResult parseComma() override {
  3321      return parser.parseToken(Token::comma, "expected ','");
  3322    }
  3323  
  3324    /// Parse a `,` token if present.
  3325    ParseResult parseOptionalComma() override {
  3326      return success(parser.consumeIf(Token::comma));
  3327    }
  3328  
  3329    /// Parses a `...` if present.
  3330    ParseResult parseOptionalEllipsis() override {
  3331      return success(parser.consumeIf(Token::ellipsis));
  3332    }
  3333  
  3334    /// Parse a `=` token.
  3335    ParseResult parseEqual() override {
  3336      return parser.parseToken(Token::equal, "expected '='");
  3337    }
  3338  
  3339    /// Parse a keyword if present.
  3340    ParseResult parseOptionalKeyword(const char *keyword) override {
  3341      // Check that the current token is a bare identifier or keyword.
  3342      if (parser.getToken().isNot(Token::bare_identifier) &&
  3343          !parser.getToken().isKeyword())
  3344        return failure();
  3345  
  3346      if (parser.getTokenSpelling() == keyword) {
  3347        parser.consumeToken();
  3348        return success();
  3349      }
  3350      return failure();
  3351    }
  3352  
  3353    /// Parse a `(` token.
  3354    ParseResult parseLParen() override {
  3355      return parser.parseToken(Token::l_paren, "expected '('");
  3356    }
  3357  
  3358    /// Parses a '(' if present.
  3359    ParseResult parseOptionalLParen() override {
  3360      return success(parser.consumeIf(Token::l_paren));
  3361    }
  3362  
  3363    /// Parse a `)` token.
  3364    ParseResult parseRParen() override {
  3365      return parser.parseToken(Token::r_paren, "expected ')'");
  3366    }
  3367  
  3368    /// Parses a ')' if present.
  3369    ParseResult parseOptionalRParen() override {
  3370      return success(parser.consumeIf(Token::r_paren));
  3371    }
  3372  
  3373    /// Parse a `[` token.
  3374    ParseResult parseLSquare() override {
  3375      return parser.parseToken(Token::l_square, "expected '['");
  3376    }
  3377  
  3378    /// Parses a '[' if present.
  3379    ParseResult parseOptionalLSquare() override {
  3380      return success(parser.consumeIf(Token::l_square));
  3381    }
  3382  
  3383    /// Parse a `]` token.
  3384    ParseResult parseRSquare() override {
  3385      return parser.parseToken(Token::r_square, "expected ']'");
  3386    }
  3387  
  3388    /// Parses a ']' if present.
  3389    ParseResult parseOptionalRSquare() override {
  3390      return success(parser.consumeIf(Token::r_square));
  3391    }
  3392  
  3393    //===--------------------------------------------------------------------===//
  3394    // Attribute Parsing
  3395    //===--------------------------------------------------------------------===//
  3396  
  3397    /// Parse an arbitrary attribute of a given type and return it in result. This
  3398    /// also adds the attribute to the specified attribute list with the specified
  3399    /// name.
  3400    ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName,
  3401                               SmallVectorImpl<NamedAttribute> &attrs) override {
  3402      result = parser.parseAttribute(type);
  3403      if (!result)
  3404        return failure();
  3405  
  3406      attrs.push_back(parser.builder.getNamedAttr(attrName, result));
  3407      return success();
  3408    }
  3409  
  3410    /// Parse a named dictionary into 'result' if it is present.
  3411    ParseResult
  3412    parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override {
  3413      if (parser.getToken().isNot(Token::l_brace))
  3414        return success();
  3415      return parser.parseAttributeDict(result);
  3416    }
  3417  
  3418    //===--------------------------------------------------------------------===//
  3419    // Identifier Parsing
  3420    //===--------------------------------------------------------------------===//
  3421  
  3422    /// Parse an @-identifier and store it (without the '@' symbol) in a string
  3423    /// attribute named 'attrName'.
  3424    ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
  3425                                SmallVectorImpl<NamedAttribute> &attrs) override {
  3426      if (parser.getToken().isNot(Token::at_identifier))
  3427        return failure();
  3428      result = getBuilder().getStringAttr(parser.getTokenSpelling().drop_front());
  3429      attrs.push_back(getBuilder().getNamedAttr(attrName, result));
  3430      parser.consumeToken();
  3431      return success();
  3432    }
  3433  
  3434    //===--------------------------------------------------------------------===//
  3435    // Operand Parsing
  3436    //===--------------------------------------------------------------------===//
  3437  
  3438    /// Parse a single operand.
  3439    ParseResult parseOperand(OperandType &result) override {
  3440      OperationParser::SSAUseInfo useInfo;
  3441      if (parser.parseSSAUse(useInfo))
  3442        return failure();
  3443  
  3444      result = {useInfo.loc, useInfo.name, useInfo.number};
  3445      return success();
  3446    }
  3447  
  3448    /// Parse zero or more SSA comma-separated operand references with a specified
  3449    /// surrounding delimiter, and an optional required operand count.
  3450    ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
  3451                                 int requiredOperandCount = -1,
  3452                                 Delimiter delimiter = Delimiter::None) override {
  3453      return parseOperandOrRegionArgList(result, /*isOperandList=*/true,
  3454                                         requiredOperandCount, delimiter);
  3455    }
  3456  
  3457    /// Parse zero or more SSA comma-separated operand or region arguments with
  3458    ///  optional surrounding delimiter and required operand count.
  3459    ParseResult
  3460    parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
  3461                                bool isOperandList, int requiredOperandCount = -1,
  3462                                Delimiter delimiter = Delimiter::None) {
  3463      auto startLoc = parser.getToken().getLoc();
  3464  
  3465      // Handle delimiters.
  3466      switch (delimiter) {
  3467      case Delimiter::None:
  3468        // Don't check for the absence of a delimiter if the number of operands
  3469        // is unknown (and hence the operand list could be empty).
  3470        if (requiredOperandCount == -1)
  3471          break;
  3472        // Token already matches an identifier and so can't be a delimiter.
  3473        if (parser.getToken().is(Token::percent_identifier))
  3474          break;
  3475        // Test against known delimiters.
  3476        if (parser.getToken().is(Token::l_paren) ||
  3477            parser.getToken().is(Token::l_square))
  3478          return emitError(startLoc, "unexpected delimiter");
  3479        return emitError(startLoc, "invalid operand");
  3480      case Delimiter::OptionalParen:
  3481        if (parser.getToken().isNot(Token::l_paren))
  3482          return success();
  3483        LLVM_FALLTHROUGH;
  3484      case Delimiter::Paren:
  3485        if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
  3486          return failure();
  3487        break;
  3488      case Delimiter::OptionalSquare:
  3489        if (parser.getToken().isNot(Token::l_square))
  3490          return success();
  3491        LLVM_FALLTHROUGH;
  3492      case Delimiter::Square:
  3493        if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
  3494          return failure();
  3495        break;
  3496      }
  3497  
  3498      // Check for zero operands.
  3499      if (parser.getToken().is(Token::percent_identifier)) {
  3500        do {
  3501          OperandType operandOrArg;
  3502          if (isOperandList ? parseOperand(operandOrArg)
  3503                            : parseRegionArgument(operandOrArg))
  3504            return failure();
  3505          result.push_back(operandOrArg);
  3506        } while (parser.consumeIf(Token::comma));
  3507      }
  3508  
  3509      // Handle delimiters.   If we reach here, the optional delimiters were
  3510      // present, so we need to parse their closing one.
  3511      switch (delimiter) {
  3512      case Delimiter::None:
  3513        break;
  3514      case Delimiter::OptionalParen:
  3515      case Delimiter::Paren:
  3516        if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
  3517          return failure();
  3518        break;
  3519      case Delimiter::OptionalSquare:
  3520      case Delimiter::Square:
  3521        if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
  3522          return failure();
  3523        break;
  3524      }
  3525  
  3526      if (requiredOperandCount != -1 &&
  3527          result.size() != static_cast<size_t>(requiredOperandCount))
  3528        return emitError(startLoc, "expected ")
  3529               << requiredOperandCount << " operands";
  3530      return success();
  3531    }
  3532  
  3533    /// Parse zero or more trailing SSA comma-separated trailing operand
  3534    /// references with a specified surrounding delimiter, and an optional
  3535    /// required operand count. A leading comma is expected before the operands.
  3536    ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
  3537                                         int requiredOperandCount,
  3538                                         Delimiter delimiter) override {
  3539      if (parser.getToken().is(Token::comma)) {
  3540        parseComma();
  3541        return parseOperandList(result, requiredOperandCount, delimiter);
  3542      }
  3543      if (requiredOperandCount != -1)
  3544        return emitError(parser.getToken().getLoc(), "expected ")
  3545               << requiredOperandCount << " operands";
  3546      return success();
  3547    }
  3548  
  3549    /// Resolve an operand to an SSA value, emitting an error on failure.
  3550    ParseResult resolveOperand(const OperandType &operand, Type type,
  3551                               SmallVectorImpl<Value *> &result) override {
  3552      OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number,
  3553                                                 operand.location};
  3554      if (auto *value = parser.resolveSSAUse(operandInfo, type)) {
  3555        result.push_back(value);
  3556        return success();
  3557      }
  3558      return failure();
  3559    }
  3560  
  3561    /// Parse an AffineMap of SSA ids.
  3562    ParseResult
  3563    parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands,
  3564                           Attribute &mapAttr, StringRef attrName,
  3565                           SmallVectorImpl<NamedAttribute> &attrs) override {
  3566      SmallVector<OperandType, 2> dimOperands;
  3567      SmallVector<OperandType, 1> symOperands;
  3568  
  3569      auto parseElement = [&](bool isSymbol) -> ParseResult {
  3570        OperandType operand;
  3571        if (parseOperand(operand))
  3572          return failure();
  3573        if (isSymbol)
  3574          symOperands.push_back(operand);
  3575        else
  3576          dimOperands.push_back(operand);
  3577        return success();
  3578      };
  3579  
  3580      AffineMap map;
  3581      if (parser.parseAffineMapOfSSAIds(map, parseElement))
  3582        return failure();
  3583      // Add AffineMap attribute.
  3584      if (map) {
  3585        mapAttr = parser.builder.getAffineMapAttr(map);
  3586        attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr));
  3587      }
  3588  
  3589      // Add dim operands before symbol operands in 'operands'.
  3590      operands.assign(dimOperands.begin(), dimOperands.end());
  3591      operands.append(symOperands.begin(), symOperands.end());
  3592      return success();
  3593    }
  3594  
  3595    //===--------------------------------------------------------------------===//
  3596    // Region Parsing
  3597    //===--------------------------------------------------------------------===//
  3598  
  3599    /// Parse a region that takes `arguments` of `argTypes` types.  This
  3600    /// effectively defines the SSA values of `arguments` and assignes their type.
  3601    ParseResult parseRegion(Region &region, ArrayRef<OperandType> arguments,
  3602                            ArrayRef<Type> argTypes,
  3603                            bool enableNameShadowing) override {
  3604      assert(arguments.size() == argTypes.size() &&
  3605             "mismatching number of arguments and types");
  3606  
  3607      SmallVector<std::pair<OperationParser::SSAUseInfo, Type>, 2>
  3608          regionArguments;
  3609      for (const auto &pair : llvm::zip(arguments, argTypes)) {
  3610        const OperandType &operand = std::get<0>(pair);
  3611        Type type = std::get<1>(pair);
  3612        OperationParser::SSAUseInfo operandInfo = {operand.name, operand.number,
  3613                                                   operand.location};
  3614        regionArguments.emplace_back(operandInfo, type);
  3615      }
  3616  
  3617      // Try to parse the region.
  3618      assert((!enableNameShadowing ||
  3619              opDefinition->hasProperty(OperationProperty::IsolatedFromAbove)) &&
  3620             "name shadowing is only allowed on isolated regions");
  3621      if (parser.parseRegion(region, regionArguments, enableNameShadowing))
  3622        return failure();
  3623      return success();
  3624    }
  3625  
  3626    /// Parses a region if present.
  3627    ParseResult parseOptionalRegion(Region &region,
  3628                                    ArrayRef<OperandType> arguments,
  3629                                    ArrayRef<Type> argTypes,
  3630                                    bool enableNameShadowing) override {
  3631      if (parser.getToken().isNot(Token::l_brace))
  3632        return success();
  3633      return parseRegion(region, arguments, argTypes, enableNameShadowing);
  3634    }
  3635  
  3636    /// Parse a region argument. The type of the argument will be resolved later
  3637    /// by a call to `parseRegion`.
  3638    ParseResult parseRegionArgument(OperandType &argument) override {
  3639      return parseOperand(argument);
  3640    }
  3641  
  3642    /// Parse a region argument if present.
  3643    ParseResult parseOptionalRegionArgument(OperandType &argument) override {
  3644      if (parser.getToken().isNot(Token::percent_identifier))
  3645        return success();
  3646      return parseRegionArgument(argument);
  3647    }
  3648  
  3649    ParseResult
  3650    parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
  3651                            int requiredOperandCount = -1,
  3652                            Delimiter delimiter = Delimiter::None) override {
  3653      return parseOperandOrRegionArgList(result, /*isOperandList=*/false,
  3654                                         requiredOperandCount, delimiter);
  3655    }
  3656  
  3657    //===--------------------------------------------------------------------===//
  3658    // Successor Parsing
  3659    //===--------------------------------------------------------------------===//
  3660  
  3661    /// Parse a single operation successor and its operand list.
  3662    ParseResult
  3663    parseSuccessorAndUseList(Block *&dest,
  3664                             SmallVectorImpl<Value *> &operands) override {
  3665      return parser.parseSuccessorAndUseList(dest, operands);
  3666    }
  3667  
  3668    //===--------------------------------------------------------------------===//
  3669    // Type Parsing
  3670    //===--------------------------------------------------------------------===//
  3671  
  3672    /// Parse a type.
  3673    ParseResult parseType(Type &result) override {
  3674      return failure(!(result = parser.parseType()));
  3675    }
  3676  
  3677    /// Parse an optional arrow followed by a type list.
  3678    ParseResult
  3679    parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
  3680      if (!parser.consumeIf(Token::arrow))
  3681        return success();
  3682      return parser.parseFunctionResultTypes(result);
  3683    }
  3684  
  3685    /// Parse a colon followed by a type.
  3686    ParseResult parseColonType(Type &result) override {
  3687      return failure(parser.parseToken(Token::colon, "expected ':'") ||
  3688                     !(result = parser.parseType()));
  3689    }
  3690  
  3691    /// Parse a colon followed by a type list, which must have at least one type.
  3692    ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
  3693      if (parser.parseToken(Token::colon, "expected ':'"))
  3694        return failure();
  3695      return parser.parseTypeListNoParens(result);
  3696    }
  3697  
  3698    /// Parse an optional colon followed by a type list, which if present must
  3699    /// have at least one type.
  3700    ParseResult
  3701    parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
  3702      if (!parser.consumeIf(Token::colon))
  3703        return success();
  3704      return parser.parseTypeListNoParens(result);
  3705    }
  3706  
  3707  private:
  3708    /// The source location of the operation name.
  3709    SMLoc nameLoc;
  3710  
  3711    /// The abstract information of the operation.
  3712    const AbstractOperation *opDefinition;
  3713  
  3714    /// The main operation parser.
  3715    OperationParser &parser;
  3716  
  3717    /// A flag that indicates if any errors were emitted during parsing.
  3718    bool emittedError = false;
  3719  };
  3720  } // end anonymous namespace.
  3721  
  3722  Operation *OperationParser::parseCustomOperation() {
  3723    auto opLoc = getToken().getLoc();
  3724    auto opName = getTokenSpelling();
  3725  
  3726    auto *opDefinition = AbstractOperation::lookup(opName, getContext());
  3727    if (!opDefinition && !opName.contains('.')) {
  3728      // If the operation name has no namespace prefix we treat it as a standard
  3729      // operation and prefix it with "std".
  3730      // TODO: Would it be better to just build a mapping of the registered
  3731      // operations in the standard dialect?
  3732      opDefinition =
  3733          AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
  3734    }
  3735  
  3736    if (!opDefinition) {
  3737      emitError(opLoc) << "custom op '" << opName << "' is unknown";
  3738      return nullptr;
  3739    }
  3740  
  3741    consumeToken();
  3742  
  3743    // If the custom op parser crashes, produce some indication to help
  3744    // debugging.
  3745    std::string opNameStr = opName.str();
  3746    llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
  3747                                     opNameStr.c_str());
  3748  
  3749    // Get location information for the operation.
  3750    auto srcLocation = getEncodedSourceLocation(opLoc);
  3751  
  3752    // Have the op implementation take a crack and parsing this.
  3753    OperationState opState(srcLocation, opDefinition->name);
  3754    CleanupOpStateRegions guard{opState};
  3755    CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this);
  3756    if (opAsmParser.parseOperation(&opState))
  3757      return nullptr;
  3758  
  3759    // If it emitted an error, we failed.
  3760    if (opAsmParser.didEmitError())
  3761      return nullptr;
  3762  
  3763    // Otherwise, we succeeded.  Use the state it parsed as our op information.
  3764    return opBuilder.createOperation(opState);
  3765  }
  3766  
  3767  //===----------------------------------------------------------------------===//
  3768  // Region Parsing
  3769  //===----------------------------------------------------------------------===//
  3770  
  3771  /// Region.
  3772  ///
  3773  ///   region ::= '{' region-body
  3774  ///
  3775  ParseResult OperationParser::parseRegion(
  3776      Region &region,
  3777      ArrayRef<std::pair<OperationParser::SSAUseInfo, Type>> entryArguments,
  3778      bool isIsolatedNameScope) {
  3779    // Parse the '{'.
  3780    if (parseToken(Token::l_brace, "expected '{' to begin a region"))
  3781      return failure();
  3782  
  3783    // Check for an empty region.
  3784    if (entryArguments.empty() && consumeIf(Token::r_brace))
  3785      return success();
  3786    auto currentPt = opBuilder.saveInsertionPoint();
  3787  
  3788    // Push a new named value scope.
  3789    pushSSANameScope(isIsolatedNameScope);
  3790  
  3791    // Parse the first block directly to allow for it to be unnamed.
  3792    Block *block = new Block();
  3793  
  3794    // Add arguments to the entry block.
  3795    if (!entryArguments.empty()) {
  3796      for (auto &placeholderArgPair : entryArguments) {
  3797        auto &argInfo = placeholderArgPair.first;
  3798        // Ensure that the argument was not already defined.
  3799        if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) {
  3800          return emitError(argInfo.loc, "region entry argument '" + argInfo.name +
  3801                                            "' is already in use")
  3802                     .attachNote(getEncodedSourceLocation(*defLoc))
  3803                 << "previously referenced here";
  3804        }
  3805        if (addDefinition(placeholderArgPair.first,
  3806                          block->addArgument(placeholderArgPair.second))) {
  3807          delete block;
  3808          return failure();
  3809        }
  3810      }
  3811  
  3812      // If we had named arguments, then don't allow a block name.
  3813      if (getToken().is(Token::caret_identifier))
  3814        return emitError("invalid block name in region with named arguments");
  3815    }
  3816  
  3817    if (parseBlock(block)) {
  3818      delete block;
  3819      return failure();
  3820    }
  3821  
  3822    // Verify that no other arguments were parsed.
  3823    if (!entryArguments.empty() &&
  3824        block->getNumArguments() > entryArguments.size()) {
  3825      delete block;
  3826      return emitError("entry block arguments were already defined");
  3827    }
  3828  
  3829    // Parse the rest of the region.
  3830    region.push_back(block);
  3831    if (parseRegionBody(region))
  3832      return failure();
  3833  
  3834    // Pop the SSA value scope for this region.
  3835    if (popSSANameScope())
  3836      return failure();
  3837  
  3838    // Reset the original insertion point.
  3839    opBuilder.restoreInsertionPoint(currentPt);
  3840    return success();
  3841  }
  3842  
  3843  /// Region.
  3844  ///
  3845  ///   region-body ::= block* '}'
  3846  ///
  3847  ParseResult OperationParser::parseRegionBody(Region &region) {
  3848    // Parse the list of blocks.
  3849    while (!consumeIf(Token::r_brace)) {
  3850      Block *newBlock = nullptr;
  3851      if (parseBlock(newBlock))
  3852        return failure();
  3853      region.push_back(newBlock);
  3854    }
  3855    return success();
  3856  }
  3857  
  3858  //===----------------------------------------------------------------------===//
  3859  // Block Parsing
  3860  //===----------------------------------------------------------------------===//
  3861  
  3862  /// Block declaration.
  3863  ///
  3864  ///   block ::= block-label? operation*
  3865  ///   block-label    ::= block-id block-arg-list? `:`
  3866  ///   block-id       ::= caret-id
  3867  ///   block-arg-list ::= `(` ssa-id-and-type-list? `)`
  3868  ///
  3869  ParseResult OperationParser::parseBlock(Block *&block) {
  3870    // The first block of a region may already exist, if it does the caret
  3871    // identifier is optional.
  3872    if (block && getToken().isNot(Token::caret_identifier))
  3873      return parseBlockBody(block);
  3874  
  3875    SMLoc nameLoc = getToken().getLoc();
  3876    auto name = getTokenSpelling();
  3877    if (parseToken(Token::caret_identifier, "expected block name"))
  3878      return failure();
  3879  
  3880    block = defineBlockNamed(name, nameLoc, block);
  3881  
  3882    // Fail if the block was already defined.
  3883    if (!block)
  3884      return emitError(nameLoc, "redefinition of block '") << name << "'";
  3885  
  3886    // If an argument list is present, parse it.
  3887    if (consumeIf(Token::l_paren)) {
  3888      SmallVector<BlockArgument *, 8> bbArgs;
  3889      if (parseOptionalBlockArgList(bbArgs, block) ||
  3890          parseToken(Token::r_paren, "expected ')' to end argument list"))
  3891        return failure();
  3892    }
  3893  
  3894    if (parseToken(Token::colon, "expected ':' after block name"))
  3895      return failure();
  3896  
  3897    return parseBlockBody(block);
  3898  }
  3899  
  3900  ParseResult OperationParser::parseBlockBody(Block *block) {
  3901    // Set the insertion point to the end of the block to parse.
  3902    opBuilder.setInsertionPointToEnd(block);
  3903  
  3904    // Parse the list of operations that make up the body of the block.
  3905    while (getToken().isNot(Token::caret_identifier, Token::r_brace))
  3906      if (parseOperation())
  3907        return failure();
  3908  
  3909    return success();
  3910  }
  3911  
  3912  /// Get the block with the specified name, creating it if it doesn't already
  3913  /// exist.  The location specified is the point of use, which allows
  3914  /// us to diagnose references to blocks that are not defined precisely.
  3915  Block *OperationParser::getBlockNamed(StringRef name, SMLoc loc) {
  3916    auto &blockAndLoc = getBlockInfoByName(name);
  3917    if (!blockAndLoc.first) {
  3918      blockAndLoc = {new Block(), loc};
  3919      insertForwardRef(blockAndLoc.first, loc);
  3920    }
  3921  
  3922    return blockAndLoc.first;
  3923  }
  3924  
  3925  /// Define the block with the specified name. Returns the Block* or nullptr in
  3926  /// the case of redefinition.
  3927  Block *OperationParser::defineBlockNamed(StringRef name, SMLoc loc,
  3928                                           Block *existing) {
  3929    auto &blockAndLoc = getBlockInfoByName(name);
  3930    if (!blockAndLoc.first) {
  3931      // If the caller provided a block, use it.  Otherwise create a new one.
  3932      if (!existing)
  3933        existing = new Block();
  3934      blockAndLoc.first = existing;
  3935      blockAndLoc.second = loc;
  3936      return blockAndLoc.first;
  3937    }
  3938  
  3939    // Forward declarations are removed once defined, so if we are defining a
  3940    // existing block and it is not a forward declaration, then it is a
  3941    // redeclaration.
  3942    if (!eraseForwardRef(blockAndLoc.first))
  3943      return nullptr;
  3944    return blockAndLoc.first;
  3945  }
  3946  
  3947  /// Parse a (possibly empty) list of SSA operands with types as block arguments.
  3948  ///
  3949  ///   ssa-id-and-type-list ::= ssa-id-and-type (`,` ssa-id-and-type)*
  3950  ///
  3951  ParseResult OperationParser::parseOptionalBlockArgList(
  3952      SmallVectorImpl<BlockArgument *> &results, Block *owner) {
  3953    if (getToken().is(Token::r_brace))
  3954      return success();
  3955  
  3956    // If the block already has arguments, then we're handling the entry block.
  3957    // Parse and register the names for the arguments, but do not add them.
  3958    bool definingExistingArgs = owner->getNumArguments() != 0;
  3959    unsigned nextArgument = 0;
  3960  
  3961    return parseCommaSeparatedList([&]() -> ParseResult {
  3962      return parseSSADefOrUseAndType(
  3963          [&](SSAUseInfo useInfo, Type type) -> ParseResult {
  3964            // If this block did not have existing arguments, define a new one.
  3965            if (!definingExistingArgs)
  3966              return addDefinition(useInfo, owner->addArgument(type));
  3967  
  3968            // Otherwise, ensure that this argument has already been created.
  3969            if (nextArgument >= owner->getNumArguments())
  3970              return emitError("too many arguments specified in argument list");
  3971  
  3972            // Finally, make sure the existing argument has the correct type.
  3973            auto *arg = owner->getArgument(nextArgument++);
  3974            if (arg->getType() != type)
  3975              return emitError("argument and block argument type mismatch");
  3976            return addDefinition(useInfo, arg);
  3977          });
  3978    });
  3979  }
  3980  
  3981  //===----------------------------------------------------------------------===//
  3982  // Top-level entity parsing.
  3983  //===----------------------------------------------------------------------===//
  3984  
  3985  namespace {
  3986  /// This parser handles entities that are only valid at the top level of the
  3987  /// file.
  3988  class ModuleParser : public Parser {
  3989  public:
  3990    explicit ModuleParser(ParserState &state) : Parser(state) {}
  3991  
  3992    ParseResult parseModule(ModuleOp module);
  3993  
  3994  private:
  3995    /// Parse an attribute alias declaration.
  3996    ParseResult parseAttributeAliasDef();
  3997  
  3998    /// Parse an attribute alias declaration.
  3999    ParseResult parseTypeAliasDef();
  4000  };
  4001  } // end anonymous namespace
  4002  
  4003  /// Parses an attribute alias declaration.
  4004  ///
  4005  ///   attribute-alias-def ::= '#' alias-name `=` attribute-value
  4006  ///
  4007  ParseResult ModuleParser::parseAttributeAliasDef() {
  4008    assert(getToken().is(Token::hash_identifier));
  4009    StringRef aliasName = getTokenSpelling().drop_front();
  4010  
  4011    // Check for redefinitions.
  4012    if (getState().attributeAliasDefinitions.count(aliasName) > 0)
  4013      return emitError("redefinition of attribute alias id '" + aliasName + "'");
  4014  
  4015    // Make sure this isn't invading the dialect attribute namespace.
  4016    if (aliasName.contains('.'))
  4017      return emitError("attribute names with a '.' are reserved for "
  4018                       "dialect-defined names");
  4019  
  4020    consumeToken(Token::hash_identifier);
  4021  
  4022    // Parse the '='.
  4023    if (parseToken(Token::equal, "expected '=' in attribute alias definition"))
  4024      return failure();
  4025  
  4026    // Parse the attribute value.
  4027    Attribute attr = parseAttribute();
  4028    if (!attr)
  4029      return failure();
  4030  
  4031    getState().attributeAliasDefinitions[aliasName] = attr;
  4032    return success();
  4033  }
  4034  
  4035  /// Parse a type alias declaration.
  4036  ///
  4037  ///   type-alias-def ::= '!' alias-name `=` 'type' type
  4038  ///
  4039  ParseResult ModuleParser::parseTypeAliasDef() {
  4040    assert(getToken().is(Token::exclamation_identifier));
  4041    StringRef aliasName = getTokenSpelling().drop_front();
  4042  
  4043    // Check for redefinitions.
  4044    if (getState().typeAliasDefinitions.count(aliasName) > 0)
  4045      return emitError("redefinition of type alias id '" + aliasName + "'");
  4046  
  4047    // Make sure this isn't invading the dialect type namespace.
  4048    if (aliasName.contains('.'))
  4049      return emitError("type names with a '.' are reserved for "
  4050                       "dialect-defined names");
  4051  
  4052    consumeToken(Token::exclamation_identifier);
  4053  
  4054    // Parse the '=' and 'type'.
  4055    if (parseToken(Token::equal, "expected '=' in type alias definition") ||
  4056        parseToken(Token::kw_type, "expected 'type' in type alias definition"))
  4057      return failure();
  4058  
  4059    // Parse the type.
  4060    Type aliasedType = parseType();
  4061    if (!aliasedType)
  4062      return failure();
  4063  
  4064    // Register this alias with the parser state.
  4065    getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType);
  4066    return success();
  4067  }
  4068  
  4069  /// This is the top-level module parser.
  4070  ParseResult ModuleParser::parseModule(ModuleOp module) {
  4071    OperationParser opParser(getState(), module);
  4072  
  4073    // Module itself is a name scope.
  4074    opParser.pushSSANameScope(/*isIsolated=*/true);
  4075  
  4076    while (1) {
  4077      switch (getToken().getKind()) {
  4078      default:
  4079        // Parse a top-level operation.
  4080        if (opParser.parseOperation())
  4081          return failure();
  4082        break;
  4083  
  4084      // If we got to the end of the file, then we're done.
  4085      case Token::eof: {
  4086        if (opParser.finalize())
  4087          return failure();
  4088  
  4089        // Handle the case where the top level module was explicitly defined.
  4090        auto &bodyBlocks = module.getBodyRegion().getBlocks();
  4091        auto &operations = bodyBlocks.front().getOperations();
  4092        assert(!operations.empty() && "expected a valid module terminator");
  4093  
  4094        // Check that the first operation is a module, and it is the only
  4095        // non-terminator operation.
  4096        ModuleOp nested = dyn_cast<ModuleOp>(operations.front());
  4097        if (nested && std::next(operations.begin(), 2) == operations.end()) {
  4098          // Merge the data of the nested module operation into 'module'.
  4099          module.setLoc(nested.getLoc());
  4100          module.setAttrs(nested.getOperation()->getAttrList());
  4101          bodyBlocks.splice(bodyBlocks.end(), nested.getBodyRegion().getBlocks());
  4102  
  4103          // Erase the original module body.
  4104          bodyBlocks.pop_front();
  4105        }
  4106  
  4107        return opParser.popSSANameScope();
  4108      }
  4109  
  4110      // If we got an error token, then the lexer already emitted an error, just
  4111      // stop.  Someday we could introduce error recovery if there was demand
  4112      // for it.
  4113      case Token::error:
  4114        return failure();
  4115  
  4116      // Parse an attribute alias.
  4117      case Token::hash_identifier:
  4118        if (parseAttributeAliasDef())
  4119          return failure();
  4120        break;
  4121  
  4122      // Parse a type alias.
  4123      case Token::exclamation_identifier:
  4124        if (parseTypeAliasDef())
  4125          return failure();
  4126        break;
  4127      }
  4128    }
  4129  }
  4130  
  4131  //===----------------------------------------------------------------------===//
  4132  
  4133  /// This parses the file specified by the indicated SourceMgr and returns an
  4134  /// MLIR module if it was valid.  If not, it emits diagnostics and returns
  4135  /// null.
  4136  OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
  4137                                        MLIRContext *context) {
  4138    auto sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
  4139  
  4140    // This is the result module we are parsing into.
  4141    OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
  4142        sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
  4143  
  4144    ParserState state(sourceMgr, context);
  4145    if (ModuleParser(state).parseModule(*module))
  4146      return nullptr;
  4147  
  4148    // Make sure the parse module has no other structural problems detected by
  4149    // the verifier.
  4150    if (failed(verify(*module)))
  4151      return nullptr;
  4152  
  4153    return module;
  4154  }
  4155  
  4156  /// This parses the file specified by the indicated filename and returns an
  4157  /// MLIR module if it was valid.  If not, the error message is emitted through
  4158  /// the error handler registered in the context, and a null pointer is returned.
  4159  OwningModuleRef mlir::parseSourceFile(StringRef filename,
  4160                                        MLIRContext *context) {
  4161    llvm::SourceMgr sourceMgr;
  4162    return parseSourceFile(filename, sourceMgr, context);
  4163  }
  4164  
  4165  /// This parses the file specified by the indicated filename using the provided
  4166  /// SourceMgr and returns an MLIR module if it was valid.  If not, the error
  4167  /// message is emitted through the error handler registered in the context, and
  4168  /// a null pointer is returned.
  4169  OwningModuleRef mlir::parseSourceFile(StringRef filename,
  4170                                        llvm::SourceMgr &sourceMgr,
  4171                                        MLIRContext *context) {
  4172    if (sourceMgr.getNumBuffers() != 0) {
  4173      // TODO(b/136086478): Extend to support multiple buffers.
  4174      emitError(mlir::UnknownLoc::get(context),
  4175                "only main buffer parsed at the moment");
  4176      return nullptr;
  4177    }
  4178    auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(filename);
  4179    if (std::error_code error = file_or_err.getError()) {
  4180      emitError(mlir::UnknownLoc::get(context),
  4181                "could not open input file " + filename);
  4182      return nullptr;
  4183    }
  4184  
  4185    // Load the MLIR module.
  4186    sourceMgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
  4187    return parseSourceFile(sourceMgr, context);
  4188  }
  4189  
  4190  /// This parses the program string to a MLIR module if it was valid. If not,
  4191  /// it emits diagnostics and returns null.
  4192  OwningModuleRef mlir::parseSourceString(StringRef moduleStr,
  4193                                          MLIRContext *context) {
  4194    auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr);
  4195    if (!memBuffer)
  4196      return nullptr;
  4197  
  4198    SourceMgr sourceMgr;
  4199    sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
  4200    return parseSourceFile(sourceMgr, context);
  4201  }
  4202  
  4203  Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) {
  4204    SourceMgr sourceMgr;
  4205    auto memBuffer =
  4206        MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"<mlir_type_buffer>",
  4207                                   /*RequiresNullTerminator=*/false);
  4208    sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
  4209    SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
  4210    ParserState state(sourceMgr, context);
  4211    Parser parser(state);
  4212    auto start = parser.getToken().getLoc();
  4213    auto ty = parser.parseType();
  4214    if (!ty)
  4215      return Type();
  4216  
  4217    auto end = parser.getToken().getLoc();
  4218    auto read = end.getPointer() - start.getPointer();
  4219    // Make sure that the parsing of type consumes the entire string
  4220    if (static_cast<size_t>(read) < typeStr.size()) {
  4221      parser.emitError("unexpected additional tokens: '")
  4222          << typeStr.substr(read) << "' after parsing type: " << ty;
  4223      return Type();
  4224    }
  4225    return ty;
  4226  }