github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp (about) 1 //===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines the SPIR-V binary to MLIR SPIR-V module deseralization. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/SPIRV/Serialization.h" 23 24 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 25 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 26 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/Location.h" 29 #include "mlir/Support/LogicalResult.h" 30 #include "mlir/Support/StringExtras.h" 31 #include "llvm/ADT/Sequence.h" 32 #include "llvm/ADT/SetVector.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/bit.h" 35 36 using namespace mlir; 37 38 // Decodes a string literal in `words` starting at `wordIndex`. Update the 39 // latter to point to the position in words after the string literal. 40 static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words, 41 unsigned &wordIndex) { 42 StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex)); 43 wordIndex += str.size() / 4 + 1; 44 return str; 45 } 46 47 // Extracts the opcode from the given first word of a SPIR-V instruction. 48 static inline spirv::Opcode extractOpcode(uint32_t word) { 49 return static_cast<spirv::Opcode>(word & 0xffff); 50 } 51 52 namespace { 53 /// A SPIR-V module serializer. 54 /// 55 /// A SPIR-V binary module is a single linear stream of instructions; each 56 /// instruction is composed of 32-bit words. The first word of an instruction 57 /// records the total number of words of that instruction using the 16 58 /// higher-order bits. So this deserializer uses that to get instruction 59 /// boundary and parse instructions and build a SPIR-V ModuleOp gradually. 60 /// 61 // TODO(antiagainst): clean up created ops on errors 62 class Deserializer { 63 public: 64 /// Creates a deserializer for the given SPIR-V `binary` module. 65 /// The SPIR-V ModuleOp will be created into `context. 66 explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context); 67 68 /// Deserializes the remembered SPIR-V binary module. 69 LogicalResult deserialize(); 70 71 /// Collects the final SPIR-V ModuleOp. 72 Optional<spirv::ModuleOp> collect(); 73 74 private: 75 //===--------------------------------------------------------------------===// 76 // Module structure 77 //===--------------------------------------------------------------------===// 78 79 /// Initializes the `module` ModuleOp in this deserializer instance. 80 spirv::ModuleOp createModuleOp(); 81 82 /// Processes SPIR-V module header in `binary`. 83 LogicalResult processHeader(); 84 85 /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping 86 /// in the deserializer. 87 LogicalResult processCapability(ArrayRef<uint32_t> operands); 88 89 /// Attaches all collected capabilites to `module` as an attribute. 90 void attachCapabilities(); 91 92 /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping 93 /// in the deserializer. 94 LogicalResult processExtension(ArrayRef<uint32_t> operands); 95 96 /// Attaches all collected extensions to `module` as an attribute. 97 void attachExtensions(); 98 99 /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`. 100 LogicalResult processMemoryModel(ArrayRef<uint32_t> operands); 101 102 /// Process SPIR-V OpName with `operands`. 103 LogicalResult processName(ArrayRef<uint32_t> operands); 104 105 /// Method to process an OpDecorate instruction. 106 LogicalResult processDecoration(ArrayRef<uint32_t> words); 107 108 // Method to process an OpMemberDecorate instruction. 109 LogicalResult processMemberDecoration(ArrayRef<uint32_t> words); 110 111 /// Gets the FuncOp associated with a result <id> of OpFunction. 112 FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } 113 114 /// Processes the SPIR-V function at the current `offset` into `binary`. 115 /// The operands to the OpFunction instruction is passed in as ``operands`. 116 /// This method processes each instruction inside the function and dispatches 117 /// them to their handler method accordingly. 118 LogicalResult processFunction(ArrayRef<uint32_t> operands); 119 120 /// Gets the constant's attribute and type associated with the given <id>. 121 Optional<std::pair<Attribute, Type>> getConstant(uint32_t id); 122 123 /// Returns a symbol to be used for the specialization constant with the given 124 /// result <id>. This tries to use the specialization constant's OpName if 125 /// exists; otherwise creates one based on the <id>. 126 std::string getSpecConstantSymbol(uint32_t id); 127 128 /// Gets the specialization constant with the given result <id>. 129 spirv::SpecConstantOp getSpecConstant(uint32_t id) { 130 return specConstMap.lookup(id); 131 } 132 133 /// Processes the OpVariable instructions at current `offset` into `binary`. 134 /// It is expected that this method is used for variables that are to be 135 /// defined at module scope and will be deserialized into a spv.globalVariable 136 /// instruction. 137 LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands); 138 139 /// Gets the global variable associated with a result <id> of OpVariable. 140 spirv::GlobalVariableOp getGlobalVariable(uint32_t id) { 141 return globalVariableMap.lookup(id); 142 } 143 144 //===--------------------------------------------------------------------===// 145 // Type 146 //===--------------------------------------------------------------------===// 147 148 /// Gets type for a given result <id>. 149 Type getType(uint32_t id) { return typeMap.lookup(id); } 150 151 /// Returns true if the given `type` is for SPIR-V void type. 152 bool isVoidType(Type type) const { return type.isa<NoneType>(); } 153 154 /// Processes a SPIR-V type instruction with given `opcode` and `operands` and 155 /// registers the type into `module`. 156 LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands); 157 158 LogicalResult processArrayType(ArrayRef<uint32_t> operands); 159 160 LogicalResult processFunctionType(ArrayRef<uint32_t> operands); 161 162 LogicalResult processStructType(ArrayRef<uint32_t> operands); 163 164 //===--------------------------------------------------------------------===// 165 // Constant 166 //===--------------------------------------------------------------------===// 167 168 /// Processes a SPIR-V Op{|Spec}Constant instruction with the given 169 /// `operands`. `isSpec` indicates whether this is a specialization constant. 170 LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec); 171 172 /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the 173 /// given `operands`. `isSpec` indicates whether this is a specialization 174 /// constant. 175 LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands, 176 bool isSpec); 177 178 /// Processes a SPIR-V OpConstantComposite instruction with the given 179 /// `operands`. 180 LogicalResult processConstantComposite(ArrayRef<uint32_t> operands); 181 182 /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. 183 LogicalResult processConstantNull(ArrayRef<uint32_t> operands); 184 185 //===--------------------------------------------------------------------===// 186 // Control flow 187 //===--------------------------------------------------------------------===// 188 189 /// Processes a SPIR-V OpLabel instruction with the given `operands`. 190 LogicalResult processLabel(ArrayRef<uint32_t> operands); 191 192 //===--------------------------------------------------------------------===// 193 // Instruction 194 //===--------------------------------------------------------------------===// 195 196 /// Get the Value associated with a result <id>. 197 /// 198 /// This method materializes normal constants and inserts "casting" ops 199 /// (`spv._address_of` and `spv._reference_of`) to turn an symbol into a SSA 200 /// value for handling uses of module scope constants/variables in functions. 201 Value *getValue(uint32_t id); 202 203 /// Slices the first instruction out of `binary` and returns its opcode and 204 /// operands via `opcode` and `operands` respectively. Returns failure if 205 /// there is no more remaining instructions (`expectedOpcode` will be used to 206 /// compose the error message) or the next instruction is malformed. 207 LogicalResult 208 sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands, 209 Optional<spirv::Opcode> expectedOpcode = llvm::None); 210 211 /// Returns the next instruction's opcode if exists. 212 Optional<spirv::Opcode> peekOpcode(); 213 214 /// Processes a SPIR-V instruction with the given `opcode` and `operands`. 215 /// This method is the main entrance for handling SPIR-V instruction; it 216 /// checks the instruction opcode and dispatches to the corresponding handler. 217 /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode) 218 /// might need to be defered, since they contain forward references to <id>s 219 /// in the deserialized binary, but module in SPIR-V dialect expects these to 220 /// be ssa-uses. 221 LogicalResult processInstruction(spirv::Opcode opcode, 222 ArrayRef<uint32_t> operands, 223 bool deferInstructions = true); 224 225 /// Method to dispatch to the specialized deserialization function for an 226 /// operation in SPIR-V dialect that is a mirror of an instruction in the 227 /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for 228 /// all operations in SPIR-V dialect that have hasOpcode == 1. 229 LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, 230 ArrayRef<uint32_t> words); 231 232 /// Method to deserialize an operation in the SPIR-V dialect that is a mirror 233 /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode 234 /// == 1 and autogenSerialization == 1 in ODS. 235 template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) { 236 return emitError(unknownLoc, "unsupported deserialization for ") 237 << OpTy::getOperationName() << " op"; 238 } 239 240 private: 241 /// The SPIR-V binary module. 242 ArrayRef<uint32_t> binary; 243 244 /// The current word offset into the binary module. 245 unsigned curOffset = 0; 246 247 /// MLIRContext to create SPIR-V ModuleOp into. 248 MLIRContext *context; 249 250 // TODO(antiagainst): create Location subclass for binary blob 251 Location unknownLoc; 252 253 /// The SPIR-V ModuleOp. 254 Optional<spirv::ModuleOp> module; 255 256 OpBuilder opBuilder; 257 258 /// The list of capabilities used by the module. 259 llvm::SmallSetVector<spirv::Capability, 4> capabilities; 260 261 /// The list of extensions used by the module. 262 llvm::SmallSetVector<StringRef, 2> extensions; 263 264 // Result <id> to type mapping. 265 DenseMap<uint32_t, Type> typeMap; 266 267 // Result <id> to constant attribute and type mapping. 268 /// 269 /// In the SPIR-V binary format, all constants are placed in the module and 270 /// shared by instructions at module level and in subsequent functions. But in 271 /// the SPIR-V dialect, we materialize the constant to where it's used in the 272 /// function. So when seeing a constant instruction in the binary format, we 273 /// don't immediately emit a constant op into the module, we keep its value 274 /// (and type) here. Later when it's used, we materialize the constant. 275 DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap; 276 277 // Result <id> to variable mapping. 278 DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap; 279 280 // Result <id> to variable mapping. 281 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap; 282 283 // Result <id> to function mapping. 284 DenseMap<uint32_t, FuncOp> funcMap; 285 286 // Result <id> to value mapping. 287 DenseMap<uint32_t, Value *> valueMap; 288 289 // Result <id> to name mapping. 290 DenseMap<uint32_t, StringRef> nameMap; 291 292 // Result <id> to decorations mapping. 293 DenseMap<uint32_t, NamedAttributeList> decorations; 294 295 // Result <id> to type decorations. 296 DenseMap<uint32_t, uint32_t> typeDecorations; 297 298 // Result <id> to member decorations. 299 DenseMap<uint32_t, DenseMap<uint32_t, uint32_t>> memberDecorationMap; 300 301 // List of instructions that are processed in a defered fashion (after an 302 // initial processing of the entire binary). Some operations like 303 // OpEntryPoint, and OpExecutionMode use forward references to function 304 // <id>s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and 305 // spv.ExecutionMode) need these references resolved. So these instructions 306 // are deserialized and stored for processing once the entire binary is 307 // processed. 308 SmallVector<std::pair<spirv::Opcode, ArrayRef<uint32_t>>, 4> 309 deferedInstructions; 310 }; 311 } // namespace 312 313 Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context) 314 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), 315 module(createModuleOp()), 316 opBuilder(module->getOperation()->getRegion(0)) {} 317 318 LogicalResult Deserializer::deserialize() { 319 if (failed(processHeader())) 320 return failure(); 321 322 spirv::Opcode opcode = spirv::Opcode::OpNop; 323 ArrayRef<uint32_t> operands; 324 auto binarySize = binary.size(); 325 while (curOffset < binarySize) { 326 // Slice the next instruction out and populate `opcode` and `operands`. 327 // Interally this also updates `curOffset`. 328 if (failed(sliceInstruction(opcode, operands))) 329 return failure(); 330 331 if (failed(processInstruction(opcode, operands))) 332 return failure(); 333 } 334 335 assert(curOffset == binarySize && 336 "deserializer should never index beyond the binary end"); 337 338 for (auto &defered : deferedInstructions) { 339 if (failed(processInstruction(defered.first, defered.second, false))) { 340 return failure(); 341 } 342 } 343 344 // Attaches the capabilities/extensions as an attribute to the module. 345 attachCapabilities(); 346 attachExtensions(); 347 348 return success(); 349 } 350 351 Optional<spirv::ModuleOp> Deserializer::collect() { return module; } 352 353 //===----------------------------------------------------------------------===// 354 // Module structure 355 //===----------------------------------------------------------------------===// 356 357 spirv::ModuleOp Deserializer::createModuleOp() { 358 Builder builder(context); 359 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); 360 // TODO(antiagainst): use target environment to select the version 361 state.addAttribute("major_version", builder.getI32IntegerAttr(1)); 362 state.addAttribute("minor_version", builder.getI32IntegerAttr(0)); 363 spirv::ModuleOp::build(&builder, &state); 364 return cast<spirv::ModuleOp>(Operation::create(state)); 365 } 366 367 LogicalResult Deserializer::processHeader() { 368 if (binary.size() < spirv::kHeaderWordCount) 369 return emitError(unknownLoc, 370 "SPIR-V binary module must have a 5-word header"); 371 372 if (binary[0] != spirv::kMagicNumber) 373 return emitError(unknownLoc, "incorrect magic number"); 374 375 // TODO(antiagainst): generator number, bound, schema 376 curOffset = spirv::kHeaderWordCount; 377 return success(); 378 } 379 380 LogicalResult Deserializer::processCapability(ArrayRef<uint32_t> operands) { 381 if (operands.size() != 1) 382 return emitError(unknownLoc, "OpMemoryModel must have one parameter"); 383 384 auto cap = spirv::symbolizeCapability(operands[0]); 385 if (!cap) 386 return emitError(unknownLoc, "unknown capability: ") << operands[0]; 387 388 capabilities.insert(*cap); 389 return success(); 390 } 391 392 void Deserializer::attachCapabilities() { 393 if (capabilities.empty()) 394 return; 395 396 SmallVector<StringRef, 2> caps; 397 caps.reserve(capabilities.size()); 398 399 for (auto cap : capabilities) { 400 caps.push_back(spirv::stringifyCapability(cap)); 401 } 402 403 module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps)); 404 } 405 406 LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> operands) { 407 if (operands.empty()) { 408 return emitError( 409 unknownLoc, 410 "OpExtension must have a literal string for the extension name"); 411 } 412 413 unsigned wordIndex = 0; 414 StringRef extName = decodeStringLiteral(operands, wordIndex); 415 if (wordIndex != operands.size()) { 416 return emitError(unknownLoc, 417 "unexpected trailing words in OpExtension instruction"); 418 } 419 420 extensions.insert(extName); 421 return success(); 422 } 423 424 void Deserializer::attachExtensions() { 425 if (extensions.empty()) 426 return; 427 428 module->setAttr("extensions", 429 opBuilder.getStrArrayAttr(extensions.getArrayRef())); 430 } 431 432 LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { 433 if (operands.size() != 2) 434 return emitError(unknownLoc, "OpMemoryModel must have two operands"); 435 436 module->setAttr( 437 "addressing_model", 438 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front()))); 439 module->setAttr( 440 "memory_model", 441 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back()))); 442 443 return success(); 444 } 445 446 LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) { 447 // TODO : This function should also be auto-generated. For now, since only a 448 // few decorations are processed/handled in a meaningful manner, going with a 449 // manual implementation. 450 if (words.size() < 2) { 451 return emitError( 452 unknownLoc, "OpDecorate must have at least result <id> and Decoration"); 453 } 454 auto decorationName = 455 stringifyDecoration(static_cast<spirv::Decoration>(words[1])); 456 if (decorationName.empty()) { 457 return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; 458 } 459 auto attrName = convertToSnakeCase(decorationName); 460 switch (static_cast<spirv::Decoration>(words[1])) { 461 case spirv::Decoration::DescriptorSet: 462 case spirv::Decoration::Binding: 463 if (words.size() != 3) { 464 return emitError(unknownLoc, "OpDecorate with ") 465 << decorationName << " needs a single integer literal"; 466 } 467 decorations[words[0]].set( 468 opBuilder.getIdentifier(attrName), 469 opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2]))); 470 break; 471 case spirv::Decoration::BuiltIn: 472 if (words.size() != 3) { 473 return emitError(unknownLoc, "OpDecorate with ") 474 << decorationName << " needs a single integer literal"; 475 } 476 decorations[words[0]].set(opBuilder.getIdentifier(attrName), 477 opBuilder.getStringAttr(stringifyBuiltIn( 478 static_cast<spirv::BuiltIn>(words[2])))); 479 break; 480 case spirv::Decoration::ArrayStride: 481 if (words.size() != 3) { 482 return emitError(unknownLoc, "OpDecorate with ") 483 << decorationName << " needs a single integer literal"; 484 } 485 typeDecorations[words[0]] = static_cast<uint32_t>(words[2]); 486 break; 487 case spirv::Decoration::Block: 488 if (words.size() != 2) { 489 return emitError(unknownLoc, "OpDecoration with ") 490 << decorationName << "needs a single target <id>"; 491 } 492 // Block decoration does not affect spv.struct type. 493 break; 494 default: 495 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; 496 } 497 return success(); 498 } 499 500 LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) { 501 // The binary layout of OpMemberDecorate is different comparing to OpDecorate 502 if (words.size() != 4) { 503 return emitError(unknownLoc, "OpMemberDecorate must have 4 operands"); 504 } 505 506 switch (static_cast<spirv::Decoration>(words[2])) { 507 case spirv::Decoration::Offset: 508 memberDecorationMap[words[0]][words[1]] = words[3]; 509 break; 510 default: 511 return emitError(unknownLoc, "unhandled OpMemberDecoration case: ") 512 << words[2]; 513 } 514 return success(); 515 } 516 517 LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) { 518 // Get the result type 519 if (operands.size() != 4) { 520 return emitError(unknownLoc, "OpFunction must have 4 parameters"); 521 } 522 Type resultType = getType(operands[0]); 523 if (!resultType) { 524 return emitError(unknownLoc, "undefined result type from <id> ") 525 << operands[0]; 526 } 527 if (funcMap.count(operands[1])) { 528 return emitError(unknownLoc, "duplicate function definition/declaration"); 529 } 530 auto functionControl = spirv::symbolizeFunctionControl(operands[2]); 531 if (!functionControl) { 532 return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; 533 } 534 if (functionControl.getValue() != spirv::FunctionControl::None) { 535 /// TODO : Handle different function controls 536 return emitError(unknownLoc, "unhandled Function Control: '") 537 << spirv::stringifyFunctionControl(functionControl.getValue()) 538 << "'"; 539 } 540 Type fnType = getType(operands[3]); 541 if (!fnType || !fnType.isa<FunctionType>()) { 542 return emitError(unknownLoc, "unknown function type from <id> ") 543 << operands[3]; 544 } 545 auto functionType = fnType.cast<FunctionType>(); 546 if ((isVoidType(resultType) && functionType.getNumResults() != 0) || 547 (functionType.getNumResults() == 1 && 548 functionType.getResult(0) != resultType)) { 549 return emitError(unknownLoc, "mismatch in function type ") 550 << functionType << " and return type " << resultType << " specified"; 551 } 552 553 std::string fnName = nameMap.lookup(operands[1]).str(); 554 if (fnName.empty()) { 555 fnName = "spirv_fn_" + std::to_string(operands[2]); 556 } 557 auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType, 558 ArrayRef<NamedAttribute>()); 559 funcMap[operands[1]] = funcOp; 560 funcOp.addEntryBlock(); 561 562 // Parse the op argument instructions 563 if (functionType.getNumInputs()) { 564 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { 565 auto argType = functionType.getInput(i); 566 spirv::Opcode opcode = spirv::Opcode::OpNop; 567 ArrayRef<uint32_t> operands; 568 if (failed(sliceInstruction(opcode, operands, 569 spirv::Opcode::OpFunctionParameter))) { 570 return failure(); 571 } 572 if (opcode != spirv::Opcode::OpFunctionParameter) { 573 return emitError( 574 unknownLoc, 575 "missing OpFunctionParameter instruction for argument ") 576 << i; 577 } 578 if (operands.size() != 2) { 579 return emitError( 580 unknownLoc, 581 "expected result type and result <id> for OpFunctionParameter"); 582 } 583 auto argDefinedType = getType(operands[0]); 584 if (!argDefinedType || argDefinedType != argType) { 585 return emitError(unknownLoc, 586 "mismatch in argument type between function type " 587 "definition ") 588 << functionType << " and argument type definition " 589 << argDefinedType << " at argument " << i; 590 } 591 if (getValue(operands[1])) { 592 return emitError(unknownLoc, "duplicate definition of result <id> '") 593 << operands[1]; 594 } 595 auto argValue = funcOp.getArgument(i); 596 valueMap[operands[1]] = argValue; 597 } 598 } 599 600 // Create a new builder for building the body. 601 OpBuilder funcBody(funcOp.getBody()); 602 std::swap(funcBody, opBuilder); 603 604 // Make sure the first basic block, if exists, starts with an OpLabel 605 // instruction. 606 if (auto nextOpcode = peekOpcode()) { 607 if (*nextOpcode != spirv::Opcode::OpFunctionEnd && 608 *nextOpcode != spirv::Opcode::OpLabel) 609 return emitError(unknownLoc, "a basic block must start with OpLabel"); 610 } 611 612 spirv::Opcode opcode = spirv::Opcode::OpNop; 613 ArrayRef<uint32_t> instOperands; 614 while (succeeded(sliceInstruction(opcode, instOperands, 615 spirv::Opcode::OpFunctionEnd)) && 616 opcode != spirv::Opcode::OpFunctionEnd) { 617 if (failed(processInstruction(opcode, instOperands))) { 618 return failure(); 619 } 620 } 621 if (opcode != spirv::Opcode::OpFunctionEnd) { 622 return failure(); 623 } 624 625 // Process OpFunctionEnd. 626 if (!instOperands.empty()) { 627 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); 628 } 629 630 std::swap(funcBody, opBuilder); 631 return success(); 632 } 633 634 Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) { 635 auto constIt = constantMap.find(id); 636 if (constIt == constantMap.end()) 637 return llvm::None; 638 return constIt->getSecond(); 639 } 640 641 std::string Deserializer::getSpecConstantSymbol(uint32_t id) { 642 auto constName = nameMap.lookup(id).str(); 643 if (constName.empty()) { 644 constName = "spirv_spec_const_" + std::to_string(id); 645 } 646 return constName; 647 } 648 649 LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) { 650 unsigned wordIndex = 0; 651 if (operands.size() < 3) { 652 return emitError( 653 unknownLoc, 654 "OpVariable needs at least 3 operands, type, <id> and storage class"); 655 } 656 657 // Result Type. 658 auto type = getType(operands[wordIndex]); 659 if (!type) { 660 return emitError(unknownLoc, "unknown result type <id> : ") 661 << operands[wordIndex]; 662 } 663 auto ptrType = type.dyn_cast<spirv::PointerType>(); 664 if (!ptrType) { 665 return emitError(unknownLoc, 666 "expected a result type <id> to be a spv.ptr, found : ") 667 << type; 668 } 669 wordIndex++; 670 671 // Result <id>. 672 auto variableID = operands[wordIndex]; 673 auto variableName = nameMap.lookup(variableID).str(); 674 if (variableName.empty()) { 675 variableName = "spirv_var_" + std::to_string(variableID); 676 } 677 wordIndex++; 678 679 // Storage class. 680 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]); 681 if (ptrType.getStorageClass() != storageClass) { 682 return emitError(unknownLoc, "mismatch in storage class of pointer type ") 683 << type << " and that specified in OpVariable instruction : " 684 << stringifyStorageClass(storageClass); 685 } 686 wordIndex++; 687 688 // Initializer. 689 SymbolRefAttr initializer = nullptr; 690 if (wordIndex < operands.size()) { 691 auto initializerOp = getGlobalVariable(operands[wordIndex]); 692 if (!initializerOp) { 693 return emitError(unknownLoc, "unknown <id> ") 694 << operands[wordIndex] << "used as initializer"; 695 } 696 wordIndex++; 697 initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); 698 } 699 if (wordIndex != operands.size()) { 700 return emitError(unknownLoc, 701 "found more operands than expected when deserializing " 702 "OpVariable instruction, only ") 703 << wordIndex << " of " << operands.size() << " processed"; 704 } 705 auto varOp = opBuilder.create<spirv::GlobalVariableOp>( 706 unknownLoc, opBuilder.getTypeAttr(type), 707 opBuilder.getStringAttr(variableName), initializer); 708 709 // Decorations. 710 if (decorations.count(variableID)) { 711 for (auto attr : decorations[variableID].getAttrs()) { 712 varOp.setAttr(attr.first, attr.second); 713 } 714 } 715 globalVariableMap[variableID] = varOp; 716 return success(); 717 } 718 719 LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) { 720 if (operands.size() < 2) { 721 return emitError(unknownLoc, "OpName needs at least 2 operands"); 722 } 723 if (!nameMap.lookup(operands[0]).empty()) { 724 return emitError(unknownLoc, "duplicate name found for result <id> ") 725 << operands[0]; 726 } 727 unsigned wordIndex = 1; 728 StringRef name = decodeStringLiteral(operands, wordIndex); 729 if (wordIndex != operands.size()) { 730 return emitError(unknownLoc, 731 "unexpected trailing words in OpName instruction"); 732 } 733 nameMap[operands[0]] = name; 734 return success(); 735 } 736 737 //===----------------------------------------------------------------------===// 738 // Type 739 //===----------------------------------------------------------------------===// 740 741 LogicalResult Deserializer::processType(spirv::Opcode opcode, 742 ArrayRef<uint32_t> operands) { 743 if (operands.empty()) { 744 return emitError(unknownLoc, "type instruction with opcode ") 745 << spirv::stringifyOpcode(opcode) << " needs at least one <id>"; 746 } 747 748 /// TODO: Types might be forward declared in some instructions and need to be 749 /// handled appropriately. 750 if (typeMap.count(operands[0])) { 751 return emitError(unknownLoc, "duplicate definition for result <id> ") 752 << operands[0]; 753 } 754 755 switch (opcode) { 756 case spirv::Opcode::OpTypeVoid: 757 if (operands.size() != 1) { 758 return emitError(unknownLoc, "OpTypeVoid must have no parameters"); 759 } 760 typeMap[operands[0]] = opBuilder.getNoneType(); 761 break; 762 case spirv::Opcode::OpTypeBool: 763 if (operands.size() != 1) { 764 return emitError(unknownLoc, "OpTypeBool must have no parameters"); 765 } 766 typeMap[operands[0]] = opBuilder.getI1Type(); 767 break; 768 case spirv::Opcode::OpTypeInt: 769 if (operands.size() != 3) { 770 return emitError( 771 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); 772 } 773 if (operands[2] == 0) { 774 return emitError(unknownLoc, "unhandled unsigned OpTypeInt"); 775 } 776 typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]); 777 break; 778 case spirv::Opcode::OpTypeFloat: { 779 if (operands.size() != 2) { 780 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); 781 } 782 Type floatTy; 783 switch (operands[1]) { 784 case 16: 785 floatTy = opBuilder.getF16Type(); 786 break; 787 case 32: 788 floatTy = opBuilder.getF32Type(); 789 break; 790 case 64: 791 floatTy = opBuilder.getF64Type(); 792 break; 793 default: 794 return emitError(unknownLoc, "unsupported OpTypeFloat bitwdith: ") 795 << operands[1]; 796 } 797 typeMap[operands[0]] = floatTy; 798 } break; 799 case spirv::Opcode::OpTypeVector: { 800 if (operands.size() != 3) { 801 return emitError( 802 unknownLoc, 803 "OpTypeVector must have element type and count parameters"); 804 } 805 Type elementTy = getType(operands[1]); 806 if (!elementTy) { 807 return emitError(unknownLoc, "OpTypeVector references undefined <id> ") 808 << operands[1]; 809 } 810 typeMap[operands[0]] = opBuilder.getVectorType({operands[2]}, elementTy); 811 } break; 812 case spirv::Opcode::OpTypePointer: { 813 if (operands.size() != 3) { 814 return emitError(unknownLoc, "OpTypePointer must have two parameters"); 815 } 816 auto pointeeType = getType(operands[2]); 817 if (!pointeeType) { 818 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ") 819 << operands[2]; 820 } 821 auto storageClass = static_cast<spirv::StorageClass>(operands[1]); 822 typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass); 823 } break; 824 case spirv::Opcode::OpTypeArray: 825 return processArrayType(operands); 826 case spirv::Opcode::OpTypeFunction: 827 return processFunctionType(operands); 828 case spirv::Opcode::OpTypeStruct: 829 return processStructType(operands); 830 default: 831 return emitError(unknownLoc, "unhandled type instruction"); 832 } 833 return success(); 834 } 835 836 LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) { 837 if (operands.size() != 3) { 838 return emitError(unknownLoc, 839 "OpTypeArray must have element type and count parameters"); 840 } 841 842 Type elementTy = getType(operands[1]); 843 if (!elementTy) { 844 return emitError(unknownLoc, "OpTypeArray references undefined <id> ") 845 << operands[1]; 846 } 847 848 unsigned count = 0; 849 // TODO(antiagainst): The count can also come frome a specialization constant. 850 auto countInfo = getConstant(operands[2]); 851 if (!countInfo) { 852 return emitError(unknownLoc, "OpTypeArray count <id> ") 853 << operands[2] << "can only come from normal constant right now"; 854 } 855 856 if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) { 857 count = intVal.getInt(); 858 } else { 859 return emitError(unknownLoc, "OpTypeArray count must come from a " 860 "scalar integer constant instruction"); 861 } 862 863 typeMap[operands[0]] = spirv::ArrayType::get( 864 elementTy, count, typeDecorations.lookup(operands[0])); 865 return success(); 866 } 867 868 LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) { 869 assert(!operands.empty() && "No operands for processing function type"); 870 if (operands.size() == 1) { 871 return emitError(unknownLoc, "missing return type for OpTypeFunction"); 872 } 873 auto returnType = getType(operands[1]); 874 if (!returnType) { 875 return emitError(unknownLoc, "unknown return type in OpTypeFunction"); 876 } 877 SmallVector<Type, 1> argTypes; 878 for (size_t i = 2, e = operands.size(); i < e; ++i) { 879 auto ty = getType(operands[i]); 880 if (!ty) { 881 return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); 882 } 883 argTypes.push_back(ty); 884 } 885 ArrayRef<Type> returnTypes; 886 if (!isVoidType(returnType)) { 887 returnTypes = llvm::makeArrayRef(returnType); 888 } 889 typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context); 890 return success(); 891 } 892 893 LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) { 894 // TODO(ravishankarm) : Regarding to the spec spv.struct must support zero 895 // amount of members. 896 if (operands.size() < 2) { 897 return emitError(unknownLoc, "OpTypeStruct must have at least 2 operand"); 898 } 899 900 SmallVector<Type, 0> memberTypes; 901 for (auto op : llvm::drop_begin(operands, 1)) { 902 Type memberType = getType(op); 903 if (!memberType) { 904 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ") 905 << op; 906 } 907 memberTypes.push_back(memberType); 908 } 909 910 SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo; 911 // Check for layoutinfo 912 auto memberDecorationIt = memberDecorationMap.find(operands[0]); 913 if (memberDecorationIt != memberDecorationMap.end()) { 914 // Each member must have an offset 915 const auto &offsetDecorationMap = memberDecorationIt->second; 916 auto offsetDecorationMapEnd = offsetDecorationMap.end(); 917 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) { 918 // Check that specific member has an offset 919 auto offsetIt = offsetDecorationMap.find(memberIndex); 920 if (offsetIt == offsetDecorationMapEnd) { 921 return emitError(unknownLoc, "OpTypeStruct with <id> ") 922 << operands[0] << " must have an offset for " << memberIndex 923 << "-th member"; 924 } 925 layoutInfo.push_back( 926 static_cast<spirv::StructType::LayoutInfo>(offsetIt->second)); 927 } 928 } 929 typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo); 930 return success(); 931 } 932 933 //===----------------------------------------------------------------------===// 934 // Constant 935 //===----------------------------------------------------------------------===// 936 937 LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands, 938 bool isSpec) { 939 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; 940 941 if (operands.size() < 2) { 942 return emitError(unknownLoc) 943 << opname << " must have type <id> and result <id>"; 944 } 945 if (operands.size() < 3) { 946 return emitError(unknownLoc) 947 << opname << " must have at least 1 more parameter"; 948 } 949 950 Type resultType = getType(operands[0]); 951 if (!resultType) { 952 return emitError(unknownLoc, "undefined result type from <id> ") 953 << operands[0]; 954 } 955 956 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { 957 if (bitwidth == 64) { 958 if (operands.size() == 4) { 959 return success(); 960 } 961 return emitError(unknownLoc) 962 << opname << " should have 2 parameters for 64-bit values"; 963 } 964 if (bitwidth <= 32) { 965 if (operands.size() == 3) { 966 return success(); 967 } 968 969 return emitError(unknownLoc) 970 << opname 971 << " should have 1 parameter for values with no more than 32 bits"; 972 } 973 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") 974 << bitwidth; 975 }; 976 977 auto resultID = operands[1]; 978 979 if (auto intType = resultType.dyn_cast<IntegerType>()) { 980 auto bitwidth = intType.getWidth(); 981 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 982 return failure(); 983 } 984 985 APInt value; 986 if (bitwidth == 64) { 987 // 64-bit integers are represented with two SPIR-V words. According to 988 // SPIR-V spec: "When the type’s bit width is larger than one word, the 989 // literal’s low-order words appear first." 990 struct DoubleWord { 991 uint32_t word1; 992 uint32_t word2; 993 } words = {operands[2], operands[3]}; 994 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true); 995 } else if (bitwidth <= 32) { 996 value = APInt(bitwidth, operands[2], /*isSigned=*/true); 997 } 998 999 auto attr = opBuilder.getIntegerAttr(intType, value); 1000 1001 if (isSpec) { 1002 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1003 auto op = 1004 opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr); 1005 specConstMap[resultID] = op; 1006 } else { 1007 // For normal constants, we just record the attribute (and its type) for 1008 // later materialization at use sites. 1009 constantMap.try_emplace(resultID, attr, intType); 1010 } 1011 1012 return success(); 1013 } 1014 1015 if (auto floatType = resultType.dyn_cast<FloatType>()) { 1016 auto bitwidth = floatType.getWidth(); 1017 if (failed(checkOperandSizeForBitwidth(bitwidth))) { 1018 return failure(); 1019 } 1020 1021 APFloat value(0.f); 1022 if (floatType.isF64()) { 1023 // Double values are represented with two SPIR-V words. According to 1024 // SPIR-V spec: "When the type’s bit width is larger than one word, the 1025 // literal’s low-order words appear first." 1026 struct DoubleWord { 1027 uint32_t word1; 1028 uint32_t word2; 1029 } words = {operands[2], operands[3]}; 1030 value = APFloat(llvm::bit_cast<double>(words)); 1031 } else if (floatType.isF32()) { 1032 value = APFloat(llvm::bit_cast<float>(operands[2])); 1033 } else if (floatType.isF16()) { 1034 APInt data(16, operands[2]); 1035 value = APFloat(APFloat::IEEEhalf(), data); 1036 } 1037 1038 auto attr = opBuilder.getFloatAttr(floatType, value); 1039 if (isSpec) { 1040 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1041 auto op = 1042 opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr); 1043 specConstMap[resultID] = op; 1044 } else { 1045 // For normal constants, we just record the attribute (and its type) for 1046 // later materialization at use sites. 1047 constantMap.try_emplace(resultID, attr, floatType); 1048 } 1049 1050 return success(); 1051 } 1052 1053 return emitError(unknownLoc, "OpConstant can only generate values of " 1054 "scalar integer or floating-point type"); 1055 } 1056 1057 LogicalResult Deserializer::processConstantBool(bool isTrue, 1058 ArrayRef<uint32_t> operands, 1059 bool isSpec) { 1060 if (operands.size() != 2) { 1061 return emitError(unknownLoc, "Op") 1062 << (isSpec ? "Spec" : "") << "Constant" 1063 << (isTrue ? "True" : "False") 1064 << " must have type <id> and result <id>"; 1065 } 1066 1067 auto attr = opBuilder.getBoolAttr(isTrue); 1068 auto resultID = operands[1]; 1069 if (isSpec) { 1070 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); 1071 auto op = 1072 opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr); 1073 specConstMap[resultID] = op; 1074 } else { 1075 // For normal constants, we just record the attribute (and its type) for 1076 // later materialization at use sites. 1077 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); 1078 } 1079 1080 return success(); 1081 } 1082 1083 LogicalResult 1084 Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { 1085 if (operands.size() < 2) { 1086 return emitError(unknownLoc, 1087 "OpConstantComposite must have type <id> and result <id>"); 1088 } 1089 if (operands.size() < 3) { 1090 return emitError(unknownLoc, 1091 "OpConstantComposite must have at least 1 parameter"); 1092 } 1093 1094 Type resultType = getType(operands[0]); 1095 if (!resultType) { 1096 return emitError(unknownLoc, "undefined result type from <id> ") 1097 << operands[0]; 1098 } 1099 1100 SmallVector<Attribute, 4> elements; 1101 elements.reserve(operands.size() - 2); 1102 for (unsigned i = 2, e = operands.size(); i < e; ++i) { 1103 auto elementInfo = getConstant(operands[i]); 1104 if (!elementInfo) { 1105 return emitError(unknownLoc, "OpConstantComposite component <id> ") 1106 << operands[i] << " must come from a normal constant"; 1107 } 1108 elements.push_back(elementInfo->first); 1109 } 1110 1111 auto resultID = operands[1]; 1112 if (auto vectorType = resultType.dyn_cast<VectorType>()) { 1113 auto attr = opBuilder.getDenseElementsAttr(vectorType, elements); 1114 // For normal constants, we just record the attribute (and its type) for 1115 // later materialization at use sites. 1116 constantMap.try_emplace(resultID, attr, resultType); 1117 } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) { 1118 auto attr = opBuilder.getArrayAttr(elements); 1119 constantMap.try_emplace(resultID, attr, resultType); 1120 } else { 1121 return emitError(unknownLoc, "unsupported OpConstantComposite type: ") 1122 << resultType; 1123 } 1124 1125 return success(); 1126 } 1127 1128 LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) { 1129 if (operands.size() != 2) { 1130 return emitError(unknownLoc, 1131 "OpConstantNull must have type <id> and result <id>"); 1132 } 1133 1134 Type resultType = getType(operands[0]); 1135 if (!resultType) { 1136 return emitError(unknownLoc, "undefined result type from <id> ") 1137 << operands[0]; 1138 } 1139 1140 auto resultID = operands[1]; 1141 if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() || 1142 resultType.isa<VectorType>()) { 1143 auto attr = opBuilder.getZeroAttr(resultType); 1144 // For normal constants, we just record the attribute (and its type) for 1145 // later materialization at use sites. 1146 constantMap.try_emplace(resultID, attr, resultType); 1147 return success(); 1148 } 1149 1150 return emitError(unknownLoc, "unsupported OpConstantNull type: ") 1151 << resultType; 1152 } 1153 1154 //===----------------------------------------------------------------------===// 1155 // Control flow 1156 //===----------------------------------------------------------------------===// 1157 1158 LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) { 1159 if (operands.size() != 1) { 1160 return emitError(unknownLoc, "OpLabel should only have result <id>"); 1161 } 1162 // TODO(antiagainst): support basic blocks and control flow properly. 1163 return success(); 1164 } 1165 1166 //===----------------------------------------------------------------------===// 1167 // Instruction 1168 //===----------------------------------------------------------------------===// 1169 1170 Value *Deserializer::getValue(uint32_t id) { 1171 if (auto constInfo = getConstant(id)) { 1172 // Materialize a `spv.constant` op at every use site. 1173 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second, 1174 constInfo->first); 1175 } 1176 if (auto varOp = getGlobalVariable(id)) { 1177 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( 1178 unknownLoc, varOp.type(), 1179 opBuilder.getSymbolRefAttr(varOp.getOperation())); 1180 return addressOfOp.pointer(); 1181 } 1182 if (auto constOp = getSpecConstant(id)) { 1183 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 1184 unknownLoc, constOp.default_value().getType(), 1185 opBuilder.getSymbolRefAttr(constOp.getOperation())); 1186 return referenceOfOp.reference(); 1187 } 1188 return valueMap.lookup(id); 1189 } 1190 1191 LogicalResult 1192 Deserializer::sliceInstruction(spirv::Opcode &opcode, 1193 ArrayRef<uint32_t> &operands, 1194 Optional<spirv::Opcode> expectedOpcode) { 1195 auto binarySize = binary.size(); 1196 if (curOffset >= binarySize) { 1197 return emitError(unknownLoc, "expected ") 1198 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) 1199 : "more") 1200 << " instruction"; 1201 } 1202 1203 // For each instruction, get its word count from the first word to slice it 1204 // from the stream properly, and then dispatch to the instruction handler. 1205 1206 uint32_t wordCount = binary[curOffset] >> 16; 1207 1208 if (wordCount == 0) 1209 return emitError(unknownLoc, "word count cannot be zero"); 1210 1211 uint32_t nextOffset = curOffset + wordCount; 1212 if (nextOffset > binarySize) 1213 return emitError(unknownLoc, "insufficient words for the last instruction"); 1214 1215 opcode = extractOpcode(binary[curOffset]); 1216 operands = binary.slice(curOffset + 1, wordCount - 1); 1217 curOffset = nextOffset; 1218 return success(); 1219 } 1220 1221 Optional<spirv::Opcode> Deserializer::peekOpcode() { 1222 if (curOffset >= binary.size()) 1223 return llvm::None; 1224 return extractOpcode(binary[curOffset]); 1225 } 1226 1227 LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, 1228 ArrayRef<uint32_t> operands, 1229 bool deferInstructions) { 1230 // First dispatch all the instructions whose opcode does not correspond to 1231 // those that have a direct mirror in the SPIR-V dialect 1232 switch (opcode) { 1233 case spirv::Opcode::OpCapability: 1234 return processCapability(operands); 1235 case spirv::Opcode::OpExtension: 1236 return processExtension(operands); 1237 case spirv::Opcode::OpMemoryModel: 1238 return processMemoryModel(operands); 1239 case spirv::Opcode::OpEntryPoint: 1240 case spirv::Opcode::OpExecutionMode: 1241 if (deferInstructions) { 1242 deferedInstructions.emplace_back(opcode, operands); 1243 return success(); 1244 } 1245 break; 1246 case spirv::Opcode::OpVariable: 1247 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { 1248 return processGlobalVariable(operands); 1249 } 1250 break; 1251 case spirv::Opcode::OpName: 1252 return processName(operands); 1253 case spirv::Opcode::OpTypeVoid: 1254 case spirv::Opcode::OpTypeBool: 1255 case spirv::Opcode::OpTypeInt: 1256 case spirv::Opcode::OpTypeFloat: 1257 case spirv::Opcode::OpTypeVector: 1258 case spirv::Opcode::OpTypeArray: 1259 case spirv::Opcode::OpTypeFunction: 1260 case spirv::Opcode::OpTypeStruct: 1261 case spirv::Opcode::OpTypePointer: 1262 return processType(opcode, operands); 1263 case spirv::Opcode::OpConstant: 1264 return processConstant(operands, /*isSpec=*/false); 1265 case spirv::Opcode::OpSpecConstant: 1266 return processConstant(operands, /*isSpec=*/true); 1267 case spirv::Opcode::OpConstantComposite: 1268 return processConstantComposite(operands); 1269 case spirv::Opcode::OpConstantTrue: 1270 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); 1271 case spirv::Opcode::OpSpecConstantTrue: 1272 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); 1273 case spirv::Opcode::OpConstantFalse: 1274 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); 1275 case spirv::Opcode::OpSpecConstantFalse: 1276 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); 1277 case spirv::Opcode::OpConstantNull: 1278 return processConstantNull(operands); 1279 case spirv::Opcode::OpDecorate: 1280 return processDecoration(operands); 1281 case spirv::Opcode::OpMemberDecorate: 1282 return processMemberDecoration(operands); 1283 case spirv::Opcode::OpFunction: 1284 return processFunction(operands); 1285 case spirv::Opcode::OpLabel: 1286 return processLabel(operands); 1287 default: 1288 break; 1289 } 1290 return dispatchToAutogenDeserialization(opcode, operands); 1291 } 1292 1293 namespace { 1294 1295 template <> 1296 LogicalResult 1297 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { 1298 unsigned wordIndex = 0; 1299 if (wordIndex >= words.size()) { 1300 return emitError(unknownLoc, 1301 "missing Execution Model specification in OpEntryPoint"); 1302 } 1303 auto exec_model = opBuilder.getI32IntegerAttr(words[wordIndex++]); 1304 if (wordIndex >= words.size()) { 1305 return emitError(unknownLoc, "missing <id> in OpEntryPoint"); 1306 } 1307 // Get the function <id> 1308 auto fnID = words[wordIndex++]; 1309 // Get the function name 1310 auto fnName = decodeStringLiteral(words, wordIndex); 1311 // Verify that the function <id> matches the fnName 1312 auto parsedFunc = getFunction(fnID); 1313 if (!parsedFunc) { 1314 return emitError(unknownLoc, "no function matching <id> ") << fnID; 1315 } 1316 if (parsedFunc.getName() != fnName) { 1317 return emitError(unknownLoc, "function name mismatch between OpEntryPoint " 1318 "and OpFunction with <id> ") 1319 << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); 1320 } 1321 SmallVector<Attribute, 4> interface; 1322 while (wordIndex < words.size()) { 1323 auto arg = getGlobalVariable(words[wordIndex]); 1324 if (!arg) { 1325 return emitError(unknownLoc, "undefined result <id> ") 1326 << words[wordIndex] << " while decoding OpEntryPoint"; 1327 } 1328 interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); 1329 wordIndex++; 1330 } 1331 opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model, 1332 opBuilder.getSymbolRefAttr(fnName), 1333 opBuilder.getArrayAttr(interface)); 1334 return success(); 1335 } 1336 1337 template <> 1338 LogicalResult 1339 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) { 1340 unsigned wordIndex = 0; 1341 if (wordIndex >= words.size()) { 1342 return emitError(unknownLoc, 1343 "missing function result <id> in OpExecutionMode"); 1344 } 1345 // Get the function <id> to get the name of the function 1346 auto fnID = words[wordIndex++]; 1347 auto fn = getFunction(fnID); 1348 if (!fn) { 1349 return emitError(unknownLoc, "no function matching <id> ") << fnID; 1350 } 1351 // Get the Execution mode 1352 if (wordIndex >= words.size()) { 1353 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); 1354 } 1355 auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); 1356 1357 // Get the values 1358 SmallVector<Attribute, 4> attrListElems; 1359 while (wordIndex < words.size()) { 1360 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); 1361 } 1362 auto values = opBuilder.getArrayAttr(attrListElems); 1363 opBuilder.create<spirv::ExecutionModeOp>( 1364 unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); 1365 return success(); 1366 } 1367 1368 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and 1369 // various Deserializer::processOp<...>() specializations. 1370 #define GET_DESERIALIZATION_FNS 1371 #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" 1372 } // namespace 1373 1374 Optional<spirv::ModuleOp> spirv::deserialize(ArrayRef<uint32_t> binary, 1375 MLIRContext *context) { 1376 Deserializer deserializer(binary, context); 1377 1378 if (failed(deserializer.deserialize())) 1379 return llvm::None; 1380 1381 return deserializer.collect(); 1382 }