github.com/kaydxh/golang@v0.0.131/pkg/gocv/cgo/third_path/pybind11/tests/test_embed/test_interpreter.cpp (about)

     1  #include <pybind11/embed.h>
     2  
     3  // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
     4  // catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
     5  PYBIND11_WARNING_DISABLE_MSVC(4996)
     6  
     7  #include <catch.hpp>
     8  #include <cstdlib>
     9  #include <fstream>
    10  #include <functional>
    11  #include <thread>
    12  #include <utility>
    13  
    14  namespace py = pybind11;
    15  using namespace py::literals;
    16  
    17  size_t get_sys_path_size() {
    18      auto sys_path = py::module::import("sys").attr("path");
    19      return py::len(sys_path);
    20  }
    21  
    22  class Widget {
    23  public:
    24      explicit Widget(std::string message) : message(std::move(message)) {}
    25      virtual ~Widget() = default;
    26  
    27      std::string the_message() const { return message; }
    28      virtual int the_answer() const = 0;
    29      virtual std::string argv0() const = 0;
    30  
    31  private:
    32      std::string message;
    33  };
    34  
    35  class PyWidget final : public Widget {
    36      using Widget::Widget;
    37  
    38      int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); }
    39      std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); }
    40  };
    41  
    42  class test_override_cache_helper {
    43  
    44  public:
    45      virtual int func() { return 0; }
    46  
    47      test_override_cache_helper() = default;
    48      virtual ~test_override_cache_helper() = default;
    49      // Non-copyable
    50      test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete;
    51      test_override_cache_helper(test_override_cache_helper const &Copy) = delete;
    52  };
    53  
    54  class test_override_cache_helper_trampoline : public test_override_cache_helper {
    55      int func() override { PYBIND11_OVERRIDE(int, test_override_cache_helper, func); }
    56  };
    57  
    58  PYBIND11_EMBEDDED_MODULE(widget_module, m) {
    59      py::class_<Widget, PyWidget>(m, "Widget")
    60          .def(py::init<std::string>())
    61          .def_property_readonly("the_message", &Widget::the_message);
    62  
    63      m.def("add", [](int i, int j) { return i + j; });
    64  }
    65  
    66  PYBIND11_EMBEDDED_MODULE(trampoline_module, m) {
    67      py::class_<test_override_cache_helper,
    68                 test_override_cache_helper_trampoline,
    69                 std::shared_ptr<test_override_cache_helper>>(m, "test_override_cache_helper")
    70          .def(py::init_alias<>())
    71          .def("func", &test_override_cache_helper::func);
    72  }
    73  
    74  PYBIND11_EMBEDDED_MODULE(throw_exception, ) { throw std::runtime_error("C++ Error"); }
    75  
    76  PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
    77      auto d = py::dict();
    78      d["missing"].cast<py::object>();
    79  }
    80  
    81  TEST_CASE("PYTHONPATH is used to update sys.path") {
    82      // The setup for this TEST_CASE is in catch.cpp!
    83      auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
    84      REQUIRE_THAT(sys_path,
    85                   Catch::Matchers::Contains("pybind11_test_embed_PYTHONPATH_2099743835476552"));
    86  }
    87  
    88  TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
    89      auto module_ = py::module_::import("test_interpreter");
    90      REQUIRE(py::hasattr(module_, "DerivedWidget"));
    91  
    92      auto locals = py::dict("hello"_a = "Hello, World!", "x"_a = 5, **module_.attr("__dict__"));
    93      py::exec(R"(
    94          widget = DerivedWidget("{} - {}".format(hello, x))
    95          message = widget.the_message
    96      )",
    97               py::globals(),
    98               locals);
    99      REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
   100  
   101      auto py_widget = module_.attr("DerivedWidget")("The question");
   102      auto message = py_widget.attr("the_message");
   103      REQUIRE(message.cast<std::string>() == "The question");
   104  
   105      const auto &cpp_widget = py_widget.cast<const Widget &>();
   106      REQUIRE(cpp_widget.the_answer() == 42);
   107  }
   108  
   109  TEST_CASE("Override cache") {
   110      auto module_ = py::module_::import("test_trampoline");
   111      REQUIRE(py::hasattr(module_, "func"));
   112      REQUIRE(py::hasattr(module_, "func2"));
   113  
   114      auto locals = py::dict(**module_.attr("__dict__"));
   115  
   116      int i = 0;
   117      for (; i < 1500; ++i) {
   118          std::shared_ptr<test_override_cache_helper> p_obj;
   119          std::shared_ptr<test_override_cache_helper> p_obj2;
   120  
   121          py::object loc_inst = locals["func"]();
   122          p_obj = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);
   123  
   124          int ret = p_obj->func();
   125  
   126          REQUIRE(ret == 42);
   127  
   128          loc_inst = locals["func2"]();
   129  
   130          p_obj2 = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);
   131  
   132          p_obj2->func();
   133      }
   134  }
   135  
   136  TEST_CASE("Import error handling") {
   137      REQUIRE_NOTHROW(py::module_::import("widget_module"));
   138      REQUIRE_THROWS_WITH(py::module_::import("throw_exception"), "ImportError: C++ Error");
   139      REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"),
   140                          Catch::Contains("ImportError: initialization failed"));
   141  
   142      auto locals = py::dict("is_keyerror"_a = false, "message"_a = "not set");
   143      py::exec(R"(
   144          try:
   145              import throw_error_already_set
   146          except ImportError as e:
   147              is_keyerror = type(e.__cause__) == KeyError
   148              message = str(e.__cause__)
   149      )",
   150               py::globals(),
   151               locals);
   152      REQUIRE(locals["is_keyerror"].cast<bool>() == true);
   153      REQUIRE(locals["message"].cast<std::string>() == "'missing'");
   154  }
   155  
   156  TEST_CASE("There can be only one interpreter") {
   157      static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
   158      static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
   159      static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
   160      static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
   161  
   162      REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
   163      REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
   164  
   165      py::finalize_interpreter();
   166      REQUIRE_NOTHROW(py::scoped_interpreter());
   167      {
   168          auto pyi1 = py::scoped_interpreter();
   169          auto pyi2 = std::move(pyi1);
   170      }
   171      py::initialize_interpreter();
   172  }
   173  
   174  #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
   175  TEST_CASE("Custom PyConfig") {
   176      py::finalize_interpreter();
   177      PyConfig config;
   178      PyConfig_InitPythonConfig(&config);
   179      REQUIRE_NOTHROW(py::scoped_interpreter{&config});
   180      {
   181          py::scoped_interpreter p{&config};
   182          REQUIRE(py::module_::import("widget_module").attr("add")(1, 41).cast<int>() == 42);
   183      }
   184      py::initialize_interpreter();
   185  }
   186  
   187  TEST_CASE("Custom PyConfig with argv") {
   188      py::finalize_interpreter();
   189      {
   190          PyConfig config;
   191          PyConfig_InitIsolatedConfig(&config);
   192          char *argv[] = {strdup("a.out")};
   193          py::scoped_interpreter argv_scope{&config, 1, argv};
   194          std::free(argv[0]);
   195          auto module = py::module::import("test_interpreter");
   196          auto py_widget = module.attr("DerivedWidget")("The question");
   197          const auto &cpp_widget = py_widget.cast<const Widget &>();
   198          REQUIRE(cpp_widget.argv0() == "a.out");
   199      }
   200      py::initialize_interpreter();
   201  }
   202  #endif
   203  
   204  TEST_CASE("Add program dir to path pre-PyConfig") {
   205      py::finalize_interpreter();
   206      size_t path_size_add_program_dir_to_path_false = 0;
   207      {
   208          py::scoped_interpreter scoped_interp{true, 0, nullptr, false};
   209          path_size_add_program_dir_to_path_false = get_sys_path_size();
   210      }
   211      {
   212          py::scoped_interpreter scoped_interp{};
   213          REQUIRE(get_sys_path_size() == path_size_add_program_dir_to_path_false + 1);
   214      }
   215      py::initialize_interpreter();
   216  }
   217  
   218  #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
   219  TEST_CASE("Add program dir to path using PyConfig") {
   220      py::finalize_interpreter();
   221      size_t path_size_add_program_dir_to_path_false = 0;
   222      {
   223          PyConfig config;
   224          PyConfig_InitPythonConfig(&config);
   225          py::scoped_interpreter scoped_interp{&config, 0, nullptr, false};
   226          path_size_add_program_dir_to_path_false = get_sys_path_size();
   227      }
   228      {
   229          PyConfig config;
   230          PyConfig_InitPythonConfig(&config);
   231          py::scoped_interpreter scoped_interp{&config};
   232          REQUIRE(get_sys_path_size() == path_size_add_program_dir_to_path_false + 1);
   233      }
   234      py::initialize_interpreter();
   235  }
   236  #endif
   237  
   238  bool has_pybind11_internals_builtin() {
   239      auto builtins = py::handle(PyEval_GetBuiltins());
   240      return builtins.contains(PYBIND11_INTERNALS_ID);
   241  };
   242  
   243  bool has_pybind11_internals_static() {
   244      auto **&ipp = py::detail::get_internals_pp();
   245      return (ipp != nullptr) && (*ipp != nullptr);
   246  }
   247  
   248  TEST_CASE("Restart the interpreter") {
   249      // Verify pre-restart state.
   250      REQUIRE(py::module_::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
   251      REQUIRE(has_pybind11_internals_builtin());
   252      REQUIRE(has_pybind11_internals_static());
   253      REQUIRE(py::module_::import("external_module").attr("A")(123).attr("value").cast<int>()
   254              == 123);
   255  
   256      // local and foreign module internals should point to the same internals:
   257      REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
   258              == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
   259  
   260      // Restart the interpreter.
   261      py::finalize_interpreter();
   262      REQUIRE(Py_IsInitialized() == 0);
   263  
   264      py::initialize_interpreter();
   265      REQUIRE(Py_IsInitialized() == 1);
   266  
   267      // Internals are deleted after a restart.
   268      REQUIRE_FALSE(has_pybind11_internals_builtin());
   269      REQUIRE_FALSE(has_pybind11_internals_static());
   270      pybind11::detail::get_internals();
   271      REQUIRE(has_pybind11_internals_builtin());
   272      REQUIRE(has_pybind11_internals_static());
   273      REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
   274              == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
   275  
   276      // Make sure that an interpreter with no get_internals() created until finalize still gets the
   277      // internals destroyed
   278      py::finalize_interpreter();
   279      py::initialize_interpreter();
   280      bool ran = false;
   281      py::module_::import("__main__").attr("internals_destroy_test")
   282          = py::capsule(&ran, [](void *ran) {
   283                py::detail::get_internals();
   284                *static_cast<bool *>(ran) = true;
   285            });
   286      REQUIRE_FALSE(has_pybind11_internals_builtin());
   287      REQUIRE_FALSE(has_pybind11_internals_static());
   288      REQUIRE_FALSE(ran);
   289      py::finalize_interpreter();
   290      REQUIRE(ran);
   291      py::initialize_interpreter();
   292      REQUIRE_FALSE(has_pybind11_internals_builtin());
   293      REQUIRE_FALSE(has_pybind11_internals_static());
   294  
   295      // C++ modules can be reloaded.
   296      auto cpp_module = py::module_::import("widget_module");
   297      REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
   298  
   299      // C++ type information is reloaded and can be used in python modules.
   300      auto py_module = py::module_::import("test_interpreter");
   301      auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
   302      REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
   303  }
   304  
   305  TEST_CASE("Subinterpreter") {
   306      // Add tags to the modules in the main interpreter and test the basics.
   307      py::module_::import("__main__").attr("main_tag") = "main interpreter";
   308      {
   309          auto m = py::module_::import("widget_module");
   310          m.attr("extension_module_tag") = "added to module in main interpreter";
   311  
   312          REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
   313      }
   314      REQUIRE(has_pybind11_internals_builtin());
   315      REQUIRE(has_pybind11_internals_static());
   316  
   317      /// Create and switch to a subinterpreter.
   318      auto *main_tstate = PyThreadState_Get();
   319      auto *sub_tstate = Py_NewInterpreter();
   320  
   321      // Subinterpreters get their own copy of builtins. detail::get_internals() still
   322      // works by returning from the static variable, i.e. all interpreters share a single
   323      // global pybind11::internals;
   324      REQUIRE_FALSE(has_pybind11_internals_builtin());
   325      REQUIRE(has_pybind11_internals_static());
   326  
   327      // Modules tags should be gone.
   328      REQUIRE_FALSE(py::hasattr(py::module_::import("__main__"), "tag"));
   329      {
   330          auto m = py::module_::import("widget_module");
   331          REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
   332  
   333          // Function bindings should still work.
   334          REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
   335      }
   336  
   337      // Restore main interpreter.
   338      Py_EndInterpreter(sub_tstate);
   339      PyThreadState_Swap(main_tstate);
   340  
   341      REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag"));
   342      REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag"));
   343  }
   344  
   345  TEST_CASE("Execution frame") {
   346      // When the interpreter is embedded, there is no execution frame, but `py::exec`
   347      // should still function by using reasonable globals: `__main__.__dict__`.
   348      py::exec("var = dict(number=42)");
   349      REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
   350  }
   351  
   352  TEST_CASE("Threads") {
   353      // Restart interpreter to ensure threads are not initialized
   354      py::finalize_interpreter();
   355      py::initialize_interpreter();
   356      REQUIRE_FALSE(has_pybind11_internals_static());
   357  
   358      constexpr auto num_threads = 10;
   359      auto locals = py::dict("count"_a = 0);
   360  
   361      {
   362          py::gil_scoped_release gil_release{};
   363  
   364          auto threads = std::vector<std::thread>();
   365          for (auto i = 0; i < num_threads; ++i) {
   366              threads.emplace_back([&]() {
   367                  py::gil_scoped_acquire gil{};
   368                  locals["count"] = locals["count"].cast<int>() + 1;
   369              });
   370          }
   371  
   372          for (auto &thread : threads) {
   373              thread.join();
   374          }
   375      }
   376  
   377      REQUIRE(locals["count"].cast<int>() == num_threads);
   378  }
   379  
   380  // Scope exit utility https://stackoverflow.com/a/36644501/7255855
   381  struct scope_exit {
   382      std::function<void()> f_;
   383      explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
   384      ~scope_exit() {
   385          if (f_) {
   386              f_();
   387          }
   388      }
   389  };
   390  
   391  TEST_CASE("Reload module from file") {
   392      // Disable generation of cached bytecode (.pyc files) for this test, otherwise
   393      // Python might pick up an old version from the cache instead of the new versions
   394      // of the .py files generated below
   395      auto sys = py::module_::import("sys");
   396      bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
   397      sys.attr("dont_write_bytecode") = true;
   398      // Reset the value at scope exit
   399      scope_exit reset_dont_write_bytecode(
   400          [&]() { sys.attr("dont_write_bytecode") = dont_write_bytecode; });
   401  
   402      std::string module_name = "test_module_reload";
   403      std::string module_file = module_name + ".py";
   404  
   405      // Create the module .py file
   406      std::ofstream test_module(module_file);
   407      test_module << "def test():\n";
   408      test_module << "    return 1\n";
   409      test_module.close();
   410      // Delete the file at scope exit
   411      scope_exit delete_module_file([&]() { std::remove(module_file.c_str()); });
   412  
   413      // Import the module from file
   414      auto module_ = py::module_::import(module_name.c_str());
   415      int result = module_.attr("test")().cast<int>();
   416      REQUIRE(result == 1);
   417  
   418      // Update the module .py file with a small change
   419      test_module.open(module_file);
   420      test_module << "def test():\n";
   421      test_module << "    return 2\n";
   422      test_module.close();
   423  
   424      // Reload the module
   425      module_.reload();
   426      result = module_.attr("test")().cast<int>();
   427      REQUIRE(result == 2);
   428  }
   429  
   430  TEST_CASE("sys.argv gets initialized properly") {
   431      py::finalize_interpreter();
   432      {
   433          py::scoped_interpreter default_scope;
   434          auto module = py::module::import("test_interpreter");
   435          auto py_widget = module.attr("DerivedWidget")("The question");
   436          const auto &cpp_widget = py_widget.cast<const Widget &>();
   437          REQUIRE(cpp_widget.argv0().empty());
   438      }
   439  
   440      {
   441          char *argv[] = {strdup("a.out")};
   442          py::scoped_interpreter argv_scope(true, 1, argv);
   443          std::free(argv[0]);
   444          auto module = py::module::import("test_interpreter");
   445          auto py_widget = module.attr("DerivedWidget")("The question");
   446          const auto &cpp_widget = py_widget.cast<const Widget &>();
   447          REQUIRE(cpp_widget.argv0() == "a.out");
   448      }
   449      py::initialize_interpreter();
   450  }
   451  
   452  TEST_CASE("make_iterator can be called before then after finalizing an interpreter") {
   453      // Reproduction of issue #2101 (https://github.com/pybind/pybind11/issues/2101)
   454      py::finalize_interpreter();
   455  
   456      std::vector<int> container;
   457      {
   458          pybind11::scoped_interpreter g;
   459          auto iter = pybind11::make_iterator(container.begin(), container.end());
   460      }
   461  
   462      REQUIRE_NOTHROW([&]() {
   463          pybind11::scoped_interpreter g;
   464          auto iter = pybind11::make_iterator(container.begin(), container.end());
   465      }());
   466  
   467      py::initialize_interpreter();
   468  }