github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/ExecutionEngine/MemRefUtils.cpp (about)

     1  //===- MemRefUtils.cpp - MLIR runtime utilities for memrefs ---------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This is a set of utilities to working with objects of memref type in an JIT
    19  // context using the MLIR execution engine.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/ExecutionEngine/MemRefUtils.h"
    24  #include "mlir/IR/Function.h"
    25  #include "mlir/IR/StandardTypes.h"
    26  #include "mlir/Support/LLVM.h"
    27  
    28  #include "llvm/Support/Error.h"
    29  #include <numeric>
    30  
    31  using namespace mlir;
    32  
    33  static inline llvm::Error make_string_error(const llvm::Twine &message) {
    34    return llvm::make_error<llvm::StringError>(message.str(),
    35                                               llvm::inconvertibleErrorCode());
    36  }
    37  
    38  static llvm::Expected<StaticFloatMemRef *>
    39  allocMemRefDescriptor(Type type, bool allocateData = true,
    40                        float initialValue = 0.0) {
    41    auto memRefType = type.dyn_cast<MemRefType>();
    42    if (!memRefType)
    43      return make_string_error("non-memref argument not supported");
    44    if (!memRefType.hasStaticShape())
    45      return make_string_error("memref with dynamic shapes not supported");
    46  
    47    auto elementType = memRefType.getElementType();
    48    if (!elementType.isF32())
    49      return make_string_error(
    50          "memref with element other than f32 not supported");
    51  
    52    auto *descriptor =
    53        reinterpret_cast<StaticFloatMemRef *>(malloc(sizeof(StaticFloatMemRef)));
    54    if (!allocateData) {
    55      descriptor->data = nullptr;
    56      return descriptor;
    57    }
    58  
    59    auto shape = memRefType.getShape();
    60    int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
    61                                   std::multiplies<int64_t>());
    62    descriptor->data = reinterpret_cast<float *>(malloc(sizeof(float) * size));
    63    for (int64_t i = 0; i < size; ++i) {
    64      descriptor->data[i] = initialValue;
    65    }
    66    return descriptor;
    67  }
    68  
    69  llvm::Expected<SmallVector<void *, 8>>
    70  mlir::allocateMemRefArguments(FuncOp func, float initialValue) {
    71    SmallVector<void *, 8> args;
    72    args.reserve(func.getNumArguments());
    73    for (const auto &arg : func.getArguments()) {
    74      auto descriptor =
    75          allocMemRefDescriptor(arg->getType(),
    76                                /*allocateData=*/true, initialValue);
    77      if (!descriptor)
    78        return descriptor.takeError();
    79      args.push_back(*descriptor);
    80    }
    81  
    82    if (func.getType().getNumResults() > 1)
    83      return make_string_error("functions with more than 1 result not supported");
    84  
    85    for (Type resType : func.getType().getResults()) {
    86      auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false);
    87      if (!descriptor)
    88        return descriptor.takeError();
    89      args.push_back(*descriptor);
    90    }
    91  
    92    return args;
    93  }
    94  
    95  // Because the function can return the same descriptor as passed in arguments,
    96  // we check that we don't attempt to free the underlying data twice.
    97  void mlir::freeMemRefArguments(ArrayRef<void *> args) {
    98    llvm::DenseSet<void *> dataPointers;
    99    for (void *arg : args) {
   100      float *dataPtr = reinterpret_cast<StaticFloatMemRef *>(arg)->data;
   101      if (dataPointers.count(dataPtr) == 0) {
   102        free(dataPtr);
   103        dataPointers.insert(dataPtr);
   104      }
   105      free(arg);
   106    }
   107  }