github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp (about)

     1  //===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file defines the SPIR-V binary to MLIR SPIR-V module deseralization.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/Dialect/SPIRV/Serialization.h"
    23  
    24  #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
    25  #include "mlir/Dialect/SPIRV/SPIRVOps.h"
    26  #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
    27  #include "mlir/IR/Builders.h"
    28  #include "mlir/IR/Location.h"
    29  #include "mlir/Support/LogicalResult.h"
    30  #include "mlir/Support/StringExtras.h"
    31  #include "llvm/ADT/Sequence.h"
    32  #include "llvm/ADT/SetVector.h"
    33  #include "llvm/ADT/SmallVector.h"
    34  #include "llvm/ADT/bit.h"
    35  
    36  using namespace mlir;
    37  
    38  // Decodes a string literal in `words` starting at `wordIndex`. Update the
    39  // latter to point to the position in words after the string literal.
    40  static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
    41                                              unsigned &wordIndex) {
    42    StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
    43    wordIndex += str.size() / 4 + 1;
    44    return str;
    45  }
    46  
    47  // Extracts the opcode from the given first word of a SPIR-V instruction.
    48  static inline spirv::Opcode extractOpcode(uint32_t word) {
    49    return static_cast<spirv::Opcode>(word & 0xffff);
    50  }
    51  
    52  namespace {
    53  /// A SPIR-V module serializer.
    54  ///
    55  /// A SPIR-V binary module is a single linear stream of instructions; each
    56  /// instruction is composed of 32-bit words. The first word of an instruction
    57  /// records the total number of words of that instruction using the 16
    58  /// higher-order bits. So this deserializer uses that to get instruction
    59  /// boundary and parse instructions and build a SPIR-V ModuleOp gradually.
    60  ///
    61  // TODO(antiagainst): clean up created ops on errors
    62  class Deserializer {
    63  public:
    64    /// Creates a deserializer for the given SPIR-V `binary` module.
    65    /// The SPIR-V ModuleOp will be created into `context.
    66    explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context);
    67  
    68    /// Deserializes the remembered SPIR-V binary module.
    69    LogicalResult deserialize();
    70  
    71    /// Collects the final SPIR-V ModuleOp.
    72    Optional<spirv::ModuleOp> collect();
    73  
    74  private:
    75    //===--------------------------------------------------------------------===//
    76    // Module structure
    77    //===--------------------------------------------------------------------===//
    78  
    79    /// Initializes the `module` ModuleOp in this deserializer instance.
    80    spirv::ModuleOp createModuleOp();
    81  
    82    /// Processes SPIR-V module header in `binary`.
    83    LogicalResult processHeader();
    84  
    85    /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping
    86    /// in the deserializer.
    87    LogicalResult processCapability(ArrayRef<uint32_t> operands);
    88  
    89    /// Attaches all collected capabilites to `module` as an attribute.
    90    void attachCapabilities();
    91  
    92    /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping
    93    /// in the deserializer.
    94    LogicalResult processExtension(ArrayRef<uint32_t> operands);
    95  
    96    /// Attaches all collected extensions to `module` as an attribute.
    97    void attachExtensions();
    98  
    99    /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
   100    LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
   101  
   102    /// Process SPIR-V OpName with `operands`.
   103    LogicalResult processName(ArrayRef<uint32_t> operands);
   104  
   105    /// Method to process an OpDecorate instruction.
   106    LogicalResult processDecoration(ArrayRef<uint32_t> words);
   107  
   108    // Method to process an OpMemberDecorate instruction.
   109    LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
   110  
   111    /// Gets the FuncOp associated with a result <id> of OpFunction.
   112    FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
   113  
   114    /// Processes the SPIR-V function at the current `offset` into `binary`.
   115    /// The operands to the OpFunction instruction is passed in as ``operands`.
   116    /// This method processes each instruction inside the function and dispatches
   117    /// them to their handler method accordingly.
   118    LogicalResult processFunction(ArrayRef<uint32_t> operands);
   119  
   120    /// Gets the constant's attribute and type associated with the given <id>.
   121    Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
   122  
   123    /// Returns a symbol to be used for the specialization constant with the given
   124    /// result <id>. This tries to use the specialization constant's OpName if
   125    /// exists; otherwise creates one based on the <id>.
   126    std::string getSpecConstantSymbol(uint32_t id);
   127  
   128    /// Gets the specialization constant with the given result <id>.
   129    spirv::SpecConstantOp getSpecConstant(uint32_t id) {
   130      return specConstMap.lookup(id);
   131    }
   132  
   133    /// Processes the OpVariable instructions at current `offset` into `binary`.
   134    /// It is expected that this method is used for variables that are to be
   135    /// defined at module scope and will be deserialized into a spv.globalVariable
   136    /// instruction.
   137    LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
   138  
   139    /// Gets the global variable associated with a result <id> of OpVariable.
   140    spirv::GlobalVariableOp getGlobalVariable(uint32_t id) {
   141      return globalVariableMap.lookup(id);
   142    }
   143  
   144    //===--------------------------------------------------------------------===//
   145    // Type
   146    //===--------------------------------------------------------------------===//
   147  
   148    /// Gets type for a given result <id>.
   149    Type getType(uint32_t id) { return typeMap.lookup(id); }
   150  
   151    /// Returns true if the given `type` is for SPIR-V void type.
   152    bool isVoidType(Type type) const { return type.isa<NoneType>(); }
   153  
   154    /// Processes a SPIR-V type instruction with given `opcode` and `operands` and
   155    /// registers the type into `module`.
   156    LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
   157  
   158    LogicalResult processArrayType(ArrayRef<uint32_t> operands);
   159  
   160    LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
   161  
   162    LogicalResult processStructType(ArrayRef<uint32_t> operands);
   163  
   164    //===--------------------------------------------------------------------===//
   165    // Constant
   166    //===--------------------------------------------------------------------===//
   167  
   168    /// Processes a SPIR-V Op{|Spec}Constant instruction with the given
   169    /// `operands`. `isSpec` indicates whether this is a specialization constant.
   170    LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
   171  
   172    /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
   173    /// given `operands`. `isSpec` indicates whether this is a specialization
   174    /// constant.
   175    LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
   176                                      bool isSpec);
   177  
   178    /// Processes a SPIR-V OpConstantComposite instruction with the given
   179    /// `operands`.
   180    LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
   181  
   182    /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   183    LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
   184  
   185    //===--------------------------------------------------------------------===//
   186    // Control flow
   187    //===--------------------------------------------------------------------===//
   188  
   189    /// Processes a SPIR-V OpLabel instruction with the given `operands`.
   190    LogicalResult processLabel(ArrayRef<uint32_t> operands);
   191  
   192    //===--------------------------------------------------------------------===//
   193    // Instruction
   194    //===--------------------------------------------------------------------===//
   195  
   196    /// Get the Value associated with a result <id>.
   197    ///
   198    /// This method materializes normal constants and inserts "casting" ops
   199    /// (`spv._address_of` and `spv._reference_of`) to turn an symbol into a SSA
   200    /// value for handling uses of module scope constants/variables in functions.
   201    Value *getValue(uint32_t id);
   202  
   203    /// Slices the first instruction out of `binary` and returns its opcode and
   204    /// operands via `opcode` and `operands` respectively. Returns failure if
   205    /// there is no more remaining instructions (`expectedOpcode` will be used to
   206    /// compose the error message) or the next instruction is malformed.
   207    LogicalResult
   208    sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
   209                     Optional<spirv::Opcode> expectedOpcode = llvm::None);
   210  
   211    /// Returns the next instruction's opcode if exists.
   212    Optional<spirv::Opcode> peekOpcode();
   213  
   214    /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
   215    /// This method is the main entrance for handling SPIR-V instruction; it
   216    /// checks the instruction opcode and dispatches to the corresponding handler.
   217    /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode)
   218    /// might need to be defered, since they contain forward references to <id>s
   219    /// in the deserialized binary, but module in SPIR-V dialect expects these to
   220    /// be ssa-uses.
   221    LogicalResult processInstruction(spirv::Opcode opcode,
   222                                     ArrayRef<uint32_t> operands,
   223                                     bool deferInstructions = true);
   224  
   225    /// Method to dispatch to the specialized deserialization function for an
   226    /// operation in SPIR-V dialect that is a mirror of an instruction in the
   227    /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
   228    /// all operations in SPIR-V dialect that have hasOpcode == 1.
   229    LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
   230                                                   ArrayRef<uint32_t> words);
   231  
   232    /// Method to deserialize an operation in the SPIR-V dialect that is a mirror
   233    /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
   234    /// == 1 and autogenSerialization == 1 in ODS.
   235    template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
   236      return emitError(unknownLoc, "unsupported deserialization for ")
   237             << OpTy::getOperationName() << " op";
   238    }
   239  
   240  private:
   241    /// The SPIR-V binary module.
   242    ArrayRef<uint32_t> binary;
   243  
   244    /// The current word offset into the binary module.
   245    unsigned curOffset = 0;
   246  
   247    /// MLIRContext to create SPIR-V ModuleOp into.
   248    MLIRContext *context;
   249  
   250    // TODO(antiagainst): create Location subclass for binary blob
   251    Location unknownLoc;
   252  
   253    /// The SPIR-V ModuleOp.
   254    Optional<spirv::ModuleOp> module;
   255  
   256    OpBuilder opBuilder;
   257  
   258    /// The list of capabilities used by the module.
   259    llvm::SmallSetVector<spirv::Capability, 4> capabilities;
   260  
   261    /// The list of extensions used by the module.
   262    llvm::SmallSetVector<StringRef, 2> extensions;
   263  
   264    // Result <id> to type mapping.
   265    DenseMap<uint32_t, Type> typeMap;
   266  
   267    // Result <id> to constant attribute and type mapping.
   268    ///
   269    /// In the SPIR-V binary format, all constants are placed in the module and
   270    /// shared by instructions at module level and in subsequent functions. But in
   271    /// the SPIR-V dialect, we materialize the constant to where it's used in the
   272    /// function. So when seeing a constant instruction in the binary format, we
   273    /// don't immediately emit a constant op into the module, we keep its value
   274    /// (and type) here. Later when it's used, we materialize the constant.
   275    DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
   276  
   277    // Result <id> to variable mapping.
   278    DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
   279  
   280    // Result <id> to variable mapping.
   281    DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
   282  
   283    // Result <id> to function mapping.
   284    DenseMap<uint32_t, FuncOp> funcMap;
   285  
   286    // Result <id> to value mapping.
   287    DenseMap<uint32_t, Value *> valueMap;
   288  
   289    // Result <id> to name mapping.
   290    DenseMap<uint32_t, StringRef> nameMap;
   291  
   292    // Result <id> to decorations mapping.
   293    DenseMap<uint32_t, NamedAttributeList> decorations;
   294  
   295    // Result <id> to type decorations.
   296    DenseMap<uint32_t, uint32_t> typeDecorations;
   297  
   298    // Result <id> to member decorations.
   299    DenseMap<uint32_t, DenseMap<uint32_t, uint32_t>> memberDecorationMap;
   300  
   301    // List of instructions that are processed in a defered fashion (after an
   302    // initial processing of the entire binary). Some operations like
   303    // OpEntryPoint, and OpExecutionMode use forward references to function
   304    // <id>s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and
   305    // spv.ExecutionMode) need these references resolved. So these instructions
   306    // are deserialized and stored for processing once the entire binary is
   307    // processed.
   308    SmallVector<std::pair<spirv::Opcode, ArrayRef<uint32_t>>, 4>
   309        deferedInstructions;
   310  };
   311  } // namespace
   312  
   313  Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
   314      : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
   315        module(createModuleOp()),
   316        opBuilder(module->getOperation()->getRegion(0)) {}
   317  
   318  LogicalResult Deserializer::deserialize() {
   319    if (failed(processHeader()))
   320      return failure();
   321  
   322    spirv::Opcode opcode = spirv::Opcode::OpNop;
   323    ArrayRef<uint32_t> operands;
   324    auto binarySize = binary.size();
   325    while (curOffset < binarySize) {
   326      // Slice the next instruction out and populate `opcode` and `operands`.
   327      // Interally this also updates `curOffset`.
   328      if (failed(sliceInstruction(opcode, operands)))
   329        return failure();
   330  
   331      if (failed(processInstruction(opcode, operands)))
   332        return failure();
   333    }
   334  
   335    assert(curOffset == binarySize &&
   336           "deserializer should never index beyond the binary end");
   337  
   338    for (auto &defered : deferedInstructions) {
   339      if (failed(processInstruction(defered.first, defered.second, false))) {
   340        return failure();
   341      }
   342    }
   343  
   344    // Attaches the capabilities/extensions as an attribute to the module.
   345    attachCapabilities();
   346    attachExtensions();
   347  
   348    return success();
   349  }
   350  
   351  Optional<spirv::ModuleOp> Deserializer::collect() { return module; }
   352  
   353  //===----------------------------------------------------------------------===//
   354  // Module structure
   355  //===----------------------------------------------------------------------===//
   356  
   357  spirv::ModuleOp Deserializer::createModuleOp() {
   358    Builder builder(context);
   359    OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
   360    // TODO(antiagainst): use target environment to select the version
   361    state.addAttribute("major_version", builder.getI32IntegerAttr(1));
   362    state.addAttribute("minor_version", builder.getI32IntegerAttr(0));
   363    spirv::ModuleOp::build(&builder, &state);
   364    return cast<spirv::ModuleOp>(Operation::create(state));
   365  }
   366  
   367  LogicalResult Deserializer::processHeader() {
   368    if (binary.size() < spirv::kHeaderWordCount)
   369      return emitError(unknownLoc,
   370                       "SPIR-V binary module must have a 5-word header");
   371  
   372    if (binary[0] != spirv::kMagicNumber)
   373      return emitError(unknownLoc, "incorrect magic number");
   374  
   375    // TODO(antiagainst): generator number, bound, schema
   376    curOffset = spirv::kHeaderWordCount;
   377    return success();
   378  }
   379  
   380  LogicalResult Deserializer::processCapability(ArrayRef<uint32_t> operands) {
   381    if (operands.size() != 1)
   382      return emitError(unknownLoc, "OpMemoryModel must have one parameter");
   383  
   384    auto cap = spirv::symbolizeCapability(operands[0]);
   385    if (!cap)
   386      return emitError(unknownLoc, "unknown capability: ") << operands[0];
   387  
   388    capabilities.insert(*cap);
   389    return success();
   390  }
   391  
   392  void Deserializer::attachCapabilities() {
   393    if (capabilities.empty())
   394      return;
   395  
   396    SmallVector<StringRef, 2> caps;
   397    caps.reserve(capabilities.size());
   398  
   399    for (auto cap : capabilities) {
   400      caps.push_back(spirv::stringifyCapability(cap));
   401    }
   402  
   403    module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps));
   404  }
   405  
   406  LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> operands) {
   407    if (operands.empty()) {
   408      return emitError(
   409          unknownLoc,
   410          "OpExtension must have a literal string for the extension name");
   411    }
   412  
   413    unsigned wordIndex = 0;
   414    StringRef extName = decodeStringLiteral(operands, wordIndex);
   415    if (wordIndex != operands.size()) {
   416      return emitError(unknownLoc,
   417                       "unexpected trailing words in OpExtension instruction");
   418    }
   419  
   420    extensions.insert(extName);
   421    return success();
   422  }
   423  
   424  void Deserializer::attachExtensions() {
   425    if (extensions.empty())
   426      return;
   427  
   428    module->setAttr("extensions",
   429                    opBuilder.getStrArrayAttr(extensions.getArrayRef()));
   430  }
   431  
   432  LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
   433    if (operands.size() != 2)
   434      return emitError(unknownLoc, "OpMemoryModel must have two operands");
   435  
   436    module->setAttr(
   437        "addressing_model",
   438        opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
   439    module->setAttr(
   440        "memory_model",
   441        opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
   442  
   443    return success();
   444  }
   445  
   446  LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   447    // TODO : This function should also be auto-generated. For now, since only a
   448    // few decorations are processed/handled in a meaningful manner, going with a
   449    // manual implementation.
   450    if (words.size() < 2) {
   451      return emitError(
   452          unknownLoc, "OpDecorate must have at least result <id> and Decoration");
   453    }
   454    auto decorationName =
   455        stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
   456    if (decorationName.empty()) {
   457      return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
   458    }
   459    auto attrName = convertToSnakeCase(decorationName);
   460    switch (static_cast<spirv::Decoration>(words[1])) {
   461    case spirv::Decoration::DescriptorSet:
   462    case spirv::Decoration::Binding:
   463      if (words.size() != 3) {
   464        return emitError(unknownLoc, "OpDecorate with ")
   465               << decorationName << " needs a single integer literal";
   466      }
   467      decorations[words[0]].set(
   468          opBuilder.getIdentifier(attrName),
   469          opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
   470      break;
   471    case spirv::Decoration::BuiltIn:
   472      if (words.size() != 3) {
   473        return emitError(unknownLoc, "OpDecorate with ")
   474               << decorationName << " needs a single integer literal";
   475      }
   476      decorations[words[0]].set(opBuilder.getIdentifier(attrName),
   477                                opBuilder.getStringAttr(stringifyBuiltIn(
   478                                    static_cast<spirv::BuiltIn>(words[2]))));
   479      break;
   480    case spirv::Decoration::ArrayStride:
   481      if (words.size() != 3) {
   482        return emitError(unknownLoc, "OpDecorate with ")
   483               << decorationName << " needs a single integer literal";
   484      }
   485      typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
   486      break;
   487    case spirv::Decoration::Block:
   488      if (words.size() != 2) {
   489        return emitError(unknownLoc, "OpDecoration with ")
   490               << decorationName << "needs a single target <id>";
   491      }
   492      // Block decoration does not affect spv.struct type.
   493      break;
   494    default:
   495      return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
   496    }
   497    return success();
   498  }
   499  
   500  LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
   501    // The binary layout of OpMemberDecorate is different comparing to OpDecorate
   502    if (words.size() != 4) {
   503      return emitError(unknownLoc, "OpMemberDecorate must have 4 operands");
   504    }
   505  
   506    switch (static_cast<spirv::Decoration>(words[2])) {
   507    case spirv::Decoration::Offset:
   508      memberDecorationMap[words[0]][words[1]] = words[3];
   509      break;
   510    default:
   511      return emitError(unknownLoc, "unhandled OpMemberDecoration case: ")
   512             << words[2];
   513    }
   514    return success();
   515  }
   516  
   517  LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
   518    // Get the result type
   519    if (operands.size() != 4) {
   520      return emitError(unknownLoc, "OpFunction must have 4 parameters");
   521    }
   522    Type resultType = getType(operands[0]);
   523    if (!resultType) {
   524      return emitError(unknownLoc, "undefined result type from <id> ")
   525             << operands[0];
   526    }
   527    if (funcMap.count(operands[1])) {
   528      return emitError(unknownLoc, "duplicate function definition/declaration");
   529    }
   530    auto functionControl = spirv::symbolizeFunctionControl(operands[2]);
   531    if (!functionControl) {
   532      return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
   533    }
   534    if (functionControl.getValue() != spirv::FunctionControl::None) {
   535      /// TODO : Handle different function controls
   536      return emitError(unknownLoc, "unhandled Function Control: '")
   537             << spirv::stringifyFunctionControl(functionControl.getValue())
   538             << "'";
   539    }
   540    Type fnType = getType(operands[3]);
   541    if (!fnType || !fnType.isa<FunctionType>()) {
   542      return emitError(unknownLoc, "unknown function type from <id> ")
   543             << operands[3];
   544    }
   545    auto functionType = fnType.cast<FunctionType>();
   546    if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
   547        (functionType.getNumResults() == 1 &&
   548         functionType.getResult(0) != resultType)) {
   549      return emitError(unknownLoc, "mismatch in function type ")
   550             << functionType << " and return type " << resultType << " specified";
   551    }
   552  
   553    std::string fnName = nameMap.lookup(operands[1]).str();
   554    if (fnName.empty()) {
   555      fnName = "spirv_fn_" + std::to_string(operands[2]);
   556    }
   557    auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
   558                                           ArrayRef<NamedAttribute>());
   559    funcMap[operands[1]] = funcOp;
   560    funcOp.addEntryBlock();
   561  
   562    // Parse the op argument instructions
   563    if (functionType.getNumInputs()) {
   564      for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
   565        auto argType = functionType.getInput(i);
   566        spirv::Opcode opcode = spirv::Opcode::OpNop;
   567        ArrayRef<uint32_t> operands;
   568        if (failed(sliceInstruction(opcode, operands,
   569                                    spirv::Opcode::OpFunctionParameter))) {
   570          return failure();
   571        }
   572        if (opcode != spirv::Opcode::OpFunctionParameter) {
   573          return emitError(
   574                     unknownLoc,
   575                     "missing OpFunctionParameter instruction for argument ")
   576                 << i;
   577        }
   578        if (operands.size() != 2) {
   579          return emitError(
   580              unknownLoc,
   581              "expected result type and result <id> for OpFunctionParameter");
   582        }
   583        auto argDefinedType = getType(operands[0]);
   584        if (!argDefinedType || argDefinedType != argType) {
   585          return emitError(unknownLoc,
   586                           "mismatch in argument type between function type "
   587                           "definition ")
   588                 << functionType << " and argument type definition "
   589                 << argDefinedType << " at argument " << i;
   590        }
   591        if (getValue(operands[1])) {
   592          return emitError(unknownLoc, "duplicate definition of result <id> '")
   593                 << operands[1];
   594        }
   595        auto argValue = funcOp.getArgument(i);
   596        valueMap[operands[1]] = argValue;
   597      }
   598    }
   599  
   600    // Create a new builder for building the body.
   601    OpBuilder funcBody(funcOp.getBody());
   602    std::swap(funcBody, opBuilder);
   603  
   604    // Make sure the first basic block, if exists, starts with an OpLabel
   605    // instruction.
   606    if (auto nextOpcode = peekOpcode()) {
   607      if (*nextOpcode != spirv::Opcode::OpFunctionEnd &&
   608          *nextOpcode != spirv::Opcode::OpLabel)
   609        return emitError(unknownLoc, "a basic block must start with OpLabel");
   610    }
   611  
   612    spirv::Opcode opcode = spirv::Opcode::OpNop;
   613    ArrayRef<uint32_t> instOperands;
   614    while (succeeded(sliceInstruction(opcode, instOperands,
   615                                      spirv::Opcode::OpFunctionEnd)) &&
   616           opcode != spirv::Opcode::OpFunctionEnd) {
   617      if (failed(processInstruction(opcode, instOperands))) {
   618        return failure();
   619      }
   620    }
   621    if (opcode != spirv::Opcode::OpFunctionEnd) {
   622      return failure();
   623    }
   624  
   625    // Process OpFunctionEnd.
   626    if (!instOperands.empty()) {
   627      return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
   628    }
   629  
   630    std::swap(funcBody, opBuilder);
   631    return success();
   632  }
   633  
   634  Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
   635    auto constIt = constantMap.find(id);
   636    if (constIt == constantMap.end())
   637      return llvm::None;
   638    return constIt->getSecond();
   639  }
   640  
   641  std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
   642    auto constName = nameMap.lookup(id).str();
   643    if (constName.empty()) {
   644      constName = "spirv_spec_const_" + std::to_string(id);
   645    }
   646    return constName;
   647  }
   648  
   649  LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
   650    unsigned wordIndex = 0;
   651    if (operands.size() < 3) {
   652      return emitError(
   653          unknownLoc,
   654          "OpVariable needs at least 3 operands, type, <id> and storage class");
   655    }
   656  
   657    // Result Type.
   658    auto type = getType(operands[wordIndex]);
   659    if (!type) {
   660      return emitError(unknownLoc, "unknown result type <id> : ")
   661             << operands[wordIndex];
   662    }
   663    auto ptrType = type.dyn_cast<spirv::PointerType>();
   664    if (!ptrType) {
   665      return emitError(unknownLoc,
   666                       "expected a result type <id> to be a spv.ptr, found : ")
   667             << type;
   668    }
   669    wordIndex++;
   670  
   671    // Result <id>.
   672    auto variableID = operands[wordIndex];
   673    auto variableName = nameMap.lookup(variableID).str();
   674    if (variableName.empty()) {
   675      variableName = "spirv_var_" + std::to_string(variableID);
   676    }
   677    wordIndex++;
   678  
   679    // Storage class.
   680    auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
   681    if (ptrType.getStorageClass() != storageClass) {
   682      return emitError(unknownLoc, "mismatch in storage class of pointer type ")
   683             << type << " and that specified in OpVariable instruction  : "
   684             << stringifyStorageClass(storageClass);
   685    }
   686    wordIndex++;
   687  
   688    // Initializer.
   689    SymbolRefAttr initializer = nullptr;
   690    if (wordIndex < operands.size()) {
   691      auto initializerOp = getGlobalVariable(operands[wordIndex]);
   692      if (!initializerOp) {
   693        return emitError(unknownLoc, "unknown <id> ")
   694               << operands[wordIndex] << "used as initializer";
   695      }
   696      wordIndex++;
   697      initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
   698    }
   699    if (wordIndex != operands.size()) {
   700      return emitError(unknownLoc,
   701                       "found more operands than expected when deserializing "
   702                       "OpVariable instruction, only ")
   703             << wordIndex << " of " << operands.size() << " processed";
   704    }
   705    auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
   706        unknownLoc, opBuilder.getTypeAttr(type),
   707        opBuilder.getStringAttr(variableName), initializer);
   708  
   709    // Decorations.
   710    if (decorations.count(variableID)) {
   711      for (auto attr : decorations[variableID].getAttrs()) {
   712        varOp.setAttr(attr.first, attr.second);
   713      }
   714    }
   715    globalVariableMap[variableID] = varOp;
   716    return success();
   717  }
   718  
   719  LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
   720    if (operands.size() < 2) {
   721      return emitError(unknownLoc, "OpName needs at least 2 operands");
   722    }
   723    if (!nameMap.lookup(operands[0]).empty()) {
   724      return emitError(unknownLoc, "duplicate name found for result <id> ")
   725             << operands[0];
   726    }
   727    unsigned wordIndex = 1;
   728    StringRef name = decodeStringLiteral(operands, wordIndex);
   729    if (wordIndex != operands.size()) {
   730      return emitError(unknownLoc,
   731                       "unexpected trailing words in OpName instruction");
   732    }
   733    nameMap[operands[0]] = name;
   734    return success();
   735  }
   736  
   737  //===----------------------------------------------------------------------===//
   738  // Type
   739  //===----------------------------------------------------------------------===//
   740  
   741  LogicalResult Deserializer::processType(spirv::Opcode opcode,
   742                                          ArrayRef<uint32_t> operands) {
   743    if (operands.empty()) {
   744      return emitError(unknownLoc, "type instruction with opcode ")
   745             << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
   746    }
   747  
   748    /// TODO: Types might be forward declared in some instructions and need to be
   749    /// handled appropriately.
   750    if (typeMap.count(operands[0])) {
   751      return emitError(unknownLoc, "duplicate definition for result <id> ")
   752             << operands[0];
   753    }
   754  
   755    switch (opcode) {
   756    case spirv::Opcode::OpTypeVoid:
   757      if (operands.size() != 1) {
   758        return emitError(unknownLoc, "OpTypeVoid must have no parameters");
   759      }
   760      typeMap[operands[0]] = opBuilder.getNoneType();
   761      break;
   762    case spirv::Opcode::OpTypeBool:
   763      if (operands.size() != 1) {
   764        return emitError(unknownLoc, "OpTypeBool must have no parameters");
   765      }
   766      typeMap[operands[0]] = opBuilder.getI1Type();
   767      break;
   768    case spirv::Opcode::OpTypeInt:
   769      if (operands.size() != 3) {
   770        return emitError(
   771            unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
   772      }
   773      if (operands[2] == 0) {
   774        return emitError(unknownLoc, "unhandled unsigned OpTypeInt");
   775      }
   776      typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]);
   777      break;
   778    case spirv::Opcode::OpTypeFloat: {
   779      if (operands.size() != 2) {
   780        return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
   781      }
   782      Type floatTy;
   783      switch (operands[1]) {
   784      case 16:
   785        floatTy = opBuilder.getF16Type();
   786        break;
   787      case 32:
   788        floatTy = opBuilder.getF32Type();
   789        break;
   790      case 64:
   791        floatTy = opBuilder.getF64Type();
   792        break;
   793      default:
   794        return emitError(unknownLoc, "unsupported OpTypeFloat bitwdith: ")
   795               << operands[1];
   796      }
   797      typeMap[operands[0]] = floatTy;
   798    } break;
   799    case spirv::Opcode::OpTypeVector: {
   800      if (operands.size() != 3) {
   801        return emitError(
   802            unknownLoc,
   803            "OpTypeVector must have element type and count parameters");
   804      }
   805      Type elementTy = getType(operands[1]);
   806      if (!elementTy) {
   807        return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
   808               << operands[1];
   809      }
   810      typeMap[operands[0]] = opBuilder.getVectorType({operands[2]}, elementTy);
   811    } break;
   812    case spirv::Opcode::OpTypePointer: {
   813      if (operands.size() != 3) {
   814        return emitError(unknownLoc, "OpTypePointer must have two parameters");
   815      }
   816      auto pointeeType = getType(operands[2]);
   817      if (!pointeeType) {
   818        return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
   819               << operands[2];
   820      }
   821      auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
   822      typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass);
   823    } break;
   824    case spirv::Opcode::OpTypeArray:
   825      return processArrayType(operands);
   826    case spirv::Opcode::OpTypeFunction:
   827      return processFunctionType(operands);
   828    case spirv::Opcode::OpTypeStruct:
   829      return processStructType(operands);
   830    default:
   831      return emitError(unknownLoc, "unhandled type instruction");
   832    }
   833    return success();
   834  }
   835  
   836  LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
   837    if (operands.size() != 3) {
   838      return emitError(unknownLoc,
   839                       "OpTypeArray must have element type and count parameters");
   840    }
   841  
   842    Type elementTy = getType(operands[1]);
   843    if (!elementTy) {
   844      return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
   845             << operands[1];
   846    }
   847  
   848    unsigned count = 0;
   849    // TODO(antiagainst): The count can also come frome a specialization constant.
   850    auto countInfo = getConstant(operands[2]);
   851    if (!countInfo) {
   852      return emitError(unknownLoc, "OpTypeArray count <id> ")
   853             << operands[2] << "can only come from normal constant right now";
   854    }
   855  
   856    if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
   857      count = intVal.getInt();
   858    } else {
   859      return emitError(unknownLoc, "OpTypeArray count must come from a "
   860                                   "scalar integer constant instruction");
   861    }
   862  
   863    typeMap[operands[0]] = spirv::ArrayType::get(
   864        elementTy, count, typeDecorations.lookup(operands[0]));
   865    return success();
   866  }
   867  
   868  LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
   869    assert(!operands.empty() && "No operands for processing function type");
   870    if (operands.size() == 1) {
   871      return emitError(unknownLoc, "missing return type for OpTypeFunction");
   872    }
   873    auto returnType = getType(operands[1]);
   874    if (!returnType) {
   875      return emitError(unknownLoc, "unknown return type in OpTypeFunction");
   876    }
   877    SmallVector<Type, 1> argTypes;
   878    for (size_t i = 2, e = operands.size(); i < e; ++i) {
   879      auto ty = getType(operands[i]);
   880      if (!ty) {
   881        return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
   882      }
   883      argTypes.push_back(ty);
   884    }
   885    ArrayRef<Type> returnTypes;
   886    if (!isVoidType(returnType)) {
   887      returnTypes = llvm::makeArrayRef(returnType);
   888    }
   889    typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
   890    return success();
   891  }
   892  
   893  LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
   894    // TODO(ravishankarm) : Regarding to the spec spv.struct must support zero
   895    // amount of members.
   896    if (operands.size() < 2) {
   897      return emitError(unknownLoc, "OpTypeStruct must have at least 2 operand");
   898    }
   899  
   900    SmallVector<Type, 0> memberTypes;
   901    for (auto op : llvm::drop_begin(operands, 1)) {
   902      Type memberType = getType(op);
   903      if (!memberType) {
   904        return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
   905               << op;
   906      }
   907      memberTypes.push_back(memberType);
   908    }
   909  
   910    SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
   911    // Check for layoutinfo
   912    auto memberDecorationIt = memberDecorationMap.find(operands[0]);
   913    if (memberDecorationIt != memberDecorationMap.end()) {
   914      // Each member must have an offset
   915      const auto &offsetDecorationMap = memberDecorationIt->second;
   916      auto offsetDecorationMapEnd = offsetDecorationMap.end();
   917      for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
   918        // Check that specific member has an offset
   919        auto offsetIt = offsetDecorationMap.find(memberIndex);
   920        if (offsetIt == offsetDecorationMapEnd) {
   921          return emitError(unknownLoc, "OpTypeStruct with <id> ")
   922                 << operands[0] << " must have an offset for " << memberIndex
   923                 << "-th member";
   924        }
   925        layoutInfo.push_back(
   926            static_cast<spirv::StructType::LayoutInfo>(offsetIt->second));
   927      }
   928    }
   929    typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo);
   930    return success();
   931  }
   932  
   933  //===----------------------------------------------------------------------===//
   934  // Constant
   935  //===----------------------------------------------------------------------===//
   936  
   937  LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
   938                                              bool isSpec) {
   939    StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
   940  
   941    if (operands.size() < 2) {
   942      return emitError(unknownLoc)
   943             << opname << " must have type <id> and result <id>";
   944    }
   945    if (operands.size() < 3) {
   946      return emitError(unknownLoc)
   947             << opname << " must have at least 1 more parameter";
   948    }
   949  
   950    Type resultType = getType(operands[0]);
   951    if (!resultType) {
   952      return emitError(unknownLoc, "undefined result type from <id> ")
   953             << operands[0];
   954    }
   955  
   956    auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
   957      if (bitwidth == 64) {
   958        if (operands.size() == 4) {
   959          return success();
   960        }
   961        return emitError(unknownLoc)
   962               << opname << " should have 2 parameters for 64-bit values";
   963      }
   964      if (bitwidth <= 32) {
   965        if (operands.size() == 3) {
   966          return success();
   967        }
   968  
   969        return emitError(unknownLoc)
   970               << opname
   971               << " should have 1 parameter for values with no more than 32 bits";
   972      }
   973      return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
   974             << bitwidth;
   975    };
   976  
   977    auto resultID = operands[1];
   978  
   979    if (auto intType = resultType.dyn_cast<IntegerType>()) {
   980      auto bitwidth = intType.getWidth();
   981      if (failed(checkOperandSizeForBitwidth(bitwidth))) {
   982        return failure();
   983      }
   984  
   985      APInt value;
   986      if (bitwidth == 64) {
   987        // 64-bit integers are represented with two SPIR-V words. According to
   988        // SPIR-V spec: "When the type’s bit width is larger than one word, the
   989        // literal’s low-order words appear first."
   990        struct DoubleWord {
   991          uint32_t word1;
   992          uint32_t word2;
   993        } words = {operands[2], operands[3]};
   994        value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
   995      } else if (bitwidth <= 32) {
   996        value = APInt(bitwidth, operands[2], /*isSigned=*/true);
   997      }
   998  
   999      auto attr = opBuilder.getIntegerAttr(intType, value);
  1000  
  1001      if (isSpec) {
  1002        auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
  1003        auto op =
  1004            opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr);
  1005        specConstMap[resultID] = op;
  1006      } else {
  1007        // For normal constants, we just record the attribute (and its type) for
  1008        // later materialization at use sites.
  1009        constantMap.try_emplace(resultID, attr, intType);
  1010      }
  1011  
  1012      return success();
  1013    }
  1014  
  1015    if (auto floatType = resultType.dyn_cast<FloatType>()) {
  1016      auto bitwidth = floatType.getWidth();
  1017      if (failed(checkOperandSizeForBitwidth(bitwidth))) {
  1018        return failure();
  1019      }
  1020  
  1021      APFloat value(0.f);
  1022      if (floatType.isF64()) {
  1023        // Double values are represented with two SPIR-V words. According to
  1024        // SPIR-V spec: "When the type’s bit width is larger than one word, the
  1025        // literal’s low-order words appear first."
  1026        struct DoubleWord {
  1027          uint32_t word1;
  1028          uint32_t word2;
  1029        } words = {operands[2], operands[3]};
  1030        value = APFloat(llvm::bit_cast<double>(words));
  1031      } else if (floatType.isF32()) {
  1032        value = APFloat(llvm::bit_cast<float>(operands[2]));
  1033      } else if (floatType.isF16()) {
  1034        APInt data(16, operands[2]);
  1035        value = APFloat(APFloat::IEEEhalf(), data);
  1036      }
  1037  
  1038      auto attr = opBuilder.getFloatAttr(floatType, value);
  1039      if (isSpec) {
  1040        auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
  1041        auto op =
  1042            opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr);
  1043        specConstMap[resultID] = op;
  1044      } else {
  1045        // For normal constants, we just record the attribute (and its type) for
  1046        // later materialization at use sites.
  1047        constantMap.try_emplace(resultID, attr, floatType);
  1048      }
  1049  
  1050      return success();
  1051    }
  1052  
  1053    return emitError(unknownLoc, "OpConstant can only generate values of "
  1054                                 "scalar integer or floating-point type");
  1055  }
  1056  
  1057  LogicalResult Deserializer::processConstantBool(bool isTrue,
  1058                                                  ArrayRef<uint32_t> operands,
  1059                                                  bool isSpec) {
  1060    if (operands.size() != 2) {
  1061      return emitError(unknownLoc, "Op")
  1062             << (isSpec ? "Spec" : "") << "Constant"
  1063             << (isTrue ? "True" : "False")
  1064             << " must have type <id> and result <id>";
  1065    }
  1066  
  1067    auto attr = opBuilder.getBoolAttr(isTrue);
  1068    auto resultID = operands[1];
  1069    if (isSpec) {
  1070      auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
  1071      auto op =
  1072          opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr);
  1073      specConstMap[resultID] = op;
  1074    } else {
  1075      // For normal constants, we just record the attribute (and its type) for
  1076      // later materialization at use sites.
  1077      constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
  1078    }
  1079  
  1080    return success();
  1081  }
  1082  
  1083  LogicalResult
  1084  Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
  1085    if (operands.size() < 2) {
  1086      return emitError(unknownLoc,
  1087                       "OpConstantComposite must have type <id> and result <id>");
  1088    }
  1089    if (operands.size() < 3) {
  1090      return emitError(unknownLoc,
  1091                       "OpConstantComposite must have at least 1 parameter");
  1092    }
  1093  
  1094    Type resultType = getType(operands[0]);
  1095    if (!resultType) {
  1096      return emitError(unknownLoc, "undefined result type from <id> ")
  1097             << operands[0];
  1098    }
  1099  
  1100    SmallVector<Attribute, 4> elements;
  1101    elements.reserve(operands.size() - 2);
  1102    for (unsigned i = 2, e = operands.size(); i < e; ++i) {
  1103      auto elementInfo = getConstant(operands[i]);
  1104      if (!elementInfo) {
  1105        return emitError(unknownLoc, "OpConstantComposite component <id> ")
  1106               << operands[i] << " must come from a normal constant";
  1107      }
  1108      elements.push_back(elementInfo->first);
  1109    }
  1110  
  1111    auto resultID = operands[1];
  1112    if (auto vectorType = resultType.dyn_cast<VectorType>()) {
  1113      auto attr = opBuilder.getDenseElementsAttr(vectorType, elements);
  1114      // For normal constants, we just record the attribute (and its type) for
  1115      // later materialization at use sites.
  1116      constantMap.try_emplace(resultID, attr, resultType);
  1117    } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
  1118      auto attr = opBuilder.getArrayAttr(elements);
  1119      constantMap.try_emplace(resultID, attr, resultType);
  1120    } else {
  1121      return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
  1122             << resultType;
  1123    }
  1124  
  1125    return success();
  1126  }
  1127  
  1128  LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
  1129    if (operands.size() != 2) {
  1130      return emitError(unknownLoc,
  1131                       "OpConstantNull must have type <id> and result <id>");
  1132    }
  1133  
  1134    Type resultType = getType(operands[0]);
  1135    if (!resultType) {
  1136      return emitError(unknownLoc, "undefined result type from <id> ")
  1137             << operands[0];
  1138    }
  1139  
  1140    auto resultID = operands[1];
  1141    if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
  1142        resultType.isa<VectorType>()) {
  1143      auto attr = opBuilder.getZeroAttr(resultType);
  1144      // For normal constants, we just record the attribute (and its type) for
  1145      // later materialization at use sites.
  1146      constantMap.try_emplace(resultID, attr, resultType);
  1147      return success();
  1148    }
  1149  
  1150      return emitError(unknownLoc, "unsupported OpConstantNull type: ")
  1151             << resultType;
  1152  }
  1153  
  1154  //===----------------------------------------------------------------------===//
  1155  // Control flow
  1156  //===----------------------------------------------------------------------===//
  1157  
  1158  LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
  1159    if (operands.size() != 1) {
  1160      return emitError(unknownLoc, "OpLabel should only have result <id>");
  1161    }
  1162    // TODO(antiagainst): support basic blocks and control flow properly.
  1163    return success();
  1164  }
  1165  
  1166  //===----------------------------------------------------------------------===//
  1167  // Instruction
  1168  //===----------------------------------------------------------------------===//
  1169  
  1170  Value *Deserializer::getValue(uint32_t id) {
  1171    if (auto constInfo = getConstant(id)) {
  1172      // Materialize a `spv.constant` op at every use site.
  1173      return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
  1174                                                 constInfo->first);
  1175    }
  1176    if (auto varOp = getGlobalVariable(id)) {
  1177      auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
  1178          unknownLoc, varOp.type(),
  1179          opBuilder.getSymbolRefAttr(varOp.getOperation()));
  1180      return addressOfOp.pointer();
  1181    }
  1182    if (auto constOp = getSpecConstant(id)) {
  1183      auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
  1184          unknownLoc, constOp.default_value().getType(),
  1185          opBuilder.getSymbolRefAttr(constOp.getOperation()));
  1186      return referenceOfOp.reference();
  1187    }
  1188    return valueMap.lookup(id);
  1189  }
  1190  
  1191  LogicalResult
  1192  Deserializer::sliceInstruction(spirv::Opcode &opcode,
  1193                                 ArrayRef<uint32_t> &operands,
  1194                                 Optional<spirv::Opcode> expectedOpcode) {
  1195    auto binarySize = binary.size();
  1196    if (curOffset >= binarySize) {
  1197      return emitError(unknownLoc, "expected ")
  1198             << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
  1199                                : "more")
  1200             << " instruction";
  1201    }
  1202  
  1203    // For each instruction, get its word count from the first word to slice it
  1204    // from the stream properly, and then dispatch to the instruction handler.
  1205  
  1206    uint32_t wordCount = binary[curOffset] >> 16;
  1207  
  1208    if (wordCount == 0)
  1209      return emitError(unknownLoc, "word count cannot be zero");
  1210  
  1211    uint32_t nextOffset = curOffset + wordCount;
  1212    if (nextOffset > binarySize)
  1213      return emitError(unknownLoc, "insufficient words for the last instruction");
  1214  
  1215    opcode = extractOpcode(binary[curOffset]);
  1216    operands = binary.slice(curOffset + 1, wordCount - 1);
  1217    curOffset = nextOffset;
  1218    return success();
  1219  }
  1220  
  1221  Optional<spirv::Opcode> Deserializer::peekOpcode() {
  1222    if (curOffset >= binary.size())
  1223      return llvm::None;
  1224    return extractOpcode(binary[curOffset]);
  1225  }
  1226  
  1227  LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
  1228                                                 ArrayRef<uint32_t> operands,
  1229                                                 bool deferInstructions) {
  1230    // First dispatch all the instructions whose opcode does not correspond to
  1231    // those that have a direct mirror in the SPIR-V dialect
  1232    switch (opcode) {
  1233    case spirv::Opcode::OpCapability:
  1234      return processCapability(operands);
  1235    case spirv::Opcode::OpExtension:
  1236      return processExtension(operands);
  1237    case spirv::Opcode::OpMemoryModel:
  1238      return processMemoryModel(operands);
  1239    case spirv::Opcode::OpEntryPoint:
  1240    case spirv::Opcode::OpExecutionMode:
  1241      if (deferInstructions) {
  1242        deferedInstructions.emplace_back(opcode, operands);
  1243        return success();
  1244      }
  1245      break;
  1246    case spirv::Opcode::OpVariable:
  1247      if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
  1248        return processGlobalVariable(operands);
  1249      }
  1250      break;
  1251    case spirv::Opcode::OpName:
  1252      return processName(operands);
  1253    case spirv::Opcode::OpTypeVoid:
  1254    case spirv::Opcode::OpTypeBool:
  1255    case spirv::Opcode::OpTypeInt:
  1256    case spirv::Opcode::OpTypeFloat:
  1257    case spirv::Opcode::OpTypeVector:
  1258    case spirv::Opcode::OpTypeArray:
  1259    case spirv::Opcode::OpTypeFunction:
  1260    case spirv::Opcode::OpTypeStruct:
  1261    case spirv::Opcode::OpTypePointer:
  1262      return processType(opcode, operands);
  1263    case spirv::Opcode::OpConstant:
  1264      return processConstant(operands, /*isSpec=*/false);
  1265    case spirv::Opcode::OpSpecConstant:
  1266      return processConstant(operands, /*isSpec=*/true);
  1267    case spirv::Opcode::OpConstantComposite:
  1268      return processConstantComposite(operands);
  1269    case spirv::Opcode::OpConstantTrue:
  1270      return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
  1271    case spirv::Opcode::OpSpecConstantTrue:
  1272      return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
  1273    case spirv::Opcode::OpConstantFalse:
  1274      return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
  1275    case spirv::Opcode::OpSpecConstantFalse:
  1276      return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
  1277    case spirv::Opcode::OpConstantNull:
  1278      return processConstantNull(operands);
  1279    case spirv::Opcode::OpDecorate:
  1280      return processDecoration(operands);
  1281    case spirv::Opcode::OpMemberDecorate:
  1282      return processMemberDecoration(operands);
  1283    case spirv::Opcode::OpFunction:
  1284      return processFunction(operands);
  1285    case spirv::Opcode::OpLabel:
  1286      return processLabel(operands);
  1287    default:
  1288      break;
  1289    }
  1290    return dispatchToAutogenDeserialization(opcode, operands);
  1291  }
  1292  
  1293  namespace {
  1294  
  1295  template <>
  1296  LogicalResult
  1297  Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
  1298    unsigned wordIndex = 0;
  1299    if (wordIndex >= words.size()) {
  1300      return emitError(unknownLoc,
  1301                       "missing Execution Model specification in OpEntryPoint");
  1302    }
  1303    auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]);
  1304    if (wordIndex >= words.size()) {
  1305      return emitError(unknownLoc, "missing <id> in OpEntryPoint");
  1306    }
  1307    // Get the function <id>
  1308    auto fnID = words[wordIndex++];
  1309    // Get the function name
  1310    auto fnName = decodeStringLiteral(words, wordIndex);
  1311    // Verify that the function <id> matches the fnName
  1312    auto parsedFunc = getFunction(fnID);
  1313    if (!parsedFunc) {
  1314      return emitError(unknownLoc, "no function matching <id> ") << fnID;
  1315    }
  1316    if (parsedFunc.getName() != fnName) {
  1317      return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
  1318                                   "and OpFunction with <id> ")
  1319             << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
  1320    }
  1321    SmallVector<Attribute, 4> interface;
  1322    while (wordIndex < words.size()) {
  1323      auto arg = getGlobalVariable(words[wordIndex]);
  1324      if (!arg) {
  1325        return emitError(unknownLoc, "undefined result <id> ")
  1326               << words[wordIndex] << " while decoding OpEntryPoint";
  1327      }
  1328      interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
  1329      wordIndex++;
  1330    }
  1331    opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
  1332                                          opBuilder.getSymbolRefAttr(fnName),
  1333                                          opBuilder.getArrayAttr(interface));
  1334    return success();
  1335  }
  1336  
  1337  template <>
  1338  LogicalResult
  1339  Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
  1340    unsigned wordIndex = 0;
  1341    if (wordIndex >= words.size()) {
  1342      return emitError(unknownLoc,
  1343                       "missing function result <id> in OpExecutionMode");
  1344    }
  1345    // Get the function <id> to get the name of the function
  1346    auto fnID = words[wordIndex++];
  1347    auto fn = getFunction(fnID);
  1348    if (!fn) {
  1349      return emitError(unknownLoc, "no function matching <id> ") << fnID;
  1350    }
  1351    // Get the Execution mode
  1352    if (wordIndex >= words.size()) {
  1353      return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
  1354    }
  1355    auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
  1356  
  1357    // Get the values
  1358    SmallVector<Attribute, 4> attrListElems;
  1359    while (wordIndex < words.size()) {
  1360      attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
  1361    }
  1362    auto values = opBuilder.getArrayAttr(attrListElems);
  1363    opBuilder.create<spirv::ExecutionModeOp>(
  1364        unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
  1365    return success();
  1366  }
  1367  
  1368  // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
  1369  // various Deserializer::processOp<...>() specializations.
  1370  #define GET_DESERIALIZATION_FNS
  1371  #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
  1372  } // namespace
  1373  
  1374  Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary,
  1375                                               MLIRContext *context) {
  1376    Deserializer deserializer(binary, context);
  1377  
  1378    if (failed(deserializer.deserialize()))
  1379      return llvm::None;
  1380  
  1381    return deserializer.collect();
  1382  }