Skip to content

Commit bfe55ed

Browse files
committed
Fix race condition with py::make_key_iterator in free threading
The creation of the iterator class needs to be synchronized.
1 parent 6c83607 commit bfe55ed

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

include/pybind11/detail/internals.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,32 @@ class pymutex {
238238
void unlock() { PyMutex_Unlock(&mutex); }
239239
};
240240

241+
// A recursive mutex implementation using PyMutex
242+
class pyrecursive_mutex {
243+
PyMutex mutex;
244+
std::atomic<uintptr_t> owner;
245+
size_t lock_count;
246+
247+
public:
248+
pyrecursive_mutex() : mutex({}), owner(0), lock_count(0) {}
249+
void lock() {
250+
if (owner.load(std::memory_order_relaxed) == _Py_ThreadId()) {
251+
++lock_count;
252+
return;
253+
}
254+
PyMutex_Lock(&mutex);
255+
owner.store(_Py_ThreadId(), std::memory_order_relaxed);
256+
}
257+
void unlock() {
258+
if (lock_count > 0) {
259+
--lock_count;
260+
return;
261+
}
262+
owner.store(0, std::memory_order_relaxed);
263+
PyMutex_Unlock(&mutex);
264+
}
265+
};
266+
241267
// Instance map shards are used to reduce mutex contention in free-threaded Python.
242268
struct instance_map_shard {
243269
instance_map registered_instances;
@@ -271,7 +297,7 @@ class loader_life_support;
271297
/// `PYBIND11_INTERNALS_VERSION` must be incremented.
272298
struct internals {
273299
#ifdef Py_GIL_DISABLED
274-
pymutex mutex;
300+
pyrecursive_mutex mutex;
275301
pymutex exception_translator_mutex;
276302
#endif
277303
#if PYBIND11_INTERNALS_VERSION >= 12
@@ -856,7 +882,7 @@ inline local_internals &get_local_internals() {
856882
}
857883

858884
#ifdef Py_GIL_DISABLED
859-
# define PYBIND11_LOCK_INTERNALS(internals) std::unique_lock<pymutex> lock((internals).mutex)
885+
# define PYBIND11_LOCK_INTERNALS(internals) std::unique_lock<pyrecursive_mutex> lock((internals).mutex)
860886
#else
861887
# define PYBIND11_LOCK_INTERNALS(internals)
862888
#endif

include/pybind11/pybind11.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3173,6 +3173,7 @@ iterator make_iterator_impl(Iterator first, Sentinel last, Extra &&...extra) {
31733173
using state = detail::iterator_state<Access, Policy, Iterator, Sentinel, ValueType, Extra...>;
31743174
// TODO: state captures only the types of Extra, not the values
31753175

3176+
PYBIND11_LOCK_INTERNALS(get_internals());
31763177
if (!detail::get_type_info(typeid(state), false)) {
31773178
class_<state>(handle(), "iterator", pybind11::module_local())
31783179
.def(

0 commit comments

Comments
 (0)