github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/Support/JitRunner.cpp (about) 1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This is a library that provides a shared implementation for command line 19 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 20 // IR before JIT-compiling and executing the latter. 21 // 22 // The translation can be customized by providing an MLIR to MLIR 23 // transformation. 24 //===----------------------------------------------------------------------===// 25 26 #include "mlir/Support/JitRunner.h" 27 28 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 29 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 30 #include "mlir/ExecutionEngine/ExecutionEngine.h" 31 #include "mlir/ExecutionEngine/MemRefUtils.h" 32 #include "mlir/ExecutionEngine/OptUtils.h" 33 #include "mlir/IR/MLIRContext.h" 34 #include "mlir/IR/Module.h" 35 #include "mlir/IR/StandardTypes.h" 36 #include "mlir/Parser.h" 37 #include "mlir/Pass/Pass.h" 38 #include "mlir/Pass/PassManager.h" 39 #include "mlir/Support/FileUtilities.h" 40 #include "mlir/Transforms/Passes.h" 41 42 #include "llvm/ADT/STLExtras.h" 43 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 44 #include "llvm/IR/IRBuilder.h" 45 #include "llvm/IR/LLVMContext.h" 46 #include "llvm/IR/LegacyPassNameParser.h" 47 #include "llvm/IR/Module.h" 48 #include "llvm/Support/CommandLine.h" 49 #include "llvm/Support/FileUtilities.h" 50 #include "llvm/Support/InitLLVM.h" 51 #include "llvm/Support/SourceMgr.h" 52 #include "llvm/Support/StringSaver.h" 53 #include "llvm/Support/TargetSelect.h" 54 #include "llvm/Support/ToolOutputFile.h" 55 #include <numeric> 56 57 using namespace mlir; 58 using llvm::Error; 59 60 static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional, 61 llvm::cl::desc("<input file>"), 62 llvm::cl::init("-")); 63 static llvm::cl::opt<std::string> 64 initValue("init-value", llvm::cl::desc("Initial value of MemRef elements"), 65 llvm::cl::value_desc("<float value>"), llvm::cl::init("0.0")); 66 static llvm::cl::opt<std::string> 67 mainFuncName("e", llvm::cl::desc("The function to be called"), 68 llvm::cl::value_desc("<function name>"), 69 llvm::cl::init("main")); 70 static llvm::cl::opt<std::string> mainFuncType( 71 "entry-point-result", 72 llvm::cl::desc("Textual description of the function type to be called"), 73 llvm::cl::value_desc("f32 | memrefs | void"), llvm::cl::init("memrefs")); 74 75 static llvm::cl::OptionCategory optFlags("opt-like flags"); 76 77 // CLI list of pass information 78 static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> 79 llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"), 80 llvm::cl::cat(optFlags)); 81 82 // CLI variables for -On options. 83 static llvm::cl::opt<bool> optO0("O0", llvm::cl::desc("Run opt O0 passes"), 84 llvm::cl::cat(optFlags)); 85 static llvm::cl::opt<bool> optO1("O1", llvm::cl::desc("Run opt O1 passes"), 86 llvm::cl::cat(optFlags)); 87 static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"), 88 llvm::cl::cat(optFlags)); 89 static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"), 90 llvm::cl::cat(optFlags)); 91 92 static llvm::cl::OptionCategory clOptionsCategory("linking options"); 93 static llvm::cl::list<std::string> 94 clSharedLibs("shared-libs", llvm::cl::desc("Libraries to link dynamically"), 95 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 96 llvm::cl::cat(clOptionsCategory)); 97 98 // CLI variables for debugging. 99 static llvm::cl::opt<bool> dumpObjectFile( 100 "dump-object-file", 101 llvm::cl::desc("Dump JITted-compiled object to file specified with " 102 "-object-filename (<input file>.o by default).")); 103 104 static llvm::cl::opt<std::string> objectFilename( 105 "object-filename", 106 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")); 107 108 static OwningModuleRef parseMLIRInput(StringRef inputFilename, 109 MLIRContext *context) { 110 // Set up the input file. 111 std::string errorMessage; 112 auto file = openInputFile(inputFilename, &errorMessage); 113 if (!file) { 114 llvm::errs() << errorMessage << "\n"; 115 return nullptr; 116 } 117 118 llvm::SourceMgr sourceMgr; 119 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 120 return OwningModuleRef(parseSourceFile(sourceMgr, context)); 121 } 122 123 // Initialize the relevant subsystems of LLVM. 124 static void initializeLLVM() { 125 llvm::InitializeNativeTarget(); 126 llvm::InitializeNativeTargetAsmPrinter(); 127 } 128 129 static inline Error make_string_error(const llvm::Twine &message) { 130 return llvm::make_error<llvm::StringError>(message.str(), 131 llvm::inconvertibleErrorCode()); 132 } 133 134 static void printOneMemRef(Type t, void *val) { 135 auto memRefType = t.cast<MemRefType>(); 136 auto shape = memRefType.getShape(); 137 int64_t size = std::accumulate(shape.begin(), shape.end(), 1, 138 std::multiplies<int64_t>()); 139 for (int64_t i = 0; i < size; ++i) { 140 llvm::outs() << reinterpret_cast<StaticFloatMemRef *>(val)->data[i] << ' '; 141 } 142 llvm::outs() << '\n'; 143 } 144 145 static void printMemRefArguments(ArrayRef<Type> argTypes, 146 ArrayRef<Type> resTypes, 147 ArrayRef<void *> args) { 148 auto properArgs = args.take_front(argTypes.size()); 149 for (const auto &kvp : llvm::zip(argTypes, properArgs)) { 150 auto type = std::get<0>(kvp); 151 auto val = std::get<1>(kvp); 152 printOneMemRef(type, val); 153 } 154 155 auto results = args.drop_front(argTypes.size()); 156 for (const auto &kvp : llvm::zip(resTypes, results)) { 157 auto type = std::get<0>(kvp); 158 auto val = std::get<1>(kvp); 159 printOneMemRef(type, val); 160 } 161 } 162 163 // Calls the passes necessary to convert affine and standard dialects to the 164 // LLVM IR dialect. 165 // Currently, these passes are: 166 // - CSE 167 // - canonicalization 168 // - affine to standard lowering 169 // - standard to llvm lowering 170 static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) { 171 PassManager manager; 172 manager.addPass(mlir::createCanonicalizerPass()); 173 manager.addPass(mlir::createCSEPass()); 174 manager.addPass(mlir::createLowerAffinePass()); 175 manager.addPass(mlir::createConvertToLLVMIRPass()); 176 return manager.run(module); 177 } 178 179 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 180 static Error 181 compileAndExecute(ModuleOp module, StringRef entryPoint, 182 std::function<llvm::Error(llvm::Module *)> transformer, 183 void **args) { 184 SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end()); 185 auto expectedEngine = 186 mlir::ExecutionEngine::create(module, transformer, libs); 187 if (!expectedEngine) 188 return expectedEngine.takeError(); 189 190 auto engine = std::move(*expectedEngine); 191 auto expectedFPtr = engine->lookup(entryPoint); 192 if (!expectedFPtr) 193 return expectedFPtr.takeError(); 194 195 if (dumpObjectFile) 196 engine->dumpToObjectFile(objectFilename.empty() ? inputFilename + ".o" 197 : objectFilename); 198 199 void (*fptr)(void **) = *expectedFPtr; 200 (*fptr)(args); 201 202 return Error::success(); 203 } 204 205 static Error compileAndExecuteVoidFunction( 206 ModuleOp module, StringRef entryPoint, 207 std::function<llvm::Error(llvm::Module *)> transformer) { 208 FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint); 209 if (!mainFunction || mainFunction.getBlocks().empty()) 210 return make_string_error("entry point not found"); 211 void *empty = nullptr; 212 return compileAndExecute(module, entryPoint, transformer, &empty); 213 } 214 215 static Error compileAndExecuteFunctionWithMemRefs( 216 ModuleOp module, StringRef entryPoint, 217 std::function<llvm::Error(llvm::Module *)> transformer) { 218 FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint); 219 if (!mainFunction || mainFunction.getBlocks().empty()) { 220 return make_string_error("entry point not found"); 221 } 222 223 // Store argument and result types of the original function necessary to 224 // pretty print the results, because the function itself will be rewritten 225 // to use the LLVM dialect. 226 SmallVector<Type, 8> argTypes = 227 llvm::to_vector<8>(mainFunction.getType().getInputs()); 228 SmallVector<Type, 8> resTypes = 229 llvm::to_vector<8>(mainFunction.getType().getResults()); 230 231 float init = std::stof(initValue.getValue()); 232 233 auto expectedArguments = allocateMemRefArguments(mainFunction, init); 234 if (!expectedArguments) 235 return expectedArguments.takeError(); 236 237 if (failed(convertAffineStandardToLLVMIR(module))) 238 return make_string_error("conversion to the LLVM IR dialect failed"); 239 240 if (auto error = compileAndExecute(module, entryPoint, transformer, 241 expectedArguments->data())) 242 return error; 243 244 printMemRefArguments(argTypes, resTypes, *expectedArguments); 245 freeMemRefArguments(*expectedArguments); 246 return Error::success(); 247 } 248 249 static Error compileAndExecuteSingleFloatReturnFunction( 250 ModuleOp module, StringRef entryPoint, 251 std::function<llvm::Error(llvm::Module *)> transformer) { 252 FuncOp mainFunction = module.lookupSymbol<FuncOp>(entryPoint); 253 if (!mainFunction || mainFunction.isExternal()) { 254 return make_string_error("entry point not found"); 255 } 256 257 if (!mainFunction.getType().getInputs().empty()) 258 return make_string_error("function inputs not supported"); 259 260 if (mainFunction.getType().getResults().size() != 1) 261 return make_string_error("only single f32 function result supported"); 262 263 auto t = mainFunction.getType().getResults()[0].dyn_cast<LLVM::LLVMType>(); 264 if (!t) 265 return make_string_error("only single llvm.f32 function result supported"); 266 auto *llvmTy = t.getUnderlyingType(); 267 if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext())) 268 return make_string_error("only single llvm.f32 function result supported"); 269 270 float res; 271 struct { 272 void *data; 273 } data; 274 data.data = &res; 275 if (auto error = 276 compileAndExecute(module, entryPoint, transformer, (void **)&data)) 277 return error; 278 279 // Intentional printing of the output so we can test. 280 llvm::outs() << res; 281 282 return Error::success(); 283 } 284 285 // Entry point for all CPU runners. Expects the common argc/argv arguments for 286 // standard C++ main functions and an mlirTransformer. 287 // The latter is applied after parsing the input into MLIR IR and before passing 288 // the MLIR module to the ExecutionEngine. 289 int mlir::JitRunnerMain( 290 int argc, char **argv, 291 llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { 292 llvm::InitLLVM y(argc, argv); 293 294 initializeLLVM(); 295 mlir::initializeLLVMPasses(); 296 297 llvm::SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 298 optO0, optO1, optO2, optO3}; 299 300 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 301 302 llvm::SmallVector<const llvm::PassInfo *, 4> passes; 303 llvm::Optional<unsigned> optLevel; 304 unsigned optCLIPosition = 0; 305 // Determine if there is an optimization flag present, and its CLI position 306 // (optCLIPosition). 307 for (unsigned j = 0; j < 4; ++j) { 308 auto &flag = optFlags[j].get(); 309 if (flag) { 310 optLevel = j; 311 optCLIPosition = flag.getPosition(); 312 break; 313 } 314 } 315 // Generate vector of pass information, plus the index at which we should 316 // insert any optimization passes in that vector (optPosition). 317 unsigned optPosition = 0; 318 for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) { 319 passes.push_back(llvmPasses[i]); 320 if (optCLIPosition < llvmPasses.getPosition(i)) { 321 optPosition = i; 322 optCLIPosition = UINT_MAX; // To ensure we never insert again 323 } 324 } 325 326 MLIRContext context; 327 auto m = parseMLIRInput(inputFilename, &context); 328 if (!m) { 329 llvm::errs() << "could not parse the input IR\n"; 330 return 1; 331 } 332 333 if (mlirTransformer) 334 if (failed(mlirTransformer(m.get()))) 335 return EXIT_FAILURE; 336 337 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 338 if (!tmBuilderOrError) { 339 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 340 return EXIT_FAILURE; 341 } 342 auto tmOrError = tmBuilderOrError->createTargetMachine(); 343 if (!tmOrError) { 344 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 345 return EXIT_FAILURE; 346 } 347 348 auto transformer = mlir::makeLLVMPassesTransformer( 349 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 350 351 // Get the function used to compile and execute the module. 352 using CompileAndExecuteFnT = Error (*)( 353 ModuleOp, StringRef, std::function<llvm::Error(llvm::Module *)>); 354 auto compileAndExecuteFn = 355 llvm::StringSwitch<CompileAndExecuteFnT>(mainFuncType.getValue()) 356 .Case("f32", compileAndExecuteSingleFloatReturnFunction) 357 .Case("memrefs", compileAndExecuteFunctionWithMemRefs) 358 .Case("void", compileAndExecuteVoidFunction) 359 .Default(nullptr); 360 361 Error error = 362 compileAndExecuteFn 363 ? compileAndExecuteFn(m.get(), mainFuncName.getValue(), transformer) 364 : make_string_error("unsupported function type"); 365 366 int exitCode = EXIT_SUCCESS; 367 llvm::handleAllErrors(std::move(error), 368 [&exitCode](const llvm::ErrorInfoBase &info) { 369 llvm::errs() << "Error: "; 370 info.log(llvm::errs()); 371 llvm::errs() << '\n'; 372 exitCode = EXIT_FAILURE; 373 }); 374 375 return exitCode; 376 }