github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp (about) 1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines the operations in the SPIR-V dialect. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 23 24 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 25 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/Function.h" 28 #include "mlir/IR/OpImplementation.h" 29 #include "mlir/IR/StandardTypes.h" 30 #include "mlir/Support/StringExtras.h" 31 32 using namespace mlir; 33 34 // TODO(antiagainst): generate these strings using ODS. 35 static constexpr const char kAlignmentAttrName[] = "alignment"; 36 static constexpr const char kBranchWeightAttrName[] = "branch_weights"; 37 static constexpr const char kDefaultValueAttrName[] = "default_value"; 38 static constexpr const char kFnNameAttrName[] = "fn"; 39 static constexpr const char kIndicesAttrName[] = "indices"; 40 static constexpr const char kInitializerAttrName[] = "initializer"; 41 static constexpr const char kInterfaceAttrName[] = "interface"; 42 static constexpr const char kSpecConstAttrName[] = "spec_const"; 43 static constexpr const char kTypeAttrName[] = "type"; 44 static constexpr const char kValueAttrName[] = "value"; 45 static constexpr const char kValuesAttrName[] = "values"; 46 static constexpr const char kVariableAttrName[] = "variable"; 47 48 //===----------------------------------------------------------------------===// 49 // Common utility functions 50 //===----------------------------------------------------------------------===// 51 52 template <typename Dst, typename Src> 53 inline Dst bitwiseCast(Src source) noexcept { 54 Dst dest; 55 static_assert(sizeof(source) == sizeof(dest), 56 "bitwiseCast requires same source and destination bitwidth"); 57 std::memcpy(&dest, &source, sizeof(dest)); 58 return dest; 59 } 60 61 static LogicalResult extractValueFromConstOp(Operation *op, 62 int32_t &indexValue) { 63 auto constOp = dyn_cast<spirv::ConstantOp>(op); 64 if (!constOp) { 65 return failure(); 66 } 67 auto valueAttr = constOp.value(); 68 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>(); 69 if (!integerValueAttr) { 70 return failure(); 71 } 72 indexValue = integerValueAttr.getInt(); 73 return success(); 74 } 75 76 static ParseResult parseBinaryLogicalOp(OpAsmParser *parser, 77 OperationState *result) { 78 SmallVector<OpAsmParser::OperandType, 2> ops; 79 Type type; 80 if (parser->parseOperandList(ops, 2) || parser->parseColonType(type) || 81 parser->resolveOperands(ops, type, result->operands)) { 82 return failure(); 83 } 84 // Result must be a scalar or vector of boolean type. 85 Type resultType = parser->getBuilder().getIntegerType(1); 86 if (auto opsType = type.dyn_cast<VectorType>()) { 87 resultType = VectorType::get(opsType.getNumElements(), resultType); 88 } 89 result->addTypes(resultType); 90 return success(); 91 } 92 93 template <typename EnumClass> 94 static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) { 95 Attribute attrVal; 96 SmallVector<NamedAttribute, 1> attr; 97 auto loc = parser->getCurrentLocation(); 98 if (parser->parseAttribute(attrVal, parser->getBuilder().getNoneType(), 99 spirv::attributeName<EnumClass>(), attr)) { 100 return failure(); 101 } 102 if (!attrVal.isa<StringAttr>()) { 103 return parser->emitError(loc, "expected ") 104 << spirv::attributeName<EnumClass>() 105 << " attribute specified as string"; 106 } 107 auto attrOptional = 108 spirv::symbolizeEnum<EnumClass>()(attrVal.cast<StringAttr>().getValue()); 109 if (!attrOptional) { 110 return parser->emitError(loc, "invalid ") 111 << spirv::attributeName<EnumClass>() 112 << " attribute specification: " << attrVal; 113 } 114 value = attrOptional.getValue(); 115 return success(); 116 } 117 118 template <typename EnumClass> 119 static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser, 120 OperationState *state) { 121 if (parseEnumAttribute(value, parser)) { 122 return failure(); 123 } 124 state->addAttribute( 125 spirv::attributeName<EnumClass>(), 126 parser->getBuilder().getI32IntegerAttr(bitwiseCast<int32_t>(value))); 127 return success(); 128 } 129 130 static ParseResult parseMemoryAccessAttributes(OpAsmParser *parser, 131 OperationState *state) { 132 // Parse an optional list of attributes staring with '[' 133 if (parser->parseOptionalLSquare()) { 134 // Nothing to do 135 return success(); 136 } 137 138 spirv::MemoryAccess memoryAccessAttr; 139 if (parseEnumAttribute(memoryAccessAttr, parser, state)) { 140 return failure(); 141 } 142 143 if (memoryAccessAttr == spirv::MemoryAccess::Aligned) { 144 // Parse integer attribute for alignment. 145 Attribute alignmentAttr; 146 Type i32Type = parser->getBuilder().getIntegerType(32); 147 if (parser->parseComma() || 148 parser->parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName, 149 state->attributes)) { 150 return failure(); 151 } 152 } 153 return parser->parseRSquare(); 154 } 155 156 // Parses an op that has no inputs and no outputs. 157 static ParseResult parseNoIOOp(OpAsmParser *parser, OperationState *state) { 158 if (parser->parseOptionalAttributeDict(state->attributes)) 159 return failure(); 160 return success(); 161 } 162 163 static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter *printer) { 164 *printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", " 165 << *logicalOp->getOperand(1); 166 *printer << " : " << logicalOp->getOperand(0)->getType(); 167 } 168 169 template <typename LoadStoreOpTy> 170 static void 171 printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter *printer, 172 SmallVectorImpl<StringRef> &elidedAttrs) { 173 // Print optional memory access attribute. 174 if (auto memAccess = loadStoreOp.memory_access()) { 175 elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>()); 176 *printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; 177 178 // Print integer alignment attribute. 179 if (auto alignment = loadStoreOp.alignment()) { 180 elidedAttrs.push_back(kAlignmentAttrName); 181 *printer << ", " << alignment; 182 } 183 *printer << "]"; 184 } 185 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); 186 } 187 188 template <typename LoadStoreOpTy> 189 static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) { 190 // ODS checks for attributes values. Just need to verify that if the 191 // memory-access attribute is Aligned, then the alignment attribute must be 192 // present. 193 auto *op = loadStoreOp.getOperation(); 194 auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>()); 195 if (!memAccessAttr) { 196 // Alignment attribute shouldn't be present if memory access attribute is 197 // not present. 198 if (op->getAttr(kAlignmentAttrName)) { 199 return loadStoreOp.emitOpError( 200 "invalid alignment specification without aligned memory access " 201 "specification"); 202 } 203 return success(); 204 } 205 206 auto memAccessVal = memAccessAttr.template cast<IntegerAttr>(); 207 auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt()); 208 209 if (!memAccess) { 210 return loadStoreOp.emitOpError("invalid memory access specifier: ") 211 << memAccessVal; 212 } 213 214 if (*memAccess == spirv::MemoryAccess::Aligned) { 215 if (!op->getAttr(kAlignmentAttrName)) { 216 return loadStoreOp.emitOpError("missing alignment value"); 217 } 218 } else { 219 if (op->getAttr(kAlignmentAttrName)) { 220 return loadStoreOp.emitOpError( 221 "invalid alignment specification with non-aligned memory access " 222 "specification"); 223 } 224 } 225 return success(); 226 } 227 228 template <typename LoadStoreOpTy> 229 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr, 230 Value *val) { 231 // ODS already checks ptr is spirv::PointerType. Just check that the pointee 232 // type of the pointer and the type of the value are the same 233 // 234 // TODO(ravishankarm): Check that the value type satisfies restrictions of 235 // SPIR-V OpLoad/OpStore operations 236 if (val->getType() != 237 ptr->getType().cast<spirv::PointerType>().getPointeeType()) { 238 return op.emitOpError("mismatch in result type and pointer type"); 239 } 240 return success(); 241 } 242 243 // Prints an op that has no inputs and no outputs. 244 static void printNoIOOp(Operation *op, OpAsmPrinter *printer) { 245 *printer << op->getName(); 246 printer->printOptionalAttrDict(op->getAttrs()); 247 } 248 249 static ParseResult parseVariableDecorations(OpAsmParser *parser, 250 OperationState *state) { 251 auto builtInName = 252 convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); 253 if (succeeded(parser->parseOptionalKeyword("bind"))) { 254 Attribute set, binding; 255 // Parse optional descriptor binding 256 auto descriptorSetName = convertToSnakeCase( 257 stringifyDecoration(spirv::Decoration::DescriptorSet)); 258 auto bindingName = 259 convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); 260 Type i32Type = parser->getBuilder().getIntegerType(32); 261 if (parser->parseLParen() || 262 parser->parseAttribute(set, i32Type, descriptorSetName, 263 state->attributes) || 264 parser->parseComma() || 265 parser->parseAttribute(binding, i32Type, bindingName, 266 state->attributes) || 267 parser->parseRParen()) { 268 return failure(); 269 } 270 } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) { 271 StringAttr builtIn; 272 if (parser->parseLParen() || 273 parser->parseAttribute(builtIn, Type(), builtInName, 274 state->attributes) || 275 parser->parseRParen()) { 276 return failure(); 277 } 278 } 279 280 // Parse other attributes 281 if (parser->parseOptionalAttributeDict(state->attributes)) 282 return failure(); 283 284 return success(); 285 } 286 287 static void printVariableDecorations(Operation *op, OpAsmPrinter *printer, 288 SmallVectorImpl<StringRef> &elidedAttrs) { 289 // Print optional descriptor binding 290 auto descriptorSetName = 291 convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); 292 auto bindingName = 293 convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); 294 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName); 295 auto binding = op->getAttrOfType<IntegerAttr>(bindingName); 296 if (descriptorSet && binding) { 297 elidedAttrs.push_back(descriptorSetName); 298 elidedAttrs.push_back(bindingName); 299 *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() 300 << ")"; 301 } 302 303 // Print BuiltIn attribute if present 304 auto builtInName = 305 convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); 306 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) { 307 *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; 308 elidedAttrs.push_back(builtInName); 309 } 310 311 printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // spv.AccessChainOp 316 //===----------------------------------------------------------------------===// 317 318 static Type getElementPtrType(Type type, ArrayRef<Value *> indices, 319 Location baseLoc) { 320 if (indices.empty()) { 321 emitError(baseLoc, "'spv.AccessChain' op expected at least " 322 "one index "); 323 return nullptr; 324 } 325 326 auto ptrType = type.dyn_cast<spirv::PointerType>(); 327 if (!ptrType) { 328 emitError(baseLoc, "'spv.AccessChain' op expected a pointer " 329 "to composite type, but provided ") 330 << type; 331 return nullptr; 332 } 333 334 auto resultType = ptrType.getPointeeType(); 335 auto resultStorageClass = ptrType.getStorageClass(); 336 int32_t index = 0; 337 338 for (auto indexSSA : indices) { 339 auto cType = resultType.dyn_cast<spirv::CompositeType>(); 340 if (!cType) { 341 emitError(baseLoc, 342 "'spv.AccessChain' op cannot extract from non-composite type ") 343 << resultType << " with index " << index; 344 return nullptr; 345 } 346 index = 0; 347 if (resultType.isa<spirv::StructType>()) { 348 Operation *op = indexSSA->getDefiningOp(); 349 if (!op) { 350 emitError(baseLoc, "'spv.AccessChain' op index must be an " 351 "integer spv.constant to access " 352 "element of spv.struct"); 353 return nullptr; 354 } 355 356 // TODO(denis0x0D): this should be relaxed to allow 357 // integer literals of other bitwidths. 358 if (failed(extractValueFromConstOp(op, index))) { 359 emitError(baseLoc, 360 "'spv.AccessChain' index must be an integer spv.constant to " 361 "access element of spv.struct, but provided ") 362 << op->getName(); 363 return nullptr; 364 } 365 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { 366 emitError(baseLoc, "'spv.AccessChain' op index ") 367 << index << " out of bounds for " << resultType; 368 return nullptr; 369 } 370 } 371 resultType = cType.getElementType(index); 372 } 373 return spirv::PointerType::get(resultType, resultStorageClass); 374 } 375 376 void spirv::AccessChainOp::build(Builder *builder, OperationState *state, 377 Value *basePtr, ArrayRef<Value *> indices) { 378 auto type = getElementPtrType(basePtr->getType(), indices, state->location); 379 assert(type && "Unable to deduce return type based on basePtr and indices"); 380 build(builder, state, type, basePtr, indices); 381 } 382 383 static ParseResult parseAccessChainOp(OpAsmParser *parser, 384 OperationState *state) { 385 OpAsmParser::OperandType ptrInfo; 386 SmallVector<OpAsmParser::OperandType, 4> indicesInfo; 387 Type type; 388 // TODO(denis0x0D): regarding to the spec an index must be any integer type, 389 // figure out how to use resolveOperand with a range of types and do not 390 // fail on first attempt. 391 Type indicesType = parser->getBuilder().getIntegerType(32); 392 393 if (parser->parseOperand(ptrInfo) || 394 parser->parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || 395 parser->parseColonType(type) || 396 parser->resolveOperand(ptrInfo, type, state->operands) || 397 parser->resolveOperands(indicesInfo, indicesType, state->operands)) { 398 return failure(); 399 } 400 401 auto resultType = getElementPtrType( 402 type, llvm::makeArrayRef(state->operands).drop_front(), state->location); 403 if (!resultType) { 404 return failure(); 405 } 406 407 state->addTypes(resultType); 408 return success(); 409 } 410 411 static void print(spirv::AccessChainOp op, OpAsmPrinter *printer) { 412 *printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr() 413 << '['; 414 printer->printOperands(op.indices()); 415 *printer << "] : " << op.base_ptr()->getType(); 416 } 417 418 static LogicalResult verify(spirv::AccessChainOp accessChainOp) { 419 SmallVector<Value *, 4> indices(accessChainOp.indices().begin(), 420 accessChainOp.indices().end()); 421 auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(), 422 indices, accessChainOp.getLoc()); 423 if (!resultType) { 424 return failure(); 425 } 426 427 auto providedResultType = 428 accessChainOp.getType().dyn_cast<spirv::PointerType>(); 429 if (!providedResultType) { 430 return accessChainOp.emitOpError( 431 "result type must be a pointer, but provided") 432 << providedResultType; 433 } 434 435 if (resultType != providedResultType) { 436 return accessChainOp.emitOpError("invalid result type: expected ") 437 << resultType << ", but provided " << providedResultType; 438 } 439 440 return success(); 441 } 442 443 //===----------------------------------------------------------------------===// 444 // spv._address_of 445 //===----------------------------------------------------------------------===// 446 447 static ParseResult parseAddressOfOp(OpAsmParser *parser, 448 OperationState *state) { 449 SymbolRefAttr varRefAttr; 450 Type type; 451 if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName, 452 state->attributes) || 453 parser->parseColonType(type)) { 454 return failure(); 455 } 456 auto ptrType = type.dyn_cast<spirv::PointerType>(); 457 if (!ptrType) { 458 return parser->emitError(parser->getCurrentLocation(), 459 "expected spv.ptr type"); 460 } 461 state->addTypes(ptrType); 462 return success(); 463 } 464 465 static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) { 466 SmallVector<StringRef, 4> elidedAttrs; 467 *printer << spirv::AddressOfOp::getOperationName(); 468 469 // Print symbol name. 470 *printer << " @" << addressOfOp.variable(); 471 472 // Print the type. 473 *printer << " : " << addressOfOp.pointer()->getType(); 474 } 475 476 static LogicalResult verify(spirv::AddressOfOp addressOfOp) { 477 auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>(); 478 auto varOp = 479 moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable()); 480 if (!varOp) { 481 return addressOfOp.emitOpError("expected spv.globalVariable symbol"); 482 } 483 if (addressOfOp.pointer()->getType() != varOp.type()) { 484 return addressOfOp.emitOpError( 485 "result type mismatch with the referenced global variable's type"); 486 } 487 return success(); 488 } 489 490 //===----------------------------------------------------------------------===// 491 // spv.BranchOp 492 //===----------------------------------------------------------------------===// 493 494 static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *state) { 495 Block *dest; 496 SmallVector<Value *, 4> destOperands; 497 if (parser->parseSuccessorAndUseList(dest, destOperands)) 498 return failure(); 499 state->addSuccessor(dest, destOperands); 500 return success(); 501 } 502 503 static void print(spirv::BranchOp branchOp, OpAsmPrinter *printer) { 504 *printer << spirv::BranchOp::getOperationName() << ' '; 505 printer->printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0); 506 } 507 508 static LogicalResult verify(spirv::BranchOp branchOp) { 509 auto *op = branchOp.getOperation(); 510 if (op->getNumSuccessors() != 1) 511 branchOp.emitOpError("must have exactly one successor"); 512 513 return success(); 514 } 515 516 //===----------------------------------------------------------------------===// 517 // spv.BranchConditionalOp 518 //===----------------------------------------------------------------------===// 519 520 static ParseResult parseBranchConditionalOp(OpAsmParser *parser, 521 OperationState *state) { 522 auto &builder = parser->getBuilder(); 523 OpAsmParser::OperandType condInfo; 524 Block *dest; 525 SmallVector<Value *, 4> destOperands; 526 527 // Parse the condition. 528 Type boolTy = builder.getI1Type(); 529 if (parser->parseOperand(condInfo) || 530 parser->resolveOperand(condInfo, boolTy, state->operands)) 531 return failure(); 532 533 // Parse the optional branch weights. 534 if (succeeded(parser->parseOptionalLSquare())) { 535 IntegerAttr trueWeight, falseWeight; 536 SmallVector<NamedAttribute, 2> weights; 537 538 auto i32Type = builder.getIntegerType(32); 539 if (parser->parseAttribute(trueWeight, i32Type, "weight", weights) || 540 parser->parseComma() || 541 parser->parseAttribute(falseWeight, i32Type, "weight", weights) || 542 parser->parseRSquare()) 543 return failure(); 544 545 state->addAttribute(kBranchWeightAttrName, 546 builder.getArrayAttr({trueWeight, falseWeight})); 547 } 548 549 // Parse the true branch. 550 if (parser->parseComma() || 551 parser->parseSuccessorAndUseList(dest, destOperands)) 552 return failure(); 553 state->addSuccessor(dest, destOperands); 554 555 // Parse the false branch. 556 destOperands.clear(); 557 if (parser->parseComma() || 558 parser->parseSuccessorAndUseList(dest, destOperands)) 559 return failure(); 560 state->addSuccessor(dest, destOperands); 561 562 return success(); 563 } 564 565 static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter *printer) { 566 *printer << spirv::BranchConditionalOp::getOperationName() << ' '; 567 printer->printOperand(branchOp.condition()); 568 569 if (auto weights = branchOp.branch_weights()) { 570 *printer << " ["; 571 mlir::interleaveComma( 572 weights->getValue(), printer->getStream(), 573 [&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); }); 574 *printer << "]"; 575 } 576 577 *printer << ", "; 578 printer->printSuccessorAndUseList(branchOp.getOperation(), 579 spirv::BranchConditionalOp::kTrueIndex); 580 *printer << ", "; 581 printer->printSuccessorAndUseList(branchOp.getOperation(), 582 spirv::BranchConditionalOp::kFalseIndex); 583 } 584 585 static LogicalResult verify(spirv::BranchConditionalOp branchOp) { 586 auto *op = branchOp.getOperation(); 587 if (op->getNumSuccessors() != 2) 588 return branchOp.emitOpError("must have exactly two successors"); 589 590 if (auto weights = branchOp.branch_weights()) { 591 if (weights->getValue().size() != 2) { 592 return branchOp.emitOpError("must have exactly two branch weights"); 593 } 594 if (llvm::all_of(*weights, [](Attribute attr) { 595 return attr.cast<IntegerAttr>().getValue().isNullValue(); 596 })) 597 return branchOp.emitOpError("branch weights cannot both be zero"); 598 } 599 600 return success(); 601 } 602 603 //===----------------------------------------------------------------------===// 604 // spv.CompositeExtractOp 605 //===----------------------------------------------------------------------===// 606 607 static ParseResult parseCompositeExtractOp(OpAsmParser *parser, 608 OperationState *state) { 609 OpAsmParser::OperandType compositeInfo; 610 Attribute indicesAttr; 611 Type compositeType; 612 llvm::SMLoc attrLocation; 613 int32_t index; 614 615 if (parser->parseOperand(compositeInfo) || 616 parser->getCurrentLocation(&attrLocation) || 617 parser->parseAttribute(indicesAttr, kIndicesAttrName, 618 state->attributes) || 619 parser->parseColonType(compositeType) || 620 parser->resolveOperand(compositeInfo, compositeType, state->operands)) { 621 return failure(); 622 } 623 624 auto indicesArrayAttr = indicesAttr.dyn_cast<ArrayAttr>(); 625 if (!indicesArrayAttr) { 626 return parser->emitError( 627 attrLocation, 628 "expected an 32-bit integer array attribute for 'indices'"); 629 } 630 631 if (!indicesArrayAttr.size()) { 632 return parser->emitError( 633 attrLocation, "expected at least one index for spv.CompositeExtract"); 634 } 635 636 Type resultType = compositeType; 637 for (auto indexAttr : indicesArrayAttr) { 638 if (auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>()) { 639 index = indexIntAttr.getInt(); 640 } else { 641 return parser->emitError( 642 attrLocation, 643 "expexted an 32-bit integer for index, but found '") 644 << indexAttr << "'"; 645 } 646 647 if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) { 648 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { 649 return parser->emitError(attrLocation, "index ") 650 << index << " out of bounds for " << resultType; 651 } 652 resultType = cType.getElementType(index); 653 } else { 654 return parser->emitError(attrLocation, 655 "cannot extract from non-composite type ") 656 << resultType << " with index " << index; 657 } 658 } 659 660 state->addTypes(resultType); 661 return success(); 662 } 663 664 static void print(spirv::CompositeExtractOp compositeExtractOp, 665 OpAsmPrinter *printer) { 666 *printer << spirv::CompositeExtractOp::getOperationName() << ' ' 667 << *compositeExtractOp.composite() << compositeExtractOp.indices() 668 << " : " << compositeExtractOp.composite()->getType(); 669 } 670 671 static LogicalResult verify(spirv::CompositeExtractOp compExOp) { 672 auto resultType = compExOp.composite()->getType(); 673 auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>(); 674 675 if (!indicesArrayAttr.size()) { 676 return compExOp.emitOpError( 677 "expexted at least one index for spv.CompositeExtractOp"); 678 } 679 680 int32_t index; 681 for (auto indexAttr : indicesArrayAttr) { 682 index = indexAttr.dyn_cast<IntegerAttr>().getInt(); 683 if (auto cType = resultType.dyn_cast<spirv::CompositeType>()) { 684 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { 685 return compExOp.emitOpError("index ") 686 << index << " out of bounds for " << resultType; 687 } 688 resultType = cType.getElementType(index); 689 } else { 690 return compExOp.emitError("cannot extract from non-composite type ") 691 << resultType << " with index " << index; 692 } 693 } 694 695 if (resultType != compExOp.getType()) { 696 return compExOp.emitOpError("invalid result type: expected ") 697 << resultType << " but provided " << compExOp.getType(); 698 } 699 700 return success(); 701 } 702 703 //===----------------------------------------------------------------------===// 704 // spv.constant 705 //===----------------------------------------------------------------------===// 706 707 static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) { 708 Attribute value; 709 if (parser->parseAttribute(value, kValueAttrName, state->attributes)) 710 return failure(); 711 712 Type type; 713 if (value.getType().isa<NoneType>()) { 714 if (parser->parseColonType(type)) 715 return failure(); 716 } else { 717 type = value.getType(); 718 } 719 720 return parser->addTypeToList(type, state->types); 721 } 722 723 static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) { 724 *printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value(); 725 if (constOp.getType().isa<spirv::ArrayType>()) { 726 *printer << " : " << constOp.getType(); 727 } 728 } 729 730 static LogicalResult verify(spirv::ConstantOp constOp) { 731 auto opType = constOp.getType(); 732 auto value = constOp.value(); 733 auto valueType = value.getType(); 734 735 // ODS already generates checks to make sure the result type is valid. We just 736 // need to additionally check that the value's attribute type is consistent 737 // with the result type. 738 switch (value.getKind()) { 739 case StandardAttributes::Bool: 740 case StandardAttributes::Integer: 741 case StandardAttributes::Float: 742 case StandardAttributes::DenseElements: 743 case StandardAttributes::SparseElements: { 744 if (valueType != opType) 745 return constOp.emitOpError("result type (") 746 << opType << ") does not match value type (" << valueType << ")"; 747 return success(); 748 } break; 749 case StandardAttributes::Array: { 750 auto arrayType = opType.dyn_cast<spirv::ArrayType>(); 751 if (!arrayType) 752 return constOp.emitOpError( 753 "must have spv.array result type for array value"); 754 auto elemType = arrayType.getElementType(); 755 for (auto element : value.cast<ArrayAttr>().getValue()) { 756 if (element.getType() != elemType) 757 return constOp.emitOpError( 758 "has array element that are not of result array element type"); 759 } 760 } break; 761 default: 762 return constOp.emitOpError("cannot have value of type ") << valueType; 763 } 764 765 return success(); 766 } 767 768 //===----------------------------------------------------------------------===// 769 // spv.EntryPoint 770 //===----------------------------------------------------------------------===// 771 772 static ParseResult parseEntryPointOp(OpAsmParser *parser, 773 OperationState *state) { 774 spirv::ExecutionModel execModel; 775 SmallVector<OpAsmParser::OperandType, 0> identifiers; 776 SmallVector<Type, 0> idTypes; 777 778 SymbolRefAttr fn; 779 if (parseEnumAttribute(execModel, parser, state) || 780 parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) { 781 return failure(); 782 } 783 784 if (!parser->parseOptionalComma()) { 785 // Parse the interface variables 786 SmallVector<Attribute, 4> interfaceVars; 787 do { 788 // The name of the interface variable attribute isnt important 789 auto attrName = "var_symbol"; 790 SymbolRefAttr var; 791 SmallVector<NamedAttribute, 1> attrs; 792 if (parser->parseAttribute(var, Type(), attrName, attrs)) { 793 return failure(); 794 } 795 interfaceVars.push_back(var); 796 } while (!parser->parseOptionalComma()); 797 state->addAttribute(kInterfaceAttrName, 798 parser->getBuilder().getArrayAttr(interfaceVars)); 799 } 800 return success(); 801 } 802 803 static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) { 804 *printer << spirv::EntryPointOp::getOperationName() << " \"" 805 << stringifyExecutionModel(entryPointOp.execution_model()) << "\" @" 806 << entryPointOp.fn(); 807 if (auto interface = entryPointOp.interface()) { 808 *printer << ", "; 809 mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(), 810 [&](Attribute a) { printer->printAttribute(a); }); 811 } 812 } 813 814 static LogicalResult verify(spirv::EntryPointOp entryPointOp) { 815 // Checks for fn and interface symbol reference are done in spirv::ModuleOp 816 // verification. 817 return success(); 818 } 819 820 //===----------------------------------------------------------------------===// 821 // spv.ExecutionMode 822 //===----------------------------------------------------------------------===// 823 824 static ParseResult parseExecutionModeOp(OpAsmParser *parser, 825 OperationState *state) { 826 spirv::ExecutionMode execMode; 827 Attribute fn; 828 if (parser->parseAttribute(fn, kFnNameAttrName, state->attributes) || 829 parseEnumAttribute(execMode, parser, state)) { 830 return failure(); 831 } 832 833 SmallVector<int32_t, 4> values; 834 Type i32Type = parser->getBuilder().getIntegerType(32); 835 while (!parser->parseOptionalComma()) { 836 SmallVector<NamedAttribute, 1> attr; 837 Attribute value; 838 if (parser->parseAttribute(value, i32Type, "value", attr)) { 839 return failure(); 840 } 841 values.push_back(value.cast<IntegerAttr>().getInt()); 842 } 843 state->addAttribute(kValuesAttrName, 844 parser->getBuilder().getI32ArrayAttr(values)); 845 return success(); 846 } 847 848 static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) { 849 *printer << spirv::ExecutionModeOp::getOperationName() << " @" 850 << execModeOp.fn() << " \"" 851 << stringifyExecutionMode(execModeOp.execution_mode()) << "\""; 852 auto values = execModeOp.values(); 853 if (!values) { 854 return; 855 } 856 *printer << ", "; 857 mlir::interleaveComma( 858 values.getValue().cast<ArrayAttr>(), printer->getStream(), 859 [&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); }); 860 } 861 862 //===----------------------------------------------------------------------===// 863 // spv.globalVariable 864 //===----------------------------------------------------------------------===// 865 866 static ParseResult parseGlobalVariableOp(OpAsmParser *parser, 867 OperationState *state) { 868 // Parse variable name. 869 StringAttr nameAttr; 870 if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 871 state->attributes)) { 872 return failure(); 873 } 874 875 // Parse optional initializer 876 if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) { 877 SymbolRefAttr initSymbol; 878 if (parser->parseLParen() || 879 parser->parseAttribute(initSymbol, Type(), kInitializerAttrName, 880 state->attributes) || 881 parser->parseRParen()) 882 return failure(); 883 } 884 885 if (parseVariableDecorations(parser, state)) { 886 return failure(); 887 } 888 889 Type type; 890 auto loc = parser->getCurrentLocation(); 891 if (parser->parseColonType(type)) { 892 return failure(); 893 } 894 if (!type.isa<spirv::PointerType>()) { 895 return parser->emitError(loc, "expected spv.ptr type"); 896 } 897 state->addAttribute(kTypeAttrName, parser->getBuilder().getTypeAttr(type)); 898 899 return success(); 900 } 901 902 static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) { 903 auto *op = varOp.getOperation(); 904 SmallVector<StringRef, 4> elidedAttrs{ 905 spirv::attributeName<spirv::StorageClass>()}; 906 *printer << spirv::GlobalVariableOp::getOperationName(); 907 908 // Print variable name. 909 *printer << " @" << varOp.sym_name(); 910 elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); 911 912 // Print optional initializer 913 if (auto initializer = varOp.initializer()) { 914 *printer << " " << kInitializerAttrName << "(@" << initializer.getValue() 915 << ")"; 916 elidedAttrs.push_back(kInitializerAttrName); 917 } 918 919 elidedAttrs.push_back(kTypeAttrName); 920 printVariableDecorations(op, printer, elidedAttrs); 921 *printer << " : " << varOp.type(); 922 } 923 924 static LogicalResult verify(spirv::GlobalVariableOp varOp) { 925 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the 926 // object. It cannot be Generic. It must be the same as the Storage Class 927 // operand of the Result Type." 928 if (varOp.storageClass() == spirv::StorageClass::Generic) 929 return varOp.emitOpError("storage class cannot be 'Generic'"); 930 931 if (auto init = varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) { 932 auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>(); 933 auto *initOp = moduleOp.lookupSymbol(init.getValue()); 934 // TODO: Currently only variable initialization with specialization 935 // constants and other variables is supported. They could be normal 936 // constants in the module scope as well. 937 if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) || 938 isa<spirv::SpecConstantOp>(initOp))) { 939 return varOp.emitOpError("initializer must be result of a " 940 "spv.specConstant or spv.globalVariable op"); 941 } 942 } 943 944 return success(); 945 } 946 947 //===----------------------------------------------------------------------===// 948 // spv.LoadOp 949 //===----------------------------------------------------------------------===// 950 951 static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) { 952 // Parse the storage class specification 953 spirv::StorageClass storageClass; 954 OpAsmParser::OperandType ptrInfo; 955 Type elementType; 956 if (parseEnumAttribute(storageClass, parser) || 957 parser->parseOperand(ptrInfo) || 958 parseMemoryAccessAttributes(parser, state) || 959 parser->parseOptionalAttributeDict(state->attributes) || 960 parser->parseColon() || parser->parseType(elementType)) { 961 return failure(); 962 } 963 964 auto ptrType = spirv::PointerType::get(elementType, storageClass); 965 if (parser->resolveOperand(ptrInfo, ptrType, state->operands)) { 966 return failure(); 967 } 968 969 state->addTypes(elementType); 970 return success(); 971 } 972 973 static void print(spirv::LoadOp loadOp, OpAsmPrinter *printer) { 974 auto *op = loadOp.getOperation(); 975 SmallVector<StringRef, 4> elidedAttrs; 976 StringRef sc = stringifyStorageClass( 977 loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); 978 *printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "; 979 // Print the pointer operand. 980 printer->printOperand(loadOp.ptr()); 981 982 printMemoryAccessAttribute(loadOp, printer, elidedAttrs); 983 984 printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); 985 *printer << " : " << loadOp.getType(); 986 } 987 988 static LogicalResult verify(spirv::LoadOp loadOp) { 989 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a 990 // type with fixed size; i.e., it cannot be, nor include, any 991 // OpTypeRuntimeArray types." 992 if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(), 993 loadOp.value()))) { 994 return failure(); 995 } 996 return verifyMemoryAccessAttribute(loadOp); 997 } 998 999 //===----------------------------------------------------------------------===// 1000 // spv.module 1001 //===----------------------------------------------------------------------===// 1002 1003 void spirv::ModuleOp::build(Builder *builder, OperationState *state) { 1004 ensureTerminator(*state->addRegion(), *builder, state->location); 1005 } 1006 1007 void spirv::ModuleOp::build(Builder *builder, OperationState *state, 1008 IntegerAttr addressing_model, 1009 IntegerAttr memory_model, ArrayAttr capabilities, 1010 ArrayAttr extensions, 1011 ArrayAttr extended_instruction_sets) { 1012 state->addAttribute("addressing_model", addressing_model); 1013 state->addAttribute("memory_model", memory_model); 1014 if (capabilities) 1015 state->addAttribute("capabilities", capabilities); 1016 if (extensions) 1017 state->addAttribute("extensions", extensions); 1018 if (extended_instruction_sets) 1019 state->addAttribute("extended_instruction_sets", extended_instruction_sets); 1020 ensureTerminator(*state->addRegion(), *builder, state->location); 1021 } 1022 1023 static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) { 1024 Region *body = state->addRegion(); 1025 1026 // Parse attributes 1027 spirv::AddressingModel addrModel; 1028 spirv::MemoryModel memoryModel; 1029 if (parseEnumAttribute(addrModel, parser, state) || 1030 parseEnumAttribute(memoryModel, parser, state)) { 1031 return failure(); 1032 } 1033 1034 if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 1035 return failure(); 1036 1037 if (succeeded(parser->parseOptionalKeyword("attributes"))) { 1038 if (parser->parseOptionalAttributeDict(state->attributes)) 1039 return failure(); 1040 } 1041 1042 spirv::ModuleOp::ensureTerminator(*body, parser->getBuilder(), 1043 state->location); 1044 return success(); 1045 } 1046 1047 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) { 1048 auto *op = moduleOp.getOperation(); 1049 1050 // Only print out addressing model and memory model in a nicer way if both 1051 // presents. Otherwise, print them in the general form. This helps debugging 1052 // ill-formed ModuleOp. 1053 SmallVector<StringRef, 2> elidedAttrs; 1054 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>(); 1055 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>(); 1056 if (op->getAttr(addressingModelAttrName) && 1057 op->getAttr(memoryModelAttrName)) { 1058 *printer << spirv::ModuleOp::getOperationName() << " \"" 1059 << spirv::stringifyAddressingModel(moduleOp.addressing_model()) 1060 << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model()) 1061 << '"'; 1062 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName}); 1063 } 1064 1065 printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, 1066 /*printBlockTerminators=*/false); 1067 1068 bool printAttrDict = 1069 elidedAttrs.size() != 2 || 1070 llvm::any_of(op->getAttrs(), [&addressingModelAttrName, 1071 &memoryModelAttrName](NamedAttribute attr) { 1072 return attr.first != addressingModelAttrName && 1073 attr.first != memoryModelAttrName; 1074 }); 1075 1076 if (printAttrDict) { 1077 *printer << " attributes"; 1078 printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); 1079 } 1080 } 1081 1082 static LogicalResult verify(spirv::ModuleOp moduleOp) { 1083 auto &op = *moduleOp.getOperation(); 1084 auto *dialect = op.getDialect(); 1085 auto &body = op.getRegion(0).front(); 1086 llvm::DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp> 1087 entryPoints; 1088 SymbolTable table(moduleOp); 1089 1090 for (auto &op : body) { 1091 if (op.getDialect() == dialect) { 1092 // For EntryPoint op, check that the function and execution model is not 1093 // duplicated in EntryPointOps. Also verify that the interface specified 1094 // comes from globalVariables here to make this check cheaper. 1095 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) { 1096 auto funcOp = table.lookup<FuncOp>(entryPointOp.fn()); 1097 if (!funcOp) { 1098 return entryPointOp.emitError("function '") 1099 << entryPointOp.fn() << "' not found in 'spv.module'"; 1100 } 1101 if (auto interface = entryPointOp.interface()) { 1102 for (auto varRef : interface.getValue().getValue()) { 1103 auto varSymRef = varRef.dyn_cast<SymbolRefAttr>(); 1104 if (!varSymRef) { 1105 return entryPointOp.emitError( 1106 "expected symbol reference for interface " 1107 "specification instead of '") 1108 << varRef; 1109 } 1110 auto variableOp = 1111 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue()); 1112 if (!variableOp) { 1113 return entryPointOp.emitError("expected spv.globalVariable " 1114 "symbol reference instead of'") 1115 << varSymRef << "'"; 1116 } 1117 } 1118 } 1119 1120 auto key = std::pair<FuncOp, spirv::ExecutionModel>( 1121 funcOp, entryPointOp.execution_model()); 1122 auto entryPtIt = entryPoints.find(key); 1123 if (entryPtIt != entryPoints.end()) { 1124 return entryPointOp.emitError("duplicate of a previous EntryPointOp"); 1125 } 1126 entryPoints[key] = entryPointOp; 1127 } 1128 continue; 1129 } 1130 1131 auto funcOp = dyn_cast<FuncOp>(op); 1132 if (!funcOp) 1133 return op.emitError("'spv.module' can only contain func and spv.* ops"); 1134 1135 if (funcOp.isExternal()) 1136 return op.emitError("'spv.module' cannot contain external functions"); 1137 1138 for (auto &block : funcOp) 1139 for (auto &op : block) { 1140 if (op.getDialect() == dialect) 1141 continue; 1142 1143 if (isa<FuncOp>(op)) 1144 return op.emitError("'spv.module' cannot contain nested functions"); 1145 1146 return op.emitError( 1147 "functions in 'spv.module' can only contain spv.* ops"); 1148 } 1149 } 1150 1151 // Verify capabilities. ODS already guarantees that we have an array of 1152 // string attributes. 1153 if (auto caps = moduleOp.getAttrOfType<ArrayAttr>("capabilities")) { 1154 for (auto cap : caps.getValue()) { 1155 auto capStr = cap.cast<StringAttr>().getValue(); 1156 if (!spirv::symbolizeCapability(capStr)) 1157 return moduleOp.emitOpError("uses unknown capability: ") << capStr; 1158 } 1159 } 1160 1161 // Verify extensions. ODS already guarantees that we have an array of 1162 // string attributes. 1163 if (auto exts = moduleOp.getAttrOfType<ArrayAttr>("extensions")) { 1164 for (auto ext : exts.getValue()) { 1165 auto extStr = ext.cast<StringAttr>().getValue(); 1166 if (!spirv::symbolizeExtension(extStr)) 1167 return moduleOp.emitOpError("uses unknown extension: ") << extStr; 1168 } 1169 } 1170 1171 return success(); 1172 } 1173 1174 //===----------------------------------------------------------------------===// 1175 // spv._reference_of 1176 //===----------------------------------------------------------------------===// 1177 1178 static ParseResult parseReferenceOfOp(OpAsmParser *parser, 1179 OperationState *state) { 1180 SymbolRefAttr constRefAttr; 1181 Type type; 1182 if (parser->parseAttribute(constRefAttr, Type(), kSpecConstAttrName, 1183 state->attributes) || 1184 parser->parseColonType(type)) { 1185 return failure(); 1186 } 1187 return parser->addTypeToList(type, state->types); 1188 } 1189 1190 static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter *printer) { 1191 *printer << spirv::ReferenceOfOp::getOperationName() << " @" 1192 << referenceOfOp.spec_const() << " : " 1193 << referenceOfOp.reference()->getType(); 1194 } 1195 1196 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { 1197 auto moduleOp = referenceOfOp.getParentOfType<spirv::ModuleOp>(); 1198 auto specConstOp = 1199 moduleOp.lookupSymbol<spirv::SpecConstantOp>(referenceOfOp.spec_const()); 1200 if (!specConstOp) { 1201 return referenceOfOp.emitOpError("expected spv.specConstant symbol"); 1202 } 1203 if (referenceOfOp.reference()->getType() != 1204 specConstOp.default_value().getType()) { 1205 return referenceOfOp.emitOpError("result type mismatch with the referenced " 1206 "specialization constant's type"); 1207 } 1208 return success(); 1209 } 1210 1211 //===----------------------------------------------------------------------===// 1212 // spv.Return 1213 //===----------------------------------------------------------------------===// 1214 1215 static LogicalResult verify(spirv::ReturnOp returnOp) { 1216 auto funcOp = cast<FuncOp>(returnOp.getParentOp()); 1217 auto numOutputs = funcOp.getType().getNumResults(); 1218 if (numOutputs != 0) 1219 return returnOp.emitOpError("cannot be used in functions returning value") 1220 << (numOutputs > 1 ? "s" : ""); 1221 1222 return success(); 1223 } 1224 1225 //===----------------------------------------------------------------------===// 1226 // spv.ReturnValue 1227 //===----------------------------------------------------------------------===// 1228 1229 static ParseResult parseReturnValueOp(OpAsmParser *parser, 1230 OperationState *state) { 1231 OpAsmParser::OperandType retValInfo; 1232 Type retValType; 1233 return failure( 1234 parser->parseOperand(retValInfo) || parser->parseColonType(retValType) || 1235 parser->resolveOperand(retValInfo, retValType, state->operands)); 1236 } 1237 1238 static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) { 1239 *printer << spirv::ReturnValueOp::getOperationName() << ' '; 1240 printer->printOperand(retValOp.value()); 1241 *printer << " : " << retValOp.value()->getType(); 1242 } 1243 1244 static LogicalResult verify(spirv::ReturnValueOp retValOp) { 1245 auto funcOp = cast<FuncOp>(retValOp.getParentOp()); 1246 auto numFnResults = funcOp.getType().getNumResults(); 1247 if (numFnResults != 1) 1248 return retValOp.emitOpError( 1249 "returns 1 value but enclosing function requires ") 1250 << numFnResults << " results"; 1251 1252 auto operandType = retValOp.value()->getType(); 1253 auto fnResultType = funcOp.getType().getResult(0); 1254 if (operandType != fnResultType) 1255 return retValOp.emitOpError(" return value's type (") 1256 << operandType << ") mismatch with function's result type (" 1257 << fnResultType << ")"; 1258 1259 return success(); 1260 } 1261 1262 //===----------------------------------------------------------------------===// 1263 // spv.specConstant 1264 //===----------------------------------------------------------------------===// 1265 1266 static ParseResult parseSpecConstantOp(OpAsmParser *parser, 1267 OperationState *state) { 1268 StringAttr nameAttr; 1269 Attribute valueAttr; 1270 1271 if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 1272 state->attributes) || 1273 parser->parseEqual() || 1274 parser->parseAttribute(valueAttr, kDefaultValueAttrName, 1275 state->attributes)) 1276 return failure(); 1277 1278 return success(); 1279 } 1280 1281 static void print(spirv::SpecConstantOp constOp, OpAsmPrinter *printer) { 1282 *printer << spirv::SpecConstantOp::getOperationName() << " @" 1283 << constOp.sym_name() << " = "; 1284 printer->printAttribute(constOp.default_value()); 1285 } 1286 1287 static LogicalResult verify(spirv::SpecConstantOp constOp) { 1288 auto value = constOp.default_value(); 1289 1290 switch (value.getKind()) { 1291 case StandardAttributes::Bool: 1292 case StandardAttributes::Integer: 1293 case StandardAttributes::Float: { 1294 // Make sure bitwidth is allowed. 1295 auto *dialect = static_cast<spirv::SPIRVDialect *>(constOp.getDialect()); 1296 if (!dialect->isValidSPIRVType(value.getType())) 1297 return constOp.emitOpError("default value bitwidth disallowed"); 1298 return success(); 1299 } 1300 default: 1301 return constOp.emitOpError( 1302 "default value can only be a bool, integer, or float scalar"); 1303 } 1304 } 1305 1306 //===----------------------------------------------------------------------===// 1307 // spv.StoreOp 1308 //===----------------------------------------------------------------------===// 1309 1310 static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) { 1311 // Parse the storage class specification 1312 spirv::StorageClass storageClass; 1313 SmallVector<OpAsmParser::OperandType, 2> operandInfo; 1314 auto loc = parser->getCurrentLocation(); 1315 Type elementType; 1316 if (parseEnumAttribute(storageClass, parser) || 1317 parser->parseOperandList(operandInfo, 2) || 1318 parseMemoryAccessAttributes(parser, state) || parser->parseColon() || 1319 parser->parseType(elementType)) { 1320 return failure(); 1321 } 1322 1323 auto ptrType = spirv::PointerType::get(elementType, storageClass); 1324 if (parser->resolveOperands(operandInfo, {ptrType, elementType}, loc, 1325 state->operands)) { 1326 return failure(); 1327 } 1328 return success(); 1329 } 1330 1331 static void print(spirv::StoreOp storeOp, OpAsmPrinter *printer) { 1332 auto *op = storeOp.getOperation(); 1333 SmallVector<StringRef, 4> elidedAttrs; 1334 StringRef sc = stringifyStorageClass( 1335 storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass()); 1336 *printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "; 1337 // Print the pointer operand 1338 printer->printOperand(storeOp.ptr()); 1339 *printer << ", "; 1340 // Print the value operand 1341 printer->printOperand(storeOp.value()); 1342 1343 printMemoryAccessAttribute(storeOp, printer, elidedAttrs); 1344 1345 *printer << " : " << storeOp.value()->getType(); 1346 1347 printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); 1348 } 1349 1350 static LogicalResult verify(spirv::StoreOp storeOp) { 1351 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an 1352 // OpTypePointer whose Type operand is the same as the type of Object." 1353 if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(), 1354 storeOp.value()))) { 1355 return failure(); 1356 } 1357 return verifyMemoryAccessAttribute(storeOp); 1358 } 1359 1360 //===----------------------------------------------------------------------===// 1361 // spv.Variable 1362 //===----------------------------------------------------------------------===// 1363 1364 static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { 1365 // Parse optional initializer 1366 Optional<OpAsmParser::OperandType> initInfo; 1367 if (succeeded(parser->parseOptionalKeyword("init"))) { 1368 initInfo = OpAsmParser::OperandType(); 1369 if (parser->parseLParen() || parser->parseOperand(*initInfo) || 1370 parser->parseRParen()) 1371 return failure(); 1372 } 1373 1374 if (parseVariableDecorations(parser, state)) { 1375 return failure(); 1376 } 1377 1378 // Parse result pointer type 1379 Type type; 1380 if (parser->parseColon()) 1381 return failure(); 1382 auto loc = parser->getCurrentLocation(); 1383 if (parser->parseType(type)) 1384 return failure(); 1385 1386 auto ptrType = type.dyn_cast<spirv::PointerType>(); 1387 if (!ptrType) 1388 return parser->emitError(loc, "expected spv.ptr type"); 1389 state->addTypes(ptrType); 1390 1391 // Resolve the initializer operand 1392 SmallVector<Value *, 1> init; 1393 if (initInfo) { 1394 if (parser->resolveOperand(*initInfo, ptrType.getPointeeType(), init)) 1395 return failure(); 1396 state->addOperands(init); 1397 } 1398 1399 auto attr = parser->getBuilder().getI32IntegerAttr( 1400 bitwiseCast<int32_t>(ptrType.getStorageClass())); 1401 state->addAttribute(spirv::attributeName<spirv::StorageClass>(), attr); 1402 1403 return success(); 1404 } 1405 1406 static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) { 1407 auto *op = varOp.getOperation(); 1408 SmallVector<StringRef, 4> elidedAttrs{ 1409 spirv::attributeName<spirv::StorageClass>()}; 1410 *printer << spirv::VariableOp::getOperationName(); 1411 1412 // Print optional initializer 1413 if (op->getNumOperands() > 0) { 1414 *printer << " init("; 1415 printer->printOperands(varOp.initializer()); 1416 *printer << ")"; 1417 } 1418 1419 printVariableDecorations(op, printer, elidedAttrs); 1420 1421 *printer << " : " << varOp.getType(); 1422 } 1423 1424 static LogicalResult verify(spirv::VariableOp varOp) { 1425 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the 1426 // object. It cannot be Generic. It must be the same as the Storage Class 1427 // operand of the Result Type." 1428 if (varOp.storage_class() != spirv::StorageClass::Function) { 1429 return varOp.emitOpError( 1430 "can only be used to model function-level variables. Use " 1431 "spv.globalVariable for module-level variables."); 1432 } 1433 1434 auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>(); 1435 if (varOp.storage_class() != pointerType.getStorageClass()) 1436 return varOp.emitOpError( 1437 "storage class must match result pointer's storage class"); 1438 1439 if (varOp.getNumOperands() != 0) { 1440 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or 1441 // a global (module scope) OpVariable instruction". 1442 auto *initOp = varOp.getOperand(0)->getDefiningOp(); 1443 if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant 1444 isa<spirv::ReferenceOfOp>(initOp) || // for spec constant 1445 isa<spirv::AddressOfOp>(initOp))) 1446 return varOp.emitOpError("initializer must be the result of a " 1447 "constant or spv.globalVariable op"); 1448 } 1449 1450 // TODO(antiagainst): generate these strings using ODS. 1451 auto *op = varOp.getOperation(); 1452 auto descriptorSetName = 1453 convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); 1454 auto bindingName = 1455 convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); 1456 auto builtInName = 1457 convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); 1458 1459 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { 1460 if (op->getAttr(attr)) 1461 return varOp.emitOpError("cannot have '") 1462 << attr << "' attribute (only allowed in spv.globalVariable)"; 1463 } 1464 1465 return success(); 1466 } 1467 1468 namespace mlir { 1469 namespace spirv { 1470 1471 #define GET_OP_CLASSES 1472 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc" 1473 1474 } // namespace spirv 1475 } // namespace mlir