github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/VectorOps/VectorOps.cpp (about) 1 //===- VectorOps.cpp - MLIR Super Vectorizer Operations -------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements convenience types for working with super-vectorization 19 // operations, in particular super-vector loads and stores. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Dialect/VectorOps/VectorOps.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/AffineMap.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/OpImplementation.h" 28 #include "mlir/IR/TypeUtilities.h" 29 #include "mlir/Support/LLVM.h" 30 31 using namespace mlir; 32 using namespace mlir::vector; 33 34 //===----------------------------------------------------------------------===// 35 // VectorOpsDialect 36 //===----------------------------------------------------------------------===// 37 38 mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context) 39 : Dialect(getDialectNamespace(), context) { 40 addOperations<VectorTransferReadOp, VectorTransferWriteOp, 41 VectorTypeCastOp>(); 42 addOperations< 43 #define GET_OP_LIST 44 #include "mlir/Dialect/VectorOps/VectorOps.cpp.inc" 45 >(); 46 } 47 48 //===----------------------------------------------------------------------===// 49 // ExtractElementOp 50 //===----------------------------------------------------------------------===// 51 52 static void print(OpAsmPrinter *p, ExtractElementOp op) { 53 *p << op.getOperationName() << " " << *op.vector() << op.position(); 54 p->printOptionalAttrDict(op.getAttrs(), {"position"}); 55 *p << " : " << op.vector()->getType(); 56 } 57 58 static ParseResult parseExtractElementOp(OpAsmParser *parser, 59 OperationState *result) { 60 llvm::SMLoc attributeLoc, typeLoc; 61 SmallVector<NamedAttribute, 4> attrs; 62 OpAsmParser::OperandType vector; 63 Type type; 64 Attribute attr; 65 if (parser->parseOperand(vector) || 66 parser->getCurrentLocation(&attributeLoc) || 67 parser->parseAttribute(attr, "position", attrs) || 68 parser->parseOptionalAttributeDict(attrs) || 69 parser->getCurrentLocation(&typeLoc) || parser->parseColonType(type)) 70 return failure(); 71 72 auto vectorType = type.dyn_cast<VectorType>(); 73 if (!vectorType) 74 return parser->emitError(typeLoc, "expected vector type"); 75 76 auto positionAttr = attr.dyn_cast<ArrayAttr>(); 77 if (!positionAttr || 78 static_cast<int64_t>(positionAttr.size()) > vectorType.getRank()) 79 return parser->emitError( 80 attributeLoc, 81 "expected position attribute of rank smaller than vector"); 82 83 Type resType = 84 (static_cast<int64_t>(positionAttr.size()) == vectorType.getRank()) 85 ? vectorType.getElementType() 86 : VectorType::get( 87 vectorType.getShape().drop_front(positionAttr.size()), 88 vectorType.getElementType()); 89 90 result->attributes = attrs; 91 return failure(parser->resolveOperand(vector, type, result->operands) || 92 parser->addTypeToList(resType, result->types)); 93 } 94 95 static LogicalResult verify(ExtractElementOp op) { 96 auto positionAttr = op.position().getValue(); 97 if (positionAttr.empty()) 98 return op.emitOpError("expected non-empty position attribute"); 99 if (positionAttr.size() > static_cast<unsigned>(op.getVectorType().getRank())) 100 return op.emitOpError( 101 "expected position attribute of rank smaller than vector"); 102 for (auto en : llvm::enumerate(positionAttr)) { 103 auto attr = en.value().dyn_cast<IntegerAttr>(); 104 if (!attr || attr.getInt() < 0 || 105 attr.getInt() > op.getVectorType().getDimSize(en.index())) 106 return op.emitOpError("expected position attribute #") 107 << (en.index() + 1) 108 << " to be a positive integer smaller than the corresponding " 109 "vector dimension"; 110 } 111 return success(); 112 } 113 //===----------------------------------------------------------------------===// 114 // OuterProductOp 115 //===----------------------------------------------------------------------===// 116 117 static void print(OpAsmPrinter *p, OuterProductOp op) { 118 *p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); 119 if (llvm::size(op.acc()) > 0) 120 *p << ", " << **op.acc().begin(); 121 *p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType(); 122 } 123 124 static ParseResult parseOuterProductOp(OpAsmParser *parser, 125 OperationState *result) { 126 SmallVector<OpAsmParser::OperandType, 3> operandsInfo; 127 Type tLHS, tRHS; 128 if (parser->parseOperandList(operandsInfo) || parser->parseColonType(tLHS) || 129 parser->parseComma() || parser->parseType(tRHS)) 130 return failure(); 131 if (operandsInfo.size() < 2) 132 return parser->emitError(parser->getNameLoc(), 133 "expected at least 2 operands"); 134 VectorType vLHS = tLHS.dyn_cast<VectorType>(); 135 VectorType vRHS = tRHS.dyn_cast<VectorType>(); 136 if (!vLHS || !vRHS) 137 return parser->emitError(parser->getNameLoc(), "expected 2 vector types"); 138 VectorType resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, 139 vLHS.getElementType()); 140 return failure( 141 parser->resolveOperand(operandsInfo[0], tLHS, result->operands) || 142 parser->resolveOperand(operandsInfo[1], tRHS, result->operands) || 143 (operandsInfo.size() > 2 && 144 parser->resolveOperand(operandsInfo[2], resType, result->operands)) || 145 parser->addTypeToList(resType, result->types)); 146 } 147 148 static LogicalResult verify(OuterProductOp op) { 149 VectorType vLHS = op.getOperandVectorTypeLHS(), 150 vRHS = op.getOperandVectorTypeRHS(), 151 vACC = op.getOperandVectorTypeACC(), vRES = op.getVectorType(); 152 if (vLHS.getRank() != 1) 153 return op.emitOpError("expected 1-d vector for operand #1"); 154 if (vRHS.getRank() != 1) 155 return op.emitOpError("expected 1-d vector for operand #2"); 156 if (vRES.getRank() != 2) 157 return op.emitOpError("expected 2-d vector result"); 158 if (vLHS.getDimSize(0) != vRES.getDimSize(0)) 159 return op.emitOpError("expected #1 operand dim to match result dim #1"); 160 if (vRHS.getDimSize(0) != vRES.getDimSize(1)) 161 return op.emitOpError("expected #2 operand dim to match result dim #2"); 162 if (vACC && vACC != vRES) 163 return op.emitOpError("expected operand #3 of same type as result type"); 164 return success(); 165 } 166 167 //===----------------------------------------------------------------------===// 168 // VectorTransferReadOp 169 //===----------------------------------------------------------------------===// 170 template <typename EmitFun> 171 static LogicalResult verifyPermutationMap(AffineMap permutationMap, 172 EmitFun emitOpError) { 173 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); 174 for (auto expr : permutationMap.getResults()) { 175 auto dim = expr.dyn_cast<AffineDimExpr>(); 176 auto zero = expr.dyn_cast<AffineConstantExpr>(); 177 if (zero) { 178 if (zero.getValue() != 0) { 179 return emitOpError( 180 "requires a projected permutation_map (at most one dim or the zero " 181 "constant can appear in each result)"); 182 } 183 continue; 184 } 185 if (!dim) { 186 return emitOpError("requires a projected permutation_map (at most one " 187 "dim or the zero constant can appear in each result)"); 188 } 189 if (seen[dim.getPosition()]) { 190 return emitOpError( 191 "requires a permutation_map that is a permutation (found one dim " 192 "used more than once)"); 193 } 194 seen[dim.getPosition()] = true; 195 } 196 return success(); 197 } 198 199 void VectorTransferReadOp::build(Builder *builder, OperationState *result, 200 VectorType vectorType, Value *srcMemRef, 201 ArrayRef<Value *> srcIndices, 202 AffineMap permutationMap, 203 Optional<Value *> paddingValue) { 204 result->addOperands(srcMemRef); 205 result->addOperands(srcIndices); 206 if (paddingValue) { 207 result->addOperands({*paddingValue}); 208 } 209 result->addAttribute(getPermutationMapAttrName(), 210 builder->getAffineMapAttr(permutationMap)); 211 result->addTypes(vectorType); 212 } 213 214 auto VectorTransferReadOp::getIndices() -> operand_range { 215 auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; 216 auto end = begin + getMemRefType().getRank(); 217 return {begin, end}; 218 } 219 220 Optional<Value *> VectorTransferReadOp::getPaddingValue() { 221 auto memRefRank = getMemRefType().getRank(); 222 if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { 223 return None; 224 } 225 return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank)); 226 } 227 228 AffineMap VectorTransferReadOp::getPermutationMap() { 229 return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue(); 230 } 231 232 void VectorTransferReadOp::print(OpAsmPrinter *p) { 233 *p << getOperationName() << " "; 234 p->printOperand(getMemRef()); 235 *p << "["; 236 p->printOperands(getIndices()); 237 *p << "]"; 238 auto optionalPaddingValue = getPaddingValue(); 239 if (optionalPaddingValue) { 240 *p << ", ("; 241 p->printOperand(*optionalPaddingValue); 242 *p << ")"; 243 } 244 p->printOptionalAttrDict(getAttrs()); 245 *p << " : " << getMemRefType(); 246 *p << ", " << getResultType(); 247 } 248 249 ParseResult VectorTransferReadOp::parse(OpAsmParser *parser, 250 OperationState *result) { 251 OpAsmParser::OperandType memrefInfo; 252 SmallVector<OpAsmParser::OperandType, 8> indexInfo; 253 SmallVector<OpAsmParser::OperandType, 8> paddingInfo; 254 SmallVector<Type, 2> types; 255 256 // Parsing with support for optional paddingValue. 257 if (parser->parseOperand(memrefInfo) || 258 parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 259 parser->parseTrailingOperandList(paddingInfo, 260 OpAsmParser::Delimiter::Paren) || 261 parser->parseOptionalAttributeDict(result->attributes) || 262 parser->parseColonTypeList(types)) 263 return failure(); 264 265 // Resolution. 266 if (types.size() != 2) 267 return parser->emitError(parser->getNameLoc(), "expected 2 types"); 268 MemRefType memrefType = types[0].dyn_cast<MemRefType>(); 269 if (!memrefType) 270 return parser->emitError(parser->getNameLoc(), "memRef type expected"); 271 VectorType vectorType = types[1].dyn_cast<VectorType>(); 272 if (!vectorType) 273 return parser->emitError(parser->getNameLoc(), "vector type expected"); 274 275 // Extract optional paddingValue. 276 // At this point, indexInfo may contain the optional paddingValue, pop it 277 // out. 278 if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank()) 279 return parser->emitError(parser->getNameLoc(), 280 "expected " + Twine(memrefType.getRank()) + 281 " indices to the memref"); 282 if (paddingInfo.size() > 1) 283 return parser->emitError(parser->getNameLoc(), 284 "expected at most one padding value"); 285 Type paddingType; 286 bool hasOptionalPaddingValue = !paddingInfo.empty(); 287 if (hasOptionalPaddingValue) { 288 paddingType = vectorType.getElementType(); 289 } 290 auto indexType = parser->getBuilder().getIndexType(); 291 return failure( 292 parser->resolveOperand(memrefInfo, memrefType, result->operands) || 293 parser->resolveOperands(indexInfo, indexType, result->operands) || 294 (hasOptionalPaddingValue && 295 parser->resolveOperand(paddingInfo[0], paddingType, result->operands)) || 296 parser->addTypeToList(vectorType, result->types)); 297 } 298 299 LogicalResult VectorTransferReadOp::verify() { 300 // Consistency of memref type in function type. 301 if (llvm::empty(getOperands())) { 302 return emitOpError( 303 "requires at least a memref operand followed by 'rank' indices"); 304 } 305 if (!getMemRef()->getType().isa<MemRefType>()) { 306 return emitOpError("requires a memref as first operand"); 307 } 308 // Consistency of vector type in function type. 309 if (!getResult()->getType().isa<VectorType>()) { 310 return emitOpError("should have a vector result type in function type: " 311 "memref_type<...xelemental_type>, vector_type"); 312 } 313 // Consistency of elemental types in memref and vector. 314 MemRefType memrefType = getMemRefType(); 315 VectorType vectorType = getResultType(); 316 if (memrefType.getElementType() != vectorType.getElementType()) 317 return emitOpError( 318 "requires memref and vector types of the same elemental type"); 319 // Consistency of number of input types. 320 auto optionalPaddingValue = getPaddingValue(); 321 unsigned expectedNumOperands = Offsets::FirstIndexOffset + 322 memrefType.getRank() + 323 (optionalPaddingValue ? 1 : 0); 324 // Checks on the actual operands and their types. 325 if (getNumOperands() != expectedNumOperands) { 326 return emitOpError("expects ") 327 << expectedNumOperands << " operands (of which " 328 << memrefType.getRank() << " indices)"; 329 } 330 // Consistency of padding value with vector type. 331 if (optionalPaddingValue) { 332 auto paddingValue = *optionalPaddingValue; 333 auto elementalType = paddingValue->getType(); 334 if (!VectorType::isValidElementType(elementalType)) { 335 return emitOpError("requires valid padding vector elemental type"); 336 } 337 if (elementalType != vectorType.getElementType()) { 338 return emitOpError( 339 "requires formal padding and vector of the same elemental type"); 340 } 341 } 342 // Consistency of indices types. 343 unsigned numIndices = 0; 344 for (auto *idx : getIndices()) { 345 if (!idx->getType().isIndex()) { 346 return emitOpError( 347 "index to vector.transfer_read must have 'index' type"); 348 } 349 ++numIndices; 350 } 351 if (numIndices != memrefType.getRank()) { 352 return emitOpError("requires at least a memref operand followed by ") 353 << memrefType.getRank() << " indices"; 354 } 355 356 // Consistency of AffineMap attribute. 357 if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) { 358 return emitOpError("requires an AffineMapAttr named 'permutation_map'"); 359 } 360 auto permutationMap = getPermutationMap(); 361 if (permutationMap.getNumSymbols() != 0) { 362 return emitOpError("requires a permutation_map without symbols"); 363 } 364 if (permutationMap.getNumInputs() != memrefType.getRank()) { 365 return emitOpError("requires a permutation_map with input dims of the " 366 "same rank as the memref type"); 367 } 368 if (permutationMap.getNumResults() != vectorType.getRank()) { 369 return emitOpError("requires a permutation_map with result dims of the " 370 "same rank as the vector type (") 371 << permutationMap.getNumResults() << " vs " << vectorType.getRank(); 372 } 373 return verifyPermutationMap(permutationMap, 374 [this](Twine t) { return emitOpError(t); }); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // VectorTransferWriteOp 379 //===----------------------------------------------------------------------===// 380 void VectorTransferWriteOp::build(Builder *builder, OperationState *result, 381 Value *srcVector, Value *dstMemRef, 382 ArrayRef<Value *> dstIndices, 383 AffineMap permutationMap) { 384 result->addOperands({srcVector, dstMemRef}); 385 result->addOperands(dstIndices); 386 result->addAttribute(getPermutationMapAttrName(), 387 builder->getAffineMapAttr(permutationMap)); 388 } 389 390 auto VectorTransferWriteOp::getIndices() -> operand_range { 391 auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; 392 auto end = begin + getMemRefType().getRank(); 393 return {begin, end}; 394 } 395 396 AffineMap VectorTransferWriteOp::getPermutationMap() { 397 return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue(); 398 } 399 400 void VectorTransferWriteOp::print(OpAsmPrinter *p) { 401 *p << getOperationName(); 402 *p << " " << *getVector(); 403 *p << ", " << *getMemRef(); 404 *p << "["; 405 p->printOperands(getIndices()); 406 *p << "]"; 407 p->printOptionalAttrDict(getAttrs()); 408 *p << " : "; 409 p->printType(getVectorType()); 410 *p << ", "; 411 p->printType(getMemRefType()); 412 } 413 414 ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser, 415 OperationState *result) { 416 OpAsmParser::OperandType storeValueInfo; 417 OpAsmParser::OperandType memrefInfo; 418 SmallVector<OpAsmParser::OperandType, 4> indexInfo; 419 SmallVector<Type, 2> types; 420 auto indexType = parser->getBuilder().getIndexType(); 421 if (parser->parseOperand(storeValueInfo) || parser->parseComma() || 422 parser->parseOperand(memrefInfo) || 423 parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || 424 parser->parseOptionalAttributeDict(result->attributes) || 425 parser->parseColonTypeList(types)) 426 return failure(); 427 428 if (types.size() != 2) 429 return parser->emitError(parser->getNameLoc(), "expected 2 types"); 430 VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>(); 431 if (!vectorType) 432 return parser->emitError(parser->getNameLoc(), "vector type expected"); 433 MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>(); 434 if (!memrefType) 435 return parser->emitError(parser->getNameLoc(), "memRef type expected"); 436 437 return failure( 438 parser->resolveOperands(storeValueInfo, vectorType, result->operands) || 439 parser->resolveOperands(memrefInfo, memrefType, result->operands) || 440 parser->resolveOperands(indexInfo, indexType, result->operands)); 441 } 442 443 LogicalResult VectorTransferWriteOp::verify() { 444 // Consistency of memref type in function type. 445 if (llvm::empty(getOperands())) { 446 return emitOpError( 447 "requires at least a memref operand followed by 'rank' indices"); 448 } 449 if (!getMemRef()->getType().isa<MemRefType>()) { 450 return emitOpError("requires a memref first operand"); 451 } 452 // Consistency of vector type in function type. 453 if (!getVector()->getType().isa<VectorType>()) { 454 return emitOpError("should have a vector input type in function type: " 455 "(vector_type, memref_type [, elemental_type]) -> ()"); 456 } 457 // Consistency of elemental types in memref and vector. 458 MemRefType memrefType = getMemRefType(); 459 VectorType vectorType = getVectorType(); 460 if (memrefType.getElementType() != vectorType.getElementType()) 461 return emitOpError( 462 "requires memref and vector types of the same elemental type"); 463 // Consistency of number of input types. 464 unsigned expectedNumOperands = 465 Offsets::FirstIndexOffset + memrefType.getRank(); 466 // Checks on the actual operands and their types. 467 if (getNumOperands() != expectedNumOperands) { 468 return emitOpError() << "expects " << expectedNumOperands 469 << " operands (of which " << memrefType.getRank() 470 << " indices)"; 471 } 472 // Consistency of indices types. 473 unsigned numIndices = 0; 474 for (auto *idx : getIndices()) { 475 if (!idx->getType().isIndex()) { 476 return emitOpError( 477 "index to vector.transfer_write must have 'index' type"); 478 } 479 numIndices++; 480 } 481 if (numIndices != memrefType.getRank()) { 482 return emitOpError("requires at least a memref operand followed by ") 483 << memrefType.getRank() << " indices"; 484 } 485 486 // Consistency of AffineMap attribute. 487 if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) { 488 return emitOpError("requires an AffineMapAttr named 'permutation_map'"); 489 } 490 auto permutationMap = getPermutationMap(); 491 if (permutationMap.getNumSymbols() != 0) { 492 return emitOpError("requires a permutation_map without symbols"); 493 } 494 if (permutationMap.getNumInputs() != memrefType.getRank()) { 495 return emitOpError("requires a permutation_map with input dims of the " 496 "same rank as the memref type"); 497 } 498 if (permutationMap.getNumResults() != vectorType.getRank()) { 499 return emitOpError("requires a permutation_map with result dims of the " 500 "same rank as the vector type (") 501 << permutationMap.getNumResults() << " vs " << vectorType.getRank(); 502 } 503 return verifyPermutationMap(permutationMap, 504 [this](Twine t) { return emitOpError(t); }); 505 } 506 507 //===----------------------------------------------------------------------===// 508 // VectorTypeCastOp 509 //===----------------------------------------------------------------------===// 510 void VectorTypeCastOp::build(Builder *builder, OperationState *result, 511 Value *srcVector, Type dstType) { 512 result->addOperands(srcVector); 513 result->addTypes(dstType); 514 } 515 516 ParseResult VectorTypeCastOp::parse(OpAsmParser *parser, 517 OperationState *result) { 518 OpAsmParser::OperandType operand; 519 Type srcType, dstType; 520 return failure(parser->parseOperand(operand) || 521 parser->parseOptionalAttributeDict(result->attributes) || 522 parser->parseColonType(srcType) || parser->parseComma() || 523 parser->parseType(dstType) || 524 parser->addTypeToList(dstType, result->types) || 525 parser->resolveOperand(operand, srcType, result->operands)); 526 } 527 528 void VectorTypeCastOp::print(OpAsmPrinter *p) { 529 *p << getOperationName() << ' ' << *getOperand() << " : " 530 << getOperand()->getType() << ", " << getType(); 531 } 532 533 LogicalResult VectorTypeCastOp::verify() { 534 auto dstMemrefType = getType().dyn_cast<MemRefType>(); 535 if (!dstMemrefType) 536 return emitOpError("expects target type to be a memref type"); 537 auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>(); 538 if (!dstVectorType) 539 return emitOpError( 540 "expects vector as an element of the target memref type"); 541 if (!dstMemrefType.hasStaticShape()) 542 return emitOpError("does not support dynamic shapes"); 543 544 if (!getOperand()->getType().isa<MemRefType>()) 545 return emitOpError("expects source type to be a memref type"); 546 547 return success(); 548 } 549 550 namespace mlir { 551 552 #define GET_OP_CLASSES 553 #include "mlir/Dialect/VectorOps/VectorOps.cpp.inc" 554 555 } // namespace mlir