github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Support/JitRunner.cpp (about)

     1  //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // This is a library that provides a shared implementation for command line
    19  // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
    20  // IR before JIT-compiling and executing the latter.
    21  //
    22  // The translation can be customized by providing an MLIR to MLIR
    23  // transformation.
    24  //===----------------------------------------------------------------------===//
    25  
    26  #include "mlir/Support/JitRunner.h"
    27  
    28  #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
    29  #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
    30  #include "mlir/ExecutionEngine/ExecutionEngine.h"
    31  #include "mlir/ExecutionEngine/MemRefUtils.h"
    32  #include "mlir/ExecutionEngine/OptUtils.h"
    33  #include "mlir/IR/MLIRContext.h"
    34  #include "mlir/IR/Module.h"
    35  #include "mlir/IR/StandardTypes.h"
    36  #include "mlir/Parser.h"
    37  #include "mlir/Pass/Pass.h"
    38  #include "mlir/Pass/PassManager.h"
    39  #include "mlir/Support/FileUtilities.h"
    40  #include "mlir/Transforms/Passes.h"
    41  
    42  #include "llvm/ADT/STLExtras.h"
    43  #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
    44  #include "llvm/IR/IRBuilder.h"
    45  #include "llvm/IR/LLVMContext.h"
    46  #include "llvm/IR/LegacyPassNameParser.h"
    47  #include "llvm/IR/Module.h"
    48  #include "llvm/Support/CommandLine.h"
    49  #include "llvm/Support/FileUtilities.h"
    50  #include "llvm/Support/InitLLVM.h"
    51  #include "llvm/Support/SourceMgr.h"
    52  #include "llvm/Support/StringSaver.h"
    53  #include "llvm/Support/TargetSelect.h"
    54  #include "llvm/Support/ToolOutputFile.h"
    55  #include <numeric>
    56  
    57  using namespace mlir;
    58  using llvm::Error;
    59  
    60  static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
    61                                                  llvm::cl::desc("<input file>"),
    62                                                  llvm::cl::init("-"));
    63  static llvm::cl::opt<std::string>
    64      initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"),
    65                llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0"));
    66  static llvm::cl::opt<std::string>
    67      mainFuncName("e", llvm::cl::desc("The function to be called"),
    68                   llvm::cl::value_desc("<function name>"),
    69                   llvm::cl::init("main"));
    70  static llvm::cl::opt<std::string> mainFuncType(
    71      "entry-point-result",
    72      llvm::cl::desc("Textual description of the function type to be called"),
    73      llvm::cl::value_desc("f32 | memrefs | void"), llvm::cl::init("memrefs"));
    74  
    75  static llvm::cl::OptionCategory optFlags("opt-like flags");
    76  
    77  // CLI list of pass information
    78  static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
    79      llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"),
    80                 llvm::cl::cat(optFlags));
    81  
    82  // CLI variables for -On options.
    83  static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"),
    84                                   llvm::cl::cat(optFlags));
    85  static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"),
    86                                   llvm::cl::cat(optFlags));
    87  static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
    88                                   llvm::cl::cat(optFlags));
    89  static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
    90                                   llvm::cl::cat(optFlags));
    91  
    92  static llvm::cl::OptionCategory clOptionsCategory("linking options");
    93  static llvm::cl::list<std::string>
    94      clSharedLibs("shared-libs", llvm::cl::desc("Libraries to link dynamically"),
    95                   llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
    96                   llvm::cl::cat(clOptionsCategory));
    97  
    98  // CLI variables for debugging.
    99  static llvm::cl::opt<bool> dumpObjectFile(
   100      "dump-object-file",
   101      llvm::cl::desc("Dump JITted-compiled object to file specified with "
   102                     "-object-filename (<input file>.o by default)."));
   103  
   104  static llvm::cl::opt<std::string> objectFilename(
   105      "object-filename",
   106      llvm::cl::desc("Dump JITted-compiled object to file <input file>.o"));
   107  
   108  static OwningModuleRef parseMLIRInput(StringRef inputFilename,
   109                                        MLIRContext *context) {
   110    // Set up the input file.
   111    std::string errorMessage;
   112    auto file = openInputFile(inputFilename, &errorMessage);
   113    if (!file) {
   114      llvm::errs() << errorMessage << "\n";
   115      return nullptr;
   116    }
   117  
   118    llvm::SourceMgr sourceMgr;
   119    sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
   120    return OwningModuleRef(parseSourceFile(sourceMgr, context));
   121  }
   122  
   123  // Initialize the relevant subsystems of LLVM.
   124  static void initializeLLVM() {
   125    llvm::InitializeNativeTarget();
   126    llvm::InitializeNativeTargetAsmPrinter();
   127  }
   128  
   129  static inline Error make_string_error(const llvm::Twine &message) {
   130    return llvm::make_error<llvm::StringError>(message.str(),
   131                                               llvm::inconvertibleErrorCode());
   132  }
   133  
   134  static void printOneMemRef(Type t, void *val) {
   135    auto memRefType = t.cast<MemRefType>();
   136    auto shape = memRefType.getShape();
   137    int64_t size = std::accumulate(shape.begin(), shape.end(), 1,
   138                                   std::multiplies<int64_t>());
   139    for (int64_t i = 0; i < size; ++i) {
   140      llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' ';
   141    }
   142    llvm::outs() << '\n';
   143  }
   144  
   145  static void printMemRefArguments(ArrayRef<Type> argTypes,
   146                                   ArrayRef<Type> resTypes,
   147                                   ArrayRef<void *> args) {
   148    auto properArgs = args.take_front(argTypes.size());
   149    for (const auto &kvp : llvm::zip(argTypes, properArgs)) {
   150      auto type = std::get<0>(kvp);
   151      auto val = std::get<1>(kvp);
   152      printOneMemRef(type, val);
   153    }
   154  
   155    auto results = args.drop_front(argTypes.size());
   156    for (const auto &kvp : llvm::zip(resTypes, results)) {
   157      auto type = std::get<0>(kvp);
   158      auto val = std::get<1>(kvp);
   159      printOneMemRef(type, val);
   160    }
   161  }
   162  
   163  // Calls the passes necessary to convert affine and standard dialects to the
   164  // LLVM IR dialect.
   165  // Currently, these passes are:
   166  // - CSE
   167  // - canonicalization
   168  // - affine to standard lowering
   169  // - standard to llvm lowering
   170  static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
   171    PassManager manager;
   172    manager.addPass(mlir::createCanonicalizerPass());
   173    manager.addPass(mlir::createCSEPass());
   174    manager.addPass(mlir::createLowerAffinePass());
   175    manager.addPass(mlir::createConvertToLLVMIRPass());
   176    return manager.run(module);
   177  }
   178  
   179  // JIT-compile the given module and run "entryPoint" with "args" as arguments.
   180  static Error
   181  compileAndExecute(ModuleOp module, StringRef entryPoint,
   182                    std::function<llvm::Error(llvm::Module *)> transformer,
   183                    void **args) {
   184    SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
   185    auto expectedEngine =
   186        mlir::ExecutionEngine::create(module, transformer, libs);
   187    if (!expectedEngine)
   188      return expectedEngine.takeError();
   189  
   190    auto engine = std::move(*expectedEngine);
   191    auto expectedFPtr = engine->lookup(entryPoint);
   192    if (!expectedFPtr)
   193      return expectedFPtr.takeError();
   194  
   195    if (dumpObjectFile)
   196      engine->dumpToObjectFile(objectFilename.empty() ? inputFilename + ".o"
   197                                                      : objectFilename);
   198  
   199    void (*fptr)(void **) = *expectedFPtr;
   200    (*fptr)(args);
   201  
   202    return Error::success();
   203  }
   204  
   205  static Error compileAndExecuteVoidFunction(
   206      ModuleOp module, StringRef entryPoint,
   207      std::function<llvm::Error(llvm::Module *)> transformer) {
   208    FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
   209    if (!mainFunction || mainFunction.getBlocks().empty())
   210      return make_string_error("entry point not found");
   211    void *empty = nullptr;
   212    return compileAndExecute(module, entryPoint, transformer, &empty);
   213  }
   214  
   215  static Error compileAndExecuteFunctionWithMemRefs(
   216      ModuleOp module, StringRef entryPoint,
   217      std::function<llvm::Error(llvm::Module *)> transformer) {
   218    FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
   219    if (!mainFunction || mainFunction.getBlocks().empty()) {
   220      return make_string_error("entry point not found");
   221    }
   222  
   223    // Store argument and result types of the original function necessary to
   224    // pretty print the results, because the function itself will be rewritten
   225    // to use the LLVM dialect.
   226    SmallVector<Type, 8> argTypes =
   227        llvm::to_vector<8>(mainFunction.getType().getInputs());
   228    SmallVector<Type, 8> resTypes =
   229        llvm::to_vector<8>(mainFunction.getType().getResults());
   230  
   231    float init = std::stof(initValue.getValue());
   232  
   233    auto expectedArguments = allocateMemRefArguments(mainFunction, init);
   234    if (!expectedArguments)
   235      return expectedArguments.takeError();
   236  
   237    if (failed(convertAffineStandardToLLVMIR(module)))
   238      return make_string_error("conversion to the LLVM IR dialect failed");
   239  
   240    if (auto error = compileAndExecute(module, entryPoint, transformer,
   241                                       expectedArguments->data()))
   242      return error;
   243  
   244    printMemRefArguments(argTypes, resTypes, *expectedArguments);
   245    freeMemRefArguments(*expectedArguments);
   246    return Error::success();
   247  }
   248  
   249  static Error compileAndExecuteSingleFloatReturnFunction(
   250      ModuleOp module, StringRef entryPoint,
   251      std::function<llvm::Error(llvm::Module *)> transformer) {
   252    FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint);
   253    if (!mainFunction || mainFunction.isExternal()) {
   254      return make_string_error("entry point not found");
   255    }
   256  
   257    if (!mainFunction.getType().getInputs().empty())
   258      return make_string_error("function inputs not supported");
   259  
   260    if (mainFunction.getType().getResults().size() != 1)
   261      return make_string_error("only single f32 function result supported");
   262  
   263    auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>();
   264    if (!t)
   265      return make_string_error("only single llvm.f32 function result supported");
   266    auto *llvmTy = t.getUnderlyingType();
   267    if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
   268      return make_string_error("only single llvm.f32 function result supported");
   269  
   270    float res;
   271    struct {
   272      void *data;
   273    } data;
   274    data.data = &res;
   275    if (auto error =
   276            compileAndExecute(module, entryPoint, transformer, (void **)&data))
   277      return error;
   278  
   279    // Intentional printing of the output so we can test.
   280    llvm::outs() << res;
   281  
   282    return Error::success();
   283  }
   284  
   285  // Entry point for all CPU runners. Expects the common argc/argv arguments for
   286  // standard C++ main functions and an mlirTransformer.
   287  // The latter is applied after parsing the input into MLIR IR and before passing
   288  // the MLIR module to the ExecutionEngine.
   289  int mlir::JitRunnerMain(
   290      int argc, char **argv,
   291      llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
   292    llvm::InitLLVM y(argc, argv);
   293  
   294    initializeLLVM();
   295    mlir::initializeLLVMPasses();
   296  
   297    llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
   298        optO0, optO1, optO2, optO3};
   299  
   300    llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
   301  
   302    llvm::SmallVector<const llvm::PassInfo *, 4> passes;
   303    llvm::Optional<unsigned> optLevel;
   304    unsigned optCLIPosition = 0;
   305    // Determine if there is an optimization flag present, and its CLI position
   306    // (optCLIPosition).
   307    for (unsigned j = 0; j < 4; ++j) {
   308      auto &flag = optFlags[j].get();
   309      if (flag) {
   310        optLevel = j;
   311        optCLIPosition = flag.getPosition();
   312        break;
   313      }
   314    }
   315    // Generate vector of pass information, plus the index at which we should
   316    // insert any optimization passes in that vector (optPosition).
   317    unsigned optPosition = 0;
   318    for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
   319      passes.push_back(llvmPasses[i]);
   320      if (optCLIPosition < llvmPasses.getPosition(i)) {
   321        optPosition = i;
   322        optCLIPosition = UINT_MAX; // To ensure we never insert again
   323      }
   324    }
   325  
   326    MLIRContext context;
   327    auto m = parseMLIRInput(inputFilename, &context);
   328    if (!m) {
   329      llvm::errs() << "could not parse the input IR\n";
   330      return 1;
   331    }
   332  
   333    if (mlirTransformer)
   334      if (failed(mlirTransformer(m.get())))
   335        return EXIT_FAILURE;
   336  
   337    auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
   338    if (!tmBuilderOrError) {
   339      llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
   340      return EXIT_FAILURE;
   341    }
   342    auto tmOrError = tmBuilderOrError->createTargetMachine();
   343    if (!tmOrError) {
   344      llvm::errs() << "Failed to create a TargetMachine for the host\n";
   345      return EXIT_FAILURE;
   346    }
   347  
   348    auto transformer = mlir::makeLLVMPassesTransformer(
   349        passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
   350  
   351    // Get the function used to compile and execute the module.
   352    using CompileAndExecuteFnT = Error (*)(
   353        ModuleOp, StringRef, std::function<llvm::Error(llvm::Module *)>);
   354    auto compileAndExecuteFn =
   355        llvm::StringSwitch<CompileAndExecuteFnT>(mainFuncType.getValue())
   356            .Case("f32", compileAndExecuteSingleFloatReturnFunction)
   357            .Case("memrefs", compileAndExecuteFunctionWithMemRefs)
   358            .Case("void", compileAndExecuteVoidFunction)
   359            .Default(nullptr);
   360  
   361    Error error =
   362        compileAndExecuteFn
   363            ? compileAndExecuteFn(m.get(), mainFuncName.getValue(), transformer)
   364            : make_string_error("unsupported function type");
   365  
   366    int exitCode = EXIT_SUCCESS;
   367    llvm::handleAllErrors(std::move(error),
   368                          [&exitCode](const llvm::ErrorInfoBase &info) {
   369                            llvm::errs() << "Error: ";
   370                            info.log(llvm::errs());
   371                            llvm::errs() << '\n';
   372                            exitCode = EXIT_FAILURE;
   373                          });
   374  
   375    return exitCode;
   376  }