github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/EDSC/Helpers.cpp (about)

     1  //===- Helpers.cpp - MLIR Declarative Helper Functionality ----------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "mlir/EDSC/Helpers.h"
    19  #include "mlir/Dialect/StandardOps/Ops.h"
    20  #include "mlir/IR/AffineExpr.h"
    21  
    22  using namespace mlir;
    23  using namespace mlir::edsc;
    24  
    25  static SmallVector<ValueHandle, 8> getMemRefSizes(Value *memRef) {
    26    MemRefType memRefType = memRef->getType().cast<MemRefType>();
    27  
    28    auto maps = memRefType.getAffineMaps();
    29    (void)maps;
    30    assert((maps.empty() || (maps.size() == 1 && maps[0].isIdentity())) &&
    31           "Layout maps not supported");
    32    SmallVector<ValueHandle, 8> res;
    33    res.reserve(memRefType.getShape().size());
    34    const auto &shape = memRefType.getShape();
    35    for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) {
    36      if (shape[idx] == -1) {
    37        res.push_back(ValueHandle::create<DimOp>(memRef, idx));
    38      } else {
    39        res.push_back(static_cast<index_t>(shape[idx]));
    40      }
    41    }
    42    return res;
    43  }
    44  
    45  mlir::edsc::MemRefView::MemRefView(Value *v) : base(v) {
    46    assert(v->getType().isa<MemRefType>() && "MemRefType expected");
    47  
    48    auto memrefSizeValues = getMemRefSizes(v);
    49    for (auto &size : memrefSizeValues) {
    50      lbs.push_back(static_cast<index_t>(0));
    51      ubs.push_back(size);
    52      steps.push_back(1);
    53    }
    54  }
    55  
    56  mlir::edsc::VectorView::VectorView(Value *v) : base(v) {
    57    auto vectorType = v->getType().cast<VectorType>();
    58  
    59    for (auto s : vectorType.getShape()) {
    60      lbs.push_back(static_cast<index_t>(0));
    61      ubs.push_back(static_cast<index_t>(s));
    62      steps.push_back(1);
    63    }
    64  }