"""
Thread-Local Context Variable Utilities.
This module provides a thread-safe context variable management system that allows
developers to build context-dependent behavior without explicitly passing values
through a call stack. It supports context inheritance, scoped variable overrides,
function wrapping for new threads, nested context management, and conditional
context creation.
The module contains the following main public components:
* :class:`ContextVars` - Thread-local context mapping with scoped mutation helpers
* :func:`context` - Accessor for the current thread's :class:`ContextVars`
* :func:`cwrap` - Function wrapper to inherit and extend context in new threads
* :func:`nested_with` - Helper to nest multiple context managers
* :func:`conditional_with` - Conditionally enter a context manager
.. note::
Context instances are stored per thread and reused within the same thread.
Use :func:`cwrap` to pass the current context into new threads.
Example::
>>> from contextlib import contextmanager
>>> from hbutils.reflection import context
>>>
>>> @contextmanager
... def use_mul():
... with context().vars(mul=True):
... yield
>>>
>>> def calc(a, b):
... if context().get('mul', None):
... return a * b
... else:
... return a + b
>>>
>>> print(calc(3, 5))
8
>>> with use_mul():
... print(calc(3, 5))
15
>>> print(calc(3, 5))
8
"""
import collections.abc
from contextlib import contextmanager
from functools import wraps
from multiprocessing import current_process
from threading import current_thread
from typing import Tuple, TypeVar, Iterator, Mapping, Optional, ContextManager, Any, Dict, Callable, List
__all__ = [
'context', 'cwrap',
'nested_with', 'conditional_with',
]
def _get_pid() -> int:
"""
Get the current process ID.
:return: The process ID of the current process.
:rtype: int
"""
return current_process().pid
def _get_tid() -> int:
"""
Get the current thread ID.
:return: The thread identifier of the current thread.
:rtype: int
"""
return current_thread().ident
def _get_context_id() -> Tuple[int, int]:
"""
Get the unique context identifier for the current thread.
:return: A tuple containing (process_id, thread_id).
:rtype: Tuple[int, int]
"""
return _get_pid(), _get_tid()
_global_contexts: Dict[Tuple[int, int], 'ContextVars'] = {}
_KeyType = TypeVar('_KeyType', bound=str)
_ValueType = TypeVar('_ValueType')
[docs]
class ContextVars(collections.abc.Mapping):
"""
Context variable management class.
This class provides a thread-safe way to manage context variables that can be
temporarily modified within a with-block scope. It inherits from
:class:`collections.abc.Mapping` and supports standard mapping operations.
.. note::
This class is inherited from :class:`collections.abc.Mapping`.
Main features of mapping object (such as ``__getitem__``, ``__len__``,
``__iter__``) are supported. See `Collections Abstract Base Classes
<https://docs.python.org/3/library/collections.abc.html#collections-abstract-base-classes>`_.
.. warning::
This object should be singleton on thread level.
It is not recommended constructing manually. Use :func:`context` instead.
"""
[docs]
def __init__(self, **kwargs: _ValueType) -> None:
"""
Initialize a ContextVars instance.
:param kwargs: Initial context variable key-value pairs.
:type kwargs: _ValueType
"""
self._vars: Dict[_KeyType, _ValueType] = dict(kwargs)
@contextmanager
def _with_vars(self, params: Mapping[_KeyType, _ValueType], clear: bool = False) -> Iterator[None]:
"""
Internal context manager for temporarily modifying context variables.
:param params: Dictionary of variables to set in the context.
:type params: Mapping[_KeyType, _ValueType]
:param clear: If True, remove all variables not present in params. Default is False.
:type clear: bool
:yield: None
"""
# Initialize new values.
_origin = dict(self._vars)
self._vars.update(params)
if clear:
for key in list(self._vars.keys()):
if key not in params:
del self._vars[key]
try:
yield
finally:
# De-initialize and recover changed values.
for k in set(_origin.keys()) | set(self._vars.keys()):
if k not in _origin:
del self._vars[k]
else:
self._vars[k] = _origin[k]
[docs]
@contextmanager
def vars(self, **kwargs: _ValueType) -> Iterator[None]:
"""
Add or modify variables in the context within a with-block.
This method temporarily adds or updates context variables for the duration
of the with-block. Original values are restored when exiting the block.
:param kwargs: Context variables to add or modify.
:type kwargs: _ValueType
:yield: None
Examples::
>>> from hbutils.reflection import context
>>>
>>> def var_detect():
... if context().get('var', None):
... print(f'Var detected, its value is {context()["var"]}.')
... else:
... print('Var not detected.')
>>>
>>> var_detect()
Var not detected.
>>> with context().vars(var=1):
... var_detect()
Var detected, its value is 1.
>>> var_detect()
Var not detected.
.. note::
See :func:`context`.
"""
with self._with_vars(kwargs, clear=False):
yield
[docs]
@contextmanager
def inherit(self, context_: 'ContextVars') -> Iterator[None]:
"""
Inherit variables from another context.
This method replaces the current context variables with those from the given
context. Variables not present in the given context will be removed.
:param context_: ContextVars object to inherit from.
:type context_: ContextVars
:yield: None
.. note::
After :meth:`inherit` is used, **the original variables which not present in the given
``context_`` will be removed**. This is different from :meth:`vars`, so attention.
"""
with self._with_vars(context_._vars, clear=True):
yield
[docs]
def __getitem__(self, key: _KeyType) -> _ValueType:
"""
Get a context variable by key.
:param key: The key of the variable to retrieve.
:type key: _KeyType
:return: The value associated with the key.
:rtype: _ValueType
:raises KeyError: If the key is not found in the context.
"""
return self._vars[key]
[docs]
def __len__(self) -> int:
"""
Get the number of variables in the context.
:return: The number of context variables.
:rtype: int
"""
return len(self._vars)
[docs]
def __iter__(self) -> Iterator[_KeyType]:
"""
Iterate over the keys of context variables.
:return: An iterator over the context variable keys.
:rtype: Iterator[_KeyType]
"""
return self._vars.__iter__()
[docs]
def context() -> ContextVars:
"""
Get the context object for the current thread.
This function returns a thread-local singleton :class:`ContextVars` instance.
Each thread has its own independent context that persists across function calls
within that thread.
:return: The :class:`ContextVars` object for the current thread.
:rtype: ContextVars
.. note::
This result is unique on one thread. Multiple calls within the same thread
will return the same :class:`ContextVars` instance.
"""
_context_id = _get_context_id()
if _context_id not in _global_contexts:
_context = ContextVars()
_global_contexts[_context_id] = _context
return _global_contexts[_context_id]
[docs]
def cwrap(
func: Callable[..., _ValueType],
*,
context_: Optional[ContextVars] = None,
**vars_: _ValueType
) -> Callable[..., _ValueType]:
"""
Wrap a function to inherit and extend context variables.
This decorator is essential for passing context variables into new threads,
as thread-local storage is not automatically inherited by child threads.
:param func: The function to wrap.
:type func: Callable[..., _ValueType]
:param context_: Context to inherit. If None, uses the current thread's context.
:type context_: Optional[ContextVars]
:param vars_: Additional variables to add after inheriting the context.
:type vars_: _ValueType
:return: A wrapped function that executes with the inherited context.
:rtype: Callable[..., _ValueType]
Examples::
>>> from threading import Thread
>>> from hbutils.reflection import context, cwrap
>>>
>>> def var_detect():
... if context().get('var', None):
... print(f'Var detected, its value is {context()["var"]}.')
... else:
... print('Var not detected.')
>>>
>>> with context().vars(var=1): # no inherit, vars will be lost in thread
... t = Thread(target=var_detect)
... t.start()
... t.join()
Var not detected.
>>> with context().vars(var=1): # with inherit, vars will be kept in thread
... t = Thread(target=cwrap(var_detect))
... t.start()
... t.join()
Var detected, its value is 1.
.. note::
:func:`cwrap` is important when you need to pass the current context into thread.
And **it is compatible on all platforms**.
.. warning::
:func:`cwrap` **is not compatible on Windows or Python3.8+ on macOS** when creating
**new process**. Please pass in direct arguments by ``args`` argument of
:class:`Process`. If you insist on using :func:`context` feature, you need to pass
the context object into the subprocess.
For example::
>>> from contextlib import contextmanager
>>> from multiprocessing import Process
>>> from hbutils.reflection import context
>>>
>>> @contextmanager
... def use_mul():
... with context().vars(mul=True):
... yield
>>>
>>> def calc(a, b):
... if context().get('mul', None):
... print(a * b)
... else:
... print(a + b)
>>>
>>> def _calc(a, b, ctx=None):
... with context().inherit(ctx or context()):
... return calc(a, b)
>>>
>>> if __name__ == '__main__':
... calc(3, 5)
... with use_mul():
... p = Process(target=_calc, args=(3, 5, context()))
... p.start()
... p.join()
... calc(3, 5)
8
15
8
"""
context_ = context_ or context()
@wraps(func)
def _new_func(*args: Any, **kwargs: Any) -> _ValueType:
with context().inherit(context_):
with context().vars(**vars_):
return func(*args, **kwargs)
return _new_func
def _yield_nested_for(
contexts: List[ContextManager[Any]],
depth: int,
items: List[Any]
) -> Iterator[Tuple[Any, ...]]:
"""
Internal recursive generator for nested context management.
:param contexts: List of context managers to nest.
:type contexts: list
:param depth: Current recursion depth.
:type depth: int
:param items: Accumulated items from entered contexts.
:type items: list
:yield: Tuple of items from all entered contexts.
:rtype: tuple
"""
if depth >= len(contexts):
yield tuple(items)
else:
with contexts[depth] as current_item:
items.append(current_item)
yield from _yield_nested_for(contexts, depth + 1, items)
[docs]
@contextmanager
def nested_with(*contexts: ContextManager[Any]) -> Iterator[Tuple[Any, ...]]:
"""
Enter and exit multiple context managers in a nested fashion.
This function allows you to manage multiple context managers simultaneously,
entering them in order and exiting them in reverse order (LIFO).
:param contexts: Variable number of context managers to nest.
:type contexts: ContextManager
:return: A context manager that yields a tuple of values from all nested contexts.
:rtype: ContextManager[Tuple[Any, ...]]
Examples::
>>> import os.path
>>> import pathlib
>>> import tempfile
>>> from contextlib import contextmanager
>>> from hbutils.reflection import nested_with
>>>
>>> @contextmanager
... def opent(x):
... with tempfile.TemporaryDirectory() as td:
... pathlib.Path(os.path.join(td, f'{x}.txt')).write_text(f'this is {x}!')
... yield td
>>>
>>> with opent(1) as d:
... print(os.listdir(d))
... print(pathlib.Path(f'{d}/1.txt').read_text())
['1.txt']
this is 1!
>>> with nested_with(*map(opent, range(5))) as ds:
... for d in ds:
... print(d)
... print(os.path.exists(d), os.listdir(d))
... print(pathlib.Path(f'{d}/{os.listdir(d)[0]}').read_text())
/tmp/tmp3u1984br
True ['0.txt']
this is 0!
/tmp/tmp0yx56hv0
True ['1.txt']
this is 1!
/tmp/tmpu_33drm3
True ['2.txt']
this is 2!
/tmp/tmpqal_vzgi
True ['3.txt']
this is 3!
/tmp/tmpy99_wwtt
True ['4.txt']
this is 4!
>>> for d in ds:
... print(d)
... print(os.path.exists(d))
/tmp/tmp3u1984br
False
/tmp/tmp0yx56hv0
False
/tmp/tmpu_33drm3
False
/tmp/tmpqal_vzgi
False
/tmp/tmpy99_wwtt
False
"""
yield from _yield_nested_for(list(contexts), 0, [])
[docs]
@contextmanager
def conditional_with(ctx: ContextManager[Any], cond: bool) -> Iterator[Optional[Any]]:
"""
Conditionally create and enter a context manager.
This function provides a way to conditionally use a context manager based on
a boolean condition. If the condition is False, the context is not entered
and ``None`` is yielded instead.
:param ctx: The context manager to conditionally enter.
:type ctx: ContextManager
:param cond: Boolean condition determining whether to enter the context.
:type cond: bool
:yield: The value from the context manager if cond is True, otherwise None.
:rtype: Optional[Any]
Examples::
Here is an example of conditionally creating a temporary directory.
>>> import os.path
>>>
>>> from hbutils.reflection import conditional_with
>>> from hbutils.system import TemporaryDirectory
>>>
>>> with conditional_with(TemporaryDirectory(), cond=True) as td:
... print('td:', td)
... print('exist:', os.path.exists(td))
... print('isdir:', os.path.isdir(td))
...
td: /tmp/tmp07lpb9ah
exist: True
isdir: True
>>> with conditional_with(TemporaryDirectory(), cond=False) as td:
... print('td:', td)
...
td: None
"""
if cond:
with ctx as f:
yield f
else:
yield None