github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Traits.cpp (about) 1 //===- Traits.cpp - Common op traits shared by dialects -------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/Dialect/Traits.h" 19 #include "mlir/IR/StandardTypes.h" 20 #include "llvm/Support/FormatVariadic.h" 21 22 using namespace mlir; 23 24 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1, 25 ArrayRef<int64_t> shape2, 26 SmallVectorImpl<int64_t> &resultShape) { 27 // To compute the result broadcasted shape, we compare operand shapes 28 // element-wise: starting with the trailing dimensions, and working the 29 // way backward. Two dimensions are compatible when 30 // 1. they are equal, or 31 // 2. one of them is 1 32 // The result shape has the maximum among the two inputs at every 33 // dimension index. 34 35 resultShape.clear(); 36 if (shape1.size() > shape2.size()) { 37 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape)); 38 } else { 39 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape)); 40 } 41 42 auto i1 = shape1.rbegin(), e1 = shape1.rend(); 43 auto i2 = shape2.rbegin(), e2 = shape2.rend(); 44 auto iR = resultShape.rbegin(); 45 46 // Check each dimension is consistent. 47 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) { 48 if (*i1 == -1 || *i2 == -1) { 49 // One or both dimensions is unknown. Follow TensorFlow behavior: 50 // - If either dimension is greater than 1, we assume that the program is 51 // correct, and the other dimension will be broadcast to match it. 52 // - If either dimension is 1, the other dimension is the output. 53 if (*i1 > 1) { 54 *iR = *i1; 55 } else if (*i2 > 1) { 56 *iR = *i2; 57 } else if (*i1 == 1) { 58 *iR = *i2; 59 } else if (*i2 == 1) { 60 *iR = *i1; 61 } else { 62 *iR = -1; 63 } 64 } else { 65 if (*i1 == *i2 || *i2 == 1) { 66 *iR = *i1; 67 } else if (*i1 == 1) { 68 *iR = *i2; 69 } else { 70 // This dimension of the two operand types is incompatible. 71 resultShape.clear(); 72 return false; 73 } 74 } 75 } 76 77 return true; 78 } 79 80 /// Returns the shape of the given type. Scalars will be considered as having a 81 /// shape with zero dimensions. 82 static ArrayRef<int64_t> getShape(Type type) { 83 if (auto sType = type.dyn_cast<ShapedType>()) 84 return sType.getShape(); 85 return {}; 86 } 87 88 /// Returns the result broadcast composition type from the two given types by 89 /// following NumPy broadcast semantics. Returned type may have dynamic shape if 90 /// either of the input types has dynamic shape. Returns null type if the two 91 /// given types are not broadcast-compatible. 92 Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { 93 // Returns the scalar type out of the given type. 94 auto getScalarType = [](Type type) -> Type { 95 if (auto shapedType = type.dyn_cast<ShapedType>()) 96 return shapedType.getElementType(); 97 return type; 98 }; 99 100 // Make sure underlying scalar type is the same. 101 auto scalarType = getScalarType(type1); 102 if (scalarType != getScalarType(type2)) 103 return {}; 104 105 // If one of the types is unranked tensor, then the other type shouldn't be 106 // vector and the result should have unranked tensor type. 107 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) { 108 if (type1.isa<VectorType>() || type2.isa<VectorType>()) 109 return {}; 110 return UnrankedTensorType::get(scalarType); 111 } 112 113 // Returns the type kind if the given type is a vector or ranked tensor type. 114 // Returns llvm::None otherwise. 115 auto getCompositeTypeKind = 116 [](Type type) -> llvm::Optional<StandardTypes::Kind> { 117 if (type.isa<VectorType>() || type.isa<RankedTensorType>()) 118 return static_cast<StandardTypes::Kind>(type.getKind()); 119 return llvm::None; 120 }; 121 122 // Make sure the composite type, if has, is consistent. 123 auto compositeKind1 = getCompositeTypeKind(type1); 124 auto compositeKind2 = getCompositeTypeKind(type2); 125 llvm::Optional<StandardTypes::Kind> resultCompositeKind; 126 127 if (compositeKind1 && compositeKind2) { 128 // Disallow mixing vector and tensor. 129 if (compositeKind1 != compositeKind2) 130 return {}; 131 resultCompositeKind = compositeKind1; 132 } else if (compositeKind1) { 133 resultCompositeKind = compositeKind1; 134 } else if (compositeKind2) { 135 resultCompositeKind = compositeKind2; 136 } 137 138 // Get the shape of each type. 139 SmallVector<int64_t, 4> resultShape; 140 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 141 return {}; 142 143 // Compose the final broadcasted type 144 if (resultCompositeKind == StandardTypes::Vector) 145 return VectorType::get(resultShape, scalarType); 146 if (resultCompositeKind == StandardTypes::RankedTensor) 147 return RankedTensorType::get(resultShape, scalarType); 148 return scalarType; 149 } 150 151 /// Returns true if the given types has both vector types and tensor types. 152 static bool hasBothVectorAndTensorType(ArrayRef<Type> types) { 153 return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) && 154 llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }); 155 } 156 157 static bool areCompatibleShapes(ArrayRef<int64_t> shape1, 158 ArrayRef<int64_t> shape2) { 159 auto isCompatible = [](int64_t dim1, int64_t dim2) { 160 return dim1 == dim2 || dim1 == -1 || dim2 == -1; 161 }; 162 if (shape1.size() != shape2.size()) 163 return false; 164 for (const auto &p : llvm::zip(shape1, shape2)) 165 if (!isCompatible(std::get<0>(p), std::get<1>(p))) 166 return false; 167 return true; 168 } 169 170 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { 171 assert(op->getNumOperands() == 2 && 172 "only support broadcast check on two operands"); 173 assert(op->getNumResults() == 1 && 174 "only support broadcast check on one result"); 175 176 auto type1 = op->getOperand(0)->getType(); 177 auto type2 = op->getOperand(1)->getType(); 178 auto retType = op->getResult(0)->getType(); 179 180 // We forbid broadcasting vector and tensor. 181 if (hasBothVectorAndTensorType({type1, type2, retType})) 182 return op->emitError("cannot broadcast vector with tensor"); 183 184 if (retType.isa<UnrankedTensorType>()) 185 return success(); 186 187 bool isUnranked1 = type1.isa<UnrankedTensorType>(); 188 bool isUnranked2 = type2.isa<UnrankedTensorType>(); 189 190 // If both operands are unranked, then all result shapes are possible. 191 if (isUnranked1 && isUnranked2) 192 return success(); 193 194 // If one of the operands is unranked, then the known dimensions in the result 195 // should be compatible with the other shaped operand. 196 if (isUnranked1 || isUnranked2) { 197 // Result should have higher rank than the shaped operand's rank and then 198 // the result's trailing dimensions should be compatible with the operand 199 // shape. 200 ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2); 201 ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size()); 202 if (!areCompatibleShapes(actualSuffix, shape)) 203 return op->emitOpError() 204 << "result type " << retType 205 << " has shape incompatible with a ranked operand type"; 206 return success(); 207 } 208 209 // If both operands are shaped, then the computed broadcasted shape should be 210 // compatible with the result shape. 211 SmallVector<int64_t, 4> resultShape; 212 if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) 213 return op->emitOpError("operands don't have broadcast-compatible shapes"); 214 215 if (!areCompatibleShapes(resultShape, getShape(retType))) 216 return op->emitOpError() << "result type " << retType 217 << " does not have shape compatible with the one " 218 "computed from the operand types"; 219 220 return success(); 221 }