github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp (about)

     1  //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file defines the operations in the SPIR-V dialect.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Dialect/SPIRV/SPIRVOps.h"
    23  
    24  #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
    25  #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
    26  #include "mlir/IR/Builders.h"
    27  #include "mlir/IR/Function.h"
    28  #include "mlir/IR/OpImplementation.h"
    29  #include "mlir/IR/StandardTypes.h"
    30  #include "mlir/Support/StringExtras.h"
    31  
    32  using namespace mlir;
    33  
    34  // TODO(antiagainst): generate these strings using ODS.
    35  static constexpr const char kAlignmentAttrName[] = "alignment";
    36  static constexpr const char kBranchWeightAttrName[] = "branch_weights";
    37  static constexpr const char kDefaultValueAttrName[] = "default_value";
    38  static constexpr const char kFnNameAttrName[] = "fn";
    39  static constexpr const char kIndicesAttrName[] = "indices";
    40  static constexpr const char kInitializerAttrName[] = "initializer";
    41  static constexpr const char kInterfaceAttrName[] = "interface";
    42  static constexpr const char kSpecConstAttrName[] = "spec_const";
    43  static constexpr const char kTypeAttrName[] = "type";
    44  static constexpr const char kValueAttrName[] = "value";
    45  static constexpr const char kValuesAttrName[] = "values";
    46  static constexpr const char kVariableAttrName[] = "variable";
    47  
    48  //===----------------------------------------------------------------------===//
    49  // Common utility functions
    50  //===----------------------------------------------------------------------===//
    51  
    52  template <typename Dst, typename Src>
    53  inline Dst bitwiseCast(Src source) noexcept {
    54    Dst dest;
    55    static_assert(sizeof(source) == sizeof(dest),
    56                  "bitwiseCast requires same source and destination bitwidth");
    57    std::memcpy(&dest, &source, sizeof(dest));
    58    return dest;
    59  }
    60  
    61  static LogicalResult extractValueFromConstOp(Operation *op,
    62                                               int32_t &indexValue) {
    63    auto constOp = dyn_cast<spirv::ConstantOp>(op);
    64    if (!constOp) {
    65      return failure();
    66    }
    67    auto valueAttr = constOp.value();
    68    auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
    69    if (!integerValueAttr) {
    70      return failure();
    71    }
    72    indexValue = integerValueAttr.getInt();
    73    return success();
    74  }
    75  
    76  static ParseResult parseBinaryLogicalOp(OpAsmParser *parser,
    77                                          OperationState *result) {
    78    SmallVector<OpAsmParser::OperandType, 2> ops;
    79    Type type;
    80    if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) ||
    81        parser->resolveOperands(ops, type, result->operands)) {
    82      return failure();
    83    }
    84    // Result must be a scalar or vector of boolean type.
    85    Type resultType = parser->getBuilder().getIntegerType(1);
    86    if (auto opsType = type.dyn_cast<VectorType>()) {
    87      resultType = VectorType::get(opsType.getNumElements(), resultType);
    88    }
    89    result->addTypes(resultType);
    90    return success();
    91  }
    92  
    93  template <typename EnumClass>
    94  static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) {
    95    Attribute attrVal;
    96    SmallVector<NamedAttribute, 1> attr;
    97    auto loc = parser->getCurrentLocation();
    98    if (parser->parseAttribute(attrVal, parser->getBuilder().getNoneType(),
    99                               spirv::attributeName<EnumClass>(), attr)) {
   100      return failure();
   101    }
   102    if (!attrVal.isa<StringAttr>()) {
   103      return parser->emitError(loc, "expected ")
   104             << spirv::attributeName<EnumClass>()
   105             << " attribute specified as string";
   106    }
   107    auto attrOptional =
   108        spirv::symbolizeEnum<EnumClass>()(attrVal.cast<StringAttr>().getValue());
   109    if (!attrOptional) {
   110      return parser->emitError(loc, "invalid ")
   111             << spirv::attributeName<EnumClass>()
   112             << " attribute specification: " << attrVal;
   113    }
   114    value = attrOptional.getValue();
   115    return success();
   116  }
   117  
   118  template <typename EnumClass>
   119  static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
   120                                        OperationState *state) {
   121    if (parseEnumAttribute(value, parser)) {
   122      return failure();
   123    }
   124    state->addAttribute(
   125        spirv::attributeName<EnumClass>(),
   126        parser->getBuilder().getI32IntegerAttr(bitwiseCast<int32_t>(value)));
   127    return success();
   128  }
   129  
   130  static ParseResult parseMemoryAccessAttributes(OpAsmParser *parser,
   131                                                 OperationState *state) {
   132    // Parse an optional list of attributes staring with '['
   133    if (parser->parseOptionalLSquare()) {
   134      // Nothing to do
   135      return success();
   136    }
   137  
   138    spirv::MemoryAccess memoryAccessAttr;
   139    if (parseEnumAttribute(memoryAccessAttr, parser, state)) {
   140      return failure();
   141    }
   142  
   143    if (memoryAccessAttr == spirv::MemoryAccess::Aligned) {
   144      // Parse integer attribute for alignment.
   145      Attribute alignmentAttr;
   146      Type i32Type = parser->getBuilder().getIntegerType(32);
   147      if (parser->parseComma() ||
   148          parser->parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
   149                                 state->attributes)) {
   150        return failure();
   151      }
   152    }
   153    return parser->parseRSquare();
   154  }
   155  
   156  // Parses an op that has no inputs and no outputs.
   157  static ParseResult parseNoIOOp(OpAsmParser *parser, OperationState *state) {
   158    if (parser->parseOptionalAttributeDict(state->attributes))
   159      return failure();
   160    return success();
   161  }
   162  
   163  static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter *printer) {
   164    *printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", "
   165             << *logicalOp->getOperand(1);
   166    *printer << " : " << logicalOp->getOperand(0)->getType();
   167  }
   168  
   169  template <typename LoadStoreOpTy>
   170  static void
   171  printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter *printer,
   172                             SmallVectorImpl<StringRef> &elidedAttrs) {
   173    // Print optional memory access attribute.
   174    if (auto memAccess = loadStoreOp.memory_access()) {
   175      elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
   176      *printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
   177  
   178      // Print integer alignment attribute.
   179      if (auto alignment = loadStoreOp.alignment()) {
   180        elidedAttrs.push_back(kAlignmentAttrName);
   181        *printer << ", " << alignment;
   182      }
   183      *printer << "]";
   184    }
   185    elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
   186  }
   187  
   188  template <typename LoadStoreOpTy>
   189  static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
   190    // ODS checks for attributes values. Just need to verify that if the
   191    // memory-access attribute is Aligned, then the alignment attribute must be
   192    // present.
   193    auto *op = loadStoreOp.getOperation();
   194    auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
   195    if (!memAccessAttr) {
   196      // Alignment attribute shouldn't be present if memory access attribute is
   197      // not present.
   198      if (op->getAttr(kAlignmentAttrName)) {
   199        return loadStoreOp.emitOpError(
   200            "invalid alignment specification without aligned memory access "
   201            "specification");
   202      }
   203      return success();
   204    }
   205  
   206    auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
   207    auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
   208  
   209    if (!memAccess) {
   210      return loadStoreOp.emitOpError("invalid memory access specifier: ")
   211             << memAccessVal;
   212    }
   213  
   214    if (*memAccess == spirv::MemoryAccess::Aligned) {
   215      if (!op->getAttr(kAlignmentAttrName)) {
   216        return loadStoreOp.emitOpError("missing alignment value");
   217      }
   218    } else {
   219      if (op->getAttr(kAlignmentAttrName)) {
   220        return loadStoreOp.emitOpError(
   221            "invalid alignment specification with non-aligned memory access "
   222            "specification");
   223      }
   224    }
   225    return success();
   226  }
   227  
   228  template <typename LoadStoreOpTy>
   229  static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr,
   230                                                     Value *val) {
   231    // ODS already checks ptr is spirv::PointerType. Just check that the pointee
   232    // type of the pointer and the type of the value are the same
   233    //
   234    // TODO(ravishankarm): Check that the value type satisfies restrictions of
   235    // SPIR-V OpLoad/OpStore operations
   236    if (val->getType() !=
   237        ptr->getType().cast<spirv::PointerType>().getPointeeType()) {
   238      return op.emitOpError("mismatch in result type and pointer type");
   239    }
   240    return success();
   241  }
   242  
   243  // Prints an op that has no inputs and no outputs.
   244  static void printNoIOOp(Operation *op, OpAsmPrinter *printer) {
   245    *printer << op->getName();
   246    printer->printOptionalAttrDict(op->getAttrs());
   247  }
   248  
   249  static ParseResult parseVariableDecorations(OpAsmParser *parser,
   250                                              OperationState *state) {
   251    auto builtInName =
   252        convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
   253    if (succeeded(parser->parseOptionalKeyword("bind"))) {
   254      Attribute set, binding;
   255      // Parse optional descriptor binding
   256      auto descriptorSetName = convertToSnakeCase(
   257          stringifyDecoration(spirv::Decoration::DescriptorSet));
   258      auto bindingName =
   259          convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
   260      Type i32Type = parser->getBuilder().getIntegerType(32);
   261      if (parser->parseLParen() ||
   262          parser->parseAttribute(set, i32Type, descriptorSetName,
   263                                 state->attributes) ||
   264          parser->parseComma() ||
   265          parser->parseAttribute(binding, i32Type, bindingName,
   266                                 state->attributes) ||
   267          parser->parseRParen()) {
   268        return failure();
   269      }
   270    } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
   271      StringAttr builtIn;
   272      if (parser->parseLParen() ||
   273          parser->parseAttribute(builtIn, Type(), builtInName,
   274                                 state->attributes) ||
   275          parser->parseRParen()) {
   276        return failure();
   277      }
   278    }
   279  
   280    // Parse other attributes
   281    if (parser->parseOptionalAttributeDict(state->attributes))
   282      return failure();
   283  
   284    return success();
   285  }
   286  
   287  static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
   288                                       SmallVectorImpl<StringRef> &elidedAttrs) {
   289    // Print optional descriptor binding
   290    auto descriptorSetName =
   291        convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
   292    auto bindingName =
   293        convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
   294    auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
   295    auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
   296    if (descriptorSet && binding) {
   297      elidedAttrs.push_back(descriptorSetName);
   298      elidedAttrs.push_back(bindingName);
   299      *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
   300               << ")";
   301    }
   302  
   303    // Print BuiltIn attribute if present
   304    auto builtInName =
   305        convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
   306    if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
   307      *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
   308      elidedAttrs.push_back(builtInName);
   309    }
   310  
   311    printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
   312  }
   313  
   314  //===----------------------------------------------------------------------===//
   315  // spv.AccessChainOp
   316  //===----------------------------------------------------------------------===//
   317  
   318  static Type getElementPtrType(Type type, ArrayRef<Value *> indices,
   319                                Location baseLoc) {
   320    if (indices.empty()) {
   321      emitError(baseLoc, "'spv.AccessChain' op expected at least "
   322                         "one index ");
   323      return nullptr;
   324    }
   325  
   326    auto ptrType = type.dyn_cast<spirv::PointerType>();
   327    if (!ptrType) {
   328      emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
   329                         "to composite type, but provided ")
   330          << type;
   331      return nullptr;
   332    }
   333  
   334    auto resultType = ptrType.getPointeeType();
   335    auto resultStorageClass = ptrType.getStorageClass();
   336    int32_t index = 0;
   337  
   338    for (auto indexSSA : indices) {
   339      auto cType = resultType.dyn_cast<spirv::CompositeType>();
   340      if (!cType) {
   341        emitError(baseLoc,
   342                  "'spv.AccessChain' op cannot extract from non-composite type ")
   343            << resultType << " with index " << index;
   344        return nullptr;
   345      }
   346      index = 0;
   347      if (resultType.isa<spirv::StructType>()) {
   348        Operation *op = indexSSA->getDefiningOp();
   349        if (!op) {
   350          emitError(baseLoc, "'spv.AccessChain' op index must be an "
   351                             "integer spv.constant to access "
   352                             "element of spv.struct");
   353          return nullptr;
   354        }
   355  
   356        // TODO(denis0x0D): this should be relaxed to allow
   357        // integer literals of other bitwidths.
   358        if (failed(extractValueFromConstOp(op, index))) {
   359          emitError(baseLoc,
   360                    "'spv.AccessChain' index must be an integer spv.constant to "
   361                    "access element of spv.struct, but provided ")
   362              << op->getName();
   363          return nullptr;
   364        }
   365        if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
   366          emitError(baseLoc, "'spv.AccessChain' op index ")
   367              << index << " out of bounds for " << resultType;
   368          return nullptr;
   369        }
   370      }
   371      resultType = cType.getElementType(index);
   372    }
   373    return spirv::PointerType::get(resultType, resultStorageClass);
   374  }
   375  
   376  void spirv::AccessChainOp::build(Builder *builder, OperationState *state,
   377                                   Value *basePtr, ArrayRef<Value *> indices) {
   378    auto type = getElementPtrType(basePtr->getType(), indices, state->location);
   379    assert(type && "Unable to deduce return type based on basePtr and indices");
   380    build(builder, state, type, basePtr, indices);
   381  }
   382  
   383  static ParseResult parseAccessChainOp(OpAsmParser *parser,
   384                                        OperationState *state) {
   385    OpAsmParser::OperandType ptrInfo;
   386    SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
   387    Type type;
   388    // TODO(denis0x0D): regarding to the spec an index must be any integer type,
   389    // figure out how to use resolveOperand with a range of types and do not
   390    // fail on first attempt.
   391    Type indicesType = parser->getBuilder().getIntegerType(32);
   392  
   393    if (parser->parseOperand(ptrInfo) ||
   394        parser->parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
   395        parser->parseColonType(type) ||
   396        parser->resolveOperand(ptrInfo, type, state->operands) ||
   397        parser->resolveOperands(indicesInfo, indicesType, state->operands)) {
   398      return failure();
   399    }
   400  
   401    auto resultType = getElementPtrType(
   402        type, llvm::makeArrayRef(state->operands).drop_front(), state->location);
   403    if (!resultType) {
   404      return failure();
   405    }
   406  
   407    state->addTypes(resultType);
   408    return success();
   409  }
   410  
   411  static void print(spirv::AccessChainOp op, OpAsmPrinter *printer) {
   412    *printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
   413             << '[';
   414    printer->printOperands(op.indices());
   415    *printer << "] : " << op.base_ptr()->getType();
   416  }
   417  
   418  static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
   419    SmallVector<Value *, 4> indices(accessChainOp.indices().begin(),
   420                                    accessChainOp.indices().end());
   421    auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(),
   422                                        indices, accessChainOp.getLoc());
   423    if (!resultType) {
   424      return failure();
   425    }
   426  
   427    auto providedResultType =
   428        accessChainOp.getType().dyn_cast<spirv::PointerType>();
   429    if (!providedResultType) {
   430      return accessChainOp.emitOpError(
   431                 "result type must be a pointer, but provided")
   432             << providedResultType;
   433    }
   434  
   435    if (resultType != providedResultType) {
   436      return accessChainOp.emitOpError("invalid result type: expected ")
   437             << resultType << ", but provided " << providedResultType;
   438    }
   439  
   440    return success();
   441  }
   442  
   443  //===----------------------------------------------------------------------===//
   444  // spv._address_of
   445  //===----------------------------------------------------------------------===//
   446  
   447  static ParseResult parseAddressOfOp(OpAsmParser *parser,
   448                                      OperationState *state) {
   449    SymbolRefAttr varRefAttr;
   450    Type type;
   451    if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName,
   452                               state->attributes) ||
   453        parser->parseColonType(type)) {
   454      return failure();
   455    }
   456    auto ptrType = type.dyn_cast<spirv::PointerType>();
   457    if (!ptrType) {
   458      return parser->emitError(parser->getCurrentLocation(),
   459                               "expected spv.ptr type");
   460    }
   461    state->addTypes(ptrType);
   462    return success();
   463  }
   464  
   465  static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) {
   466    SmallVector<StringRef, 4> elidedAttrs;
   467    *printer << spirv::AddressOfOp::getOperationName();
   468  
   469    // Print symbol name.
   470    *printer << " @" << addressOfOp.variable();
   471  
   472    // Print the type.
   473    *printer << " : " << addressOfOp.pointer()->getType();
   474  }
   475  
   476  static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
   477    auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>();
   478    auto varOp =
   479        moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
   480    if (!varOp) {
   481      return addressOfOp.emitOpError("expected spv.globalVariable symbol");
   482    }
   483    if (addressOfOp.pointer()->getType() != varOp.type()) {
   484      return addressOfOp.emitOpError(
   485          "result type mismatch with the referenced global variable's type");
   486    }
   487    return success();
   488  }
   489  
   490  //===----------------------------------------------------------------------===//
   491  // spv.BranchOp
   492  //===----------------------------------------------------------------------===//
   493  
   494  static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *state) {
   495    Block *dest;
   496    SmallVector<Value *, 4> destOperands;
   497    if (parser->parseSuccessorAndUseList(dest, destOperands))
   498      return failure();
   499    state->addSuccessor(dest, destOperands);
   500    return success();
   501  }
   502  
   503  static void print(spirv::BranchOp branchOp, OpAsmPrinter *printer) {
   504    *printer << spirv::BranchOp::getOperationName() << ' ';
   505    printer->printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
   506  }
   507  
   508  static LogicalResult verify(spirv::BranchOp branchOp) {
   509    auto *op = branchOp.getOperation();
   510    if (op->getNumSuccessors() != 1)
   511      branchOp.emitOpError("must have exactly one successor");
   512  
   513    return success();
   514  }
   515  
   516  //===----------------------------------------------------------------------===//
   517  // spv.BranchConditionalOp
   518  //===----------------------------------------------------------------------===//
   519  
   520  static ParseResult parseBranchConditionalOp(OpAsmParser *parser,
   521                                              OperationState *state) {
   522    auto &builder = parser->getBuilder();
   523    OpAsmParser::OperandType condInfo;
   524    Block *dest;
   525    SmallVector<Value *, 4> destOperands;
   526  
   527    // Parse the condition.
   528    Type boolTy = builder.getI1Type();
   529    if (parser->parseOperand(condInfo) ||
   530        parser->resolveOperand(condInfo, boolTy, state->operands))
   531      return failure();
   532  
   533    // Parse the optional branch weights.
   534    if (succeeded(parser->parseOptionalLSquare())) {
   535      IntegerAttr trueWeight, falseWeight;
   536      SmallVector<NamedAttribute, 2> weights;
   537  
   538      auto i32Type = builder.getIntegerType(32);
   539      if (parser->parseAttribute(trueWeight, i32Type, "weight", weights) ||
   540          parser->parseComma() ||
   541          parser->parseAttribute(falseWeight, i32Type, "weight", weights) ||
   542          parser->parseRSquare())
   543        return failure();
   544  
   545      state->addAttribute(kBranchWeightAttrName,
   546                          builder.getArrayAttr({trueWeight, falseWeight}));
   547    }
   548  
   549    // Parse the true branch.
   550    if (parser->parseComma() ||
   551        parser->parseSuccessorAndUseList(dest, destOperands))
   552      return failure();
   553    state->addSuccessor(dest, destOperands);
   554  
   555    // Parse the false branch.
   556    destOperands.clear();
   557    if (parser->parseComma() ||
   558        parser->parseSuccessorAndUseList(dest, destOperands))
   559      return failure();
   560    state->addSuccessor(dest, destOperands);
   561  
   562    return success();
   563  }
   564  
   565  static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter *printer) {
   566    *printer << spirv::BranchConditionalOp::getOperationName() << ' ';
   567    printer->printOperand(branchOp.condition());
   568  
   569    if (auto weights = branchOp.branch_weights()) {
   570      *printer << " [";
   571      mlir::interleaveComma(
   572          weights->getValue(), printer->getStream(),
   573          [&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
   574      *printer << "]";
   575    }
   576  
   577    *printer << ", ";
   578    printer->printSuccessorAndUseList(branchOp.getOperation(),
   579                                      spirv::BranchConditionalOp::kTrueIndex);
   580    *printer << ", ";
   581    printer->printSuccessorAndUseList(branchOp.getOperation(),
   582                                      spirv::BranchConditionalOp::kFalseIndex);
   583  }
   584  
   585  static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
   586    auto *op = branchOp.getOperation();
   587    if (op->getNumSuccessors() != 2)
   588      return branchOp.emitOpError("must have exactly two successors");
   589  
   590    if (auto weights = branchOp.branch_weights()) {
   591      if (weights->getValue().size() != 2) {
   592        return branchOp.emitOpError("must have exactly two branch weights");
   593      }
   594      if (llvm::all_of(*weights, [](Attribute attr) {
   595            return attr.cast<IntegerAttr>().getValue().isNullValue();
   596          }))
   597        return branchOp.emitOpError("branch weights cannot both be zero");
   598    }
   599  
   600    return success();
   601  }
   602  
   603  //===----------------------------------------------------------------------===//
   604  // spv.CompositeExtractOp
   605  //===----------------------------------------------------------------------===//
   606  
   607  static ParseResult parseCompositeExtractOp(OpAsmParser *parser,
   608                                             OperationState *state) {
   609    OpAsmParser::OperandType compositeInfo;
   610    Attribute indicesAttr;
   611    Type compositeType;
   612    llvm::SMLoc attrLocation;
   613    int32_t index;
   614  
   615    if (parser->parseOperand(compositeInfo) ||
   616        parser->getCurrentLocation(&attrLocation) ||
   617        parser->parseAttribute(indicesAttr, kIndicesAttrName,
   618                               state->attributes) ||
   619        parser->parseColonType(compositeType) ||
   620        parser->resolveOperand(compositeInfo, compositeType, state->operands)) {
   621      return failure();
   622    }
   623  
   624    auto indicesArrayAttr = indicesAttr.dyn_cast<ArrayAttr>();
   625    if (!indicesArrayAttr) {
   626      return parser->emitError(
   627          attrLocation,
   628          "expected an 32-bit integer array attribute for 'indices'");
   629    }
   630  
   631    if (!indicesArrayAttr.size()) {
   632      return parser->emitError(
   633          attrLocation, "expected at least one index for spv.CompositeExtract");
   634    }
   635  
   636    Type resultType = compositeType;
   637    for (auto indexAttr : indicesArrayAttr) {
   638      if (auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>()) {
   639        index = indexIntAttr.getInt();
   640      } else {
   641        return parser->emitError(
   642                   attrLocation,
   643                   "expexted an 32-bit integer for index, but found '")
   644               << indexAttr << "'";
   645      }
   646  
   647      if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) {
   648        if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
   649          return parser->emitError(attrLocation, "index ")
   650                 << index << " out of bounds for " << resultType;
   651        }
   652        resultType = cType.getElementType(index);
   653      } else {
   654        return parser->emitError(attrLocation,
   655                                 "cannot extract from non-composite type ")
   656               << resultType << " with index " << index;
   657      }
   658    }
   659  
   660    state->addTypes(resultType);
   661    return success();
   662  }
   663  
   664  static void print(spirv::CompositeExtractOp compositeExtractOp,
   665                    OpAsmPrinter *printer) {
   666    *printer << spirv::CompositeExtractOp::getOperationName() << ' '
   667             << *compositeExtractOp.composite() << compositeExtractOp.indices()
   668             << " : " << compositeExtractOp.composite()->getType();
   669  }
   670  
   671  static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
   672    auto resultType = compExOp.composite()->getType();
   673    auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
   674  
   675    if (!indicesArrayAttr.size()) {
   676      return compExOp.emitOpError(
   677          "expexted at least one index for spv.CompositeExtractOp");
   678    }
   679  
   680    int32_t index;
   681    for (auto indexAttr : indicesArrayAttr) {
   682      index = indexAttr.dyn_cast<IntegerAttr>().getInt();
   683      if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) {
   684        if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
   685          return compExOp.emitOpError("index ")
   686                 << index << " out of bounds for " << resultType;
   687        }
   688        resultType = cType.getElementType(index);
   689      } else {
   690        return compExOp.emitError("cannot extract from non-composite type ")
   691               << resultType << " with index " << index;
   692      }
   693    }
   694  
   695    if (resultType != compExOp.getType()) {
   696      return compExOp.emitOpError("invalid result type: expected ")
   697             << resultType << " but provided " << compExOp.getType();
   698    }
   699  
   700    return success();
   701  }
   702  
   703  //===----------------------------------------------------------------------===//
   704  // spv.constant
   705  //===----------------------------------------------------------------------===//
   706  
   707  static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
   708    Attribute value;
   709    if (parser->parseAttribute(value, kValueAttrName, state->attributes))
   710      return failure();
   711  
   712    Type type;
   713    if (value.getType().isa<NoneType>()) {
   714      if (parser->parseColonType(type))
   715        return failure();
   716    } else {
   717      type = value.getType();
   718    }
   719  
   720    return parser->addTypeToList(type, state->types);
   721  }
   722  
   723  static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
   724    *printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
   725    if (constOp.getType().isa<spirv::ArrayType>()) {
   726      *printer << " : " << constOp.getType();
   727    }
   728  }
   729  
   730  static LogicalResult verify(spirv::ConstantOp constOp) {
   731    auto opType = constOp.getType();
   732    auto value = constOp.value();
   733    auto valueType = value.getType();
   734  
   735    // ODS already generates checks to make sure the result type is valid. We just
   736    // need to additionally check that the value's attribute type is consistent
   737    // with the result type.
   738    switch (value.getKind()) {
   739    case StandardAttributes::Bool:
   740    case StandardAttributes::Integer:
   741    case StandardAttributes::Float:
   742    case StandardAttributes::DenseElements:
   743    case StandardAttributes::SparseElements: {
   744      if (valueType != opType)
   745        return constOp.emitOpError("result type (")
   746               << opType << ") does not match value type (" << valueType << ")";
   747      return success();
   748    } break;
   749    case StandardAttributes::Array: {
   750      auto arrayType = opType.dyn_cast<spirv::ArrayType>();
   751      if (!arrayType)
   752        return constOp.emitOpError(
   753            "must have spv.array result type for array value");
   754      auto elemType = arrayType.getElementType();
   755      for (auto element : value.cast<ArrayAttr>().getValue()) {
   756        if (element.getType() != elemType)
   757          return constOp.emitOpError(
   758              "has array element that are not of result array element type");
   759      }
   760    } break;
   761    default:
   762      return constOp.emitOpError("cannot have value of type ") << valueType;
   763    }
   764  
   765    return success();
   766  }
   767  
   768  //===----------------------------------------------------------------------===//
   769  // spv.EntryPoint
   770  //===----------------------------------------------------------------------===//
   771  
   772  static ParseResult parseEntryPointOp(OpAsmParser *parser,
   773                                       OperationState *state) {
   774    spirv::ExecutionModel execModel;
   775    SmallVector<OpAsmParser::OperandType, 0> identifiers;
   776    SmallVector<Type, 0> idTypes;
   777  
   778    SymbolRefAttr fn;
   779    if (parseEnumAttribute(execModel, parser, state) ||
   780        parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) {
   781      return failure();
   782    }
   783  
   784    if (!parser->parseOptionalComma()) {
   785      // Parse the interface variables
   786      SmallVector<Attribute, 4> interfaceVars;
   787      do {
   788        // The name of the interface variable attribute isnt important
   789        auto attrName = "var_symbol";
   790        SymbolRefAttr var;
   791        SmallVector<NamedAttribute, 1> attrs;
   792        if (parser->parseAttribute(var, Type(), attrName, attrs)) {
   793          return failure();
   794        }
   795        interfaceVars.push_back(var);
   796      } while (!parser->parseOptionalComma());
   797      state->addAttribute(kInterfaceAttrName,
   798                          parser->getBuilder().getArrayAttr(interfaceVars));
   799    }
   800    return success();
   801  }
   802  
   803  static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) {
   804    *printer << spirv::EntryPointOp::getOperationName() << " \""
   805             << stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
   806             << entryPointOp.fn();
   807    if (auto interface = entryPointOp.interface()) {
   808      *printer << ", ";
   809      mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(),
   810                            [&](Attribute a) { printer->printAttribute(a); });
   811    }
   812  }
   813  
   814  static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
   815    // Checks for fn and interface symbol reference are done in spirv::ModuleOp
   816    // verification.
   817    return success();
   818  }
   819  
   820  //===----------------------------------------------------------------------===//
   821  // spv.ExecutionMode
   822  //===----------------------------------------------------------------------===//
   823  
   824  static ParseResult parseExecutionModeOp(OpAsmParser *parser,
   825                                          OperationState *state) {
   826    spirv::ExecutionMode execMode;
   827    Attribute fn;
   828    if (parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
   829        parseEnumAttribute(execMode, parser, state)) {
   830      return failure();
   831    }
   832  
   833    SmallVector<int32_t, 4> values;
   834    Type i32Type = parser->getBuilder().getIntegerType(32);
   835    while (!parser->parseOptionalComma()) {
   836      SmallVector<NamedAttribute, 1> attr;
   837      Attribute value;
   838      if (parser->parseAttribute(value, i32Type, "value", attr)) {
   839        return failure();
   840      }
   841      values.push_back(value.cast<IntegerAttr>().getInt());
   842    }
   843    state->addAttribute(kValuesAttrName,
   844                        parser->getBuilder().getI32ArrayAttr(values));
   845    return success();
   846  }
   847  
   848  static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
   849    *printer << spirv::ExecutionModeOp::getOperationName() << " @"
   850             << execModeOp.fn() << " \""
   851             << stringifyExecutionMode(execModeOp.execution_mode()) << "\"";
   852    auto values = execModeOp.values();
   853    if (!values) {
   854      return;
   855    }
   856    *printer << ", ";
   857    mlir::interleaveComma(
   858        values.getValue().cast<ArrayAttr>(), printer->getStream(),
   859        [&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
   860  }
   861  
   862  //===----------------------------------------------------------------------===//
   863  // spv.globalVariable
   864  //===----------------------------------------------------------------------===//
   865  
   866  static ParseResult parseGlobalVariableOp(OpAsmParser *parser,
   867                                           OperationState *state) {
   868    // Parse variable name.
   869    StringAttr nameAttr;
   870    if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
   871                                state->attributes)) {
   872      return failure();
   873    }
   874  
   875    // Parse optional initializer
   876    if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) {
   877      SymbolRefAttr initSymbol;
   878      if (parser->parseLParen() ||
   879          parser->parseAttribute(initSymbol, Type(), kInitializerAttrName,
   880                                 state->attributes) ||
   881          parser->parseRParen())
   882        return failure();
   883    }
   884  
   885    if (parseVariableDecorations(parser, state)) {
   886      return failure();
   887    }
   888  
   889    Type type;
   890    auto loc = parser->getCurrentLocation();
   891    if (parser->parseColonType(type)) {
   892      return failure();
   893    }
   894    if (!type.isa<spirv::PointerType>()) {
   895      return parser->emitError(loc, "expected spv.ptr type");
   896    }
   897    state->addAttribute(kTypeAttrName, parser->getBuilder().getTypeAttr(type));
   898  
   899    return success();
   900  }
   901  
   902  static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
   903    auto *op = varOp.getOperation();
   904    SmallVector<StringRef, 4> elidedAttrs{
   905        spirv::attributeName<spirv::StorageClass>()};
   906    *printer << spirv::GlobalVariableOp::getOperationName();
   907  
   908    // Print variable name.
   909    *printer << " @" << varOp.sym_name();
   910    elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
   911  
   912    // Print optional initializer
   913    if (auto initializer = varOp.initializer()) {
   914      *printer << " " << kInitializerAttrName << "(@" << initializer.getValue()
   915               << ")";
   916      elidedAttrs.push_back(kInitializerAttrName);
   917    }
   918  
   919    elidedAttrs.push_back(kTypeAttrName);
   920    printVariableDecorations(op, printer, elidedAttrs);
   921    *printer << " : " << varOp.type();
   922  }
   923  
   924  static LogicalResult verify(spirv::GlobalVariableOp varOp) {
   925    // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
   926    // object. It cannot be Generic. It must be the same as the Storage Class
   927    // operand of the Result Type."
   928    if (varOp.storageClass() == spirv::StorageClass::Generic)
   929      return varOp.emitOpError("storage class cannot be 'Generic'");
   930  
   931    if (auto init = varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
   932      auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
   933      auto *initOp = moduleOp.lookupSymbol(init.getValue());
   934      // TODO: Currently only variable initialization with specialization
   935      // constants and other variables is supported. They could be normal
   936      // constants in the module scope as well.
   937      if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) ||
   938                       isa<spirv::SpecConstantOp>(initOp))) {
   939        return varOp.emitOpError("initializer must be result of a "
   940                                 "spv.specConstant or spv.globalVariable op");
   941      }
   942    }
   943  
   944    return success();
   945  }
   946  
   947  //===----------------------------------------------------------------------===//
   948  // spv.LoadOp
   949  //===----------------------------------------------------------------------===//
   950  
   951  static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) {
   952    // Parse the storage class specification
   953    spirv::StorageClass storageClass;
   954    OpAsmParser::OperandType ptrInfo;
   955    Type elementType;
   956    if (parseEnumAttribute(storageClass, parser) ||
   957        parser->parseOperand(ptrInfo) ||
   958        parseMemoryAccessAttributes(parser, state) ||
   959        parser->parseOptionalAttributeDict(state->attributes) ||
   960        parser->parseColon() || parser->parseType(elementType)) {
   961      return failure();
   962    }
   963  
   964    auto ptrType = spirv::PointerType::get(elementType, storageClass);
   965    if (parser->resolveOperand(ptrInfo, ptrType, state->operands)) {
   966      return failure();
   967    }
   968  
   969    state->addTypes(elementType);
   970    return success();
   971  }
   972  
   973  static void print(spirv::LoadOp loadOp, OpAsmPrinter *printer) {
   974    auto *op = loadOp.getOperation();
   975    SmallVector<StringRef, 4> elidedAttrs;
   976    StringRef sc = stringifyStorageClass(
   977        loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
   978    *printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
   979    // Print the pointer operand.
   980    printer->printOperand(loadOp.ptr());
   981  
   982    printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
   983  
   984    printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
   985    *printer << " : " << loadOp.getType();
   986  }
   987  
   988  static LogicalResult verify(spirv::LoadOp loadOp) {
   989    // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
   990    // type with fixed size; i.e., it cannot be, nor include, any
   991    // OpTypeRuntimeArray types."
   992    if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
   993                                             loadOp.value()))) {
   994      return failure();
   995    }
   996    return verifyMemoryAccessAttribute(loadOp);
   997  }
   998  
   999  //===----------------------------------------------------------------------===//
  1000  // spv.module
  1001  //===----------------------------------------------------------------------===//
  1002  
  1003  void spirv::ModuleOp::build(Builder *builder, OperationState *state) {
  1004    ensureTerminator(*state->addRegion(), *builder, state->location);
  1005  }
  1006  
  1007  void spirv::ModuleOp::build(Builder *builder, OperationState *state,
  1008                              IntegerAttr addressing_model,
  1009                              IntegerAttr memory_model, ArrayAttr capabilities,
  1010                              ArrayAttr extensions,
  1011                              ArrayAttr extended_instruction_sets) {
  1012    state->addAttribute("addressing_model", addressing_model);
  1013    state->addAttribute("memory_model", memory_model);
  1014    if (capabilities)
  1015      state->addAttribute("capabilities", capabilities);
  1016    if (extensions)
  1017      state->addAttribute("extensions", extensions);
  1018    if (extended_instruction_sets)
  1019      state->addAttribute("extended_instruction_sets", extended_instruction_sets);
  1020    ensureTerminator(*state->addRegion(), *builder, state->location);
  1021  }
  1022  
  1023  static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) {
  1024    Region *body = state->addRegion();
  1025  
  1026    // Parse attributes
  1027    spirv::AddressingModel addrModel;
  1028    spirv::MemoryModel memoryModel;
  1029    if (parseEnumAttribute(addrModel, parser, state) ||
  1030        parseEnumAttribute(memoryModel, parser, state)) {
  1031      return failure();
  1032    }
  1033  
  1034    if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
  1035      return failure();
  1036  
  1037    if (succeeded(parser->parseOptionalKeyword("attributes"))) {
  1038      if (parser->parseOptionalAttributeDict(state->attributes))
  1039        return failure();
  1040    }
  1041  
  1042    spirv::ModuleOp::ensureTerminator(*body, parser->getBuilder(),
  1043                                      state->location);
  1044    return success();
  1045  }
  1046  
  1047  static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
  1048    auto *op = moduleOp.getOperation();
  1049  
  1050    // Only print out addressing model and memory model in a nicer way if both
  1051    // presents. Otherwise, print them in the general form. This helps debugging
  1052    // ill-formed ModuleOp.
  1053    SmallVector<StringRef, 2> elidedAttrs;
  1054    auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
  1055    auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
  1056    if (op->getAttr(addressingModelAttrName) &&
  1057        op->getAttr(memoryModelAttrName)) {
  1058      *printer << spirv::ModuleOp::getOperationName() << " \""
  1059               << spirv::stringifyAddressingModel(moduleOp.addressing_model())
  1060               << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
  1061               << '"';
  1062      elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
  1063    }
  1064  
  1065    printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
  1066                         /*printBlockTerminators=*/false);
  1067  
  1068    bool printAttrDict =
  1069        elidedAttrs.size() != 2 ||
  1070        llvm::any_of(op->getAttrs(), [&addressingModelAttrName,
  1071                                      &memoryModelAttrName](NamedAttribute attr) {
  1072          return attr.first != addressingModelAttrName &&
  1073                 attr.first != memoryModelAttrName;
  1074        });
  1075  
  1076    if (printAttrDict) {
  1077      *printer << " attributes";
  1078      printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
  1079    }
  1080  }
  1081  
  1082  static LogicalResult verify(spirv::ModuleOp moduleOp) {
  1083    auto &op = *moduleOp.getOperation();
  1084    auto *dialect = op.getDialect();
  1085    auto &body = op.getRegion(0).front();
  1086    llvm::DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
  1087        entryPoints;
  1088    SymbolTable table(moduleOp);
  1089  
  1090    for (auto &op : body) {
  1091      if (op.getDialect() == dialect) {
  1092        // For EntryPoint op, check that the function and execution model is not
  1093        // duplicated in EntryPointOps. Also verify that the interface specified
  1094        // comes from globalVariables here to make this check cheaper.
  1095        if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
  1096          auto funcOp = table.lookup<FuncOp>(entryPointOp.fn());
  1097          if (!funcOp) {
  1098            return entryPointOp.emitError("function '")
  1099                   << entryPointOp.fn() << "' not found in 'spv.module'";
  1100          }
  1101          if (auto interface = entryPointOp.interface()) {
  1102            for (auto varRef : interface.getValue().getValue()) {
  1103              auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
  1104              if (!varSymRef) {
  1105                return entryPointOp.emitError(
  1106                           "expected symbol reference for interface "
  1107                           "specification instead of '")
  1108                       << varRef;
  1109              }
  1110              auto variableOp =
  1111                  table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
  1112              if (!variableOp) {
  1113                return entryPointOp.emitError("expected spv.globalVariable "
  1114                                              "symbol reference instead of'")
  1115                       << varSymRef << "'";
  1116              }
  1117            }
  1118          }
  1119  
  1120          auto key = std::pair<FuncOp, spirv::ExecutionModel>(
  1121              funcOp, entryPointOp.execution_model());
  1122          auto entryPtIt = entryPoints.find(key);
  1123          if (entryPtIt != entryPoints.end()) {
  1124            return entryPointOp.emitError("duplicate of a previous EntryPointOp");
  1125          }
  1126          entryPoints[key] = entryPointOp;
  1127        }
  1128        continue;
  1129      }
  1130  
  1131      auto funcOp = dyn_cast<FuncOp>(op);
  1132      if (!funcOp)
  1133        return op.emitError("'spv.module' can only contain func and spv.* ops");
  1134  
  1135      if (funcOp.isExternal())
  1136        return op.emitError("'spv.module' cannot contain external functions");
  1137  
  1138      for (auto &block : funcOp)
  1139        for (auto &op : block) {
  1140          if (op.getDialect() == dialect)
  1141            continue;
  1142  
  1143          if (isa<FuncOp>(op))
  1144            return op.emitError("'spv.module' cannot contain nested functions");
  1145  
  1146          return op.emitError(
  1147              "functions in 'spv.module' can only contain spv.* ops");
  1148        }
  1149    }
  1150  
  1151    // Verify capabilities. ODS already guarantees that we have an array of
  1152    // string attributes.
  1153    if (auto caps = moduleOp.getAttrOfType<ArrayAttr>("capabilities")) {
  1154      for (auto cap : caps.getValue()) {
  1155        auto capStr = cap.cast<StringAttr>().getValue();
  1156        if (!spirv::symbolizeCapability(capStr))
  1157          return moduleOp.emitOpError("uses unknown capability: ") << capStr;
  1158      }
  1159    }
  1160  
  1161    // Verify extensions. ODS already guarantees that we have an array of
  1162    // string attributes.
  1163    if (auto exts = moduleOp.getAttrOfType<ArrayAttr>("extensions")) {
  1164      for (auto ext : exts.getValue()) {
  1165        auto extStr = ext.cast<StringAttr>().getValue();
  1166        if (!spirv::symbolizeExtension(extStr))
  1167          return moduleOp.emitOpError("uses unknown extension: ") << extStr;
  1168      }
  1169    }
  1170  
  1171    return success();
  1172  }
  1173  
  1174  //===----------------------------------------------------------------------===//
  1175  // spv._reference_of
  1176  //===----------------------------------------------------------------------===//
  1177  
  1178  static ParseResult parseReferenceOfOp(OpAsmParser *parser,
  1179                                        OperationState *state) {
  1180    SymbolRefAttr constRefAttr;
  1181    Type type;
  1182    if (parser->parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
  1183                               state->attributes) ||
  1184        parser->parseColonType(type)) {
  1185      return failure();
  1186    }
  1187    return parser->addTypeToList(type, state->types);
  1188  }
  1189  
  1190  static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter *printer) {
  1191    *printer << spirv::ReferenceOfOp::getOperationName() << " @"
  1192             << referenceOfOp.spec_const() << " : "
  1193             << referenceOfOp.reference()->getType();
  1194  }
  1195  
  1196  static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
  1197    auto moduleOp = referenceOfOp.getParentOfType<spirv::ModuleOp>();
  1198    auto specConstOp =
  1199        moduleOp.lookupSymbol<spirv::SpecConstantOp>(referenceOfOp.spec_const());
  1200    if (!specConstOp) {
  1201      return referenceOfOp.emitOpError("expected spv.specConstant symbol");
  1202    }
  1203    if (referenceOfOp.reference()->getType() !=
  1204        specConstOp.default_value().getType()) {
  1205      return referenceOfOp.emitOpError("result type mismatch with the referenced "
  1206                                       "specialization constant's type");
  1207    }
  1208    return success();
  1209  }
  1210  
  1211  //===----------------------------------------------------------------------===//
  1212  // spv.Return
  1213  //===----------------------------------------------------------------------===//
  1214  
  1215  static LogicalResult verify(spirv::ReturnOp returnOp) {
  1216    auto funcOp = cast<FuncOp>(returnOp.getParentOp());
  1217    auto numOutputs = funcOp.getType().getNumResults();
  1218    if (numOutputs != 0)
  1219      return returnOp.emitOpError("cannot be used in functions returning value")
  1220             << (numOutputs > 1 ? "s" : "");
  1221  
  1222    return success();
  1223  }
  1224  
  1225  //===----------------------------------------------------------------------===//
  1226  // spv.ReturnValue
  1227  //===----------------------------------------------------------------------===//
  1228  
  1229  static ParseResult parseReturnValueOp(OpAsmParser *parser,
  1230                                        OperationState *state) {
  1231    OpAsmParser::OperandType retValInfo;
  1232    Type retValType;
  1233    return failure(
  1234        parser->parseOperand(retValInfo) || parser->parseColonType(retValType) ||
  1235        parser->resolveOperand(retValInfo, retValType, state->operands));
  1236  }
  1237  
  1238  static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) {
  1239    *printer << spirv::ReturnValueOp::getOperationName() << ' ';
  1240    printer->printOperand(retValOp.value());
  1241    *printer << " : " << retValOp.value()->getType();
  1242  }
  1243  
  1244  static LogicalResult verify(spirv::ReturnValueOp retValOp) {
  1245    auto funcOp = cast<FuncOp>(retValOp.getParentOp());
  1246    auto numFnResults = funcOp.getType().getNumResults();
  1247    if (numFnResults != 1)
  1248      return retValOp.emitOpError(
  1249                 "returns 1 value but enclosing function requires ")
  1250             << numFnResults << " results";
  1251  
  1252    auto operandType = retValOp.value()->getType();
  1253    auto fnResultType = funcOp.getType().getResult(0);
  1254    if (operandType != fnResultType)
  1255      return retValOp.emitOpError(" return value's type (")
  1256             << operandType << ") mismatch with function's result type ("
  1257             << fnResultType << ")";
  1258  
  1259    return success();
  1260  }
  1261  
  1262  //===----------------------------------------------------------------------===//
  1263  // spv.specConstant
  1264  //===----------------------------------------------------------------------===//
  1265  
  1266  static ParseResult parseSpecConstantOp(OpAsmParser *parser,
  1267                                         OperationState *state) {
  1268    StringAttr nameAttr;
  1269    Attribute valueAttr;
  1270  
  1271    if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
  1272                                state->attributes) ||
  1273        parser->parseEqual() ||
  1274        parser->parseAttribute(valueAttr, kDefaultValueAttrName,
  1275                               state->attributes))
  1276      return failure();
  1277  
  1278    return success();
  1279  }
  1280  
  1281  static void print(spirv::SpecConstantOp constOp, OpAsmPrinter *printer) {
  1282    *printer << spirv::SpecConstantOp::getOperationName() << " @"
  1283             << constOp.sym_name() << " = ";
  1284    printer->printAttribute(constOp.default_value());
  1285  }
  1286  
  1287  static LogicalResult verify(spirv::SpecConstantOp constOp) {
  1288    auto value = constOp.default_value();
  1289  
  1290    switch (value.getKind()) {
  1291    case StandardAttributes::Bool:
  1292    case StandardAttributes::Integer:
  1293    case StandardAttributes::Float: {
  1294      // Make sure bitwidth is allowed.
  1295      auto *dialect = static_cast<spirv::SPIRVDialect *>(constOp.getDialect());
  1296      if (!dialect->isValidSPIRVType(value.getType()))
  1297        return constOp.emitOpError("default value bitwidth disallowed");
  1298      return success();
  1299    }
  1300    default:
  1301      return constOp.emitOpError(
  1302          "default value can only be a bool, integer, or float scalar");
  1303    }
  1304  }
  1305  
  1306  //===----------------------------------------------------------------------===//
  1307  // spv.StoreOp
  1308  //===----------------------------------------------------------------------===//
  1309  
  1310  static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) {
  1311    // Parse the storage class specification
  1312    spirv::StorageClass storageClass;
  1313    SmallVector<OpAsmParser::OperandType, 2> operandInfo;
  1314    auto loc = parser->getCurrentLocation();
  1315    Type elementType;
  1316    if (parseEnumAttribute(storageClass, parser) ||
  1317        parser->parseOperandList(operandInfo, 2) ||
  1318        parseMemoryAccessAttributes(parser, state) || parser->parseColon() ||
  1319        parser->parseType(elementType)) {
  1320      return failure();
  1321    }
  1322  
  1323    auto ptrType = spirv::PointerType::get(elementType, storageClass);
  1324    if (parser->resolveOperands(operandInfo, {ptrType, elementType}, loc,
  1325                                state->operands)) {
  1326      return failure();
  1327    }
  1328    return success();
  1329  }
  1330  
  1331  static void print(spirv::StoreOp storeOp, OpAsmPrinter *printer) {
  1332    auto *op = storeOp.getOperation();
  1333    SmallVector<StringRef, 4> elidedAttrs;
  1334    StringRef sc = stringifyStorageClass(
  1335        storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
  1336    *printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
  1337    // Print the pointer operand
  1338    printer->printOperand(storeOp.ptr());
  1339    *printer << ", ";
  1340    // Print the value operand
  1341    printer->printOperand(storeOp.value());
  1342  
  1343    printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
  1344  
  1345    *printer << " : " << storeOp.value()->getType();
  1346  
  1347    printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
  1348  }
  1349  
  1350  static LogicalResult verify(spirv::StoreOp storeOp) {
  1351    // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
  1352    // OpTypePointer whose Type operand is the same as the type of Object."
  1353    if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
  1354                                             storeOp.value()))) {
  1355      return failure();
  1356    }
  1357    return verifyMemoryAccessAttribute(storeOp);
  1358  }
  1359  
  1360  //===----------------------------------------------------------------------===//
  1361  // spv.Variable
  1362  //===----------------------------------------------------------------------===//
  1363  
  1364  static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
  1365    // Parse optional initializer
  1366    Optional<OpAsmParser::OperandType> initInfo;
  1367    if (succeeded(parser->parseOptionalKeyword("init"))) {
  1368      initInfo = OpAsmParser::OperandType();
  1369      if (parser->parseLParen() || parser->parseOperand(*initInfo) ||
  1370          parser->parseRParen())
  1371        return failure();
  1372    }
  1373  
  1374    if (parseVariableDecorations(parser, state)) {
  1375      return failure();
  1376    }
  1377  
  1378    // Parse result pointer type
  1379    Type type;
  1380    if (parser->parseColon())
  1381      return failure();
  1382    auto loc = parser->getCurrentLocation();
  1383    if (parser->parseType(type))
  1384      return failure();
  1385  
  1386    auto ptrType = type.dyn_cast<spirv::PointerType>();
  1387    if (!ptrType)
  1388      return parser->emitError(loc, "expected spv.ptr type");
  1389    state->addTypes(ptrType);
  1390  
  1391    // Resolve the initializer operand
  1392    SmallVector<Value *, 1> init;
  1393    if (initInfo) {
  1394      if (parser->resolveOperand(*initInfo, ptrType.getPointeeType(), init))
  1395        return failure();
  1396      state->addOperands(init);
  1397    }
  1398  
  1399    auto attr = parser->getBuilder().getI32IntegerAttr(
  1400        bitwiseCast<int32_t>(ptrType.getStorageClass()));
  1401    state->addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
  1402  
  1403    return success();
  1404  }
  1405  
  1406  static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
  1407    auto *op = varOp.getOperation();
  1408    SmallVector<StringRef, 4> elidedAttrs{
  1409        spirv::attributeName<spirv::StorageClass>()};
  1410    *printer << spirv::VariableOp::getOperationName();
  1411  
  1412    // Print optional initializer
  1413    if (op->getNumOperands() > 0) {
  1414      *printer << " init(";
  1415      printer->printOperands(varOp.initializer());
  1416      *printer << ")";
  1417    }
  1418  
  1419    printVariableDecorations(op, printer, elidedAttrs);
  1420  
  1421    *printer << " : " << varOp.getType();
  1422  }
  1423  
  1424  static LogicalResult verify(spirv::VariableOp varOp) {
  1425    // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
  1426    // object. It cannot be Generic. It must be the same as the Storage Class
  1427    // operand of the Result Type."
  1428    if (varOp.storage_class() != spirv::StorageClass::Function) {
  1429      return varOp.emitOpError(
  1430          "can only be used to model function-level variables. Use "
  1431          "spv.globalVariable for module-level variables.");
  1432    }
  1433  
  1434    auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
  1435    if (varOp.storage_class() != pointerType.getStorageClass())
  1436      return varOp.emitOpError(
  1437          "storage class must match result pointer's storage class");
  1438  
  1439    if (varOp.getNumOperands() != 0) {
  1440      // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
  1441      // a global (module scope) OpVariable instruction".
  1442      auto *initOp = varOp.getOperand(0)->getDefiningOp();
  1443      if (!initOp || !(isa<spirv::ConstantOp>(initOp) ||    // for normal constant
  1444                       isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
  1445                       isa<spirv::AddressOfOp>(initOp)))
  1446        return varOp.emitOpError("initializer must be the result of a "
  1447                                 "constant or spv.globalVariable op");
  1448    }
  1449  
  1450    // TODO(antiagainst): generate these strings using ODS.
  1451    auto *op = varOp.getOperation();
  1452    auto descriptorSetName =
  1453        convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
  1454    auto bindingName =
  1455        convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
  1456    auto builtInName =
  1457        convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
  1458  
  1459    for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
  1460      if (op->getAttr(attr))
  1461        return varOp.emitOpError("cannot have '")
  1462               << attr << "' attribute (only allowed in spv.globalVariable)";
  1463    }
  1464  
  1465    return success();
  1466  }
  1467  
  1468  namespace mlir {
  1469  namespace spirv {
  1470  
  1471  #define GET_OP_CLASSES
  1472  #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
  1473  
  1474  } // namespace spirv
  1475  } // namespace mlir