github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp (about) 1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the SPIR-V dialect in MLIR. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 14 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 15 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 16 #include "mlir/IR/MLIRContext.h" 17 #include "mlir/IR/StandardTypes.h" 18 #include "mlir/Parser.h" 19 #include "llvm/ADT/DenseMap.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringMap.h" 23 #include "llvm/ADT/StringSwitch.h" 24 #include "llvm/Support/raw_ostream.h" 25 26 namespace mlir { 27 namespace spirv { 28 #include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc" 29 } // namespace spirv 30 } // namespace mlir 31 32 using namespace mlir; 33 using namespace mlir::spirv; 34 35 //===----------------------------------------------------------------------===// 36 // SPIR-V Dialect 37 //===----------------------------------------------------------------------===// 38 39 SPIRVDialect::SPIRVDialect(MLIRContext *context) 40 : Dialect(getDialectNamespace(), context) { 41 addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>(); 42 43 addOperations< 44 #define GET_OP_LIST 45 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc" 46 >(); 47 48 // Allow unknown operations because SPIR-V is extensible. 49 allowUnknownOperations(); 50 } 51 52 //===----------------------------------------------------------------------===// 53 // Type Parsing 54 //===----------------------------------------------------------------------===// 55 56 // Forward declarations. 57 template <typename ValTy> 58 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc, 59 StringRef spec); 60 template <> 61 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc, 62 StringRef spec); 63 64 template <> 65 Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc, 66 StringRef spec); 67 68 // Parses "<number> x" from the beginning of `spec`. 69 static bool parseNumberX(StringRef &spec, int64_t &number) { 70 spec = spec.ltrim(); 71 if (spec.empty() || !llvm::isDigit(spec.front())) 72 return false; 73 74 number = 0; 75 do { 76 number = number * 10 + spec.front() - '0'; 77 spec = spec.drop_front(); 78 } while (!spec.empty() && llvm::isDigit(spec.front())); 79 80 spec = spec.ltrim(); 81 if (!spec.consume_front("x")) 82 return false; 83 84 return true; 85 } 86 87 static bool isValidSPIRVIntType(IntegerType type) { 88 return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}), 89 type.getWidth()); 90 } 91 92 static bool isValidSPIRVScalarType(Type type) { 93 if (type.isa<FloatType>()) { 94 return !type.isBF16(); 95 } 96 if (auto intType = type.dyn_cast<IntegerType>()) { 97 return isValidSPIRVIntType(intType); 98 } 99 return false; 100 } 101 102 static bool isValidSPIRVVectorType(VectorType type) { 103 return type.getRank() == 1 && isValidSPIRVScalarType(type.getElementType()) && 104 type.getNumElements() >= 2 && type.getNumElements() <= 4; 105 } 106 107 bool SPIRVDialect::isValidSPIRVType(Type type) const { 108 // Allow SPIR-V dialect types 109 if (&type.getDialect() == this) { 110 return true; 111 } 112 if (isValidSPIRVScalarType(type)) { 113 return true; 114 } 115 if (auto vectorType = type.dyn_cast<VectorType>()) { 116 return isValidSPIRVVectorType(vectorType); 117 } 118 return false; 119 } 120 121 static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec, 122 Location loc) { 123 spec = spec.trim(); 124 auto *context = dialect.getContext(); 125 auto type = mlir::parseType(spec.trim(), context); 126 if (!type) { 127 emitError(loc, "cannot parse type: ") << spec; 128 return Type(); 129 } 130 131 // Allow SPIR-V dialect types 132 if (&type.getDialect() == &dialect) 133 return type; 134 135 // Check other allowed types 136 if (auto t = type.dyn_cast<FloatType>()) { 137 if (type.isBF16()) { 138 emitError(loc, "cannot use 'bf16' to compose SPIR-V types"); 139 return Type(); 140 } 141 } else if (auto t = type.dyn_cast<IntegerType>()) { 142 if (!isValidSPIRVIntType(t)) { 143 emitError(loc, "only 1/8/16/32/64-bit integer type allowed but found ") 144 << type; 145 return Type(); 146 } 147 } else if (auto t = type.dyn_cast<VectorType>()) { 148 if (t.getRank() != 1) { 149 emitError(loc, "only 1-D vector allowed but found ") << t; 150 return Type(); 151 } 152 if (t.getNumElements() > 4) { 153 emitError(loc, 154 "vector length has to be less than or equal to 4 but found ") 155 << t.getNumElements(); 156 return Type(); 157 } 158 } else { 159 emitError(loc, "cannot use ") << type << " to compose SPIR-V types"; 160 return Type(); 161 } 162 163 return type; 164 } 165 166 // element-type ::= integer-type 167 // | floating-point-type 168 // | vector-type 169 // | spirv-type 170 // 171 // array-type ::= `!spv.array<` integer-literal `x` element-type 172 // (`[` integer-literal `]`)? `>` 173 static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, 174 Location loc) { 175 if (!spec.consume_front("array<") || !spec.consume_back(">")) { 176 emitError(loc, "spv.array delimiter <...> mismatch"); 177 return Type(); 178 } 179 180 int64_t count = 0; 181 spec = spec.trim(); 182 if (!parseNumberX(spec, count)) { 183 emitError(loc, "expected array element count followed by 'x' but found '") 184 << spec << "'"; 185 return Type(); 186 } 187 188 if (spec.trim().empty()) { 189 emitError(loc, "expected element type"); 190 return Type(); 191 } 192 193 ArrayType::LayoutInfo layoutInfo = 0; 194 size_t lastLSquare; 195 196 // Handle case when element type is not a trivial type 197 auto lastRDelimiter = spec.rfind('>'); 198 if (lastRDelimiter != StringRef::npos) { 199 lastLSquare = spec.find('[', lastRDelimiter); 200 } else { 201 lastLSquare = spec.rfind('['); 202 } 203 204 if (lastLSquare != StringRef::npos) { 205 auto layoutSpec = spec.substr(lastLSquare); 206 auto layout = 207 parseAndVerify<ArrayType::LayoutInfo>(dialect, loc, layoutSpec); 208 if (!layout) { 209 return Type(); 210 } 211 212 if (!(layoutInfo = layout.getValue())) { 213 emitError(loc, "ArrayStride must be greater than zero"); 214 return Type(); 215 } 216 spec = spec.substr(0, lastLSquare); 217 } 218 219 Type elementType = parseAndVerifyType(dialect, spec, loc); 220 if (!elementType) 221 return Type(); 222 223 return ArrayType::get(elementType, count, layoutInfo); 224 } 225 226 // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type 227 // methods in alphabetical order 228 // 229 // storage-class ::= `UniformConstant` 230 // | `Uniform` 231 // | `Workgroup` 232 // | <and other storage classes...> 233 // 234 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>` 235 static Type parsePointerType(SPIRVDialect const &dialect, StringRef spec, 236 Location loc) { 237 if (!spec.consume_front("ptr<") || !spec.consume_back(">")) { 238 emitError(loc, "spv.ptr delimiter <...> mismatch"); 239 return Type(); 240 } 241 242 // Split into pointee type and storage class 243 StringRef scSpec, ptSpec; 244 std::tie(ptSpec, scSpec) = spec.rsplit(','); 245 if (scSpec.empty()) { 246 emitError(loc, 247 "expected comma to separate pointee type and storage class in '") 248 << spec << "'"; 249 return Type(); 250 } 251 252 scSpec = scSpec.trim(); 253 auto storageClass = symbolizeStorageClass(scSpec); 254 if (!storageClass) { 255 emitError(loc, "unknown storage class: ") << scSpec; 256 return Type(); 257 } 258 259 if (ptSpec.trim().empty()) { 260 emitError(loc, "expected pointee type"); 261 return Type(); 262 } 263 264 auto pointeeType = parseAndVerifyType(dialect, ptSpec, loc); 265 if (!pointeeType) 266 return Type(); 267 268 return PointerType::get(pointeeType, *storageClass); 269 } 270 271 // runtime-array-type ::= `!spv.rtarray<` element-type `>` 272 static Type parseRuntimeArrayType(SPIRVDialect const &dialect, StringRef spec, 273 Location loc) { 274 if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) { 275 emitError(loc, "spv.rtarray delimiter <...> mismatch"); 276 return Type(); 277 } 278 279 if (spec.trim().empty()) { 280 emitError(loc, "expected element type"); 281 return Type(); 282 } 283 284 Type elementType = parseAndVerifyType(dialect, spec, loc); 285 if (!elementType) 286 return Type(); 287 288 return RuntimeArrayType::get(elementType); 289 } 290 291 // Specialize this function to parse each of the parameters that define an 292 // ImageType. By default it assumes this is an enum type. 293 template <typename ValTy> 294 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc, 295 StringRef spec) { 296 auto val = spirv::symbolizeEnum<ValTy>()(spec); 297 if (!val) { 298 emitError(loc, "unknown attribute: '") << spec << "'"; 299 } 300 return val; 301 } 302 303 template <> 304 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc, 305 StringRef spec) { 306 // TODO(ravishankarm): Further verify that the element type can be sampled 307 auto ty = parseAndVerifyType(dialect, spec, loc); 308 if (!ty) { 309 return llvm::None; 310 } 311 return ty; 312 } 313 314 template <> 315 Optional<uint64_t> parseAndVerify(SPIRVDialect const &dialect, Location loc, 316 StringRef spec) { 317 uint64_t offsetVal = std::numeric_limits<uint64_t>::max(); 318 if (!spec.consume_front("[")) { 319 emitError(loc, "expected '[' while parsing layout specification in '") 320 << spec << "'"; 321 return llvm::None; 322 } 323 spec = spec.trim(); 324 if (spec.consumeInteger(10, offsetVal)) { 325 emitError(loc, "expected unsigned integer to specify layout information: '") 326 << spec << "'"; 327 return llvm::None; 328 } 329 spec = spec.trim(); 330 if (!spec.consume_front("]")) { 331 emitError(loc, "missing ']' in decorations spec: '") << spec << "'"; 332 return llvm::None; 333 } 334 if (spec != "") { 335 emitError(loc, "unexpected extra tokens in layout information: '") 336 << spec << "'"; 337 return llvm::None; 338 } 339 return offsetVal; 340 } 341 342 // Functor object to parse a comma separated list of specs. The function 343 // parseAndVerify does the actual parsing and verification of individual 344 // elements. This is a functor since parsing the last element of the list 345 // (termination condition) needs partial specialization. 346 template <typename ParseType, typename... Args> struct parseCommaSeparatedList { 347 Optional<std::tuple<ParseType, Args...>> 348 operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const { 349 auto numArgs = std::tuple_size<std::tuple<Args...>>::value; 350 StringRef parseSpec, restSpec; 351 std::tie(parseSpec, restSpec) = spec.split(','); 352 353 parseSpec = parseSpec.trim(); 354 if (numArgs != 0 && restSpec.empty()) { 355 emitError(loc, "expected more parameters for image type '") 356 << parseSpec << "'"; 357 return llvm::None; 358 } 359 360 auto parseVal = parseAndVerify<ParseType>(dialect, loc, parseSpec); 361 if (!parseVal) { 362 return llvm::None; 363 } 364 365 auto remainingValues = 366 parseCommaSeparatedList<Args...>{}(dialect, loc, restSpec); 367 if (!remainingValues) { 368 return llvm::None; 369 } 370 return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()), 371 remainingValues.getValue()); 372 } 373 }; 374 375 // Partial specialization of the function to parse a comma separated list of 376 // specs to parse the last element of the list. 377 template <typename ParseType> struct parseCommaSeparatedList<ParseType> { 378 Optional<std::tuple<ParseType>> 379 operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const { 380 spec = spec.trim(); 381 auto value = parseAndVerify<ParseType>(dialect, loc, spec); 382 if (!value) { 383 return llvm::None; 384 } 385 return std::tuple<ParseType>(value.getValue()); 386 } 387 }; 388 389 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...> 390 // 391 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` 392 // 393 // arrayed-info ::= `NonArrayed` | `Arrayed` 394 // 395 // sampling-info ::= `SingleSampled` | `MultiSampled` 396 // 397 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` 398 // 399 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...> 400 // 401 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,` 402 // arrayed-info `,` sampling-info `,` 403 // sampler-use-info `,` format `>` 404 static Type parseImageType(SPIRVDialect const &dialect, StringRef spec, 405 Location loc) { 406 if (!spec.consume_front("image<") || !spec.consume_back(">")) { 407 emitError(loc, "spv.image delimiter <...> mismatch"); 408 return Type(); 409 } 410 411 auto value = 412 parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo, 413 ImageSamplingInfo, ImageSamplerUseInfo, 414 ImageFormat>{}(dialect, loc, spec); 415 if (!value) { 416 return Type(); 417 } 418 419 return ImageType::get(value.getValue()); 420 } 421 422 // Method to parse one member of a struct (including Layout information) 423 static ParseResult 424 parseStructElement(SPIRVDialect const &dialect, StringRef spec, Location loc, 425 SmallVectorImpl<Type> &memberTypes, 426 SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) { 427 // Check for a '[' <layoutInfo> ']' 428 auto lastLSquare = spec.rfind('['); 429 auto typeSpec = spec.substr(0, lastLSquare); 430 auto layoutSpec = (lastLSquare == StringRef::npos ? StringRef("") 431 : spec.substr(lastLSquare)); 432 auto type = parseAndVerify<Type>(dialect, loc, typeSpec); 433 if (!type) { 434 return failure(); 435 } 436 memberTypes.push_back(type.getValue()); 437 if (layoutSpec.empty()) { 438 return success(); 439 } 440 if (layoutInfo.size() != memberTypes.size() - 1) { 441 emitError(loc, "layout specification must be given for all members"); 442 return failure(); 443 } 444 auto layout = 445 parseAndVerify<StructType::LayoutInfo>(dialect, loc, layoutSpec); 446 if (!layout) { 447 return failure(); 448 } 449 layoutInfo.push_back(layout.getValue()); 450 return success(); 451 } 452 453 // Helper method to record the position of the corresponding '>' for every '<' 454 // encountered when parsing the string left to right. The relative position of 455 // '>' w.r.t to the '<' is recorded. 456 static bool 457 computeMatchingRAngles(Location loc, StringRef const &spec, 458 SmallVectorImpl<size_t> &matchingRAngleOffset) { 459 SmallVector<size_t, 4> openBrackets; 460 for (size_t i = 0, e = spec.size(); i != e; ++i) { 461 if (spec[i] == '<') { 462 openBrackets.push_back(i); 463 } else if (spec[i] == '>') { 464 if (openBrackets.empty()) { 465 emitError(loc, "unbalanced '<' in '") << spec << "'"; 466 return false; 467 } 468 matchingRAngleOffset.push_back(i - openBrackets.pop_back_val()); 469 } 470 } 471 return true; 472 } 473 474 static ParseResult 475 parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc, 476 ArrayRef<size_t> matchingRAngleOffset, 477 SmallVectorImpl<Type> &memberTypes, 478 SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) { 479 // Check if the occurrence of ',' or '<' is before. If former, split using 480 // ','. If latter, split using matching '>' to get the entire type 481 // description 482 auto firstComma = spec.find(','); 483 auto firstLAngle = spec.find('<'); 484 if (firstLAngle == StringRef::npos && firstComma == StringRef::npos) { 485 return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo); 486 } 487 if (firstLAngle == StringRef::npos || firstComma < firstLAngle) { 488 // Parse the type before the ',' 489 if (parseStructElement(dialect, spec.substr(0, firstComma), loc, 490 memberTypes, layoutInfo)) { 491 return failure(); 492 } 493 return parseStructHelper(dialect, spec.substr(firstComma + 1).ltrim(), loc, 494 matchingRAngleOffset, memberTypes, layoutInfo); 495 } 496 auto matchingRAngle = matchingRAngleOffset.front() + firstLAngle; 497 // Find the next ',' or '>' 498 auto endLoc = std::min(spec.find(',', matchingRAngle + 1), spec.size()); 499 if (parseStructElement(dialect, spec.substr(0, endLoc), loc, memberTypes, 500 layoutInfo)) { 501 return failure(); 502 } 503 auto rest = spec.substr(endLoc + 1).ltrim(); 504 if (rest.empty()) { 505 return success(); 506 } 507 if (rest.front() == ',') { 508 return parseStructHelper( 509 dialect, rest.drop_front().trim(), loc, 510 ArrayRef<size_t>(std::next(matchingRAngleOffset.begin()), 511 matchingRAngleOffset.end()), 512 memberTypes, layoutInfo); 513 } 514 emitError(loc, "unexpected string : '") << rest << "'"; 515 return failure(); 516 } 517 518 // struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]`)? 519 // (`, ` spirv-type ( ` [` integer-literal `] ` )? )* 520 static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, 521 Location loc) { 522 if (!spec.consume_front("struct<") || !spec.consume_back(">")) { 523 emitError(loc, "spv.struct delimiter <...> mismatch"); 524 return Type(); 525 } 526 527 if (spec.trim().empty()) { 528 emitError(loc, "expected SPIR-V type"); 529 return Type(); 530 } 531 532 SmallVector<Type, 4> memberTypes; 533 SmallVector<StructType::LayoutInfo, 4> layoutInfo; 534 SmallVector<size_t, 4> matchingRAngleOffset; 535 if (!computeMatchingRAngles(loc, spec, matchingRAngleOffset) || 536 parseStructHelper(dialect, spec, loc, matchingRAngleOffset, memberTypes, 537 layoutInfo)) { 538 return Type(); 539 } 540 if (layoutInfo.empty()) { 541 return StructType::get(memberTypes); 542 } 543 if (memberTypes.size() != layoutInfo.size()) { 544 emitError(loc, "layout specification must be given for all members"); 545 return Type(); 546 } 547 return StructType::get(memberTypes, layoutInfo); 548 } 549 550 // spirv-type ::= array-type 551 // | element-type 552 // | image-type 553 // | pointer-type 554 // | runtime-array-type 555 // | struct-type 556 Type SPIRVDialect::parseType(StringRef spec, Location loc) const { 557 if (spec.startswith("array")) 558 return parseArrayType(*this, spec, loc); 559 if (spec.startswith("image")) 560 return parseImageType(*this, spec, loc); 561 if (spec.startswith("ptr")) 562 return parsePointerType(*this, spec, loc); 563 if (spec.startswith("rtarray")) 564 return parseRuntimeArrayType(*this, spec, loc); 565 if (spec.startswith("struct")) 566 return parseStructType(*this, spec, loc); 567 568 emitError(loc, "unknown SPIR-V type: ") << spec; 569 return Type(); 570 } 571 572 //===----------------------------------------------------------------------===// 573 // Type Printing 574 //===----------------------------------------------------------------------===// 575 576 static void print(ArrayType type, llvm::raw_ostream &os) { 577 os << "array<" << type.getNumElements() << " x " << type.getElementType(); 578 if (type.hasLayout()) { 579 os << " [" << type.getArrayStride() << "]"; 580 } 581 os << ">"; 582 } 583 584 static void print(RuntimeArrayType type, llvm::raw_ostream &os) { 585 os << "rtarray<" << type.getElementType() << ">"; 586 } 587 588 static void print(PointerType type, llvm::raw_ostream &os) { 589 os << "ptr<" << type.getPointeeType() << ", " 590 << stringifyStorageClass(type.getStorageClass()) << ">"; 591 } 592 593 static void print(ImageType type, llvm::raw_ostream &os) { 594 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) 595 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " 596 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " 597 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " 598 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " 599 << stringifyImageFormat(type.getImageFormat()) << ">"; 600 } 601 602 static void print(StructType type, llvm::raw_ostream &os) { 603 os << "struct<"; 604 auto printMember = [&](unsigned i) { 605 os << type.getElementType(i); 606 if (type.hasLayout()) { 607 os << " [" << type.getOffset(i) << "]"; 608 } 609 }; 610 mlir::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, 611 printMember); 612 os << ">"; 613 } 614 615 void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const { 616 switch (type.getKind()) { 617 case TypeKind::Array: 618 print(type.cast<ArrayType>(), os); 619 return; 620 case TypeKind::Pointer: 621 print(type.cast<PointerType>(), os); 622 return; 623 case TypeKind::RuntimeArray: 624 print(type.cast<RuntimeArrayType>(), os); 625 return; 626 case TypeKind::Image: 627 print(type.cast<ImageType>(), os); 628 return; 629 case TypeKind::Struct: 630 print(type.cast<StructType>(), os); 631 return; 632 default: 633 llvm_unreachable("unhandled SPIR-V type"); 634 } 635 }