github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/bindings/python/pybind.cpp (about) 1 //===- pybind.cpp - MLIR Python bindings ----------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/StringRef.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/Module.h" 22 #include "llvm/Support/TargetSelect.h" 23 #include "llvm/Support/raw_ostream.h" 24 #include <cstddef> 25 #include <unordered_map> 26 27 #include "mlir-c/Core.h" 28 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 29 #include "mlir/EDSC/Builders.h" 30 #include "mlir/EDSC/Helpers.h" 31 #include "mlir/EDSC/Intrinsics.h" 32 #include "mlir/ExecutionEngine/ExecutionEngine.h" 33 #include "mlir/IR/Attributes.h" 34 #include "mlir/IR/Function.h" 35 #include "mlir/IR/Module.h" 36 #include "mlir/IR/Types.h" 37 #include "mlir/Pass/Pass.h" 38 #include "mlir/Pass/PassManager.h" 39 #include "mlir/Target/LLVMIR.h" 40 #include "mlir/Transforms/Passes.h" 41 #include "pybind11/pybind11.h" 42 #include "pybind11/pytypes.h" 43 #include "pybind11/stl.h" 44 45 static bool inited = [] { 46 llvm::InitializeNativeTarget(); 47 llvm::InitializeNativeTargetAsmPrinter(); 48 return true; 49 }(); 50 51 namespace mlir { 52 namespace edsc { 53 namespace python { 54 55 namespace py = pybind11; 56 57 struct PythonAttribute; 58 struct PythonAttributedType; 59 struct PythonBindable; 60 struct PythonExpr; 61 struct PythonFunctionContext; 62 struct PythonStmt; 63 struct PythonBlock; 64 65 struct PythonType { 66 PythonType() : type{nullptr} {} 67 PythonType(mlir_type_t t) : type{t} {} 68 69 operator mlir_type_t() const { return type; } 70 71 PythonAttributedType attachAttributeDict( 72 const std::unordered_map<std::string, PythonAttribute> &attrs) const; 73 74 std::string str() { 75 mlir::Type f = mlir::Type::getFromOpaquePointer(type); 76 std::string res; 77 llvm::raw_string_ostream os(res); 78 f.print(os); 79 return res; 80 } 81 82 mlir_type_t type; 83 }; 84 85 struct PythonValueHandle { 86 PythonValueHandle(PythonType type) 87 : value(mlir::Type::getFromOpaquePointer(type.type)) {} 88 PythonValueHandle(const PythonValueHandle &other) = default; 89 PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {} 90 operator ValueHandle() const { return value; } 91 operator ValueHandle &() { return value; } 92 93 std::string str() const { 94 return std::to_string(reinterpret_cast<intptr_t>(value.getValue())); 95 } 96 97 PythonValueHandle call(const std::vector<PythonValueHandle> &args) { 98 assert(value.hasType() && value.getType().isa<FunctionType>() && 99 "can only call function-typed values"); 100 101 std::vector<Value *> argValues; 102 argValues.reserve(args.size()); 103 for (auto arg : args) 104 argValues.push_back(arg.value.getValue()); 105 return ValueHandle::create<CallIndirectOp>(value, argValues); 106 } 107 108 mlir::edsc::ValueHandle value; 109 }; 110 111 struct PythonFunction { 112 PythonFunction() : function{nullptr} {} 113 PythonFunction(mlir_func_t f) : function{f} {} 114 PythonFunction(mlir::FuncOp f) 115 : function(const_cast<void *>(f.getAsOpaquePointer())) {} 116 operator mlir_func_t() { return function; } 117 std::string str() { 118 mlir::FuncOp f = mlir::FuncOp::getFromOpaquePointer(function); 119 std::string res; 120 llvm::raw_string_ostream os(res); 121 f.print(os); 122 return res; 123 } 124 125 // If the function does not yet have an entry block, i.e. if it is a function 126 // declaration, add the entry block, transforming the declaration into a 127 // definition. Return true if the block was added, false otherwise. 128 bool define() { 129 auto f = mlir::FuncOp::getFromOpaquePointer(function); 130 if (!f.getBlocks().empty()) 131 return false; 132 133 f.addEntryBlock(); 134 return true; 135 } 136 137 PythonValueHandle arg(unsigned index) { 138 auto f = mlir::FuncOp::getFromOpaquePointer(function); 139 assert(index < f.getNumArguments() && "argument index out of bounds"); 140 return PythonValueHandle(ValueHandle(f.getArgument(index))); 141 } 142 143 mlir_func_t function; 144 }; 145 146 /// Trivial C++ wrappers make use of the EDSC C API. 147 struct PythonMLIRModule { 148 PythonMLIRModule() 149 : mlirContext(), 150 module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&mlirContext))), 151 moduleManager(*module) {} 152 153 PythonType makeScalarType(const std::string &mlirElemType, 154 unsigned bitwidth) { 155 return ::makeScalarType(mlir_context_t{&mlirContext}, mlirElemType.c_str(), 156 bitwidth); 157 } 158 PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) { 159 return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType, 160 int64_list_t{sizes.data(), sizes.size()}); 161 } 162 PythonType makeIndexType() { 163 return ::makeIndexType(mlir_context_t{&mlirContext}); 164 } 165 166 // Declare a function with the given name, input types and their attributes, 167 // output types, and function attributes, but do not define it. 168 PythonFunction declareFunction(const std::string &name, 169 const py::list &inputs, 170 const std::vector<PythonType> &outputTypes, 171 const py::kwargs &funcAttributes); 172 173 // Declare a function with the given name, input types and their attributes, 174 // output types, and function attributes. 175 PythonFunction makeFunction(const std::string &name, const py::list &inputs, 176 const std::vector<PythonType> &outputTypes, 177 const py::kwargs &funcAttributes) { 178 auto declaration = 179 declareFunction(name, inputs, outputTypes, funcAttributes); 180 declaration.define(); 181 return declaration; 182 } 183 184 // Create a custom op given its name and arguments. 185 PythonExpr op(const std::string &name, PythonType type, 186 const py::list &arguments, const py::list &successors, 187 py::kwargs attributes); 188 189 // Create an integer attribute. 190 PythonAttribute integerAttr(PythonType type, int64_t value); 191 192 // Create a boolean attribute. 193 PythonAttribute boolAttr(bool value); 194 195 void compile() { 196 PassManager manager; 197 manager.addPass(mlir::createCanonicalizerPass()); 198 manager.addPass(mlir::createCSEPass()); 199 manager.addPass(mlir::createLowerAffinePass()); 200 manager.addPass(mlir::createConvertToLLVMIRPass()); 201 if (failed(manager.run(*module))) { 202 llvm::errs() << "conversion to the LLVM IR dialect failed\n"; 203 return; 204 } 205 206 auto created = mlir::ExecutionEngine::create(*module); 207 llvm::handleAllErrors(created.takeError(), 208 [](const llvm::ErrorInfoBase &b) { 209 b.log(llvm::errs()); 210 assert(false); 211 }); 212 engine = std::move(*created); 213 } 214 215 std::string getIR() { 216 std::string res; 217 llvm::raw_string_ostream os(res); 218 module->print(os); 219 return res; 220 } 221 222 uint64_t getEngineAddress() { 223 assert(engine && "module must be compiled into engine first"); 224 return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get())); 225 } 226 227 PythonFunction getNamedFunction(const std::string &name) { 228 return moduleManager.lookupSymbol<FuncOp>(name); 229 } 230 231 PythonFunctionContext 232 makeFunctionContext(const std::string &name, const py::list &inputs, 233 const std::vector<PythonType> &outputs, 234 const py::kwargs &attributes); 235 236 private: 237 mlir::MLIRContext mlirContext; 238 // One single module in a python-exposed MLIRContext for now. 239 mlir::OwningModuleRef module; 240 mlir::ModuleManager moduleManager; 241 std::unique_ptr<mlir::ExecutionEngine> engine; 242 }; 243 244 struct PythonFunctionContext { 245 PythonFunctionContext(PythonFunction f) : function(f) {} 246 PythonFunctionContext(PythonMLIRModule &module, const std::string &name, 247 const py::list &inputs, 248 const std::vector<PythonType> &outputs, 249 const py::kwargs &attributes) { 250 auto function = module.declareFunction(name, inputs, outputs, attributes); 251 function.define(); 252 } 253 254 PythonFunction enter() { 255 assert(function.function && "function is not set up"); 256 auto mlirFunc = mlir::FuncOp::getFromOpaquePointer(function.function); 257 contextBuilder.emplace(mlirFunc.getBody()); 258 context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc()); 259 return function; 260 } 261 262 void exit(py::object, py::object, py::object) { 263 delete context; 264 context = nullptr; 265 contextBuilder.reset(); 266 } 267 268 PythonFunction function; 269 mlir::edsc::ScopedContext *context; 270 llvm::Optional<OpBuilder> contextBuilder; 271 }; 272 273 PythonFunctionContext PythonMLIRModule::makeFunctionContext( 274 const std::string &name, const py::list &inputs, 275 const std::vector<PythonType> &outputs, const py::kwargs &attributes) { 276 auto func = declareFunction(name, inputs, outputs, attributes); 277 func.define(); 278 return PythonFunctionContext(func); 279 } 280 281 struct PythonBlockHandle { 282 PythonBlockHandle() : value(nullptr) {} 283 PythonBlockHandle(const PythonBlockHandle &other) = default; 284 PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {} 285 operator mlir::edsc::BlockHandle() const { return value; } 286 287 PythonValueHandle arg(int index) { return arguments[index]; } 288 289 std::string str() { 290 std::string s; 291 llvm::raw_string_ostream os(s); 292 value.getBlock()->print(os); 293 return os.str(); 294 } 295 296 mlir::edsc::BlockHandle value; 297 std::vector<mlir::edsc::ValueHandle> arguments; 298 }; 299 300 struct PythonLoopContext { 301 PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step) 302 : lb(lb), ub(ub), step(step) {} 303 PythonLoopContext(const PythonLoopContext &) = delete; 304 PythonLoopContext(PythonLoopContext &&) = default; 305 PythonLoopContext &operator=(const PythonLoopContext &) = delete; 306 PythonLoopContext &operator=(PythonLoopContext &&) = default; 307 ~PythonLoopContext() { assert(!builder && "did not exit from the context"); } 308 309 PythonValueHandle enter() { 310 ValueHandle iv(lb.value.getType()); 311 builder = new LoopBuilder(&iv, lb.value, ub.value, step); 312 return iv; 313 } 314 315 void exit(py::object, py::object, py::object) { 316 (*builder)({}); // exit from the builder's scope. 317 delete builder; 318 builder = nullptr; 319 } 320 321 PythonValueHandle lb, ub; 322 int64_t step; 323 LoopBuilder *builder = nullptr; 324 }; 325 326 struct PythonLoopNestContext { 327 PythonLoopNestContext(const std::vector<PythonValueHandle> &lbs, 328 const std::vector<PythonValueHandle> &ubs, 329 const std::vector<int64_t> steps) 330 : lbs(lbs), ubs(ubs), steps(steps) { 331 assert(lbs.size() == ubs.size() && lbs.size() == steps.size() && 332 "expected the same number of lower, upper bounds, and steps"); 333 } 334 PythonLoopNestContext(const PythonLoopNestContext &) = delete; 335 PythonLoopNestContext(PythonLoopNestContext &&) = default; 336 PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete; 337 PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default; 338 ~PythonLoopNestContext() { 339 assert(!builder && "did not exit from the context"); 340 } 341 342 std::vector<PythonValueHandle> enter() { 343 if (steps.empty()) 344 return {}; 345 346 auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer()); 347 std::vector<PythonValueHandle> handles(steps.size(), 348 PythonValueHandle(type)); 349 std::vector<ValueHandle *> handlePtrs; 350 handlePtrs.reserve(steps.size()); 351 for (auto &h : handles) 352 handlePtrs.push_back(&h.value); 353 builder = new LoopNestBuilder( 354 handlePtrs, std::vector<ValueHandle>(lbs.begin(), lbs.end()), 355 std::vector<ValueHandle>(ubs.begin(), ubs.end()), steps); 356 return handles; 357 } 358 359 void exit(py::object, py::object, py::object) { 360 (*builder)({}); // exit from the builder's scope. 361 delete builder; 362 builder = nullptr; 363 } 364 365 std::vector<PythonValueHandle> lbs; 366 std::vector<PythonValueHandle> ubs; 367 std::vector<int64_t> steps; 368 LoopNestBuilder *builder = nullptr; 369 }; 370 371 struct PythonBlockAppender { 372 PythonBlockAppender(const PythonBlockHandle &handle) : handle(handle) {} 373 PythonBlockHandle handle; 374 }; 375 376 struct PythonBlockContext { 377 public: 378 PythonBlockContext() { 379 createBlockBuilder(); 380 clearBuilder(); 381 } 382 PythonBlockContext(const std::vector<PythonType> &argTypes) { 383 handle.arguments.reserve(argTypes.size()); 384 for (const auto &t : argTypes) { 385 auto type = 386 Type::getFromOpaquePointer(reinterpret_cast<const void *>(t.type)); 387 handle.arguments.emplace_back(type); 388 } 389 createBlockBuilder(); 390 clearBuilder(); 391 } 392 PythonBlockContext(const PythonBlockAppender &a) : handle(a.handle) {} 393 PythonBlockContext(const PythonBlockContext &) = delete; 394 PythonBlockContext(PythonBlockContext &&) = default; 395 PythonBlockContext &operator=(const PythonBlockContext &) = delete; 396 PythonBlockContext &operator=(PythonBlockContext &&) = default; 397 ~PythonBlockContext() { 398 assert(!builder && "did not exit from the block context"); 399 } 400 401 // EDSC maintain an implicit stack of builders (mostly for keeping track of 402 // insretion points); every operation gets inserted using the top-of-the-stack 403 // builder. Creating a new EDSC Builder automatically puts it on the stack, 404 // effectively entering the block for it. 405 void createBlockBuilder() { 406 if (handle.value.getBlock()) { 407 builder = new BlockBuilder(handle.value, mlir::edsc::Append()); 408 } else { 409 std::vector<ValueHandle *> args; 410 args.reserve(handle.arguments.size()); 411 for (auto &a : handle.arguments) 412 args.push_back(&a); 413 builder = new BlockBuilder(&handle.value, args); 414 } 415 } 416 417 PythonBlockHandle enter() { 418 createBlockBuilder(); 419 return handle; 420 } 421 422 void exit(py::object, py::object, py::object) { clearBuilder(); } 423 424 PythonBlockHandle getHandle() { return handle; } 425 426 // EDSC maintain an implicit stack of builders (mostly for keeping track of 427 // insretion points); every operation gets inserted using the top-of-the-stack 428 // builder. Calling operator() on a builder pops the builder from the stack, 429 // effectively resetting the insertion point to its position before we entered 430 // the block. 431 void clearBuilder() { 432 (*builder)({}); // exit from the builder's scope. 433 delete builder; 434 builder = nullptr; 435 } 436 437 PythonBlockHandle handle; 438 BlockBuilder *builder = nullptr; 439 }; 440 441 struct PythonAttribute { 442 PythonAttribute() : attr(nullptr) {} 443 PythonAttribute(const mlir_attr_t &a) : attr(a) {} 444 PythonAttribute(const PythonAttribute &other) = default; 445 operator mlir_attr_t() { return attr; } 446 447 std::string str() const { 448 if (!attr) 449 return "##null attr##"; 450 451 std::string res; 452 llvm::raw_string_ostream os(res); 453 Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(attr)) 454 .print(os); 455 return res; 456 } 457 458 mlir_attr_t attr; 459 }; 460 461 struct PythonAttributedType { 462 PythonAttributedType() : type(nullptr) {} 463 PythonAttributedType(mlir_type_t t) : type(t) {} 464 PythonAttributedType( 465 PythonType t, 466 const std::unordered_map<std::string, PythonAttribute> &attributes = 467 std::unordered_map<std::string, PythonAttribute>()) 468 : type(t), attrs(attributes) {} 469 470 operator mlir_type_t() const { return type.type; } 471 operator PythonType() const { return type; } 472 473 // Return a vector of named attribute descriptors. The vector owns the 474 // mlir_named_attr_t objects it contains, but not the names and attributes 475 // those objects point to (names and opaque pointers to attributes are owned 476 // by `this`). 477 std::vector<mlir_named_attr_t> getNamedAttrs() const { 478 std::vector<mlir_named_attr_t> result; 479 result.reserve(attrs.size()); 480 for (const auto &namedAttr : attrs) 481 result.push_back({namedAttr.first.c_str(), namedAttr.second.attr}); 482 return result; 483 } 484 485 std::string str() { 486 mlir::Type t = mlir::Type::getFromOpaquePointer(type); 487 std::string res; 488 llvm::raw_string_ostream os(res); 489 t.print(os); 490 if (attrs.empty()) 491 return os.str(); 492 493 os << '{'; 494 bool first = true; 495 for (const auto &namedAttr : attrs) { 496 if (first) 497 first = false; 498 else 499 os << ", "; 500 os << namedAttr.first << ": " << namedAttr.second.str(); 501 } 502 os << '}'; 503 504 return os.str(); 505 } 506 507 private: 508 PythonType type; 509 std::unordered_map<std::string, PythonAttribute> attrs; 510 }; 511 512 struct PythonIndexedValue { 513 explicit PythonIndexedValue(PythonType type) 514 : indexed(Type::getFromOpaquePointer(type.type)) {} 515 explicit PythonIndexedValue(const IndexedValue &other) : indexed(other) {} 516 PythonIndexedValue(PythonValueHandle handle) : indexed(handle.value) {} 517 PythonIndexedValue(const PythonIndexedValue &other) = default; 518 519 // Create a new indexed value with the same base as this one but with indices 520 // provided as arguments. 521 PythonIndexedValue index(const std::vector<PythonValueHandle> &indices) { 522 std::vector<ValueHandle> handles(indices.begin(), indices.end()); 523 return PythonIndexedValue(IndexedValue(indexed(handles))); 524 } 525 526 void store(const std::vector<PythonValueHandle> &indices, 527 PythonValueHandle value) { 528 // Uses the overloaded `opreator=` to emit a store. 529 index(indices).indexed = value.value; 530 } 531 532 PythonValueHandle load(const std::vector<PythonValueHandle> &indices) { 533 // Uses the overloaded cast to `ValueHandle` to emit a load. 534 return static_cast<ValueHandle>(index(indices).indexed); 535 } 536 537 IndexedValue indexed; 538 }; 539 540 template <typename ListTy, typename PythonTy, typename Ty> 541 ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) { 542 for (auto &inp : list) { 543 owning.push_back(Ty{inp.cast<PythonTy>()}); 544 } 545 return ListTy{owning.data(), owning.size()}; 546 } 547 548 static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning, 549 const py::list &types) { 550 return makeCList<mlir_type_list_t, PythonType>(owning, types); 551 } 552 553 PythonFunction 554 PythonMLIRModule::declareFunction(const std::string &name, 555 const py::list &inputs, 556 const std::vector<PythonType> &outputTypes, 557 const py::kwargs &funcAttributes) { 558 559 std::vector<PythonAttributedType> attributedInputs; 560 attributedInputs.reserve(inputs.size()); 561 for (const auto &in : inputs) { 562 std::string className = in.get_type().str(); 563 if (className.find(".Type'") != std::string::npos) 564 attributedInputs.emplace_back(in.cast<PythonType>()); 565 else 566 attributedInputs.push_back(in.cast<PythonAttributedType>()); 567 } 568 569 // Create the function type. 570 std::vector<mlir_type_t> ins(attributedInputs.begin(), 571 attributedInputs.end()); 572 std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end()); 573 auto funcType = ::makeFunctionType( 574 mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()}, 575 mlir_type_list_t{outs.data(), outs.size()}); 576 577 // Build the list of function attributes. 578 std::vector<mlir::NamedAttribute> attrs; 579 attrs.reserve(funcAttributes.size()); 580 for (const auto &named : funcAttributes) 581 attrs.emplace_back( 582 Identifier::get(std::string(named.first.str()), &mlirContext), 583 mlir::Attribute::getFromOpaquePointer(reinterpret_cast<const void *>( 584 named.second.cast<PythonAttribute>().attr))); 585 586 // Build the list of lists of function argument attributes. 587 std::vector<mlir::NamedAttributeList> inputAttrs; 588 inputAttrs.reserve(attributedInputs.size()); 589 for (const auto &in : attributedInputs) { 590 std::vector<mlir::NamedAttribute> inAttrs; 591 for (const auto &named : in.getNamedAttrs()) 592 inAttrs.emplace_back(Identifier::get(named.name, &mlirContext), 593 mlir::Attribute::getFromOpaquePointer( 594 reinterpret_cast<const void *>(named.value))); 595 inputAttrs.emplace_back(inAttrs); 596 } 597 598 // Create the function itself. 599 auto func = mlir::FuncOp::create( 600 UnknownLoc::get(&mlirContext), name, 601 mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs, 602 inputAttrs); 603 moduleManager.insert(func); 604 return func; 605 } 606 607 PythonAttributedType PythonType::attachAttributeDict( 608 const std::unordered_map<std::string, PythonAttribute> &attrs) const { 609 return PythonAttributedType(*this, attrs); 610 } 611 612 PythonAttribute PythonMLIRModule::integerAttr(PythonType type, int64_t value) { 613 return PythonAttribute(::makeIntegerAttr(type, value)); 614 } 615 616 PythonAttribute PythonMLIRModule::boolAttr(bool value) { 617 return PythonAttribute(::makeBoolAttr(&mlirContext, value)); 618 } 619 620 PYBIND11_MODULE(pybind, m) { 621 m.doc() = 622 "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)"; 623 m.def("version", []() { return "EDSC Python extensions v1.0"; }); 624 625 py::class_<PythonLoopContext>( 626 m, "LoopContext", "A context for building the body of a 'for' loop") 627 .def(py::init<PythonValueHandle, PythonValueHandle, int64_t>()) 628 .def("__enter__", &PythonLoopContext::enter) 629 .def("__exit__", &PythonLoopContext::exit); 630 631 py::class_<PythonLoopNestContext>(m, "LoopNestContext", 632 "A context for building the body of a the " 633 "innermost loop in a nest of 'for' loops") 634 .def(py::init<const std::vector<PythonValueHandle> &, 635 const std::vector<PythonValueHandle> &, 636 const std::vector<int64_t> &>()) 637 .def("__enter__", &PythonLoopNestContext::enter) 638 .def("__exit__", &PythonLoopNestContext::exit); 639 640 m.def("constant_index", [](int64_t val) -> PythonValueHandle { 641 return ValueHandle(index_t(val)); 642 }); 643 m.def("constant_int", [](int64_t val, int width) -> PythonValueHandle { 644 return ValueHandle::create<ConstantIntOp>(val, width); 645 }); 646 m.def("constant_float", [](double val, PythonType type) -> PythonValueHandle { 647 FloatType floatType = 648 Type::getFromOpaquePointer(type.type).cast<FloatType>(); 649 assert(floatType); 650 auto value = APFloat(val); 651 bool lostPrecision; 652 value.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, 653 &lostPrecision); 654 return ValueHandle::create<ConstantFloatOp>(value, floatType); 655 }); 656 m.def("constant_function", [](PythonFunction func) -> PythonValueHandle { 657 auto function = FuncOp::getFromOpaquePointer(func.function); 658 auto attr = SymbolRefAttr::get(function.getName(), function.getContext()); 659 return ValueHandle::create<ConstantOp>(function.getType(), attr); 660 }); 661 m.def("appendTo", [](const PythonBlockHandle &handle) { 662 return PythonBlockAppender(handle); 663 }); 664 m.def( 665 "ret", 666 [](const std::vector<PythonValueHandle> &args) { 667 std::vector<ValueHandle> values(args.begin(), args.end()); 668 (intrinsics::ret(ArrayRef<ValueHandle>{values})); // vexing parse 669 return PythonValueHandle(nullptr); 670 }, 671 py::arg("args") = std::vector<PythonValueHandle>()); 672 m.def( 673 "br", 674 [](const PythonBlockHandle &dest, 675 const std::vector<PythonValueHandle> &args) { 676 std::vector<ValueHandle> values(args.begin(), args.end()); 677 intrinsics::br(dest, values); 678 return PythonValueHandle(nullptr); 679 }, 680 py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>()); 681 m.def( 682 "cond_br", 683 [](PythonValueHandle condition, const PythonBlockHandle &trueDest, 684 const std::vector<PythonValueHandle> &trueArgs, 685 const PythonBlockHandle &falseDest, 686 const std::vector<PythonValueHandle> &falseArgs) -> PythonValueHandle { 687 std::vector<ValueHandle> trueArguments(trueArgs.begin(), 688 trueArgs.end()); 689 std::vector<ValueHandle> falseArguments(falseArgs.begin(), 690 falseArgs.end()); 691 intrinsics::cond_br(condition, trueDest, trueArguments, falseDest, 692 falseArguments); 693 return PythonValueHandle(nullptr); 694 }); 695 m.def("select", 696 [](PythonValueHandle condition, PythonValueHandle trueValue, 697 PythonValueHandle falseValue) -> PythonValueHandle { 698 return ValueHandle::create<SelectOp>(condition.value, trueValue.value, 699 falseValue.value); 700 }); 701 m.def("op", 702 [](const std::string &name, 703 const std::vector<PythonValueHandle> &operands, 704 const std::vector<PythonType> &resultTypes, 705 const py::kwargs &attributes) -> PythonValueHandle { 706 std::vector<ValueHandle> operandHandles(operands.begin(), 707 operands.end()); 708 std::vector<Type> types; 709 types.reserve(resultTypes.size()); 710 for (auto t : resultTypes) 711 types.push_back(Type::getFromOpaquePointer(t.type)); 712 713 std::vector<NamedAttribute> attrs; 714 attrs.reserve(attributes.size()); 715 for (const auto &a : attributes) { 716 std::string name = a.first.str(); 717 auto pyAttr = a.second.cast<PythonAttribute>(); 718 auto cppAttr = Attribute::getFromOpaquePointer(pyAttr.attr); 719 auto identifier = 720 Identifier::get(name, ScopedContext::getContext()); 721 attrs.emplace_back(identifier, cppAttr); 722 } 723 724 return ValueHandle::create(name, operandHandles, types, attrs); 725 }); 726 727 py::class_<PythonFunction>(m, "Function", "Wrapping class for mlir::FuncOp.") 728 .def(py::init<PythonFunction>()) 729 .def("__str__", &PythonFunction::str) 730 .def("define", &PythonFunction::define, 731 "Adds a body to the function if it does not already have one. " 732 "Returns true if the body was added") 733 .def("arg", &PythonFunction::arg, 734 "Get the ValueHandle to the indexed argument of the function"); 735 736 py::class_<PythonAttribute>(m, "Attribute", 737 "Wrapping class for mlir::Attribute") 738 .def(py::init<PythonAttribute>()) 739 .def("__str__", &PythonAttribute::str); 740 741 py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.") 742 .def(py::init<PythonType>()) 743 .def("__call__", &PythonType::attachAttributeDict, 744 "Attach the attributes to these type, making it suitable for " 745 "constructing functions with argument attributes") 746 .def("__str__", &PythonType::str); 747 748 py::class_<PythonAttributedType>( 749 m, "AttributedType", 750 "A class containing a wrapped mlir::Type and a wrapped " 751 "mlir::NamedAttributeList that are used together, e.g. in function " 752 "argument declaration") 753 .def(py::init<PythonAttributedType>()) 754 .def("__str__", &PythonAttributedType::str); 755 756 py::class_<PythonMLIRModule>( 757 m, "MLIRModule", 758 "An MLIRModule is the abstraction that owns the allocations to support " 759 "compilation of a single mlir::ModuleOp into an ExecutionEngine backed " 760 "by " 761 "the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, " 762 "adding functions, compiling the module to obtain an ExecutionEngine on " 763 "which named functions may be called. For now the only means to retrieve " 764 "the ExecutionEngine is by calling `get_engine_address`. This mode of " 765 "execution is limited to passing the pointer to C++ where the function " 766 "is called. Extending the API to allow calling JIT compiled functions " 767 "directly require integration with a tensor library (e.g. numpy). This " 768 "is left as the prerogative of libraries and frameworks for now.") 769 .def(py::init<>()) 770 .def("boolAttr", &PythonMLIRModule::boolAttr, 771 "Creates an mlir::BoolAttr with the given value") 772 .def( 773 "integerAttr", &PythonMLIRModule::integerAttr, 774 "Creates an mlir::IntegerAttr of the given type with the given value " 775 "in the context associated with this MLIR module.") 776 .def("declare_function", &PythonMLIRModule::declareFunction, 777 "Declares a new mlir::FuncOp in the current mlir::ModuleOp. The " 778 "function arguments can have attributes. The function has no " 779 "definition and can be linked to an external library.") 780 .def("make_function", &PythonMLIRModule::makeFunction, 781 "Defines a new mlir::FuncOp in the current mlir::ModuleOp.") 782 .def("function_context", &PythonMLIRModule::makeFunctionContext, 783 "Defines a new mlir::FuncOp in the mlir::ModuleOp and creates the " 784 "function context for building the body of the function.") 785 .def("get_function", &PythonMLIRModule::getNamedFunction, 786 "Looks up the function with the given name in the module.") 787 .def( 788 "make_scalar_type", 789 [](PythonMLIRModule &instance, const std::string &type, 790 unsigned bitwidth) { 791 return instance.makeScalarType(type, bitwidth); 792 }, 793 py::arg("type"), py::arg("bitwidth") = 0, 794 "Returns a scalar mlir::Type using the following convention:\n" 795 " - makeScalarType(c, \"bf16\") return an " 796 "`mlir::FloatType::getBF16`\n" 797 " - makeScalarType(c, \"f16\") return an `mlir::FloatType::getF16`\n" 798 " - makeScalarType(c, \"f32\") return an `mlir::FloatType::getF32`\n" 799 " - makeScalarType(c, \"f64\") return an `mlir::FloatType::getF64`\n" 800 " - makeScalarType(c, \"index\") return an `mlir::IndexType::get`\n" 801 " - makeScalarType(c, \"i\", bitwidth) return an " 802 "`mlir::IntegerType::get(bitwidth)`\n\n" 803 " No other combinations are currently supported.") 804 .def("make_memref_type", &PythonMLIRModule::makeMemRefType, 805 "Returns an mlir::MemRefType of an elemental scalar. -1 is used to " 806 "denote symbolic dimensions in the resulting memref shape.") 807 .def("make_index_type", &PythonMLIRModule::makeIndexType, 808 "Returns an mlir::IndexType") 809 .def("compile", &PythonMLIRModule::compile, 810 "Compiles the mlir::ModuleOp to LLVMIR a creates new opaque " 811 "ExecutionEngine backed by the ORC JIT.") 812 .def("get_ir", &PythonMLIRModule::getIR, 813 "Returns a dump of the MLIR representation of the module. This is " 814 "used for serde to support out-of-process execution as well as " 815 "debugging purposes.") 816 .def("get_engine_address", &PythonMLIRModule::getEngineAddress, 817 "Returns the address of the compiled ExecutionEngine. This is used " 818 "for in-process execution.") 819 .def("__str__", &PythonMLIRModule::getIR, 820 "Get the string representation of the module"); 821 822 py::class_<PythonFunctionContext>( 823 m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext") 824 .def(py::init<PythonFunction>()) 825 .def("__enter__", &PythonFunctionContext::enter) 826 .def("__exit__", &PythonFunctionContext::exit); 827 828 { 829 using namespace mlir::edsc::op; 830 py::class_<PythonValueHandle>(m, "ValueHandle", 831 "A wrapper around mlir::edsc::ValueHandle") 832 .def(py::init<PythonType>()) 833 .def(py::init<PythonValueHandle>()) 834 .def("__add__", 835 [](PythonValueHandle lhs, PythonValueHandle rhs) 836 -> PythonValueHandle { return lhs.value + rhs.value; }) 837 .def("__sub__", 838 [](PythonValueHandle lhs, PythonValueHandle rhs) 839 -> PythonValueHandle { return lhs.value - rhs.value; }) 840 .def("__mul__", 841 [](PythonValueHandle lhs, PythonValueHandle rhs) 842 -> PythonValueHandle { return lhs.value * rhs.value; }) 843 .def("__div__", 844 [](PythonValueHandle lhs, PythonValueHandle rhs) 845 -> PythonValueHandle { return lhs.value / rhs.value; }) 846 .def("__truediv__", 847 [](PythonValueHandle lhs, PythonValueHandle rhs) 848 -> PythonValueHandle { return lhs.value / rhs.value; }) 849 .def("__floordiv__", 850 [](PythonValueHandle lhs, PythonValueHandle rhs) 851 -> PythonValueHandle { return floorDiv(lhs, rhs); }) 852 .def("__mod__", 853 [](PythonValueHandle lhs, PythonValueHandle rhs) 854 -> PythonValueHandle { return lhs.value % rhs.value; }) 855 .def("__lt__", 856 [](PythonValueHandle lhs, 857 PythonValueHandle rhs) -> PythonValueHandle { 858 return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value, 859 rhs.value); 860 }) 861 .def("__le__", 862 [](PythonValueHandle lhs, 863 PythonValueHandle rhs) -> PythonValueHandle { 864 return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value, 865 rhs.value); 866 }) 867 .def("__gt__", 868 [](PythonValueHandle lhs, 869 PythonValueHandle rhs) -> PythonValueHandle { 870 return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value, 871 rhs.value); 872 }) 873 .def("__ge__", 874 [](PythonValueHandle lhs, 875 PythonValueHandle rhs) -> PythonValueHandle { 876 return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value, 877 rhs.value); 878 }) 879 .def("__eq__", 880 [](PythonValueHandle lhs, 881 PythonValueHandle rhs) -> PythonValueHandle { 882 return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value, 883 rhs.value); 884 }) 885 .def("__ne__", 886 [](PythonValueHandle lhs, 887 PythonValueHandle rhs) -> PythonValueHandle { 888 return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value, 889 rhs.value); 890 }) 891 .def("__invert__", 892 [](PythonValueHandle handle) -> PythonValueHandle { 893 return !handle.value; 894 }) 895 .def("__and__", 896 [](PythonValueHandle lhs, PythonValueHandle rhs) 897 -> PythonValueHandle { return lhs.value && rhs.value; }) 898 .def("__or__", 899 [](PythonValueHandle lhs, PythonValueHandle rhs) 900 -> PythonValueHandle { return lhs.value || rhs.value; }) 901 .def("__call__", &PythonValueHandle::call); 902 } 903 904 py::class_<PythonBlockAppender>( 905 m, "BlockAppender", 906 "A dummy class signaling BlockContext to append IR to the given block " 907 "instead of creating a new block") 908 .def(py::init<const PythonBlockHandle &>()); 909 py::class_<PythonBlockHandle>(m, "BlockHandle", 910 "A wrapper around mlir::edsc::BlockHandle") 911 .def(py::init<PythonBlockHandle>()) 912 .def("arg", &PythonBlockHandle::arg); 913 914 py::class_<PythonBlockContext>(m, "BlockContext", 915 "A wrapper around mlir::edsc::BlockBuilder") 916 .def(py::init<>()) 917 .def(py::init<const std::vector<PythonType> &>()) 918 .def(py::init<const PythonBlockAppender &>()) 919 .def("__enter__", &PythonBlockContext::enter) 920 .def("__exit__", &PythonBlockContext::exit) 921 .def("handle", &PythonBlockContext::getHandle); 922 923 py::class_<PythonIndexedValue>(m, "IndexedValue", 924 "A wrapper around mlir::edsc::IndexedValue") 925 .def(py::init<PythonValueHandle>()) 926 .def("load", &PythonIndexedValue::load) 927 .def("store", &PythonIndexedValue::store); 928 } 929 930 } // namespace python 931 } // namespace edsc 932 } // namespace mlir