github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/ExecutionEngine/ExecutionEngine.cpp (about)

     1  //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This file implements the execution engine for MLIR modules based on LLVM Orc
    19  // JIT engine.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  #include "mlir/ExecutionEngine/ExecutionEngine.h"
    23  #include "mlir/IR/Function.h"
    24  #include "mlir/IR/Module.h"
    25  #include "mlir/Support/FileUtilities.h"
    26  #include "mlir/Target/LLVMIR.h"
    27  
    28  #include "llvm/Bitcode/BitcodeReader.h"
    29  #include "llvm/Bitcode/BitcodeWriter.h"
    30  #include "llvm/ExecutionEngine/ObjectCache.h"
    31  #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
    32  #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
    33  #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
    34  #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
    35  #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
    36  #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
    37  #include "llvm/ExecutionEngine/SectionMemoryManager.h"
    38  #include "llvm/IR/IRBuilder.h"
    39  #include "llvm/Support/Error.h"
    40  #include "llvm/Support/TargetRegistry.h"
    41  #include "llvm/Support/ToolOutputFile.h"
    42  
    43  using namespace mlir;
    44  using llvm::dbgs;
    45  using llvm::Error;
    46  using llvm::errs;
    47  using llvm::Expected;
    48  using llvm::LLVMContext;
    49  using llvm::MemoryBuffer;
    50  using llvm::MemoryBufferRef;
    51  using llvm::Module;
    52  using llvm::SectionMemoryManager;
    53  using llvm::StringError;
    54  using llvm::Triple;
    55  using llvm::orc::DynamicLibrarySearchGenerator;
    56  using llvm::orc::ExecutionSession;
    57  using llvm::orc::IRCompileLayer;
    58  using llvm::orc::JITTargetMachineBuilder;
    59  using llvm::orc::RTDyldObjectLinkingLayer;
    60  using llvm::orc::ThreadSafeModule;
    61  using llvm::orc::TMOwningSimpleCompiler;
    62  
    63  // Wrap a string into an llvm::StringError.
    64  static inline Error make_string_error(const llvm::Twine &message) {
    65    return llvm::make_error<StringError>(message.str(),
    66                                         llvm::inconvertibleErrorCode());
    67  }
    68  
    69  namespace mlir {
    70  
    71  void SimpleObjectCache::notifyObjectCompiled(const Module *M,
    72                                               MemoryBufferRef ObjBuffer) {
    73    cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
    74        ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier());
    75  }
    76  
    77  std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) {
    78    auto I = cachedObjects.find(M->getModuleIdentifier());
    79    if (I == cachedObjects.end()) {
    80      dbgs() << "No object for " << M->getModuleIdentifier()
    81             << " in cache. Compiling.\n";
    82      return nullptr;
    83    }
    84    dbgs() << "Object for " << M->getModuleIdentifier()
    85           << " loaded from cache.\n";
    86    return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef());
    87  }
    88  
    89  void SimpleObjectCache::dumpToObjectFile(llvm::StringRef outputFilename) {
    90    // Set up the output file.
    91    std::string errorMessage;
    92    auto file = openOutputFile(outputFilename, &errorMessage);
    93    if (!file) {
    94      llvm::errs() << errorMessage << "\n";
    95      return;
    96    }
    97  
    98    // Dump the object generated for a single module to the output file.
    99    assert(cachedObjects.size() == 1 && "Expected only one object entry.");
   100    auto &cachedObject = cachedObjects.begin()->second;
   101    file->os() << cachedObject->getBuffer();
   102    file->keep();
   103  }
   104  
   105  void ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) {
   106    cache->dumpToObjectFile(filename);
   107  }
   108  
   109  // Setup LLVM target triple from the current machine.
   110  bool ExecutionEngine::setupTargetTriple(Module *llvmModule) {
   111    // Setup the machine properties from the current architecture.
   112    auto targetTriple = llvm::sys::getDefaultTargetTriple();
   113    std::string errorMessage;
   114    auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
   115    if (!target) {
   116      errs() << "NO target: " << errorMessage << "\n";
   117      return true;
   118    }
   119    auto machine =
   120        target->createTargetMachine(targetTriple, "generic", "", {}, {});
   121    llvmModule->setDataLayout(machine->createDataLayout());
   122    llvmModule->setTargetTriple(targetTriple);
   123    return false;
   124  }
   125  
   126  static std::string makePackedFunctionName(StringRef name) {
   127    return "_mlir_" + name.str();
   128  }
   129  
   130  // For each function in the LLVM module, define an interface function that wraps
   131  // all the arguments of the original function and all its results into an i8**
   132  // pointer to provide a unified invocation interface.
   133  void packFunctionArguments(Module *module) {
   134    auto &ctx = module->getContext();
   135    llvm::IRBuilder<> builder(ctx);
   136    llvm::DenseSet<llvm::Function *> interfaceFunctions;
   137    for (auto &func : module->getFunctionList()) {
   138      if (func.isDeclaration()) {
   139        continue;
   140      }
   141      if (interfaceFunctions.count(&func)) {
   142        continue;
   143      }
   144  
   145      // Given a function `foo(<...>)`, define the interface function
   146      // `mlir_foo(i8**)`.
   147      auto newType = llvm::FunctionType::get(
   148          builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
   149          /*isVarArg=*/false);
   150      auto newName = makePackedFunctionName(func.getName());
   151      auto funcCst = module->getOrInsertFunction(newName, newType);
   152      llvm::Function *interfaceFunc =
   153          llvm::cast<llvm::Function>(funcCst.getCallee());
   154      interfaceFunctions.insert(interfaceFunc);
   155  
   156      // Extract the arguments from the type-erased argument list and cast them to
   157      // the proper types.
   158      auto bb = llvm::BasicBlock::Create(ctx);
   159      bb->insertInto(interfaceFunc);
   160      builder.SetInsertPoint(bb);
   161      llvm::Value *argList = interfaceFunc->arg_begin();
   162      llvm::SmallVector<llvm::Value *, 8> args;
   163      args.reserve(llvm::size(func.args()));
   164      for (auto &indexedArg : llvm::enumerate(func.args())) {
   165        llvm::Value *argIndex = llvm::Constant::getIntegerValue(
   166            builder.getInt64Ty(), llvm::APInt(64, indexedArg.index()));
   167        llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex);
   168        llvm::Value *argPtr = builder.CreateLoad(argPtrPtr);
   169        argPtr = builder.CreateBitCast(
   170            argPtr, indexedArg.value().getType()->getPointerTo());
   171        llvm::Value *arg = builder.CreateLoad(argPtr);
   172        args.push_back(arg);
   173      }
   174  
   175      // Call the implementation function with the extracted arguments.
   176      llvm::Value *result = builder.CreateCall(&func, args);
   177  
   178      // Assuming the result is one value, potentially of type `void`.
   179      if (!result->getType()->isVoidTy()) {
   180        llvm::Value *retIndex = llvm::Constant::getIntegerValue(
   181            builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args())));
   182        llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex);
   183        llvm::Value *retPtr = builder.CreateLoad(retPtrPtr);
   184        retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
   185        builder.CreateStore(result, retPtr);
   186      }
   187  
   188      // The interface function returns void.
   189      builder.CreateRetVoid();
   190    }
   191  }
   192  
   193  ExecutionEngine::ExecutionEngine(bool enableObjectCache)
   194      : cache(enableObjectCache ? nullptr : new SimpleObjectCache()) {}
   195  
   196  Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
   197      ModuleOp m, std::function<Error(llvm::Module *)> transformer,
   198      ArrayRef<StringRef> sharedLibPaths, bool enableObjectCache) {
   199    auto engine = std::make_unique<ExecutionEngine>(enableObjectCache);
   200  
   201    std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
   202    auto llvmModule = translateModuleToLLVMIR(m);
   203    if (!llvmModule)
   204      return make_string_error("could not convert to LLVM IR");
   205    // FIXME: the triple should be passed to the translation or dialect conversion
   206    // instead of this.  Currently, the LLVM module created above has no triple
   207    // associated with it.
   208    setupTargetTriple(llvmModule.get());
   209    packFunctionArguments(llvmModule.get());
   210  
   211    // Clone module in a new LLVMContext since translateModuleToLLVMIR buries
   212    // ownership too deeply.
   213    // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect.
   214    SmallVector<char, 1> buffer;
   215    {
   216      llvm::raw_svector_ostream os(buffer);
   217      WriteBitcodeToFile(*llvmModule, os);
   218    }
   219    llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()),
   220                                    "cloned module buffer");
   221    auto expectedModule = parseBitcodeFile(bufferRef, *ctx);
   222    if (!expectedModule)
   223      return expectedModule.takeError();
   224    std::unique_ptr<Module> deserModule = std::move(*expectedModule);
   225  
   226    // Callback to create the object layer with symbol resolution to current
   227    // process and dynamically linked libraries.
   228    auto objectLinkingLayerCreator = [&](ExecutionSession &session,
   229                                         const Triple &TT) {
   230      auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
   231          session, []() { return std::make_unique<SectionMemoryManager>(); });
   232      auto dataLayout = deserModule->getDataLayout();
   233  
   234      // Resolve symbols that are statically linked in the current process.
   235      session.getMainJITDylib().addGenerator(
   236          cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
   237              dataLayout.getGlobalPrefix())));
   238  
   239      // Resolve symbols from shared libraries.
   240      for (auto libPath : sharedLibPaths) {
   241        auto mb = llvm::MemoryBuffer::getFile(libPath);
   242        if (!mb) {
   243          errs() << "Fail to create MemoryBuffer for: " << libPath << "\n";
   244          continue;
   245        }
   246        auto &JD = session.createJITDylib(libPath);
   247        auto loaded = DynamicLibrarySearchGenerator::Load(
   248            libPath.data(), dataLayout.getGlobalPrefix());
   249        if (!loaded) {
   250          errs() << "Could not load: " << libPath << "\n";
   251          continue;
   252        }
   253        JD.addGenerator(std::move(*loaded));
   254        cantFail(objectLayer->add(JD, std::move(mb.get())));
   255      }
   256  
   257      return objectLayer;
   258    };
   259  
   260    // Callback to inspect the cache and recompile on demand. This follows Lang's
   261    // LLJITWithObjectCache example.
   262    auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB)
   263        -> Expected<IRCompileLayer::CompileFunction> {
   264      auto TM = JTMB.createTargetMachine();
   265      if (!TM)
   266        return TM.takeError();
   267      return IRCompileLayer::CompileFunction(
   268          TMOwningSimpleCompiler(std::move(*TM), engine->cache.get()));
   269    };
   270  
   271    // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
   272    auto jit =
   273        cantFail(llvm::orc::LLJITBuilder()
   274                     .setCompileFunctionCreator(compileFunctionCreator)
   275                     .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
   276                     .create());
   277  
   278    // Add a ThreadSafemodule to the engine and return.
   279    ThreadSafeModule tsm(std::move(deserModule), std::move(ctx));
   280    cantFail(jit->addIRModule(std::move(tsm)));
   281    engine->jit = std::move(jit);
   282  
   283    return std::move(engine);
   284  }
   285  
   286  Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
   287    auto expectedSymbol = jit->lookup(makePackedFunctionName(name));
   288    if (!expectedSymbol)
   289      return expectedSymbol.takeError();
   290    auto rawFPtr = expectedSymbol->getAddress();
   291    auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr);
   292    if (!fptr)
   293      return make_string_error("looked up function is null");
   294    return fptr;
   295  }
   296  
   297  Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) {
   298    auto expectedFPtr = lookup(name);
   299    if (!expectedFPtr)
   300      return expectedFPtr.takeError();
   301    auto fptr = *expectedFPtr;
   302  
   303    (*fptr)(args.data());
   304  
   305    return Error::success();
   306  }
   307  } // end namespace mlir