diff --git a/doc/changes/2499.feature b/doc/changes/2499.feature new file mode 100644 index 0000000000..e990850f3e --- /dev/null +++ b/doc/changes/2499.feature @@ -0,0 +1 @@ +Enable mcsolve with jax.grad using numpy_backend \ No newline at end of file diff --git a/qutip/solver/mcsolve.py b/qutip/solver/mcsolve.py index 68f9b8caf8..b33a86c769 100644 --- a/qutip/solver/mcsolve.py +++ b/qutip/solver/mcsolve.py @@ -3,7 +3,7 @@ __all__ = ['mcsolve', "MCSolver"] -import numpy as np +from ..core.numpy_backend import np from numpy.typing import ArrayLike from numpy.random import SeedSequence from ..core import QobjEvo, spre, spost, Qobj, unstack_columns, qzero_like diff --git a/qutip/solver/multitraj.py b/qutip/solver/multitraj.py index 646e38ab45..106f5753da 100644 --- a/qutip/solver/multitraj.py +++ b/qutip/solver/multitraj.py @@ -7,9 +7,9 @@ from time import time from .solver_base import Solver from ..core import QobjEvo, Qobj -import numpy as np +from ..core.numpy_backend import np from numpy.typing import ArrayLike -from numpy.random import SeedSequence +from numpy.random import SeedSequence, default_rng from numbers import Number from typing import Any, Callable import bisect @@ -87,7 +87,7 @@ def __init__(self, rhs, *, options=None): else: raise TypeError("The system should be a QobjEvo") self.options = options - self.seed_sequence = np.random.SeedSequence() + self.seed_sequence = SeedSequence() self._integrator = self._get_integrator() self._state_metadata = {} self.stats = self._initialize_stats() @@ -360,15 +360,15 @@ def _read_seed(self, seed, ntraj): """ if seed is None: seeds = self.seed_sequence.spawn(ntraj) - elif isinstance(seed, np.random.SeedSequence): + elif isinstance(seed, SeedSequence): seeds = seed.spawn(ntraj) elif not isinstance(seed, list): - seeds = np.random.SeedSequence(seed).spawn(ntraj) + seeds = SeedSequence(seed).spawn(ntraj) elif len(seed) >= ntraj: seeds = [ - seed_ if (isinstance(seed_, np.random.SeedSequence) + seed_ if (isinstance(seed_, SeedSequence) or hasattr(seed_, 'random')) - else np.random.SeedSequence(seed_) + else SeedSequence(seed_) for seed_ in seed[:ntraj] ] else: @@ -391,7 +391,7 @@ def _get_generator(self, seed): bit_gen = getattr(np.random, self.options['bitgenerator']) generator = np.random.Generator(bit_gen(seed)) else: - generator = np.random.default_rng(seed) + generator = default_rng(seed) return generator diff --git a/qutip/solver/multitrajresult.py b/qutip/solver/multitrajresult.py index 96ea4e8fdd..75f50f97c5 100644 --- a/qutip/solver/multitrajresult.py +++ b/qutip/solver/multitrajresult.py @@ -5,7 +5,7 @@ """ from typing import TypedDict -import numpy as np +from ..core.numpy_backend import np from copy import copy diff --git a/qutip/solver/result.py b/qutip/solver/result.py index 7b019381a3..7cb77cd5d6 100644 --- a/qutip/solver/result.py +++ b/qutip/solver/result.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TypedDict, Any, Callable -import numpy as np +from ..core.numpy_backend import np from numpy.typing import ArrayLike from ..core import Qobj, QobjEvo, expect