Skip to content

Commit

Permalink
change numpy to jax-numpy to enable grad functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
rochisha0 committed Jun 10, 2024
1 parent c874c4a commit ce6328b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions qutip/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'hellinger_dist', 'hilbert_dist', 'average_gate_fidelity',
'process_fidelity', 'unitarity', 'dnorm']

import numpy as np
import jax.numpy as np
from scipy import linalg as la
import scipy.sparse as sp
from .superop_reps import to_choi, _to_superpauli, to_super, kraus_to_choi
Expand Down Expand Up @@ -80,7 +80,7 @@ def fidelity(A, B):
# even for positive semidefinite matrices, small negative eigenvalues
# can be reported.
eig_vals = (sqrtmA * B * sqrtmA).eigenenergies()
return float(np.real(np.sqrt(eig_vals[eig_vals > 0]).sum()))
return np.float64(np.real(np.sqrt(eig_vals[eig_vals > 0]).sum()))


def _hilbert_space_dims(oper):
Expand Down Expand Up @@ -288,7 +288,7 @@ def tracedist(A, B, sparse=False, tol=0):
diff = A - B
diff = diff.dag() * diff
vals = diff.eigenenergies(sparse=sparse, tol=tol)
return float(np.real(0.5 * np.sum(np.sqrt(np.abs(vals)))))
return np.float64(np.real(0.5 * np.sum(np.sqrt(np.abs(vals)))))


def hilbert_dist(A, B):
Expand Down
8 changes: 4 additions & 4 deletions qutip/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
'concurrence', 'entropy_conditional', 'entangling_power',
'entropy_relative']

from numpy import conj, e, inf, imag, inner, real, sort, sqrt
from numpy.lib.scimath import log, log2
from jax.numpy import conj, e, inf, imag, inner, real, sort, sqrt, float64
from jax.numpy import log, log2
from .partial_transpose import partial_transpose
from . import (ptrace, tensor, sigmay, ket2dm,
expand_operator)
Expand Down Expand Up @@ -45,7 +45,7 @@ def entropy_vn(rho, base=e, sparse=False):
logvals = log(nzvals)
else:
raise ValueError("Base must be 2 or e.")
return float(real(-sum(nzvals * logvals)))
return float64(real(-sum(nzvals * logvals)))


def entropy_linear(rho):
Expand All @@ -71,7 +71,7 @@ def entropy_linear(rho):
"""
if rho.type == 'ket' or rho.type == 'bra':
rho = ket2dm(rho)
return float(real(1.0 - (rho ** 2).tr()))
return float64(real(1.0 - (rho ** 2).tr()))


def concurrence(rho):
Expand Down

0 comments on commit ce6328b

Please sign in to comment.