github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Dialect/Traits.cpp (about)

     1  //===- Traits.cpp - Common op traits shared by dialects -------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/Dialect/Traits.h"
    19  #include "mlir/IR/StandardTypes.h"
    20  #include "llvm/Support/FormatVariadic.h"
    21  
    22  using namespace mlir;
    23  
    24  bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
    25                                          ArrayRef<int64_t> shape2,
    26                                          SmallVectorImpl<int64_t> &resultShape) {
    27    // To compute the result broadcasted shape, we compare operand shapes
    28    // element-wise: starting with the trailing dimensions, and working the
    29    // way backward. Two dimensions are compatible when
    30    //   1. they are equal, or
    31    //   2. one of them is 1
    32    // The result shape has the maximum among the two inputs at every
    33    // dimension index.
    34  
    35    resultShape.clear();
    36    if (shape1.size() > shape2.size()) {
    37      std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
    38    } else {
    39      std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
    40    }
    41  
    42    auto i1 = shape1.rbegin(), e1 = shape1.rend();
    43    auto i2 = shape2.rbegin(), e2 = shape2.rend();
    44    auto iR = resultShape.rbegin();
    45  
    46    // Check each dimension is consistent.
    47    for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
    48      if (*i1 == -1 || *i2 == -1) {
    49        // One or both dimensions is unknown. Follow TensorFlow behavior:
    50        // - If either dimension is greater than 1, we assume that the program is
    51        //   correct, and the other dimension will be broadcast to match it.
    52        // - If either dimension is 1, the other dimension is the output.
    53        if (*i1 > 1) {
    54          *iR = *i1;
    55        } else if (*i2 > 1) {
    56          *iR = *i2;
    57        } else if (*i1 == 1) {
    58          *iR = *i2;
    59        } else if (*i2 == 1) {
    60          *iR = *i1;
    61        } else {
    62          *iR = -1;
    63        }
    64      } else {
    65        if (*i1 == *i2 || *i2 == 1) {
    66          *iR = *i1;
    67        } else if (*i1 == 1) {
    68          *iR = *i2;
    69        } else {
    70          // This dimension of the two operand types is incompatible.
    71          resultShape.clear();
    72          return false;
    73        }
    74      }
    75    }
    76  
    77    return true;
    78  }
    79  
    80  /// Returns the shape of the given type. Scalars will be considered as having a
    81  /// shape with zero dimensions.
    82  static ArrayRef<int64_t> getShape(Type type) {
    83    if (auto sType = type.dyn_cast<ShapedType>())
    84      return sType.getShape();
    85    return {};
    86  }
    87  
    88  /// Returns the result broadcast composition type from the two given types by
    89  /// following NumPy broadcast semantics. Returned type may have dynamic shape if
    90  /// either of the input types has dynamic shape. Returns null type if the two
    91  /// given types are not broadcast-compatible.
    92  Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
    93    // Returns the scalar type out of the given type.
    94    auto getScalarType = [](Type type) -> Type {
    95      if (auto shapedType = type.dyn_cast<ShapedType>())
    96        return shapedType.getElementType();
    97      return type;
    98    };
    99  
   100    // Make sure underlying scalar type is the same.
   101    auto scalarType = getScalarType(type1);
   102    if (scalarType != getScalarType(type2))
   103      return {};
   104  
   105    // If one of the types is unranked tensor, then the other type shouldn't be
   106    // vector and the result should have unranked tensor type.
   107    if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
   108      if (type1.isa<VectorType>() || type2.isa<VectorType>())
   109        return {};
   110      return UnrankedTensorType::get(scalarType);
   111    }
   112  
   113    // Returns the type kind if the given type is a vector or ranked tensor type.
   114    // Returns llvm::None otherwise.
   115    auto getCompositeTypeKind =
   116        [](Type type) -> llvm::Optional<StandardTypes::Kind> {
   117      if (type.isa<VectorType>() || type.isa<RankedTensorType>())
   118        return static_cast<StandardTypes::Kind>(type.getKind());
   119      return llvm::None;
   120    };
   121  
   122    // Make sure the composite type, if has, is consistent.
   123    auto compositeKind1 = getCompositeTypeKind(type1);
   124    auto compositeKind2 = getCompositeTypeKind(type2);
   125    llvm::Optional<StandardTypes::Kind> resultCompositeKind;
   126  
   127    if (compositeKind1 && compositeKind2) {
   128      // Disallow mixing vector and tensor.
   129      if (compositeKind1 != compositeKind2)
   130        return {};
   131      resultCompositeKind = compositeKind1;
   132    } else if (compositeKind1) {
   133      resultCompositeKind = compositeKind1;
   134    } else if (compositeKind2) {
   135      resultCompositeKind = compositeKind2;
   136    }
   137  
   138    // Get the shape of each type.
   139    SmallVector<int64_t, 4> resultShape;
   140    if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
   141      return {};
   142  
   143    // Compose the final broadcasted type
   144    if (resultCompositeKind == StandardTypes::Vector)
   145      return VectorType::get(resultShape, scalarType);
   146    if (resultCompositeKind == StandardTypes::RankedTensor)
   147      return RankedTensorType::get(resultShape, scalarType);
   148    return scalarType;
   149  }
   150  
   151  /// Returns true if the given types has both vector types and tensor types.
   152  static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
   153    return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
   154           llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
   155  }
   156  
   157  static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
   158                                  ArrayRef<int64_t> shape2) {
   159    auto isCompatible = [](int64_t dim1, int64_t dim2) {
   160      return dim1 == dim2 || dim1 == -1 || dim2 == -1;
   161    };
   162    if (shape1.size() != shape2.size())
   163      return false;
   164    for (const auto &p : llvm::zip(shape1, shape2))
   165      if (!isCompatible(std::get<0>(p), std::get<1>(p)))
   166        return false;
   167    return true;
   168  }
   169  
   170  LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
   171    assert(op->getNumOperands() == 2 &&
   172           "only support broadcast check on two operands");
   173    assert(op->getNumResults() == 1 &&
   174           "only support broadcast check on one result");
   175  
   176    auto type1 = op->getOperand(0)->getType();
   177    auto type2 = op->getOperand(1)->getType();
   178    auto retType = op->getResult(0)->getType();
   179  
   180    // We forbid broadcasting vector and tensor.
   181    if (hasBothVectorAndTensorType({type1, type2, retType}))
   182      return op->emitError("cannot broadcast vector with tensor");
   183  
   184    if (retType.isa<UnrankedTensorType>())
   185      return success();
   186  
   187    bool isUnranked1 = type1.isa<UnrankedTensorType>();
   188    bool isUnranked2 = type2.isa<UnrankedTensorType>();
   189  
   190    // If both operands are unranked, then all result shapes are possible.
   191    if (isUnranked1 && isUnranked2)
   192      return success();
   193  
   194    // If one of the operands is unranked, then the known dimensions in the result
   195    // should be compatible with the other shaped operand.
   196    if (isUnranked1 || isUnranked2) {
   197      // Result should have higher rank than the shaped operand's rank and then
   198      // the result's trailing dimensions should be compatible with the operand
   199      // shape.
   200      ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
   201      ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
   202      if (!areCompatibleShapes(actualSuffix, shape))
   203        return op->emitOpError()
   204               << "result type " << retType
   205               << " has shape incompatible with a ranked operand type";
   206      return success();
   207    }
   208  
   209    // If both operands are shaped, then the computed broadcasted shape should be
   210    // compatible with the result shape.
   211    SmallVector<int64_t, 4> resultShape;
   212    if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
   213      return op->emitOpError("operands don't have broadcast-compatible shapes");
   214  
   215    if (!areCompatibleShapes(resultShape, getShape(retType)))
   216      return op->emitOpError() << "result type " << retType
   217                               << " does not have shape compatible with the one "
   218                                  "computed from the operand types";
   219  
   220    return success();
   221  }