github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (about) 1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file defines the types and operation details for the LLVM IR dialect in 19 // MLIR, and the LLVM IR dialect. It also registers the dialect. 20 // 21 //===----------------------------------------------------------------------===// 22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 23 #include "mlir/IR/Builders.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/Module.h" 26 #include "mlir/IR/StandardTypes.h" 27 28 #include "llvm/AsmParser/Parser.h" 29 #include "llvm/IR/Attributes.h" 30 #include "llvm/IR/Function.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/Support/Mutex.h" 33 #include "llvm/Support/SourceMgr.h" 34 35 using namespace mlir; 36 using namespace mlir::LLVM; 37 38 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" 39 40 //===----------------------------------------------------------------------===// 41 // Printing/parsing for LLVM::CmpOp. 42 //===----------------------------------------------------------------------===// 43 static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) { 44 *p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate()) 45 << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); 46 p->printOptionalAttrDict(op.getAttrs(), {"predicate"}); 47 *p << " : " << op.lhs()->getType(); 48 } 49 50 static void printFCmpOp(OpAsmPrinter *p, FCmpOp &op) { 51 *p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate()) 52 << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1); 53 p->printOptionalAttrDict(op.getAttrs(), {"predicate"}); 54 *p << " : " << op.lhs()->getType(); 55 } 56 57 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use 58 // attribute-dict? `:` type 59 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use 60 // attribute-dict? `:` type 61 template <typename CmpPredicateType> 62 static ParseResult parseCmpOp(OpAsmParser *parser, OperationState *result) { 63 Builder &builder = parser->getBuilder(); 64 65 Attribute predicate; 66 SmallVector<NamedAttribute, 4> attrs; 67 OpAsmParser::OperandType lhs, rhs; 68 Type type; 69 llvm::SMLoc predicateLoc, trailingTypeLoc; 70 if (parser->getCurrentLocation(&predicateLoc) || 71 parser->parseAttribute(predicate, "predicate", attrs) || 72 parser->parseOperand(lhs) || parser->parseComma() || 73 parser->parseOperand(rhs) || parser->parseOptionalAttributeDict(attrs) || 74 parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || 75 parser->parseType(type) || 76 parser->resolveOperand(lhs, type, result->operands) || 77 parser->resolveOperand(rhs, type, result->operands)) 78 return failure(); 79 80 // Replace the string attribute `predicate` with an integer attribute. 81 auto predicateStr = predicate.dyn_cast<StringAttr>(); 82 if (!predicateStr) 83 return parser->emitError(predicateLoc, 84 "expected 'predicate' attribute of string type"); 85 86 int64_t predicateValue = 0; 87 if (std::is_same<CmpPredicateType, ICmpPredicate>()) { 88 Optional<ICmpPredicate> predicate = 89 symbolizeICmpPredicate(predicateStr.getValue()); 90 if (!predicate) 91 return parser->emitError(predicateLoc) 92 << "'" << predicateStr.getValue() 93 << "' is an incorrect value of the 'predicate' attribute"; 94 predicateValue = static_cast<int64_t>(predicate.getValue()); 95 } else { 96 Optional<FCmpPredicate> predicate = 97 symbolizeFCmpPredicate(predicateStr.getValue()); 98 if (!predicate) 99 return parser->emitError(predicateLoc) 100 << "'" << predicateStr.getValue() 101 << "' is an incorrect value of the 'predicate' attribute"; 102 predicateValue = static_cast<int64_t>(predicate.getValue()); 103 } 104 105 attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue); 106 107 // The result type is either i1 or a vector type <? x i1> if the inputs are 108 // vectors. 109 auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>(); 110 auto resultType = LLVMType::getInt1Ty(dialect); 111 auto argType = type.dyn_cast<LLVM::LLVMType>(); 112 if (!argType) 113 return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"); 114 if (argType.getUnderlyingType()->isVectorTy()) 115 resultType = LLVMType::getVectorTy( 116 resultType, argType.getUnderlyingType()->getVectorNumElements()); 117 118 result->attributes = attrs; 119 result->addTypes({resultType}); 120 return success(); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // Printing/parsing for LLVM::AllocaOp. 125 //===----------------------------------------------------------------------===// 126 127 static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) { 128 auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy(); 129 130 auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()}, 131 op.getContext()); 132 133 *p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy; 134 if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) 135 p->printOptionalAttrDict(op.getAttrs()); 136 else 137 p->printOptionalAttrDict(op.getAttrs(), {"alignment"}); 138 *p << " : " << funcTy; 139 } 140 141 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict? 142 // `:` type `,` type 143 static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) { 144 SmallVector<NamedAttribute, 4> attrs; 145 OpAsmParser::OperandType arraySize; 146 Type type, elemType; 147 llvm::SMLoc trailingTypeLoc; 148 if (parser->parseOperand(arraySize) || parser->parseKeyword("x") || 149 parser->parseType(elemType) || 150 parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || 151 parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) 152 return failure(); 153 154 // Extract the result type from the trailing function type. 155 auto funcType = type.dyn_cast<FunctionType>(); 156 if (!funcType || funcType.getNumInputs() != 1 || 157 funcType.getNumResults() != 1) 158 return parser->emitError( 159 trailingTypeLoc, 160 "expected trailing function type with one argument and one result"); 161 162 if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands)) 163 return failure(); 164 165 result->attributes = attrs; 166 result->addTypes({funcType.getResult(0)}); 167 return success(); 168 } 169 170 //===----------------------------------------------------------------------===// 171 // Printing/parsing for LLVM::GEPOp. 172 //===----------------------------------------------------------------------===// 173 174 static void printGEPOp(OpAsmPrinter *p, GEPOp &op) { 175 SmallVector<Type, 8> types(op.getOperandTypes()); 176 auto funcTy = FunctionType::get(types, op.getType(), op.getContext()); 177 178 *p << op.getOperationName() << ' ' << *op.base() << '['; 179 p->printOperands(std::next(op.operand_begin()), op.operand_end()); 180 *p << ']'; 181 p->printOptionalAttrDict(op.getAttrs()); 182 *p << " : " << funcTy; 183 } 184 185 // <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]` 186 // attribute-dict? `:` type 187 static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) { 188 SmallVector<NamedAttribute, 4> attrs; 189 OpAsmParser::OperandType base; 190 SmallVector<OpAsmParser::OperandType, 8> indices; 191 Type type; 192 llvm::SMLoc trailingTypeLoc; 193 if (parser->parseOperand(base) || 194 parser->parseOperandList(indices, OpAsmParser::Delimiter::Square) || 195 parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || 196 parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) 197 return failure(); 198 199 // Deconstruct the trailing function type to extract the types of the base 200 // pointer and result (same type) and the types of the indices. 201 auto funcType = type.dyn_cast<FunctionType>(); 202 if (!funcType || funcType.getNumResults() != 1 || 203 funcType.getNumInputs() == 0) 204 return parser->emitError(trailingTypeLoc, 205 "expected trailing function type with at least " 206 "one argument and one result"); 207 208 if (parser->resolveOperand(base, funcType.getInput(0), result->operands) || 209 parser->resolveOperands(indices, funcType.getInputs().drop_front(), 210 parser->getNameLoc(), result->operands)) 211 return failure(); 212 213 result->attributes = attrs; 214 result->addTypes(funcType.getResults()); 215 return success(); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // Printing/parsing for LLVM::LoadOp. 220 //===----------------------------------------------------------------------===// 221 222 static void printLoadOp(OpAsmPrinter *p, LoadOp &op) { 223 *p << op.getOperationName() << ' ' << *op.addr(); 224 p->printOptionalAttrDict(op.getAttrs()); 225 *p << " : " << op.addr()->getType(); 226 } 227 228 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return 229 // the resulting type wrapped in MLIR, or nullptr on error. 230 static Type getLoadStoreElementType(OpAsmParser *parser, Type type, 231 llvm::SMLoc trailingTypeLoc) { 232 auto llvmTy = type.dyn_cast<LLVM::LLVMType>(); 233 if (!llvmTy) 234 return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"), 235 nullptr; 236 if (!llvmTy.getUnderlyingType()->isPointerTy()) 237 return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"), 238 nullptr; 239 return llvmTy.getPointerElementTy(); 240 } 241 242 // <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type 243 static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) { 244 SmallVector<NamedAttribute, 4> attrs; 245 OpAsmParser::OperandType addr; 246 Type type; 247 llvm::SMLoc trailingTypeLoc; 248 249 if (parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || 250 parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || 251 parser->parseType(type) || 252 parser->resolveOperand(addr, type, result->operands)) 253 return failure(); 254 255 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 256 257 result->attributes = attrs; 258 result->addTypes(elemTy); 259 return success(); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // Printing/parsing for LLVM::StoreOp. 264 //===----------------------------------------------------------------------===// 265 266 static void printStoreOp(OpAsmPrinter *p, StoreOp &op) { 267 *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr(); 268 p->printOptionalAttrDict(op.getAttrs()); 269 *p << " : " << op.addr()->getType(); 270 } 271 272 // <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type 273 static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) { 274 SmallVector<NamedAttribute, 4> attrs; 275 OpAsmParser::OperandType addr, value; 276 Type type; 277 llvm::SMLoc trailingTypeLoc; 278 279 if (parser->parseOperand(value) || parser->parseComma() || 280 parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) || 281 parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || 282 parser->parseType(type)) 283 return failure(); 284 285 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc); 286 if (!elemTy) 287 return failure(); 288 289 if (parser->resolveOperand(value, elemTy, result->operands) || 290 parser->resolveOperand(addr, type, result->operands)) 291 return failure(); 292 293 result->attributes = attrs; 294 return success(); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // Printing/parsing for LLVM::CallOp. 299 //===----------------------------------------------------------------------===// 300 301 static void printCallOp(OpAsmPrinter *p, CallOp &op) { 302 auto callee = op.callee(); 303 bool isDirect = callee.hasValue(); 304 305 // Print the direct callee if present as a function attribute, or an indirect 306 // callee (first operand) otherwise. 307 *p << op.getOperationName() << ' '; 308 if (isDirect) 309 *p << '@' << callee.getValue(); 310 else 311 *p << *op.getOperand(0); 312 313 *p << '('; 314 p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1)); 315 *p << ')'; 316 317 p->printOptionalAttrDict(op.getAttrs(), {"callee"}); 318 319 // Reconstruct the function MLIR function type from operand and result types. 320 SmallVector<Type, 1> resultTypes(op.getResultTypes()); 321 SmallVector<Type, 8> argTypes( 322 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 323 324 *p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); 325 } 326 327 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` 328 // attribute-dict? `:` function-type 329 static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) { 330 SmallVector<NamedAttribute, 4> attrs; 331 SmallVector<OpAsmParser::OperandType, 8> operands; 332 Type type; 333 SymbolRefAttr funcAttr; 334 llvm::SMLoc trailingTypeLoc; 335 336 // Parse an operand list that will, in practice, contain 0 or 1 operand. In 337 // case of an indirect call, there will be 1 operand before `(`. In case of a 338 // direct call, there will be no operands and the parser will stop at the 339 // function identifier without complaining. 340 if (parser->parseOperandList(operands)) 341 return failure(); 342 bool isDirect = operands.empty(); 343 344 // Optionally parse a function identifier. 345 if (isDirect) 346 if (parser->parseAttribute(funcAttr, "callee", attrs)) 347 return failure(); 348 349 if (parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || 350 parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || 351 parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type)) 352 return failure(); 353 354 auto funcType = type.dyn_cast<FunctionType>(); 355 if (!funcType) 356 return parser->emitError(trailingTypeLoc, "expected function type"); 357 if (isDirect) { 358 // Make sure types match. 359 if (parser->resolveOperands(operands, funcType.getInputs(), 360 parser->getNameLoc(), result->operands)) 361 return failure(); 362 result->addTypes(funcType.getResults()); 363 } else { 364 // Construct the LLVM IR Dialect function type that the first operand 365 // should match. 366 if (funcType.getNumResults() > 1) 367 return parser->emitError(trailingTypeLoc, 368 "expected function with 0 or 1 result"); 369 370 Builder &builder = parser->getBuilder(); 371 auto *llvmDialect = 372 builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); 373 LLVM::LLVMType llvmResultType; 374 if (funcType.getNumResults() == 0) { 375 llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); 376 } else { 377 llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>(); 378 if (!llvmResultType) 379 return parser->emitError(trailingTypeLoc, 380 "expected result to have LLVM type"); 381 } 382 383 SmallVector<LLVM::LLVMType, 8> argTypes; 384 argTypes.reserve(funcType.getNumInputs()); 385 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { 386 auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>(); 387 if (!argType) 388 return parser->emitError(trailingTypeLoc, 389 "expected LLVM types as inputs"); 390 argTypes.push_back(argType); 391 } 392 auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, 393 /*isVarArg=*/false); 394 auto wrappedFuncType = llvmFuncType.getPointerTo(); 395 396 auto funcArguments = 397 ArrayRef<OpAsmParser::OperandType>(operands).drop_front(); 398 399 // Make sure that the first operand (indirect callee) matches the wrapped 400 // LLVM IR function type, and that the types of the other call operands 401 // match the types of the function arguments. 402 if (parser->resolveOperand(operands[0], wrappedFuncType, 403 result->operands) || 404 parser->resolveOperands(funcArguments, funcType.getInputs(), 405 parser->getNameLoc(), result->operands)) 406 return failure(); 407 408 result->addTypes(llvmResultType); 409 } 410 411 result->attributes = attrs; 412 return success(); 413 } 414 415 //===----------------------------------------------------------------------===// 416 // Printing/parsing for LLVM::ExtractElementOp. 417 //===----------------------------------------------------------------------===// 418 // Expects vector to be of wrapped LLVM vector type and position to be of 419 // wrapped LLVM i32 type. 420 void LLVM::ExtractElementOp::build(Builder *b, OperationState *result, 421 Value *vector, Value *position, 422 ArrayRef<NamedAttribute> attrs) { 423 auto wrappedVectorType = vector->getType().cast<LLVM::LLVMType>(); 424 auto llvmType = wrappedVectorType.getVectorElementType(); 425 build(b, result, llvmType, vector, position); 426 result->addAttributes(attrs); 427 } 428 429 static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp &op) { 430 *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position(); 431 p->printOptionalAttrDict(op.getAttrs()); 432 *p << " : " << op.vector()->getType(); 433 } 434 435 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use 436 // attribute-dict? `:` type 437 static ParseResult parseExtractElementOp(OpAsmParser *parser, 438 OperationState *result) { 439 llvm::SMLoc loc; 440 OpAsmParser::OperandType vector, position; 441 auto *llvmDialect = parser->getBuilder() 442 .getContext() 443 ->getRegisteredDialect<LLVM::LLVMDialect>(); 444 Type type, i32Type = LLVMType::getInt32Ty(llvmDialect); 445 if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || 446 parser->parseComma() || parser->parseOperand(position) || 447 parser->parseOptionalAttributeDict(result->attributes) || 448 parser->parseColonType(type) || 449 parser->resolveOperand(vector, type, result->operands) || 450 parser->resolveOperand(position, i32Type, result->operands)) 451 return failure(); 452 auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>(); 453 if (!wrappedVectorType || 454 !wrappedVectorType.getUnderlyingType()->isVectorTy()) 455 return parser->emitError( 456 loc, "expected LLVM IR dialect vector type for operand #1"); 457 result->addTypes(wrappedVectorType.getVectorElementType()); 458 return success(); 459 } 460 461 //===----------------------------------------------------------------------===// 462 // Printing/parsing for LLVM::ExtractValueOp. 463 //===----------------------------------------------------------------------===// 464 465 static void printExtractValueOp(OpAsmPrinter *p, ExtractValueOp &op) { 466 *p << op.getOperationName() << ' ' << *op.container() << op.position(); 467 p->printOptionalAttrDict(op.getAttrs(), {"position"}); 468 *p << " : " << op.container()->getType(); 469 } 470 471 // Extract the type at `position` in the wrapped LLVM IR aggregate type 472 // `containerType`. Position is an integer array attribute where each value 473 // is a zero-based position of the element in the aggregate type. Return the 474 // resulting type wrapped in MLIR, or nullptr on error. 475 static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser, 476 Type containerType, 477 Attribute positionAttr, 478 llvm::SMLoc attributeLoc, 479 llvm::SMLoc typeLoc) { 480 auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>(); 481 if (!wrappedContainerType) 482 return parser->emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; 483 484 auto positionArrayAttr = positionAttr.dyn_cast<ArrayAttr>(); 485 if (!positionArrayAttr) 486 return parser->emitError(attributeLoc, "expected an array attribute"), 487 nullptr; 488 489 // Infer the element type from the structure type: iteratively step inside the 490 // type by taking the element type, indexed by the position attribute for 491 // stuctures. Check the position index before accessing, it is supposed to be 492 // in bounds. 493 for (Attribute subAttr : positionArrayAttr) { 494 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>(); 495 if (!positionElementAttr) 496 return parser->emitError(attributeLoc, 497 "expected an array of integer literals"), 498 nullptr; 499 int position = positionElementAttr.getInt(); 500 auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); 501 if (llvmContainerType->isArrayTy()) { 502 if (position < 0 || static_cast<unsigned>(position) >= 503 llvmContainerType->getArrayNumElements()) 504 return parser->emitError(attributeLoc, "position out of bounds"), 505 nullptr; 506 wrappedContainerType = wrappedContainerType.getArrayElementType(); 507 } else if (llvmContainerType->isStructTy()) { 508 if (position < 0 || static_cast<unsigned>(position) >= 509 llvmContainerType->getStructNumElements()) 510 return parser->emitError(attributeLoc, "position out of bounds"), 511 nullptr; 512 wrappedContainerType = 513 wrappedContainerType.getStructElementType(position); 514 } else { 515 return parser->emitError(typeLoc, 516 "expected wrapped LLVM IR structure/array type"), 517 nullptr; 518 } 519 } 520 return wrappedContainerType; 521 } 522 523 // <operation> ::= `llvm.extractvalue` ssa-use 524 // `[` integer-literal (`,` integer-literal)* `]` 525 // attribute-dict? `:` type 526 static ParseResult parseExtractValueOp(OpAsmParser *parser, 527 OperationState *result) { 528 SmallVector<NamedAttribute, 4> attrs; 529 OpAsmParser::OperandType container; 530 Type containerType; 531 Attribute positionAttr; 532 llvm::SMLoc attributeLoc, trailingTypeLoc; 533 534 if (parser->parseOperand(container) || 535 parser->getCurrentLocation(&attributeLoc) || 536 parser->parseAttribute(positionAttr, "position", attrs) || 537 parser->parseOptionalAttributeDict(attrs) || parser->parseColon() || 538 parser->getCurrentLocation(&trailingTypeLoc) || 539 parser->parseType(containerType) || 540 parser->resolveOperand(container, containerType, result->operands)) 541 return failure(); 542 543 auto elementType = getInsertExtractValueElementType( 544 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 545 if (!elementType) 546 return failure(); 547 548 result->attributes = attrs; 549 result->addTypes(elementType); 550 return success(); 551 } 552 553 //===----------------------------------------------------------------------===// 554 // Printing/parsing for LLVM::InsertElementOp. 555 //===----------------------------------------------------------------------===// 556 557 static void printInsertElementOp(OpAsmPrinter *p, InsertElementOp &op) { 558 *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value() 559 << ", " << *op.position(); 560 p->printOptionalAttrDict(op.getAttrs()); 561 *p << " : " << op.vector()->getType(); 562 } 563 564 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use 565 // attribute-dict? `:` type 566 static ParseResult parseInsertElementOp(OpAsmParser *parser, 567 OperationState *result) { 568 llvm::SMLoc loc; 569 OpAsmParser::OperandType vector, value, position; 570 auto *llvmDialect = parser->getBuilder() 571 .getContext() 572 ->getRegisteredDialect<LLVM::LLVMDialect>(); 573 Type vectorType, i32Type = LLVMType::getInt32Ty(llvmDialect); 574 if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) || 575 parser->parseComma() || parser->parseOperand(value) || 576 parser->parseComma() || parser->parseOperand(position) || 577 parser->parseOptionalAttributeDict(result->attributes) || 578 parser->parseColonType(vectorType)) 579 return failure(); 580 581 auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>(); 582 if (!wrappedVectorType || 583 !wrappedVectorType.getUnderlyingType()->isVectorTy()) 584 return parser->emitError( 585 loc, "expected LLVM IR dialect vector type for operand #1"); 586 auto valueType = wrappedVectorType.getVectorElementType(); 587 if (!valueType) 588 return failure(); 589 590 if (parser->resolveOperand(vector, vectorType, result->operands) || 591 parser->resolveOperand(value, valueType, result->operands) || 592 parser->resolveOperand(position, i32Type, result->operands)) 593 return failure(); 594 595 result->addTypes(vectorType); 596 return success(); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // Printing/parsing for LLVM::InsertValueOp. 601 //===----------------------------------------------------------------------===// 602 603 static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) { 604 *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container() 605 << op.position(); 606 p->printOptionalAttrDict(op.getAttrs(), {"position"}); 607 *p << " : " << op.container()->getType(); 608 } 609 610 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use 611 // `[` integer-literal (`,` integer-literal)* `]` 612 // attribute-dict? `:` type 613 static ParseResult parseInsertValueOp(OpAsmParser *parser, 614 OperationState *result) { 615 OpAsmParser::OperandType container, value; 616 Type containerType; 617 Attribute positionAttr; 618 llvm::SMLoc attributeLoc, trailingTypeLoc; 619 620 if (parser->parseOperand(value) || parser->parseComma() || 621 parser->parseOperand(container) || 622 parser->getCurrentLocation(&attributeLoc) || 623 parser->parseAttribute(positionAttr, "position", result->attributes) || 624 parser->parseOptionalAttributeDict(result->attributes) || 625 parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) || 626 parser->parseType(containerType)) 627 return failure(); 628 629 auto valueType = getInsertExtractValueElementType( 630 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc); 631 if (!valueType) 632 return failure(); 633 634 if (parser->resolveOperand(container, containerType, result->operands) || 635 parser->resolveOperand(value, valueType, result->operands)) 636 return failure(); 637 638 result->addTypes(containerType); 639 return success(); 640 } 641 642 //===----------------------------------------------------------------------===// 643 // Printing/parsing for LLVM::SelectOp. 644 //===----------------------------------------------------------------------===// 645 646 static void printSelectOp(OpAsmPrinter *p, SelectOp &op) { 647 *p << op.getOperationName() << ' ' << *op.condition() << ", " 648 << *op.trueValue() << ", " << *op.falseValue(); 649 p->printOptionalAttrDict(op.getAttrs()); 650 *p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType(); 651 } 652 653 // <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use 654 // attribute-dict? `:` type, type 655 static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) { 656 OpAsmParser::OperandType condition, trueValue, falseValue; 657 Type conditionType, argType; 658 659 if (parser->parseOperand(condition) || parser->parseComma() || 660 parser->parseOperand(trueValue) || parser->parseComma() || 661 parser->parseOperand(falseValue) || 662 parser->parseOptionalAttributeDict(result->attributes) || 663 parser->parseColonType(conditionType) || parser->parseComma() || 664 parser->parseType(argType)) 665 return failure(); 666 667 if (parser->resolveOperand(condition, conditionType, result->operands) || 668 parser->resolveOperand(trueValue, argType, result->operands) || 669 parser->resolveOperand(falseValue, argType, result->operands)) 670 return failure(); 671 672 result->addTypes(argType); 673 return success(); 674 } 675 676 //===----------------------------------------------------------------------===// 677 // Printing/parsing for LLVM::BrOp. 678 //===----------------------------------------------------------------------===// 679 680 static void printBrOp(OpAsmPrinter *p, BrOp &op) { 681 *p << op.getOperationName() << ' '; 682 p->printSuccessorAndUseList(op.getOperation(), 0); 683 p->printOptionalAttrDict(op.getAttrs()); 684 } 685 686 // <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)? 687 // attribute-dict? 688 static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) { 689 Block *dest; 690 SmallVector<Value *, 4> operands; 691 if (parser->parseSuccessorAndUseList(dest, operands) || 692 parser->parseOptionalAttributeDict(result->attributes)) 693 return failure(); 694 695 result->addSuccessor(dest, operands); 696 return success(); 697 } 698 699 //===----------------------------------------------------------------------===// 700 // Printing/parsing for LLVM::CondBrOp. 701 //===----------------------------------------------------------------------===// 702 703 static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) { 704 *p << op.getOperationName() << ' ' << *op.getOperand(0) << ", "; 705 p->printSuccessorAndUseList(op.getOperation(), 0); 706 *p << ", "; 707 p->printSuccessorAndUseList(op.getOperation(), 1); 708 p->printOptionalAttrDict(op.getAttrs()); 709 } 710 711 // <operation> ::= `llvm.cond_br` ssa-use `,` 712 // bb-id (`[` ssa-use-and-type-list `]`)? `,` 713 // bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict? 714 static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) { 715 Block *trueDest; 716 Block *falseDest; 717 SmallVector<Value *, 4> trueOperands; 718 SmallVector<Value *, 4> falseOperands; 719 OpAsmParser::OperandType condition; 720 721 Builder &builder = parser->getBuilder(); 722 auto *llvmDialect = 723 builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); 724 auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect); 725 726 if (parser->parseOperand(condition) || parser->parseComma() || 727 parser->parseSuccessorAndUseList(trueDest, trueOperands) || 728 parser->parseComma() || 729 parser->parseSuccessorAndUseList(falseDest, falseOperands) || 730 parser->parseOptionalAttributeDict(result->attributes) || 731 parser->resolveOperand(condition, i1Type, result->operands)) 732 return failure(); 733 734 result->addSuccessor(trueDest, trueOperands); 735 result->addSuccessor(falseDest, falseOperands); 736 return success(); 737 } 738 739 //===----------------------------------------------------------------------===// 740 // Printing/parsing for LLVM::ReturnOp. 741 //===----------------------------------------------------------------------===// 742 743 static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) { 744 *p << op.getOperationName(); 745 p->printOptionalAttrDict(op.getAttrs()); 746 assert(op.getNumOperands() <= 1); 747 748 if (op.getNumOperands() == 0) 749 return; 750 751 *p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType(); 752 } 753 754 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:` 755 // type-list-no-parens 756 static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) { 757 SmallVector<OpAsmParser::OperandType, 1> operands; 758 Type type; 759 760 if (parser->parseOperandList(operands) || 761 parser->parseOptionalAttributeDict(result->attributes)) 762 return failure(); 763 if (operands.empty()) 764 return success(); 765 766 if (parser->parseColonType(type) || 767 parser->resolveOperand(operands[0], type, result->operands)) 768 return failure(); 769 return success(); 770 } 771 772 //===----------------------------------------------------------------------===// 773 // Printing/parsing for LLVM::UndefOp. 774 //===----------------------------------------------------------------------===// 775 776 static void printUndefOp(OpAsmPrinter *p, UndefOp &op) { 777 *p << op.getOperationName(); 778 p->printOptionalAttrDict(op.getAttrs()); 779 *p << " : " << op.res()->getType(); 780 } 781 782 // <operation> ::= `llvm.undef` attribute-dict? : type 783 static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) { 784 Type type; 785 786 if (parser->parseOptionalAttributeDict(result->attributes) || 787 parser->parseColonType(type)) 788 return failure(); 789 790 result->addTypes(type); 791 return success(); 792 } 793 794 //===----------------------------------------------------------------------===// 795 // Printer, parser and verifier for LLVM::AddressOfOp. 796 //===----------------------------------------------------------------------===// 797 798 GlobalOp AddressOfOp::getGlobal() { 799 auto module = getParentOfType<ModuleOp>(); 800 assert(module && "unexpected operation outside of a module"); 801 return module.lookupSymbol<LLVM::GlobalOp>(global_name()); 802 } 803 804 static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) { 805 *p << op.getOperationName() << " @" << op.global_name(); 806 p->printOptionalAttrDict(op.getAttrs(), {"global_name"}); 807 *p << " : " << op.getResult()->getType(); 808 } 809 810 static ParseResult parseAddressOfOp(OpAsmParser *parser, 811 OperationState *result) { 812 Attribute symRef; 813 Type type; 814 if (parser->parseAttribute(symRef, "global_name", result->attributes) || 815 parser->parseOptionalAttributeDict(result->attributes) || 816 parser->parseColonType(type) || 817 parser->addTypeToList(type, result->types)) 818 return failure(); 819 820 if (!symRef.isa<SymbolRefAttr>()) 821 return parser->emitError(parser->getNameLoc(), "expected symbol reference"); 822 return success(); 823 } 824 825 static LogicalResult verify(AddressOfOp op) { 826 auto global = op.getGlobal(); 827 if (!global) 828 return op.emitOpError("must reference a global defined by 'llvm.global'"); 829 830 if (global.getType().getPointerTo() != op.getResult()->getType()) 831 return op.emitOpError( 832 "the type must be a pointer to the type of the referred global"); 833 834 return success(); 835 } 836 837 //===----------------------------------------------------------------------===// 838 // Printing/parsing for LLVM::ConstantOp. 839 //===----------------------------------------------------------------------===// 840 841 static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) { 842 *p << op.getOperationName() << '(' << op.value() << ')'; 843 p->printOptionalAttrDict(op.getAttrs(), {"value"}); 844 *p << " : " << op.res()->getType(); 845 } 846 847 // <operation> ::= `llvm.constant` `(` attribute `)` attribute-list? : type 848 static ParseResult parseConstantOp(OpAsmParser *parser, 849 OperationState *result) { 850 Attribute valueAttr; 851 Type type; 852 853 if (parser->parseLParen() || 854 parser->parseAttribute(valueAttr, "value", result->attributes) || 855 parser->parseRParen() || 856 parser->parseOptionalAttributeDict(result->attributes) || 857 parser->parseColonType(type)) 858 return failure(); 859 860 result->addTypes(type); 861 return success(); 862 } 863 864 //===----------------------------------------------------------------------===// 865 // Builder, printer and verifier for LLVM::GlobalOp. 866 //===----------------------------------------------------------------------===// 867 868 void GlobalOp::build(Builder *builder, OperationState *result, LLVMType type, 869 bool isConstant, StringRef name, Attribute value, 870 ArrayRef<NamedAttribute> attrs) { 871 result->addAttribute(SymbolTable::getSymbolAttrName(), 872 builder->getStringAttr(name)); 873 result->addAttribute("type", builder->getTypeAttr(type)); 874 if (isConstant) 875 result->addAttribute("constant", builder->getUnitAttr()); 876 result->addAttribute("value", value); 877 result->attributes.append(attrs.begin(), attrs.end()); 878 } 879 880 static void printGlobalOp(OpAsmPrinter *p, GlobalOp op) { 881 *p << op.getOperationName() << ' '; 882 if (op.constant()) 883 *p << "constant "; 884 *p << '@' << op.sym_name() << '('; 885 p->printAttribute(op.value()); 886 *p << ')'; 887 p->printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(), 888 "type", "constant", "value"}); 889 890 // Print the trailing type unless it's a string global. 891 if (op.value().isa<StringAttr>()) 892 return; 893 *p << " : "; 894 p->printType(op.type()); 895 } 896 897 // <operation> ::= `llvm.global` `constant`? `@` identifier `(` attribute `)` 898 // attribute-list? (`:` type)? 899 // 900 // The type can be omitted for string attributes, in which case it will be 901 // inferred from the value of the string as [strlen(value) x i8]. 902 static ParseResult parseGlobalOp(OpAsmParser *parser, OperationState *result) { 903 if (succeeded(parser->parseOptionalKeyword("constant"))) 904 result->addAttribute("constant", parser->getBuilder().getUnitAttr()); 905 906 Attribute value; 907 StringAttr name; 908 SmallVector<Type, 1> types; 909 if (parser->parseSymbolName(name, SymbolTable::getSymbolAttrName(), 910 result->attributes) || 911 parser->parseLParen() || 912 parser->parseAttribute(value, "value", result->attributes) || 913 parser->parseRParen() || 914 parser->parseOptionalAttributeDict(result->attributes) || 915 parser->parseOptionalColonTypeList(types)) 916 return failure(); 917 918 if (types.size() > 1) 919 return parser->emitError(parser->getNameLoc(), "expected zero or one type"); 920 921 if (types.empty()) { 922 if (auto strAttr = value.dyn_cast<StringAttr>()) { 923 MLIRContext *context = parser->getBuilder().getContext(); 924 auto *dialect = context->getRegisteredDialect<LLVMDialect>(); 925 auto arrayType = LLVM::LLVMType::getArrayTy( 926 LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size()); 927 types.push_back(arrayType); 928 } else { 929 return parser->emitError(parser->getNameLoc(), 930 "type can only be omitted for string globals"); 931 } 932 } 933 934 result->addAttribute("type", parser->getBuilder().getTypeAttr(types[0])); 935 return success(); 936 } 937 938 static LogicalResult verify(GlobalOp op) { 939 if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) 940 return op.emitOpError( 941 "expects type to be a valid element type for an LLVM pointer"); 942 if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp())) 943 return op.emitOpError("must appear at the module level"); 944 if (auto strAttr = op.value().dyn_cast<StringAttr>()) { 945 auto type = op.getType(); 946 if (!type.getUnderlyingType()->isArrayTy() || 947 !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) || 948 type.getArrayNumElements() != strAttr.getValue().size()) 949 return op.emitOpError( 950 "requires an i8 array type of the length equal to that of the string " 951 "attribute"); 952 } 953 return success(); 954 } 955 956 //===----------------------------------------------------------------------===// 957 // Printing/parsing for LLVM::ShuffleVectorOp. 958 //===----------------------------------------------------------------------===// 959 // Expects vector to be of wrapped LLVM vector type and position to be of 960 // wrapped LLVM i32 type. 961 void LLVM::ShuffleVectorOp::build(Builder *b, OperationState *result, Value *v1, 962 Value *v2, ArrayAttr mask, 963 ArrayRef<NamedAttribute> attrs) { 964 auto wrappedContainerType1 = v1->getType().cast<LLVM::LLVMType>(); 965 auto vType = LLVMType::getVectorTy( 966 wrappedContainerType1.getVectorElementType(), mask.size()); 967 build(b, result, vType, v1, v2, mask); 968 result->addAttributes(attrs); 969 } 970 971 static void printShuffleVectorOp(OpAsmPrinter *p, ShuffleVectorOp &op) { 972 *p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " " 973 << op.mask(); 974 p->printOptionalAttrDict(op.getAttrs(), {"mask"}); 975 *p << " : " << op.v1()->getType() << ", " << op.v2()->getType(); 976 } 977 978 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use 979 // `[` integer-literal (`,` integer-literal)* `]` 980 // attribute-dict? `:` type 981 static ParseResult parseShuffleVectorOp(OpAsmParser *parser, 982 OperationState *result) { 983 llvm::SMLoc loc; 984 SmallVector<NamedAttribute, 4> attrs; 985 OpAsmParser::OperandType v1, v2; 986 Attribute maskAttr; 987 Type typeV1, typeV2; 988 if (parser->getCurrentLocation(&loc) || parser->parseOperand(v1) || 989 parser->parseComma() || parser->parseOperand(v2) || 990 parser->parseAttribute(maskAttr, "mask", attrs) || 991 parser->parseOptionalAttributeDict(attrs) || 992 parser->parseColonType(typeV1) || parser->parseComma() || 993 parser->parseType(typeV2) || 994 parser->resolveOperand(v1, typeV1, result->operands) || 995 parser->resolveOperand(v2, typeV2, result->operands)) 996 return failure(); 997 auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>(); 998 if (!wrappedContainerType1 || 999 !wrappedContainerType1.getUnderlyingType()->isVectorTy()) 1000 return parser->emitError( 1001 loc, "expected LLVM IR dialect vector type for operand #1"); 1002 auto vType = 1003 LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(), 1004 maskAttr.cast<ArrayAttr>().size()); 1005 result->attributes = attrs; 1006 result->addTypes(vType); 1007 return success(); 1008 } 1009 1010 //===----------------------------------------------------------------------===// 1011 // Builder, printer and verifier for LLVM::LLVMFuncOp. 1012 //===----------------------------------------------------------------------===// 1013 1014 void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name, 1015 LLVMType type, ArrayRef<NamedAttribute> attrs, 1016 ArrayRef<NamedAttributeList> argAttrs) { 1017 result->addRegion(); 1018 result->addAttribute(SymbolTable::getSymbolAttrName(), 1019 builder->getStringAttr(name)); 1020 result->addAttribute("type", builder->getTypeAttr(type)); 1021 result->attributes.append(attrs.begin(), attrs.end()); 1022 if (argAttrs.empty()) 1023 return; 1024 1025 unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams(); 1026 assert(numInputs == argAttrs.size() && 1027 "expected as many argument attribute lists as arguments"); 1028 SmallString<8> argAttrName; 1029 for (unsigned i = 0; i < numInputs; ++i) 1030 if (auto argDict = argAttrs[i].getDictionary()) 1031 result->addAttribute(getArgAttrName(i, argAttrName), argDict); 1032 } 1033 1034 // Build an LLVM function type from the given lists of input and output types. 1035 // Returns a null type if any of the types provided are non-LLVM types, or if 1036 // there is more than one output type. 1037 static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs, 1038 ArrayRef<Type> outputs, 1039 impl::VariadicFlag variadicFlag, 1040 std::string &errorMessage) { 1041 if (outputs.size() > 1) { 1042 errorMessage = "expected zero or one function result"; 1043 return {}; 1044 } 1045 1046 // Convert inputs to LLVM types, exit early on error. 1047 SmallVector<LLVMType, 4> llvmInputs; 1048 for (auto t : inputs) { 1049 auto llvmTy = t.dyn_cast<LLVMType>(); 1050 if (!llvmTy) { 1051 errorMessage = "expected LLVM type for function arguments"; 1052 return {}; 1053 } 1054 llvmInputs.push_back(llvmTy); 1055 } 1056 1057 // Get the dialect from the input type, if any exist. Look it up in the 1058 // context otherwise. 1059 LLVMDialect *dialect = 1060 llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>() 1061 : &llvmInputs.front().getDialect(); 1062 1063 // No output is denoted as "void" in LLVM type system. 1064 LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) 1065 : outputs.front().dyn_cast<LLVMType>(); 1066 if (!llvmOutput) { 1067 errorMessage = "expected LLVM type for function results"; 1068 return {}; 1069 } 1070 return LLVMType::getFunctionTy(llvmOutput, llvmInputs, 1071 variadicFlag.isVariadic()); 1072 } 1073 1074 // Print the LLVMFuncOp. Collects argument and result types and passes them 1075 // to the trait printer. Drops "void" result since it cannot be parsed back. 1076 static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) { 1077 LLVMType fnType = op.getType(); 1078 SmallVector<Type, 8> argTypes; 1079 SmallVector<Type, 1> resTypes; 1080 argTypes.reserve(fnType.getFunctionNumParams()); 1081 for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) 1082 argTypes.push_back(fnType.getFunctionParamType(i)); 1083 1084 LLVMType returnType = fnType.getFunctionResultType(); 1085 if (!returnType.getUnderlyingType()->isVoidTy()) 1086 resTypes.push_back(returnType); 1087 1088 impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes); 1089 } 1090 1091 // Hook for OpTrait::FunctionLike, called after verifying that the 'type' 1092 // attribute is present. This can check for preconditions of the 1093 // getNumArguments hook not failing. 1094 LogicalResult LLVMFuncOp::verifyType() { 1095 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>(); 1096 if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy()) 1097 return emitOpError("requires '" + getTypeAttrName() + 1098 "' attribute of wrapped LLVM function type"); 1099 1100 return success(); 1101 } 1102 1103 // Hook for OpTrait::FunctionLike, returns the number of function arguments. 1104 // Depends on the type attribute being correct as checked by verifyType 1105 unsigned LLVMFuncOp::getNumFuncArguments() { 1106 return getType().getUnderlyingType()->getFunctionNumParams(); 1107 } 1108 1109 static LogicalResult verify(LLVMFuncOp op) { 1110 if (op.isExternal()) 1111 return success(); 1112 1113 if (op.isVarArg()) 1114 return op.emitOpError("only external functions can be variadic"); 1115 1116 auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType()); 1117 unsigned numArguments = funcType->getNumParams(); 1118 Block &entryBlock = op.front(); 1119 for (unsigned i = 0; i < numArguments; ++i) { 1120 Type argType = entryBlock.getArgument(i)->getType(); 1121 auto argLLVMType = argType.dyn_cast<LLVMType>(); 1122 if (!argLLVMType) 1123 return op.emitOpError("entry block argument #") 1124 << i << " is not of LLVM type"; 1125 if (funcType->getParamType(i) != argLLVMType.getUnderlyingType()) 1126 return op.emitOpError("the type of entry block argument #") 1127 << i << " does not match the function signature"; 1128 } 1129 1130 return success(); 1131 } 1132 1133 //===----------------------------------------------------------------------===// 1134 // LLVMDialect initialization, type parsing, and registration. 1135 //===----------------------------------------------------------------------===// 1136 1137 namespace mlir { 1138 namespace LLVM { 1139 namespace detail { 1140 struct LLVMDialectImpl { 1141 LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} 1142 1143 llvm::LLVMContext llvmContext; 1144 llvm::Module module; 1145 1146 /// A set of LLVMTypes that are cached on construction to avoid any lookups or 1147 /// locking. 1148 LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 1149 LLVMType doubleTy, floatTy, halfTy; 1150 LLVMType voidTy; 1151 1152 /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not 1153 /// multi-threaded and requires locked access to prevent race conditions. 1154 llvm::sys::SmartMutex<true> mutex; 1155 }; 1156 } // end namespace detail 1157 } // end namespace LLVM 1158 } // end namespace mlir 1159 1160 LLVMDialect::LLVMDialect(MLIRContext *context) 1161 : Dialect(getDialectNamespace(), context), 1162 impl(new detail::LLVMDialectImpl()) { 1163 addTypes<LLVMType>(); 1164 addOperations< 1165 #define GET_OP_LIST 1166 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 1167 >(); 1168 1169 // Support unknown operations because not all LLVM operations are registered. 1170 allowUnknownOperations(); 1171 1172 // Cache some of the common LLVM types to avoid the need for lookups/locking. 1173 auto &llvmContext = impl->llvmContext; 1174 /// Integer Types. 1175 impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext)); 1176 impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext)); 1177 impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext)); 1178 impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext)); 1179 impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext)); 1180 impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext)); 1181 /// Float Types. 1182 impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext)); 1183 impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext)); 1184 impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext)); 1185 /// Other Types. 1186 impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext)); 1187 } 1188 1189 LLVMDialect::~LLVMDialect() {} 1190 1191 #define GET_OP_CLASSES 1192 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc" 1193 1194 llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } 1195 llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } 1196 1197 /// Parse a type registered to this dialect. 1198 Type LLVMDialect::parseType(StringRef tyData, Location loc) const { 1199 // LLVM is not thread-safe, so lock access to it. 1200 llvm::sys::SmartScopedLock<true> lock(impl->mutex); 1201 1202 llvm::SMDiagnostic errorMessage; 1203 llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); 1204 if (!type) 1205 return (emitError(loc, errorMessage.getMessage()), nullptr); 1206 return LLVMType::get(getContext(), type); 1207 } 1208 1209 /// Print a type registered to this dialect. 1210 void LLVMDialect::printType(Type type, raw_ostream &os) const { 1211 auto llvmType = type.dyn_cast<LLVMType>(); 1212 assert(llvmType && "printing wrong type"); 1213 assert(llvmType.getUnderlyingType() && "no underlying LLVM type"); 1214 llvmType.getUnderlyingType()->print(os); 1215 } 1216 1217 /// Verify LLVMIR function argument attributes. 1218 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, 1219 unsigned regionIdx, 1220 unsigned argIdx, 1221 NamedAttribute argAttr) { 1222 // Check that llvm.noalias is a boolean attribute. 1223 if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>()) 1224 return op->emitError() 1225 << "llvm.noalias argument attribute of non boolean type"; 1226 return success(); 1227 } 1228 1229 static DialectRegistration<LLVMDialect> llvmDialect; 1230 1231 //===----------------------------------------------------------------------===// 1232 // LLVMType. 1233 //===----------------------------------------------------------------------===// 1234 1235 namespace mlir { 1236 namespace LLVM { 1237 namespace detail { 1238 struct LLVMTypeStorage : public ::mlir::TypeStorage { 1239 LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {} 1240 1241 // LLVM types are pointer-unique. 1242 using KeyTy = llvm::Type *; 1243 bool operator==(const KeyTy &key) const { return key == underlyingType; } 1244 1245 static LLVMTypeStorage *construct(TypeStorageAllocator &allocator, 1246 llvm::Type *ty) { 1247 return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty); 1248 } 1249 1250 llvm::Type *underlyingType; 1251 }; 1252 } // end namespace detail 1253 } // end namespace LLVM 1254 } // end namespace mlir 1255 1256 LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { 1257 return Base::get(context, FIRST_LLVM_TYPE, llvmType); 1258 } 1259 1260 /// Get an LLVMType with an llvm type that may cause changes to the underlying 1261 /// llvm context when constructed. 1262 LLVMType LLVMType::getLocked(LLVMDialect *dialect, 1263 llvm::function_ref<llvm::Type *()> typeBuilder) { 1264 // Lock access to the llvm context and build the type. 1265 llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex); 1266 return get(dialect->getContext(), typeBuilder()); 1267 } 1268 1269 LLVMDialect &LLVMType::getDialect() { 1270 return static_cast<LLVMDialect &>(Type::getDialect()); 1271 } 1272 1273 llvm::Type *LLVMType::getUnderlyingType() const { 1274 return getImpl()->underlyingType; 1275 } 1276 1277 /// Array type utilities. 1278 LLVMType LLVMType::getArrayElementType() { 1279 return get(getContext(), getUnderlyingType()->getArrayElementType()); 1280 } 1281 unsigned LLVMType::getArrayNumElements() { 1282 return getUnderlyingType()->getArrayNumElements(); 1283 } 1284 bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); } 1285 1286 /// Vector type utilities. 1287 LLVMType LLVMType::getVectorElementType() { 1288 return get(getContext(), getUnderlyingType()->getVectorElementType()); 1289 } 1290 bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } 1291 1292 /// Function type utilities. 1293 LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { 1294 return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx)); 1295 } 1296 unsigned LLVMType::getFunctionNumParams() { 1297 return getUnderlyingType()->getFunctionNumParams(); 1298 } 1299 LLVMType LLVMType::getFunctionResultType() { 1300 return get( 1301 getContext(), 1302 llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType()); 1303 } 1304 bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); } 1305 1306 /// Pointer type utilities. 1307 LLVMType LLVMType::getPointerTo(unsigned addrSpace) { 1308 // Lock access to the dialect as this may modify the LLVM context. 1309 return getLocked(&getDialect(), [=] { 1310 return getUnderlyingType()->getPointerTo(addrSpace); 1311 }); 1312 } 1313 LLVMType LLVMType::getPointerElementTy() { 1314 return get(getContext(), getUnderlyingType()->getPointerElementType()); 1315 } 1316 bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } 1317 1318 /// Struct type utilities. 1319 LLVMType LLVMType::getStructElementType(unsigned i) { 1320 return get(getContext(), getUnderlyingType()->getStructElementType(i)); 1321 } 1322 bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); } 1323 1324 /// Utilities used to generate floating point types. 1325 LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { 1326 return dialect->impl->doubleTy; 1327 } 1328 LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { 1329 return dialect->impl->floatTy; 1330 } 1331 LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { 1332 return dialect->impl->halfTy; 1333 } 1334 1335 /// Utilities used to generate integer types. 1336 LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) { 1337 switch (numBits) { 1338 case 1: 1339 return dialect->impl->int1Ty; 1340 case 8: 1341 return dialect->impl->int8Ty; 1342 case 16: 1343 return dialect->impl->int16Ty; 1344 case 32: 1345 return dialect->impl->int32Ty; 1346 case 64: 1347 return dialect->impl->int64Ty; 1348 case 128: 1349 return dialect->impl->int128Ty; 1350 default: 1351 break; 1352 } 1353 1354 // Lock access to the dialect as this may modify the LLVM context. 1355 return getLocked(dialect, [=] { 1356 return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits); 1357 }); 1358 } 1359 1360 /// Utilities used to generate other miscellaneous types. 1361 LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { 1362 // Lock access to the dialect as this may modify the LLVM context. 1363 return getLocked(&elementType.getDialect(), [=] { 1364 return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements); 1365 }); 1366 } 1367 LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params, 1368 bool isVarArg) { 1369 SmallVector<llvm::Type *, 8> llvmParams; 1370 for (auto param : params) 1371 llvmParams.push_back(param.getUnderlyingType()); 1372 1373 // Lock access to the dialect as this may modify the LLVM context. 1374 return getLocked(&result.getDialect(), [=] { 1375 return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams, 1376 isVarArg); 1377 }); 1378 } 1379 LLVMType LLVMType::getStructTy(LLVMDialect *dialect, 1380 ArrayRef<LLVMType> elements, bool isPacked) { 1381 SmallVector<llvm::Type *, 8> llvmElements; 1382 for (auto elt : elements) 1383 llvmElements.push_back(elt.getUnderlyingType()); 1384 1385 // Lock access to the dialect as this may modify the LLVM context. 1386 return getLocked(dialect, [=] { 1387 return llvm::StructType::get(dialect->getLLVMContext(), llvmElements, 1388 isPacked); 1389 }); 1390 } 1391 LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { 1392 // Lock access to the dialect as this may modify the LLVM context. 1393 return getLocked(&elementType.getDialect(), [=] { 1394 return llvm::VectorType::get(elementType.getUnderlyingType(), numElements); 1395 }); 1396 } 1397 LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { 1398 return dialect->impl->voidTy; 1399 } 1400 1401 //===----------------------------------------------------------------------===// 1402 // Utility functions. 1403 //===----------------------------------------------------------------------===// 1404 1405 Value *mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, 1406 StringRef name, StringRef value, 1407 LLVM::LLVMDialect *llvmDialect) { 1408 assert(builder.getInsertionBlock() && 1409 builder.getInsertionBlock()->getParentOp() && 1410 "expected builder to point to a block constained in an op"); 1411 auto module = 1412 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>(); 1413 assert(module && "builder points to an op outside of a module"); 1414 1415 // Create the global at the entry of the module. 1416 OpBuilder moduleBuilder(module.getBodyRegion()); 1417 auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect), 1418 value.size()); 1419 auto global = moduleBuilder.create<LLVM::GlobalOp>( 1420 loc, type, /*isConstant=*/true, name, builder.getStringAttr(value)); 1421 1422 // Get the pointer to the first character in the global string. 1423 Value *globalPtr = builder.create<LLVM::AddressOfOp>(loc, global); 1424 Value *cst0 = builder.create<LLVM::ConstantOp>( 1425 loc, LLVM::LLVMType::getInt64Ty(llvmDialect), 1426 builder.getIntegerAttr(builder.getIndexType(), 0)); 1427 return builder.create<LLVM::GEPOp>( 1428 loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, 1429 ArrayRef<Value *>({cst0, cst0})); 1430 }