hbutils.random.state

Overview:

Random state and seed management.

Native default random instance (random._inst), torch, numpy and faker (default random instance) are supported. Other random sources can be registered with register_random_source() and register_random_instance().

register_random_source

hbutils.random.state.register_random_source(name: str, seed: Callable[[int], None], getstate: Callable[[], T], setstate: Callable[[T], None])[source]

Register random source, providing name, random seed function, state getting function and state setting function.

Parameters:
  • name (str) – Name of random source.

  • seed (_SEED_FUNC) – Seed function, format: seed(x).

  • getstate (_GETSTATE_FUNC) – State getting function, format: getstate().

  • setstate (_SETSTATE_FUNC) – State setting function, format: setstate(state).

Raises:

NameError – If the name already exists in registered random sources.

Examples::
>>> import random
>>> from hbutils.random import global_seed, register_random_source
>>>
>>> rnd = random.Random()  # custom random object
>>> global_seed(0)  # try to use same seed
>>> rnd.random()
0.4563765178328746
>>> global_seed(0)
>>> rnd.random()  # not same
0.06875325446897462
>>>
>>> register_random_source('custom_random', rnd.seed, rnd.getstate, rnd.setstate)
>>> global_seed(0)  # try again
>>> rnd.random()
0.8444218515250481
>>> global_seed(0)
>>> rnd.random()  # the same
0.8444218515250481

register_random_instance

hbutils.random.state.register_random_instance(name: str, rnd: Random)[source]

Register custom random instance.

Parameters:
  • name (str) – Name of random source.

  • rnd (random.Random) – Custom random instance, should be an instance of random.Random.

Examples::
>>> import random
>>> from hbutils.random import global_seed, register_random_instance
>>>
>>> rnd = random.Random()  # custom random object
>>> global_seed(0)  # try to use same seed
>>> rnd.random()
0.48936053503964005
>>> global_seed(0)
>>> rnd.random()  # not same
0.4113361070387721
>>>
>>> register_random_instance('custom_random', rnd)
>>> global_seed(0)  # try again
>>> rnd.random()
0.8444218515250481
>>> global_seed(0)
>>> rnd.random()  # the same
0.8444218515250481

get_global_state

hbutils.random.state.get_global_state() Dict[str, T][source]

Get states of all registered random sources.

Returns:

A dictionary mapping random source names to their current states.

Return type:

Dict[str, T]

Examples::
>>> import random
>>> import numpy as np
>>> import torch
>>> from faker import Faker
>>>
>>> from hbutils.random import get_global_state, set_global_state
>>>
>>> _ = random.randint(0, 100)  # just do something
>>> _ = random.random()
>>> _ = torch.randn(2, 3)
>>> _ = np.random.randn(2, 3)
>>>
>>> states = get_global_state()
>>> random.randint(0, 100)  # first time's result
99
>>> random.random()
0.4656250864192085
>>> torch.randn(2, 3)
tensor([[ 0.8886, -0.3602,  1.3071],
        [-0.0187, -0.5980, -0.5469]])
>>> np.random.randn(2, 3)
array([[ 1.24249156,  0.71018699, -0.53496231],
       [ 0.78748336, -0.01407442, -0.6607438 ]])
>>> Faker().sentence(5)
'New pass crime most.'
>>>
>>> set_global_state(states)
>>> random.randint(0, 100)  # same as the first time
99
>>> random.random()
0.4656250864192085
>>> torch.randn(2, 3)
tensor([[ 0.8886, -0.3602,  1.3071],
        [-0.0187, -0.5980, -0.5469]])
>>> np.random.randn(2, 3)
array([[ 1.24249156,  0.71018699, -0.53496231],
       [ 0.78748336, -0.01407442, -0.6607438 ]])
>>> Faker().sentence(5)
'New pass crime most.'

set_global_state

hbutils.random.state.set_global_state(states: Mapping[str, T])[source]

Set states of registered random sources.

Parameters:

states (Mapping[str, T]) – A mapping of random source names to their states to be restored.

Note

If a state is provided for a non-existent random source, a warning will be issued. If a registered random source is not provided in the states, a warning will be issued.

Examples::

See get_global_state().

keep_global_state

hbutils.random.state.keep_global_state()[source]

Context manager to preserve all random states during execution.

This context manager saves the current state of all registered random sources, executes the code block, and then restores the saved states regardless of whether the code block completes successfully or raises an exception.

Yields:

None

Examples::
>>> import torch
>>> from hbutils.random import global_seed, keep_global_state
>>>
>>> global_seed(0)
>>> torch.randn(2, 3)  # before value 1
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986]])
>>> torch.randn(2, 3)  # after value 1
tensor([[ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])
>>>
>>> global_seed(0)
>>> torch.randn(2, 3)  # before value 2, same as 1
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986]])
>>> with keep_global_state():  # do anything you want here
...     _ = torch.randn(100, 200, 2)
...     _ = torch.randint(20, 30, (30, 40))
>>> torch.randn(2, 3)  # after value 2, same as 1
tensor([[ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])

global_seed

hbutils.random.state.global_seed(seed: int)[source]

Set seed for all registered random sources.

This function applies the same seed value to all registered random sources, ensuring reproducible random number generation across different libraries.

Parameters:

seed (int) – Random seed value to be applied to all random sources.

Examples::

See keep_global_state() and register_random_instance().

seedable_func

hbutils.random.state.seedable_func(func: Callable) Callable[source]

Decorator to add seed support to a function.

This decorator wraps a function to add an optional seed keyword argument. When provided, the seed is applied to all registered random sources before executing the function, enabling reproducible results.

Parameters:

func (Callable) – Function to be decorated.

Returns:

Wrapped function with an additional seed keyword argument.

Return type:

Callable

Examples::
>>> import torch
>>> from hbutils.random import seedable_func
>>>
>>> @seedable_func
... def get_random_value(mean, std):
...     return torch.randn((2, 3)) * std + mean
>>>
>>> get_random_value(2, 3)
tensor([[-0.0844,  5.2530,  3.4248],
        [ 4.4923,  0.0492,  3.6731]])
>>> get_random_value(2, 3)  # not the same
tensor([[2.7600, 2.5135, 0.6484],
        [1.2459, 0.1020, 2.5905]])
>>>
>>> get_random_value(2, 3, seed=0)
tensor([[ 6.6230,  1.1197, -4.5364],
        [ 3.7053, -1.2536, -2.1958]])
>>> get_random_value(2, 3, seed=1)
tensor([[3.9841, 2.8008, 2.1850],
        [3.8640, 0.6443, 1.5016]])
>>> get_random_value(2, 3, seed=0)  # repeatable
tensor([[ 6.6230,  1.1197, -4.5364],
        [ 3.7053, -1.2536, -2.1958]])