github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (about) 1 //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the GPU kernel-related dialect and its operations. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Dialect/GPU/GPUDialect.h" 23 #include "mlir/Dialect/StandardOps/Ops.h" 24 #include "mlir/IR/Builders.h" 25 #include "mlir/IR/Function.h" 26 #include "mlir/IR/Module.h" 27 #include "mlir/IR/OpImplementation.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/IR/StandardTypes.h" 30 31 using namespace mlir; 32 using namespace mlir::gpu; 33 34 StringRef GPUDialect::getDialectName() { return "gpu"; } 35 36 bool GPUDialect::isKernel(FuncOp function) { 37 UnitAttr isKernelAttr = 38 function.getAttrOfType<UnitAttr>(getKernelFuncAttrName()); 39 return static_cast<bool>(isKernelAttr); 40 } 41 42 GPUDialect::GPUDialect(MLIRContext *context) 43 : Dialect(getDialectName(), context) { 44 addOperations<LaunchOp, LaunchFuncOp, 45 #define GET_OP_LIST 46 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 47 >(); 48 } 49 50 template <typename T> static LogicalResult verifyIndexOp(T op) { 51 auto dimension = op.dimension(); 52 if (dimension != "x" && dimension != "y" && dimension != "z") 53 return op.emitError("dimension \"") << dimension << "\" is invalid"; 54 return success(); 55 } 56 57 #define GET_OP_CLASSES 58 #include "mlir/Dialect/GPU/GPUOps.cpp.inc" 59 60 //===----------------------------------------------------------------------===// 61 // LaunchOp 62 //===----------------------------------------------------------------------===// 63 64 static SmallVector<Type, 4> getValueTypes(ArrayRef<Value *> values) { 65 SmallVector<Type, 4> types; 66 types.reserve(values.size()); 67 for (Value *v : values) 68 types.push_back(v->getType()); 69 return types; 70 } 71 72 void LaunchOp::build(Builder *builder, OperationState *result, Value *gridSizeX, 73 Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, 74 Value *blockSizeY, Value *blockSizeZ, 75 ArrayRef<Value *> operands) { 76 // Add grid and block sizes as op operands, followed by the data operands. 77 result->addOperands( 78 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 79 result->addOperands(operands); 80 81 // Create a kernel body region with kNumConfigRegionAttributes + N arguments, 82 // where the first kNumConfigRegionAttributes arguments have `index` type and 83 // the rest have the same types as the data operands. 84 Region *kernelRegion = result->addRegion(); 85 Block *body = new Block(); 86 body->addArguments( 87 std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType())); 88 body->addArguments(getValueTypes(operands)); 89 kernelRegion->push_back(body); 90 } 91 92 Region &LaunchOp::getBody() { return getOperation()->getRegion(0); } 93 94 KernelDim3 LaunchOp::getBlockIds() { 95 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 96 auto args = getBody().getBlocks().front().getArguments(); 97 return KernelDim3{args[0], args[1], args[2]}; 98 } 99 100 KernelDim3 LaunchOp::getThreadIds() { 101 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 102 auto args = getBody().getBlocks().front().getArguments(); 103 return KernelDim3{args[3], args[4], args[5]}; 104 } 105 106 KernelDim3 LaunchOp::getGridSize() { 107 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 108 auto args = getBody().getBlocks().front().getArguments(); 109 return KernelDim3{args[6], args[7], args[8]}; 110 } 111 112 KernelDim3 LaunchOp::getBlockSize() { 113 assert(!getBody().getBlocks().empty() && "FuncOp body must not be empty."); 114 auto args = getBody().getBlocks().front().getArguments(); 115 return KernelDim3{args[9], args[10], args[11]}; 116 } 117 118 LaunchOp::operand_range LaunchOp::getKernelOperandValues() { 119 return llvm::drop_begin(getOperands(), kNumConfigOperands); 120 } 121 122 LaunchOp::operand_type_range LaunchOp::getKernelOperandTypes() { 123 return llvm::drop_begin(getOperandTypes(), kNumConfigOperands); 124 } 125 126 KernelDim3 LaunchOp::getGridSizeOperandValues() { 127 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 128 } 129 130 KernelDim3 LaunchOp::getBlockSizeOperandValues() { 131 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 132 } 133 134 llvm::iterator_range<Block::args_iterator> LaunchOp::getKernelArguments() { 135 auto args = getBody().getBlocks().front().getArguments(); 136 return llvm::drop_begin(args, LaunchOp::kNumConfigRegionAttributes); 137 } 138 139 LogicalResult LaunchOp::verify() { 140 // Kernel launch takes kNumConfigOperands leading operands for grid/block 141 // sizes and transforms them into kNumConfigRegionAttributes region arguments 142 // for block/thread identifiers and grid/block sizes. 143 if (!getBody().empty()) { 144 Block &entryBlock = getBody().front(); 145 if (entryBlock.getNumArguments() != kNumConfigOperands + getNumOperands()) 146 return emitError("unexpected number of region arguments"); 147 } 148 149 // Block terminators without successors are expected to exit the kernel region 150 // and must be `gpu.launch`. 151 for (Block &block : getBody()) { 152 if (block.empty()) 153 continue; 154 if (block.back().getNumSuccessors() != 0) 155 continue; 156 if (!isa<gpu::Return>(&block.back())) { 157 return block.back() 158 .emitError("expected 'gpu.terminator' or a terminator with " 159 "successors") 160 .attachNote(getLoc()) 161 << "in '" << getOperationName() << "' body region"; 162 } 163 } 164 165 return success(); 166 } 167 168 // Pretty-print the kernel grid/block size assignment as 169 // (%iter-x, %iter-y, %iter-z) in 170 // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) 171 // where %size-* and %iter-* will correspond to the body region arguments. 172 static void printSizeAssignment(OpAsmPrinter *p, KernelDim3 size, 173 ArrayRef<Value *> operands, KernelDim3 ids) { 174 *p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in ("; 175 *p << *size.x << " = " << *operands[0] << ", "; 176 *p << *size.y << " = " << *operands[1] << ", "; 177 *p << *size.z << " = " << *operands[2] << ')'; 178 } 179 180 void LaunchOp::print(OpAsmPrinter *p) { 181 SmallVector<Value *, 12> operandContainer(operand_begin(), operand_end()); 182 ArrayRef<Value *> operands(operandContainer); 183 184 // Print the launch configuration. 185 *p << getOperationName() << ' ' << getBlocksKeyword(); 186 printSizeAssignment(p, getGridSize(), operands.take_front(3), getBlockIds()); 187 *p << ' ' << getThreadsKeyword(); 188 printSizeAssignment(p, getBlockSize(), operands.slice(3, 3), getThreadIds()); 189 190 // From now on, the first kNumConfigOperands operands corresponding to grid 191 // and block sizes are irrelevant, so we can drop them. 192 operands = operands.drop_front(kNumConfigOperands); 193 194 // Print the data argument remapping. 195 if (!getBody().empty() && !operands.empty()) { 196 *p << ' ' << getArgsKeyword() << '('; 197 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 198 if (i != 0) 199 *p << ", "; 200 *p << *getBody().front().getArgument(kNumConfigRegionAttributes + i) 201 << " = " << *operands[i]; 202 } 203 *p << ") "; 204 } 205 206 // Print the types of data arguments. 207 if (!operands.empty()) { 208 *p << ": "; 209 for (unsigned i = 0, e = operands.size(); i < e; ++i) { 210 if (i != 0) 211 *p << ", "; 212 *p << operands[i]->getType(); 213 } 214 } 215 216 p->printRegion(getBody(), /*printEntryBlockArgs=*/false); 217 p->printOptionalAttrDict(getAttrs()); 218 } 219 220 // Parse the size assignment blocks for blocks and threads. These have the form 221 // (%region_arg, %region_arg, %region_arg) in 222 // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) 223 // where %region_arg are percent-identifiers for the region arguments to be 224 // introduced futher (SSA defs), and %operand are percent-identifiers for the 225 // SSA value uses. 226 static ParseResult 227 parseSizeAssignment(OpAsmParser *parser, 228 MutableArrayRef<OpAsmParser::OperandType> sizes, 229 MutableArrayRef<OpAsmParser::OperandType> regionSizes, 230 MutableArrayRef<OpAsmParser::OperandType> indices) { 231 assert(indices.size() == 3 && "space for three indices expected"); 232 SmallVector<OpAsmParser::OperandType, 3> args; 233 if (parser->parseRegionArgumentList(args, /*requiredOperandCount=*/3, 234 OpAsmParser::Delimiter::Paren) || 235 parser->parseKeyword("in") || parser->parseLParen()) 236 return failure(); 237 std::move(args.begin(), args.end(), indices.begin()); 238 239 for (int i = 0; i < 3; ++i) { 240 if (i != 0 && parser->parseComma()) 241 return failure(); 242 if (parser->parseRegionArgument(regionSizes[i]) || parser->parseEqual() || 243 parser->parseOperand(sizes[i])) 244 return failure(); 245 } 246 247 return parser->parseRParen(); 248 } 249 250 // Parses a Launch operation. 251 // operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in` ssa-reassignment 252 // `threads` `(` ssa-id-list `)` `in` ssa-reassignment 253 // (`args` ssa-reassignment `:` type-list)? 254 // region attr-dict? 255 // ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` 256 ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) { 257 // Sizes of the grid and block. 258 SmallVector<OpAsmParser::OperandType, kNumConfigOperands> sizes( 259 kNumConfigOperands); 260 MutableArrayRef<OpAsmParser::OperandType> sizesRef(sizes); 261 262 // Actual (data) operands passed to the kernel. 263 SmallVector<OpAsmParser::OperandType, 4> dataOperands; 264 265 // Region arguments to be created. 266 SmallVector<OpAsmParser::OperandType, 16> regionArgs( 267 kNumConfigRegionAttributes); 268 MutableArrayRef<OpAsmParser::OperandType> regionArgsRef(regionArgs); 269 270 // Parse the size assignment segments: the first segment assigns grid siezs 271 // and defines values for block identifiers; the second segment assigns block 272 // sies and defines values for thread identifiers. In the region argument 273 // list, identifiers preceed sizes, and block-related values preceed 274 // thread-related values. 275 if (parser->parseKeyword(getBlocksKeyword().data()) || 276 parseSizeAssignment(parser, sizesRef.take_front(3), 277 regionArgsRef.slice(6, 3), 278 regionArgsRef.slice(0, 3)) || 279 parser->parseKeyword(getThreadsKeyword().data()) || 280 parseSizeAssignment(parser, sizesRef.drop_front(3), 281 regionArgsRef.slice(9, 3), 282 regionArgsRef.slice(3, 3)) || 283 parser->resolveOperands(sizes, parser->getBuilder().getIndexType(), 284 result->operands)) 285 return failure(); 286 287 // If kernel argument renaming segment is present, parse it. When present, 288 // the segment should have at least one element. If this segment is present, 289 // so is the trailing type list. Parse it as well and use the parsed types 290 // to resolve the operands passed to the kernel arguments. 291 SmallVector<Type, 4> dataTypes; 292 if (!parser->parseOptionalKeyword(getArgsKeyword().data())) { 293 llvm::SMLoc argsLoc = parser->getCurrentLocation(); 294 295 regionArgs.push_back({}); 296 dataOperands.push_back({}); 297 if (parser->parseLParen() || 298 parser->parseRegionArgument(regionArgs.back()) || 299 parser->parseEqual() || parser->parseOperand(dataOperands.back())) 300 return failure(); 301 302 while (!parser->parseOptionalComma()) { 303 regionArgs.push_back({}); 304 dataOperands.push_back({}); 305 if (parser->parseRegionArgument(regionArgs.back()) || 306 parser->parseEqual() || parser->parseOperand(dataOperands.back())) 307 return failure(); 308 } 309 310 if (parser->parseRParen() || parser->parseColonTypeList(dataTypes) || 311 parser->resolveOperands(dataOperands, dataTypes, argsLoc, 312 result->operands)) 313 return failure(); 314 } 315 316 // Introduce the body region and parse it. The region has 317 // kNumConfigRegionAttributes leading arguments that correspond to 318 // block/thread identifiers and grid/block sizes, all of the `index` type. 319 // Follow the actual kernel arguments. 320 Type index = parser->getBuilder().getIndexType(); 321 dataTypes.insert(dataTypes.begin(), kNumConfigRegionAttributes, index); 322 Region *body = result->addRegion(); 323 return failure(parser->parseRegion(*body, regionArgs, dataTypes) || 324 parser->parseOptionalAttributeDict(result->attributes)); 325 } 326 327 void LaunchOp::eraseKernelArgument(unsigned index) { 328 Block &entryBlock = getBody().front(); 329 assert(index < entryBlock.getNumArguments() - kNumConfigRegionAttributes && 330 "kernel argument index overflow"); 331 entryBlock.eraseArgument(kNumConfigRegionAttributes + index); 332 getOperation()->eraseOperand(kNumConfigOperands + index); 333 } 334 335 namespace { 336 // Clone any known constants passed as operands to the kernel into its body. 337 class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { 338 using OpRewritePattern<LaunchOp>::OpRewritePattern; 339 340 PatternMatchResult matchAndRewrite(LaunchOp launchOp, 341 PatternRewriter &rewriter) const override { 342 auto oringInsertionPoint = rewriter.saveInsertionPoint(); 343 rewriter.setInsertionPointToStart(&launchOp.getBody().front()); 344 345 // Traverse operands passed to kernel and check if some of them are known 346 // constants. If so, clone the constant operation inside the kernel region 347 // and use it instead of passing the value from the parent region. Perform 348 // the traversal in the inverse order to simplify index arithmetics when 349 // dropping arguments. 350 SmallVector<Value *, 8> operands(launchOp.getKernelOperandValues().begin(), 351 launchOp.getKernelOperandValues().end()); 352 SmallVector<Value *, 8> kernelArgs(launchOp.getKernelArguments().begin(), 353 launchOp.getKernelArguments().end()); 354 bool found = false; 355 for (unsigned i = operands.size(); i > 0; --i) { 356 unsigned index = i - 1; 357 Value *operand = operands[index]; 358 if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { 359 continue; 360 } 361 362 found = true; 363 Value *internalConstant = 364 rewriter.clone(*operand->getDefiningOp())->getResult(0); 365 Value *kernelArg = kernelArgs[index]; 366 kernelArg->replaceAllUsesWith(internalConstant); 367 launchOp.eraseKernelArgument(index); 368 } 369 rewriter.restoreInsertionPoint(oringInsertionPoint); 370 371 if (!found) 372 return matchFailure(); 373 374 rewriter.updatedRootInPlace(launchOp); 375 return matchSuccess(); 376 } 377 }; 378 } // end namespace 379 380 void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, 381 MLIRContext *context) { 382 results.insert<PropagateConstantBounds>(context); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // LaunchFuncOp 387 //===----------------------------------------------------------------------===// 388 389 void LaunchFuncOp::build(Builder *builder, OperationState *result, 390 FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, 391 Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, 392 Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { 393 // Add grid and block sizes as op operands, followed by the data operands. 394 result->addOperands( 395 {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); 396 result->addOperands(kernelOperands); 397 result->addAttribute(getKernelAttrName(), 398 builder->getSymbolRefAttr(kernelFunc)); 399 } 400 401 void LaunchFuncOp::build(Builder *builder, OperationState *result, 402 FuncOp kernelFunc, KernelDim3 gridSize, 403 KernelDim3 blockSize, 404 ArrayRef<Value *> kernelOperands) { 405 build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, 406 blockSize.x, blockSize.y, blockSize.z, kernelOperands); 407 } 408 409 StringRef LaunchFuncOp::kernel() { 410 return getAttrOfType<SymbolRefAttr>(getKernelAttrName()).getValue(); 411 } 412 413 unsigned LaunchFuncOp::getNumKernelOperands() { 414 return getNumOperands() - kNumConfigOperands; 415 } 416 417 Value *LaunchFuncOp::getKernelOperand(unsigned i) { 418 return getOperation()->getOperand(i + kNumConfigOperands); 419 } 420 421 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { 422 return KernelDim3{getOperand(0), getOperand(1), getOperand(2)}; 423 } 424 425 KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { 426 return KernelDim3{getOperand(3), getOperand(4), getOperand(5)}; 427 } 428 429 LogicalResult LaunchFuncOp::verify() { 430 auto kernelAttr = this->getAttr(getKernelAttrName()); 431 if (!kernelAttr) { 432 return emitOpError("attribute 'kernel' must be specified"); 433 } else if (!kernelAttr.isa<SymbolRefAttr>()) { 434 return emitOpError("attribute 'kernel' must be a function"); 435 } 436 437 auto module = getParentOfType<ModuleOp>(); 438 FuncOp kernelFunc = module.lookupSymbol<FuncOp>(kernel()); 439 if (!kernelFunc) 440 return emitError() << "kernel function '" << kernelAttr << "' is undefined"; 441 442 if (!kernelFunc.getAttrOfType<mlir::UnitAttr>( 443 GPUDialect::getKernelFuncAttrName())) { 444 return emitError("kernel function is missing the '") 445 << GPUDialect::getKernelFuncAttrName() << "' attribute"; 446 } 447 unsigned numKernelFuncArgs = kernelFunc.getNumArguments(); 448 if (getNumKernelOperands() != numKernelFuncArgs) { 449 return emitOpError("got ") 450 << getNumKernelOperands() << " kernel operands but expected " 451 << numKernelFuncArgs; 452 } 453 auto functionType = kernelFunc.getType(); 454 for (unsigned i = 0; i < numKernelFuncArgs; ++i) { 455 if (getKernelOperand(i)->getType() != functionType.getInput(i)) { 456 return emitOpError("type of function argument ") 457 << i << " does not match"; 458 } 459 } 460 return success(); 461 }