github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/bindings/python/pybind.cpp (about)

     1  //===- pybind.cpp - MLIR Python bindings ----------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  
    18  #include "llvm/ADT/SmallVector.h"
    19  #include "llvm/ADT/StringRef.h"
    20  #include "llvm/IR/Function.h"
    21  #include "llvm/IR/Module.h"
    22  #include "llvm/Support/TargetSelect.h"
    23  #include "llvm/Support/raw_ostream.h"
    24  #include <cstddef>
    25  #include <unordered_map>
    26  
    27  #include "mlir-c/Core.h"
    28  #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
    29  #include "mlir/EDSC/Builders.h"
    30  #include "mlir/EDSC/Helpers.h"
    31  #include "mlir/EDSC/Intrinsics.h"
    32  #include "mlir/ExecutionEngine/ExecutionEngine.h"
    33  #include "mlir/IR/Attributes.h"
    34  #include "mlir/IR/Function.h"
    35  #include "mlir/IR/Module.h"
    36  #include "mlir/IR/Types.h"
    37  #include "mlir/Pass/Pass.h"
    38  #include "mlir/Pass/PassManager.h"
    39  #include "mlir/Target/LLVMIR.h"
    40  #include "mlir/Transforms/Passes.h"
    41  #include "pybind11/pybind11.h"
    42  #include "pybind11/pytypes.h"
    43  #include "pybind11/stl.h"
    44  
    45  static bool inited = [] {
    46    llvm::InitializeNativeTarget();
    47    llvm::InitializeNativeTargetAsmPrinter();
    48    return true;
    49  }();
    50  
    51  namespace mlir {
    52  namespace edsc {
    53  namespace python {
    54  
    55  namespace py = pybind11;
    56  
    57  struct PythonAttribute;
    58  struct PythonAttributedType;
    59  struct PythonBindable;
    60  struct PythonExpr;
    61  struct PythonFunctionContext;
    62  struct PythonStmt;
    63  struct PythonBlock;
    64  
    65  struct PythonType {
    66    PythonType() : type{nullptr} {}
    67    PythonType(mlir_type_t t) : type{t} {}
    68  
    69    operator mlir_type_t() const { return type; }
    70  
    71    PythonAttributedType attachAttributeDict(
    72        const std::unordered_map<std::string, PythonAttribute> &attrs) const;
    73  
    74    std::string str() {
    75      mlir::Type f = mlir::Type::getFromOpaquePointer(type);
    76      std::string res;
    77      llvm::raw_string_ostream os(res);
    78      f.print(os);
    79      return res;
    80    }
    81  
    82    mlir_type_t type;
    83  };
    84  
    85  struct PythonValueHandle {
    86    PythonValueHandle(PythonType type)
    87        : value(mlir::Type::getFromOpaquePointer(type.type)) {}
    88    PythonValueHandle(const PythonValueHandle &other) = default;
    89    PythonValueHandle(const mlir::edsc::ValueHandle &other) : value(other) {}
    90    operator ValueHandle() const { return value; }
    91    operator ValueHandle &() { return value; }
    92  
    93    std::string str() const {
    94      return std::to_string(reinterpret_cast<intptr_t>(value.getValue()));
    95    }
    96  
    97    PythonValueHandle call(const std::vector<PythonValueHandle> &args) {
    98      assert(value.hasType() && value.getType().isa<FunctionType>() &&
    99             "can only call function-typed values");
   100  
   101      std::vector<Value *> argValues;
   102      argValues.reserve(args.size());
   103      for (auto arg : args)
   104        argValues.push_back(arg.value.getValue());
   105      return ValueHandle::create<CallIndirectOp>(value, argValues);
   106    }
   107  
   108    mlir::edsc::ValueHandle value;
   109  };
   110  
   111  struct PythonFunction {
   112    PythonFunction() : function{nullptr} {}
   113    PythonFunction(mlir_func_t f) : function{f} {}
   114    PythonFunction(mlir::FuncOp f)
   115        : function(const_cast<void *>(f.getAsOpaquePointer())) {}
   116    operator mlir_func_t() { return function; }
   117    std::string str() {
   118      mlir::FuncOp f = mlir::FuncOp::getFromOpaquePointer(function);
   119      std::string res;
   120      llvm::raw_string_ostream os(res);
   121      f.print(os);
   122      return res;
   123    }
   124  
   125    // If the function does not yet have an entry block, i.e. if it is a function
   126    // declaration, add the entry block, transforming the declaration into a
   127    // definition.  Return true if the block was added, false otherwise.
   128    bool define() {
   129      auto f = mlir::FuncOp::getFromOpaquePointer(function);
   130      if (!f.getBlocks().empty())
   131        return false;
   132  
   133      f.addEntryBlock();
   134      return true;
   135    }
   136  
   137    PythonValueHandle arg(unsigned index) {
   138      auto f = mlir::FuncOp::getFromOpaquePointer(function);
   139      assert(index < f.getNumArguments() && "argument index out of bounds");
   140      return PythonValueHandle(ValueHandle(f.getArgument(index)));
   141    }
   142  
   143    mlir_func_t function;
   144  };
   145  
   146  /// Trivial C++ wrappers make use of the EDSC C API.
   147  struct PythonMLIRModule {
   148    PythonMLIRModule()
   149        : mlirContext(),
   150          module(mlir::ModuleOp::create(mlir::UnknownLoc::get(&mlirContext))),
   151          moduleManager(*module) {}
   152  
   153    PythonType makeScalarType(const std::string &mlirElemType,
   154                              unsigned bitwidth) {
   155      return ::makeScalarType(mlir_context_t{&mlirContext}, mlirElemType.c_str(),
   156                              bitwidth);
   157    }
   158    PythonType makeMemRefType(PythonType elemType, std::vector<int64_t> sizes) {
   159      return ::makeMemRefType(mlir_context_t{&mlirContext}, elemType,
   160                              int64_list_t{sizes.data(), sizes.size()});
   161    }
   162    PythonType makeIndexType() {
   163      return ::makeIndexType(mlir_context_t{&mlirContext});
   164    }
   165  
   166    // Declare a function with the given name, input types and their attributes,
   167    // output types, and function attributes, but do not define it.
   168    PythonFunction declareFunction(const std::string &name,
   169                                   const py::list &inputs,
   170                                   const std::vector<PythonType> &outputTypes,
   171                                   const py::kwargs &funcAttributes);
   172  
   173    // Declare a function with the given name, input types and their attributes,
   174    // output types, and function attributes.
   175    PythonFunction makeFunction(const std::string &name, const py::list &inputs,
   176                                const std::vector<PythonType> &outputTypes,
   177                                const py::kwargs &funcAttributes) {
   178      auto declaration =
   179          declareFunction(name, inputs, outputTypes, funcAttributes);
   180      declaration.define();
   181      return declaration;
   182    }
   183  
   184    // Create a custom op given its name and arguments.
   185    PythonExpr op(const std::string &name, PythonType type,
   186                  const py::list &arguments, const py::list &successors,
   187                  py::kwargs attributes);
   188  
   189    // Create an integer attribute.
   190    PythonAttribute integerAttr(PythonType type, int64_t value);
   191  
   192    // Create a boolean attribute.
   193    PythonAttribute boolAttr(bool value);
   194  
   195    void compile() {
   196      PassManager manager;
   197      manager.addPass(mlir::createCanonicalizerPass());
   198      manager.addPass(mlir::createCSEPass());
   199      manager.addPass(mlir::createLowerAffinePass());
   200      manager.addPass(mlir::createConvertToLLVMIRPass());
   201      if (failed(manager.run(*module))) {
   202        llvm::errs() << "conversion to the LLVM IR dialect failed\n";
   203        return;
   204      }
   205  
   206      auto created = mlir::ExecutionEngine::create(*module);
   207      llvm::handleAllErrors(created.takeError(),
   208                            [](const llvm::ErrorInfoBase &b) {
   209                              b.log(llvm::errs());
   210                              assert(false);
   211                            });
   212      engine = std::move(*created);
   213    }
   214  
   215    std::string getIR() {
   216      std::string res;
   217      llvm::raw_string_ostream os(res);
   218      module->print(os);
   219      return res;
   220    }
   221  
   222    uint64_t getEngineAddress() {
   223      assert(engine && "module must be compiled into engine first");
   224      return reinterpret_cast<uint64_t>(reinterpret_cast<void *>(engine.get()));
   225    }
   226  
   227    PythonFunction getNamedFunction(const std::string &name) {
   228      return moduleManager.lookupSymbol<FuncOp>(name);
   229    }
   230  
   231    PythonFunctionContext
   232    makeFunctionContext(const std::string &name, const py::list &inputs,
   233                        const std::vector<PythonType> &outputs,
   234                        const py::kwargs &attributes);
   235  
   236  private:
   237    mlir::MLIRContext mlirContext;
   238    // One single module in a python-exposed MLIRContext for now.
   239    mlir::OwningModuleRef module;
   240    mlir::ModuleManager moduleManager;
   241    std::unique_ptr<mlir::ExecutionEngine> engine;
   242  };
   243  
   244  struct PythonFunctionContext {
   245    PythonFunctionContext(PythonFunction f) : function(f) {}
   246    PythonFunctionContext(PythonMLIRModule &module, const std::string &name,
   247                          const py::list &inputs,
   248                          const std::vector<PythonType> &outputs,
   249                          const py::kwargs &attributes) {
   250      auto function = module.declareFunction(name, inputs, outputs, attributes);
   251      function.define();
   252    }
   253  
   254    PythonFunction enter() {
   255      assert(function.function && "function is not set up");
   256      auto mlirFunc = mlir::FuncOp::getFromOpaquePointer(function.function);
   257      contextBuilder.emplace(mlirFunc.getBody());
   258      context = new mlir::edsc::ScopedContext(*contextBuilder, mlirFunc.getLoc());
   259      return function;
   260    }
   261  
   262    void exit(py::object, py::object, py::object) {
   263      delete context;
   264      context = nullptr;
   265      contextBuilder.reset();
   266    }
   267  
   268    PythonFunction function;
   269    mlir::edsc::ScopedContext *context;
   270    llvm::Optional<OpBuilder> contextBuilder;
   271  };
   272  
   273  PythonFunctionContext PythonMLIRModule::makeFunctionContext(
   274      const std::string &name, const py::list &inputs,
   275      const std::vector<PythonType> &outputs, const py::kwargs &attributes) {
   276    auto func = declareFunction(name, inputs, outputs, attributes);
   277    func.define();
   278    return PythonFunctionContext(func);
   279  }
   280  
   281  struct PythonBlockHandle {
   282    PythonBlockHandle() : value(nullptr) {}
   283    PythonBlockHandle(const PythonBlockHandle &other) = default;
   284    PythonBlockHandle(const mlir::edsc::BlockHandle &other) : value(other) {}
   285    operator mlir::edsc::BlockHandle() const { return value; }
   286  
   287    PythonValueHandle arg(int index) { return arguments[index]; }
   288  
   289    std::string str() {
   290      std::string s;
   291      llvm::raw_string_ostream os(s);
   292      value.getBlock()->print(os);
   293      return os.str();
   294    }
   295  
   296    mlir::edsc::BlockHandle value;
   297    std::vector<mlir::edsc::ValueHandle> arguments;
   298  };
   299  
   300  struct PythonLoopContext {
   301    PythonLoopContext(PythonValueHandle lb, PythonValueHandle ub, int64_t step)
   302        : lb(lb), ub(ub), step(step) {}
   303    PythonLoopContext(const PythonLoopContext &) = delete;
   304    PythonLoopContext(PythonLoopContext &&) = default;
   305    PythonLoopContext &operator=(const PythonLoopContext &) = delete;
   306    PythonLoopContext &operator=(PythonLoopContext &&) = default;
   307    ~PythonLoopContext() { assert(!builder && "did not exit from the context"); }
   308  
   309    PythonValueHandle enter() {
   310      ValueHandle iv(lb.value.getType());
   311      builder = new LoopBuilder(&iv, lb.value, ub.value, step);
   312      return iv;
   313    }
   314  
   315    void exit(py::object, py::object, py::object) {
   316      (*builder)({}); // exit from the builder's scope.
   317      delete builder;
   318      builder = nullptr;
   319    }
   320  
   321    PythonValueHandle lb, ub;
   322    int64_t step;
   323    LoopBuilder *builder = nullptr;
   324  };
   325  
   326  struct PythonLoopNestContext {
   327    PythonLoopNestContext(const std::vector<PythonValueHandle> &lbs,
   328                          const std::vector<PythonValueHandle> &ubs,
   329                          const std::vector<int64_t> steps)
   330        : lbs(lbs), ubs(ubs), steps(steps) {
   331      assert(lbs.size() == ubs.size() && lbs.size() == steps.size() &&
   332             "expected the same number of lower, upper bounds, and steps");
   333    }
   334    PythonLoopNestContext(const PythonLoopNestContext &) = delete;
   335    PythonLoopNestContext(PythonLoopNestContext &&) = default;
   336    PythonLoopNestContext &operator=(const PythonLoopNestContext &) = delete;
   337    PythonLoopNestContext &operator=(PythonLoopNestContext &&) = default;
   338    ~PythonLoopNestContext() {
   339      assert(!builder && "did not exit from the context");
   340    }
   341  
   342    std::vector<PythonValueHandle> enter() {
   343      if (steps.empty())
   344        return {};
   345  
   346      auto type = mlir_type_t(lbs.front().value.getType().getAsOpaquePointer());
   347      std::vector<PythonValueHandle> handles(steps.size(),
   348                                             PythonValueHandle(type));
   349      std::vector<ValueHandle *> handlePtrs;
   350      handlePtrs.reserve(steps.size());
   351      for (auto &h : handles)
   352        handlePtrs.push_back(&h.value);
   353      builder = new LoopNestBuilder(
   354          handlePtrs, std::vector<ValueHandle>(lbs.begin(), lbs.end()),
   355          std::vector<ValueHandle>(ubs.begin(), ubs.end()), steps);
   356      return handles;
   357    }
   358  
   359    void exit(py::object, py::object, py::object) {
   360      (*builder)({}); // exit from the builder's scope.
   361      delete builder;
   362      builder = nullptr;
   363    }
   364  
   365    std::vector<PythonValueHandle> lbs;
   366    std::vector<PythonValueHandle> ubs;
   367    std::vector<int64_t> steps;
   368    LoopNestBuilder *builder = nullptr;
   369  };
   370  
   371  struct PythonBlockAppender {
   372    PythonBlockAppender(const PythonBlockHandle &handle) : handle(handle) {}
   373    PythonBlockHandle handle;
   374  };
   375  
   376  struct PythonBlockContext {
   377  public:
   378    PythonBlockContext() {
   379      createBlockBuilder();
   380      clearBuilder();
   381    }
   382    PythonBlockContext(const std::vector<PythonType> &argTypes) {
   383      handle.arguments.reserve(argTypes.size());
   384      for (const auto &t : argTypes) {
   385        auto type =
   386            Type::getFromOpaquePointer(reinterpret_cast<const void *>(t.type));
   387        handle.arguments.emplace_back(type);
   388      }
   389      createBlockBuilder();
   390      clearBuilder();
   391    }
   392    PythonBlockContext(const PythonBlockAppender &a) : handle(a.handle) {}
   393    PythonBlockContext(const PythonBlockContext &) = delete;
   394    PythonBlockContext(PythonBlockContext &&) = default;
   395    PythonBlockContext &operator=(const PythonBlockContext &) = delete;
   396    PythonBlockContext &operator=(PythonBlockContext &&) = default;
   397    ~PythonBlockContext() {
   398      assert(!builder && "did not exit from the block context");
   399    }
   400  
   401    // EDSC maintain an implicit stack of builders (mostly for keeping track of
   402    // insretion points); every operation gets inserted using the top-of-the-stack
   403    // builder.  Creating a new EDSC Builder automatically puts it on the stack,
   404    // effectively entering the block for it.
   405    void createBlockBuilder() {
   406      if (handle.value.getBlock()) {
   407        builder = new BlockBuilder(handle.value, mlir::edsc::Append());
   408      } else {
   409        std::vector<ValueHandle *> args;
   410        args.reserve(handle.arguments.size());
   411        for (auto &a : handle.arguments)
   412          args.push_back(&a);
   413        builder = new BlockBuilder(&handle.value, args);
   414      }
   415    }
   416  
   417    PythonBlockHandle enter() {
   418      createBlockBuilder();
   419      return handle;
   420    }
   421  
   422    void exit(py::object, py::object, py::object) { clearBuilder(); }
   423  
   424    PythonBlockHandle getHandle() { return handle; }
   425  
   426    // EDSC maintain an implicit stack of builders (mostly for keeping track of
   427    // insretion points); every operation gets inserted using the top-of-the-stack
   428    // builder.  Calling operator() on a builder pops the builder from the stack,
   429    // effectively resetting the insertion point to its position before we entered
   430    // the block.
   431    void clearBuilder() {
   432      (*builder)({}); // exit from the builder's scope.
   433      delete builder;
   434      builder = nullptr;
   435    }
   436  
   437    PythonBlockHandle handle;
   438    BlockBuilder *builder = nullptr;
   439  };
   440  
   441  struct PythonAttribute {
   442    PythonAttribute() : attr(nullptr) {}
   443    PythonAttribute(const mlir_attr_t &a) : attr(a) {}
   444    PythonAttribute(const PythonAttribute &other) = default;
   445    operator mlir_attr_t() { return attr; }
   446  
   447    std::string str() const {
   448      if (!attr)
   449        return "##null attr##";
   450  
   451      std::string res;
   452      llvm::raw_string_ostream os(res);
   453      Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(attr))
   454          .print(os);
   455      return res;
   456    }
   457  
   458    mlir_attr_t attr;
   459  };
   460  
   461  struct PythonAttributedType {
   462    PythonAttributedType() : type(nullptr) {}
   463    PythonAttributedType(mlir_type_t t) : type(t) {}
   464    PythonAttributedType(
   465        PythonType t,
   466        const std::unordered_map<std::string, PythonAttribute> &attributes =
   467            std::unordered_map<std::string, PythonAttribute>())
   468        : type(t), attrs(attributes) {}
   469  
   470    operator mlir_type_t() const { return type.type; }
   471    operator PythonType() const { return type; }
   472  
   473    // Return a vector of named attribute descriptors.  The vector owns the
   474    // mlir_named_attr_t objects it contains, but not the names and attributes
   475    // those objects point to (names and opaque pointers to attributes are owned
   476    // by `this`).
   477    std::vector<mlir_named_attr_t> getNamedAttrs() const {
   478      std::vector<mlir_named_attr_t> result;
   479      result.reserve(attrs.size());
   480      for (const auto &namedAttr : attrs)
   481        result.push_back({namedAttr.first.c_str(), namedAttr.second.attr});
   482      return result;
   483    }
   484  
   485    std::string str() {
   486      mlir::Type t = mlir::Type::getFromOpaquePointer(type);
   487      std::string res;
   488      llvm::raw_string_ostream os(res);
   489      t.print(os);
   490      if (attrs.empty())
   491        return os.str();
   492  
   493      os << '{';
   494      bool first = true;
   495      for (const auto &namedAttr : attrs) {
   496        if (first)
   497          first = false;
   498        else
   499          os << ", ";
   500        os << namedAttr.first << ": " << namedAttr.second.str();
   501      }
   502      os << '}';
   503  
   504      return os.str();
   505    }
   506  
   507  private:
   508    PythonType type;
   509    std::unordered_map<std::string, PythonAttribute> attrs;
   510  };
   511  
   512  struct PythonIndexedValue {
   513    explicit PythonIndexedValue(PythonType type)
   514        : indexed(Type::getFromOpaquePointer(type.type)) {}
   515    explicit PythonIndexedValue(const IndexedValue &other) : indexed(other) {}
   516    PythonIndexedValue(PythonValueHandle handle) : indexed(handle.value) {}
   517    PythonIndexedValue(const PythonIndexedValue &other) = default;
   518  
   519    // Create a new indexed value with the same base as this one but with indices
   520    // provided as arguments.
   521    PythonIndexedValue index(const std::vector<PythonValueHandle> &indices) {
   522      std::vector<ValueHandle> handles(indices.begin(), indices.end());
   523      return PythonIndexedValue(IndexedValue(indexed(handles)));
   524    }
   525  
   526    void store(const std::vector<PythonValueHandle> &indices,
   527               PythonValueHandle value) {
   528      // Uses the overloaded `opreator=` to emit a store.
   529      index(indices).indexed = value.value;
   530    }
   531  
   532    PythonValueHandle load(const std::vector<PythonValueHandle> &indices) {
   533      // Uses the overloaded cast to `ValueHandle` to emit a load.
   534      return static_cast<ValueHandle>(index(indices).indexed);
   535    }
   536  
   537    IndexedValue indexed;
   538  };
   539  
   540  template <typename ListTy, typename PythonTy, typename Ty>
   541  ListTy makeCList(SmallVectorImpl<Ty> &owning, const py::list &list) {
   542    for (auto &inp : list) {
   543      owning.push_back(Ty{inp.cast<PythonTy>()});
   544    }
   545    return ListTy{owning.data(), owning.size()};
   546  }
   547  
   548  static mlir_type_list_t makeCTypes(llvm::SmallVectorImpl<mlir_type_t> &owning,
   549                                     const py::list &types) {
   550    return makeCList<mlir_type_list_t, PythonType>(owning, types);
   551  }
   552  
   553  PythonFunction
   554  PythonMLIRModule::declareFunction(const std::string &name,
   555                                    const py::list &inputs,
   556                                    const std::vector<PythonType> &outputTypes,
   557                                    const py::kwargs &funcAttributes) {
   558  
   559    std::vector<PythonAttributedType> attributedInputs;
   560    attributedInputs.reserve(inputs.size());
   561    for (const auto &in : inputs) {
   562      std::string className = in.get_type().str();
   563      if (className.find(".Type'") != std::string::npos)
   564        attributedInputs.emplace_back(in.cast<PythonType>());
   565      else
   566        attributedInputs.push_back(in.cast<PythonAttributedType>());
   567    }
   568  
   569    // Create the function type.
   570    std::vector<mlir_type_t> ins(attributedInputs.begin(),
   571                                 attributedInputs.end());
   572    std::vector<mlir_type_t> outs(outputTypes.begin(), outputTypes.end());
   573    auto funcType = ::makeFunctionType(
   574        mlir_context_t{&mlirContext}, mlir_type_list_t{ins.data(), ins.size()},
   575        mlir_type_list_t{outs.data(), outs.size()});
   576  
   577    // Build the list of function attributes.
   578    std::vector<mlir::NamedAttribute> attrs;
   579    attrs.reserve(funcAttributes.size());
   580    for (const auto &named : funcAttributes)
   581      attrs.emplace_back(
   582          Identifier::get(std::string(named.first.str()), &mlirContext),
   583          mlir::Attribute::getFromOpaquePointer(reinterpret_cast<const void *>(
   584              named.second.cast<PythonAttribute>().attr)));
   585  
   586    // Build the list of lists of function argument attributes.
   587    std::vector<mlir::NamedAttributeList> inputAttrs;
   588    inputAttrs.reserve(attributedInputs.size());
   589    for (const auto &in : attributedInputs) {
   590      std::vector<mlir::NamedAttribute> inAttrs;
   591      for (const auto &named : in.getNamedAttrs())
   592        inAttrs.emplace_back(Identifier::get(named.name, &mlirContext),
   593                             mlir::Attribute::getFromOpaquePointer(
   594                                 reinterpret_cast<const void *>(named.value)));
   595      inputAttrs.emplace_back(inAttrs);
   596    }
   597  
   598    // Create the function itself.
   599    auto func = mlir::FuncOp::create(
   600        UnknownLoc::get(&mlirContext), name,
   601        mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
   602        inputAttrs);
   603    moduleManager.insert(func);
   604    return func;
   605  }
   606  
   607  PythonAttributedType PythonType::attachAttributeDict(
   608      const std::unordered_map<std::string, PythonAttribute> &attrs) const {
   609    return PythonAttributedType(*this, attrs);
   610  }
   611  
   612  PythonAttribute PythonMLIRModule::integerAttr(PythonType type, int64_t value) {
   613    return PythonAttribute(::makeIntegerAttr(type, value));
   614  }
   615  
   616  PythonAttribute PythonMLIRModule::boolAttr(bool value) {
   617    return PythonAttribute(::makeBoolAttr(&mlirContext, value));
   618  }
   619  
   620  PYBIND11_MODULE(pybind, m) {
   621    m.doc() =
   622        "Python bindings for MLIR Embedded Domain-Specific Components (EDSCs)";
   623    m.def("version", []() { return "EDSC Python extensions v1.0"; });
   624  
   625    py::class_<PythonLoopContext>(
   626        m, "LoopContext", "A context for building the body of a 'for' loop")
   627        .def(py::init<PythonValueHandle, PythonValueHandle, int64_t>())
   628        .def("__enter__", &PythonLoopContext::enter)
   629        .def("__exit__", &PythonLoopContext::exit);
   630  
   631    py::class_<PythonLoopNestContext>(m, "LoopNestContext",
   632                                      "A context for building the body of a the "
   633                                      "innermost loop in a nest of 'for' loops")
   634        .def(py::init<const std::vector<PythonValueHandle> &,
   635                      const std::vector<PythonValueHandle> &,
   636                      const std::vector<int64_t> &>())
   637        .def("__enter__", &PythonLoopNestContext::enter)
   638        .def("__exit__", &PythonLoopNestContext::exit);
   639  
   640    m.def("constant_index", [](int64_t val) -> PythonValueHandle {
   641      return ValueHandle(index_t(val));
   642    });
   643    m.def("constant_int", [](int64_t val, int width) -> PythonValueHandle {
   644      return ValueHandle::create<ConstantIntOp>(val, width);
   645    });
   646    m.def("constant_float", [](double val, PythonType type) -> PythonValueHandle {
   647      FloatType floatType =
   648          Type::getFromOpaquePointer(type.type).cast<FloatType>();
   649      assert(floatType);
   650      auto value = APFloat(val);
   651      bool lostPrecision;
   652      value.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven,
   653                    &lostPrecision);
   654      return ValueHandle::create<ConstantFloatOp>(value, floatType);
   655    });
   656    m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
   657      auto function = FuncOp::getFromOpaquePointer(func.function);
   658      auto attr = SymbolRefAttr::get(function.getName(), function.getContext());
   659      return ValueHandle::create<ConstantOp>(function.getType(), attr);
   660    });
   661    m.def("appendTo", [](const PythonBlockHandle &handle) {
   662      return PythonBlockAppender(handle);
   663    });
   664    m.def(
   665        "ret",
   666        [](const std::vector<PythonValueHandle> &args) {
   667          std::vector<ValueHandle> values(args.begin(), args.end());
   668          (intrinsics::ret(ArrayRef<ValueHandle>{values})); // vexing parse
   669          return PythonValueHandle(nullptr);
   670        },
   671        py::arg("args") = std::vector<PythonValueHandle>());
   672    m.def(
   673        "br",
   674        [](const PythonBlockHandle &dest,
   675           const std::vector<PythonValueHandle> &args) {
   676          std::vector<ValueHandle> values(args.begin(), args.end());
   677          intrinsics::br(dest, values);
   678          return PythonValueHandle(nullptr);
   679        },
   680        py::arg("dest"), py::arg("args") = std::vector<PythonValueHandle>());
   681    m.def(
   682        "cond_br",
   683        [](PythonValueHandle condition, const PythonBlockHandle &trueDest,
   684           const std::vector<PythonValueHandle> &trueArgs,
   685           const PythonBlockHandle &falseDest,
   686           const std::vector<PythonValueHandle> &falseArgs) -> PythonValueHandle {
   687          std::vector<ValueHandle> trueArguments(trueArgs.begin(),
   688                                                 trueArgs.end());
   689          std::vector<ValueHandle> falseArguments(falseArgs.begin(),
   690                                                  falseArgs.end());
   691          intrinsics::cond_br(condition, trueDest, trueArguments, falseDest,
   692                              falseArguments);
   693          return PythonValueHandle(nullptr);
   694        });
   695    m.def("select",
   696          [](PythonValueHandle condition, PythonValueHandle trueValue,
   697             PythonValueHandle falseValue) -> PythonValueHandle {
   698            return ValueHandle::create<SelectOp>(condition.value, trueValue.value,
   699                                                 falseValue.value);
   700          });
   701    m.def("op",
   702          [](const std::string &name,
   703             const std::vector<PythonValueHandle> &operands,
   704             const std::vector<PythonType> &resultTypes,
   705             const py::kwargs &attributes) -> PythonValueHandle {
   706            std::vector<ValueHandle> operandHandles(operands.begin(),
   707                                                    operands.end());
   708            std::vector<Type> types;
   709            types.reserve(resultTypes.size());
   710            for (auto t : resultTypes)
   711              types.push_back(Type::getFromOpaquePointer(t.type));
   712  
   713            std::vector<NamedAttribute> attrs;
   714            attrs.reserve(attributes.size());
   715            for (const auto &a : attributes) {
   716              std::string name = a.first.str();
   717              auto pyAttr = a.second.cast<PythonAttribute>();
   718              auto cppAttr = Attribute::getFromOpaquePointer(pyAttr.attr);
   719              auto identifier =
   720                  Identifier::get(name, ScopedContext::getContext());
   721              attrs.emplace_back(identifier, cppAttr);
   722            }
   723  
   724            return ValueHandle::create(name, operandHandles, types, attrs);
   725          });
   726  
   727    py::class_<PythonFunction>(m, "Function", "Wrapping class for mlir::FuncOp.")
   728        .def(py::init<PythonFunction>())
   729        .def("__str__", &PythonFunction::str)
   730        .def("define", &PythonFunction::define,
   731             "Adds a body to the function if it does not already have one.  "
   732             "Returns true if the body was added")
   733        .def("arg", &PythonFunction::arg,
   734             "Get the ValueHandle to the indexed argument of the function");
   735  
   736    py::class_<PythonAttribute>(m, "Attribute",
   737                                "Wrapping class for mlir::Attribute")
   738        .def(py::init<PythonAttribute>())
   739        .def("__str__", &PythonAttribute::str);
   740  
   741    py::class_<PythonType>(m, "Type", "Wrapping class for mlir::Type.")
   742        .def(py::init<PythonType>())
   743        .def("__call__", &PythonType::attachAttributeDict,
   744             "Attach the attributes to these type, making it suitable for "
   745             "constructing functions with argument attributes")
   746        .def("__str__", &PythonType::str);
   747  
   748    py::class_<PythonAttributedType>(
   749        m, "AttributedType",
   750        "A class containing a wrapped mlir::Type and a wrapped "
   751        "mlir::NamedAttributeList that are used together, e.g. in function "
   752        "argument declaration")
   753        .def(py::init<PythonAttributedType>())
   754        .def("__str__", &PythonAttributedType::str);
   755  
   756    py::class_<PythonMLIRModule>(
   757        m, "MLIRModule",
   758        "An MLIRModule is the abstraction that owns the allocations to support "
   759        "compilation of a single mlir::ModuleOp into an ExecutionEngine backed "
   760        "by "
   761        "the LLVM ORC JIT. A typical flow consists in creating an MLIRModule, "
   762        "adding functions, compiling the module to obtain an ExecutionEngine on "
   763        "which named functions may be called. For now the only means to retrieve "
   764        "the ExecutionEngine is by calling `get_engine_address`. This mode of "
   765        "execution is limited to passing the pointer to C++ where the function "
   766        "is called. Extending the API to allow calling JIT compiled functions "
   767        "directly require integration with a tensor library (e.g. numpy). This "
   768        "is left as the prerogative of libraries and frameworks for now.")
   769        .def(py::init<>())
   770        .def("boolAttr", &PythonMLIRModule::boolAttr,
   771             "Creates an mlir::BoolAttr with the given value")
   772        .def(
   773            "integerAttr", &PythonMLIRModule::integerAttr,
   774            "Creates an mlir::IntegerAttr of the given type with the given value "
   775            "in the context associated with this MLIR module.")
   776        .def("declare_function", &PythonMLIRModule::declareFunction,
   777             "Declares a new mlir::FuncOp in the current mlir::ModuleOp.  The "
   778             "function arguments can have attributes.  The function has no "
   779             "definition and can be linked to an external library.")
   780        .def("make_function", &PythonMLIRModule::makeFunction,
   781             "Defines a new mlir::FuncOp in the current mlir::ModuleOp.")
   782        .def("function_context", &PythonMLIRModule::makeFunctionContext,
   783             "Defines a new mlir::FuncOp in the mlir::ModuleOp and creates the "
   784             "function context for building the body of the function.")
   785        .def("get_function", &PythonMLIRModule::getNamedFunction,
   786             "Looks up the function with the given name in the module.")
   787        .def(
   788            "make_scalar_type",
   789            [](PythonMLIRModule &instance, const std::string &type,
   790               unsigned bitwidth) {
   791              return instance.makeScalarType(type, bitwidth);
   792            },
   793            py::arg("type"), py::arg("bitwidth") = 0,
   794            "Returns a scalar mlir::Type using the following convention:\n"
   795            "  - makeScalarType(c, \"bf16\") return an "
   796            "`mlir::FloatType::getBF16`\n"
   797            "  - makeScalarType(c, \"f16\") return an `mlir::FloatType::getF16`\n"
   798            "  - makeScalarType(c, \"f32\") return an `mlir::FloatType::getF32`\n"
   799            "  - makeScalarType(c, \"f64\") return an `mlir::FloatType::getF64`\n"
   800            "  - makeScalarType(c, \"index\") return an `mlir::IndexType::get`\n"
   801            "  - makeScalarType(c, \"i\", bitwidth) return an "
   802            "`mlir::IntegerType::get(bitwidth)`\n\n"
   803            " No other combinations are currently supported.")
   804        .def("make_memref_type", &PythonMLIRModule::makeMemRefType,
   805             "Returns an mlir::MemRefType of an elemental scalar. -1 is used to "
   806             "denote symbolic dimensions in the resulting memref shape.")
   807        .def("make_index_type", &PythonMLIRModule::makeIndexType,
   808             "Returns an mlir::IndexType")
   809        .def("compile", &PythonMLIRModule::compile,
   810             "Compiles the mlir::ModuleOp to LLVMIR a creates new opaque "
   811             "ExecutionEngine backed by the ORC JIT.")
   812        .def("get_ir", &PythonMLIRModule::getIR,
   813             "Returns a dump of the MLIR representation of the module. This is "
   814             "used for serde to support out-of-process execution as well as "
   815             "debugging purposes.")
   816        .def("get_engine_address", &PythonMLIRModule::getEngineAddress,
   817             "Returns the address of the compiled ExecutionEngine. This is used "
   818             "for in-process execution.")
   819        .def("__str__", &PythonMLIRModule::getIR,
   820             "Get the string representation of the module");
   821  
   822    py::class_<PythonFunctionContext>(
   823        m, "FunctionContext", "A wrapper around mlir::edsc::ScopedContext")
   824        .def(py::init<PythonFunction>())
   825        .def("__enter__", &PythonFunctionContext::enter)
   826        .def("__exit__", &PythonFunctionContext::exit);
   827  
   828    {
   829      using namespace mlir::edsc::op;
   830      py::class_<PythonValueHandle>(m, "ValueHandle",
   831                                    "A wrapper around mlir::edsc::ValueHandle")
   832          .def(py::init<PythonType>())
   833          .def(py::init<PythonValueHandle>())
   834          .def("__add__",
   835               [](PythonValueHandle lhs, PythonValueHandle rhs)
   836                   -> PythonValueHandle { return lhs.value + rhs.value; })
   837          .def("__sub__",
   838               [](PythonValueHandle lhs, PythonValueHandle rhs)
   839                   -> PythonValueHandle { return lhs.value - rhs.value; })
   840          .def("__mul__",
   841               [](PythonValueHandle lhs, PythonValueHandle rhs)
   842                   -> PythonValueHandle { return lhs.value * rhs.value; })
   843          .def("__div__",
   844               [](PythonValueHandle lhs, PythonValueHandle rhs)
   845                   -> PythonValueHandle { return lhs.value / rhs.value; })
   846          .def("__truediv__",
   847               [](PythonValueHandle lhs, PythonValueHandle rhs)
   848                   -> PythonValueHandle { return lhs.value / rhs.value; })
   849          .def("__floordiv__",
   850               [](PythonValueHandle lhs, PythonValueHandle rhs)
   851                   -> PythonValueHandle { return floorDiv(lhs, rhs); })
   852          .def("__mod__",
   853               [](PythonValueHandle lhs, PythonValueHandle rhs)
   854                   -> PythonValueHandle { return lhs.value % rhs.value; })
   855          .def("__lt__",
   856               [](PythonValueHandle lhs,
   857                  PythonValueHandle rhs) -> PythonValueHandle {
   858                 return ValueHandle::create<CmpIOp>(CmpIPredicate::SLT, lhs.value,
   859                                                    rhs.value);
   860               })
   861          .def("__le__",
   862               [](PythonValueHandle lhs,
   863                  PythonValueHandle rhs) -> PythonValueHandle {
   864                 return ValueHandle::create<CmpIOp>(CmpIPredicate::SLE, lhs.value,
   865                                                    rhs.value);
   866               })
   867          .def("__gt__",
   868               [](PythonValueHandle lhs,
   869                  PythonValueHandle rhs) -> PythonValueHandle {
   870                 return ValueHandle::create<CmpIOp>(CmpIPredicate::SGT, lhs.value,
   871                                                    rhs.value);
   872               })
   873          .def("__ge__",
   874               [](PythonValueHandle lhs,
   875                  PythonValueHandle rhs) -> PythonValueHandle {
   876                 return ValueHandle::create<CmpIOp>(CmpIPredicate::SGE, lhs.value,
   877                                                    rhs.value);
   878               })
   879          .def("__eq__",
   880               [](PythonValueHandle lhs,
   881                  PythonValueHandle rhs) -> PythonValueHandle {
   882                 return ValueHandle::create<CmpIOp>(CmpIPredicate::EQ, lhs.value,
   883                                                    rhs.value);
   884               })
   885          .def("__ne__",
   886               [](PythonValueHandle lhs,
   887                  PythonValueHandle rhs) -> PythonValueHandle {
   888                 return ValueHandle::create<CmpIOp>(CmpIPredicate::NE, lhs.value,
   889                                                    rhs.value);
   890               })
   891          .def("__invert__",
   892               [](PythonValueHandle handle) -> PythonValueHandle {
   893                 return !handle.value;
   894               })
   895          .def("__and__",
   896               [](PythonValueHandle lhs, PythonValueHandle rhs)
   897                   -> PythonValueHandle { return lhs.value && rhs.value; })
   898          .def("__or__",
   899               [](PythonValueHandle lhs, PythonValueHandle rhs)
   900                   -> PythonValueHandle { return lhs.value || rhs.value; })
   901          .def("__call__", &PythonValueHandle::call);
   902    }
   903  
   904    py::class_<PythonBlockAppender>(
   905        m, "BlockAppender",
   906        "A dummy class signaling BlockContext to append IR to the given block "
   907        "instead of creating a new block")
   908        .def(py::init<const PythonBlockHandle &>());
   909    py::class_<PythonBlockHandle>(m, "BlockHandle",
   910                                  "A wrapper around mlir::edsc::BlockHandle")
   911        .def(py::init<PythonBlockHandle>())
   912        .def("arg", &PythonBlockHandle::arg);
   913  
   914    py::class_<PythonBlockContext>(m, "BlockContext",
   915                                   "A wrapper around mlir::edsc::BlockBuilder")
   916        .def(py::init<>())
   917        .def(py::init<const std::vector<PythonType> &>())
   918        .def(py::init<const PythonBlockAppender &>())
   919        .def("__enter__", &PythonBlockContext::enter)
   920        .def("__exit__", &PythonBlockContext::exit)
   921        .def("handle", &PythonBlockContext::getHandle);
   922  
   923    py::class_<PythonIndexedValue>(m, "IndexedValue",
   924                                   "A wrapper around mlir::edsc::IndexedValue")
   925        .def(py::init<PythonValueHandle>())
   926        .def("load", &PythonIndexedValue::load)
   927        .def("store", &PythonIndexedValue::store);
   928  }
   929  
   930  } // namespace python
   931  } // namespace edsc
   932  } // namespace mlir