| | import random |
| | import sys |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | class RandomState: |
| | def __init__(self): |
| | self.random_mod_state = random.getstate() |
| | self.np_state = np.random.get_state() |
| | self.torch_cpu_state = torch.get_rng_state() |
| | self.torch_gpu_states = [ |
| | torch.cuda.get_rng_state(d) |
| | for d in range(torch.cuda.device_count()) |
| | ] |
| |
|
| | def restore(self): |
| | random.setstate(self.random_mod_state) |
| | np.random.set_state(self.np_state) |
| | torch.set_rng_state(self.torch_cpu_state) |
| | for d, state in enumerate(self.torch_gpu_states): |
| | torch.cuda.set_rng_state(state, d) |
| |
|
| |
|
| | class RandomContext: |
| | '''Save and restore state of PyTorch, NumPy, Python RNGs.''' |
| | def __init__(self, seed=None): |
| | outside_state = RandomState() |
| |
|
| | random.seed(seed) |
| | np.random.seed(seed) |
| | if seed is None: |
| | torch.manual_seed(random.randint(-sys.maxsize - 1, sys.maxsize)) |
| | else: |
| | torch.manual_seed(seed) |
| | |
| | self.inside_state = RandomState() |
| |
|
| | outside_state.restore() |
| |
|
| | self._active = False |
| | |
| | def __enter__(self): |
| | if self._active: |
| | raise Exception('RandomContext can be active only once') |
| |
|
| | |
| | self.outside_state = RandomState() |
| | |
| | self.inside_state.restore() |
| | self._active = True |
| |
|
| | def __exit__(self, exception_type, exception_value, traceback): |
| | |
| | self.inside_state = RandomState() |
| | |
| | self.outside_state.restore() |
| | self.outside_state = None |
| |
|
| | self._active = False |
| |
|