github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/ExecutionEngine/MemRefUtils.cpp (about) 1 //===- MemRefUtils.cpp - MLIR runtime utilities for memrefs ---------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This is a set of utilities to working with objects of memref type in an JIT 19 // context using the MLIR execution engine. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/ExecutionEngine/MemRefUtils.h" 24 #include "mlir/IR/Function.h" 25 #include "mlir/IR/StandardTypes.h" 26 #include "mlir/Support/LLVM.h" 27 28 #include "llvm/Support/Error.h" 29 #include <numeric> 30 31 using namespace mlir; 32 33 static inline llvm::Error make_string_error(const llvm::Twine &message) { 34 return llvm::make_error<llvm::StringError>(message.str(), 35 llvm::inconvertibleErrorCode()); 36 } 37 38 static llvm::Expected<StaticFloatMemRef *> 39 allocMemRefDescriptor(Type type, bool allocateData = true, 40 float initialValue = 0.0) { 41 auto memRefType = type.dyn_cast<MemRefType>(); 42 if (!memRefType) 43 return make_string_error("non-memref argument not supported"); 44 if (!memRefType.hasStaticShape()) 45 return make_string_error("memref with dynamic shapes not supported"); 46 47 auto elementType = memRefType.getElementType(); 48 if (!elementType.isF32()) 49 return make_string_error( 50 "memref with element other than f32 not supported"); 51 52 auto *descriptor = 53 reinterpret_cast<StaticFloatMemRef *>(malloc(sizeof(StaticFloatMemRef))); 54 if (!allocateData) { 55 descriptor->data = nullptr; 56 return descriptor; 57 } 58 59 auto shape = memRefType.getShape(); 60 int64_t size = std::accumulate(shape.begin(), shape.end(), 1, 61 std::multiplies<int64_t>()); 62 descriptor->data = reinterpret_cast<float *>(malloc(sizeof(float) * size)); 63 for (int64_t i = 0; i < size; ++i) { 64 descriptor->data[i] = initialValue; 65 } 66 return descriptor; 67 } 68 69 llvm::Expected<SmallVector<void *, 8>> 70 mlir::allocateMemRefArguments(FuncOp func, float initialValue) { 71 SmallVector<void *, 8> args; 72 args.reserve(func.getNumArguments()); 73 for (const auto &arg : func.getArguments()) { 74 auto descriptor = 75 allocMemRefDescriptor(arg->getType(), 76 /*allocateData=*/true, initialValue); 77 if (!descriptor) 78 return descriptor.takeError(); 79 args.push_back(*descriptor); 80 } 81 82 if (func.getType().getNumResults() > 1) 83 return make_string_error("functions with more than 1 result not supported"); 84 85 for (Type resType : func.getType().getResults()) { 86 auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false); 87 if (!descriptor) 88 return descriptor.takeError(); 89 args.push_back(*descriptor); 90 } 91 92 return args; 93 } 94 95 // Because the function can return the same descriptor as passed in arguments, 96 // we check that we don't attempt to free the underlying data twice. 97 void mlir::freeMemRefArguments(ArrayRef<void *> args) { 98 llvm::DenseSet<void *> dataPointers; 99 for (void *arg : args) { 100 float *dataPtr = reinterpret_cast<StaticFloatMemRef *>(arg)->data; 101 if (dataPointers.count(dataPtr) == 0) { 102 free(dataPtr); 103 dataPointers.insert(dataPtr); 104 } 105 free(arg); 106 } 107 }