Skip to content

Commit

Permalink
Clean up redundant complex type unions (#7041)
Browse files Browse the repository at this point in the history
* Clean up redundant complex type unions

* format

* fix numpy ufunc, one mypy err

* Additional cleanup
  • Loading branch information
daxfohl authored Feb 7, 2025
1 parent 14d61c8 commit 72e1d20
Show file tree
Hide file tree
Showing 19 changed files with 50 additions and 58 deletions.
2 changes: 1 addition & 1 deletion cirq-core/cirq/interop/quirk/cells/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def classify(e: str) -> Union[str, float]:
return _merge_scientific_float_tokens(g for g in result if g.strip())


_ResolvedToken = Union[sympy.Expr, int, float, complex]
_ResolvedToken = Union[sympy.Expr, complex]


class _CustomQuirkOperationToken:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/linalg/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from numpy.typing import DTypeLike, ArrayLike


def kron(*factors: Union[np.ndarray, complex, float], shape_len: int = 2) -> np.ndarray:
def kron(*factors: Union[np.ndarray, complex], shape_len: int = 2) -> np.ndarray:
"""Computes the kronecker product of a sequence of values.
A *args version of lambda args: functools.reduce(np.kron, args).
Expand Down Expand Up @@ -56,7 +56,7 @@ def kron(*factors: Union[np.ndarray, complex, float], shape_len: int = 2) -> np.
)


def kron_with_controls(*factors: Union[np.ndarray, complex, float]) -> np.ndarray:
def kron_with_controls(*factors: Union[np.ndarray, complex]) -> np.ndarray:
"""Computes the kronecker product of a sequence of values and control tags.
Use `cirq.CONTROL_TAG` to represent controls. Any entry of the output
Expand Down
10 changes: 2 additions & 8 deletions cirq-core/cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,7 @@ class AxisAngleDecomposition:
rotation axis, and g is the global phase.
"""

def __init__(
self,
*,
angle: float,
axis: Tuple[float, float, float],
global_phase: Union[int, float, complex],
):
def __init__(self, *, angle: float, axis: Tuple[float, float, float], global_phase: complex):
if not np.isclose(np.linalg.norm(axis, 2), 1, atol=1e-8):
raise ValueError('Axis vector must be normalized.')
self.global_phase = complex(global_phase)
Expand Down Expand Up @@ -634,7 +628,7 @@ def scatter_plot_normalized_kak_interaction_coefficients(
ax = cast(mplot3d.axes3d.Axes3D, fig.add_subplot(1, 1, 1, projection='3d'))

def coord_transform(
pts: Union[List[Tuple[int, int, int]], np.ndarray]
pts: Union[List[Tuple[int, int, int]], np.ndarray],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if len(pts) == 0:
return np.array([]), np.array([]), np.array([])
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/linalg/tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def all_near_zero(a: 'ArrayLike', *, atol: float = 1e-8) -> bool:


def all_near_zero_mod(
a: Union[float, complex, Iterable[float], np.ndarray], period: float, *, atol: float = 1e-8
a: Union[float, Iterable[float], np.ndarray], period: float, *, atol: float = 1e-8
) -> bool:
"""Checks if the tensor's elements are all near multiples of the period.
Expand Down
10 changes: 5 additions & 5 deletions cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def pauli_mask(self) -> np.ndarray:
return self._pauli_mask

@property
def coefficient(self) -> Union[sympy.Expr, complex]:
def coefficient(self) -> 'cirq.TParamValComplex':
"""A complex coefficient or symbol."""
return self._coefficient

Expand Down Expand Up @@ -359,7 +359,7 @@ def __str__(self) -> str:
coef = '+'
elif self.coefficient == -1:
coef = '-'
elif isinstance(self.coefficient, (complex, sympy.Symbol)):
elif isinstance(self.coefficient, (numbers.Complex, sympy.Symbol)):
coef = f'{self.coefficient}*'
else:
coef = f'({self.coefficient})*'
Expand Down Expand Up @@ -403,7 +403,7 @@ def mutable_copy(self) -> 'MutableDensePauliString':
@abc.abstractmethod
def copy(
self,
coefficient: Optional[Union[sympy.Expr, int, float, complex]] = None,
coefficient: Optional['cirq.TParamValComplex'] = None,
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
) -> Self:
"""Returns a copy with possibly modified contents.
Expand Down Expand Up @@ -459,7 +459,7 @@ def frozen(self) -> 'DensePauliString':

def copy(
self,
coefficient: Optional[Union[sympy.Expr, int, float, complex]] = None,
coefficient: Optional['cirq.TParamValComplex'] = None,
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
) -> 'DensePauliString':
if pauli_mask is None and (coefficient is None or coefficient == self.coefficient):
Expand Down Expand Up @@ -559,7 +559,7 @@ def __imul__(self, other):

def copy(
self,
coefficient: Optional[Union[sympy.Expr, int, float, complex]] = None,
coefficient: Optional['cirq.TParamValComplex'] = None,
pauli_mask: Union[None, str, Iterable[int], np.ndarray] = None,
) -> 'MutableDensePauliString':
return MutableDensePauliString(
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/eigen_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _parameter_names_(self) -> AbstractSet[str]:

def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'EigenGate':
exponent = resolver.value_of(self._exponent, recursive)
if isinstance(exponent, (complex, numbers.Complex)):
if isinstance(exponent, numbers.Complex):
if isinstance(exponent, numbers.Real):
exponent = float(exponent)
else:
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/ops/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""IdentityGate."""

import numbers
from types import NotImplementedType
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING, Sequence, Union

Expand Down Expand Up @@ -72,7 +73,7 @@ def num_qubits(self) -> int:
return len(self._qid_shape)

def __pow__(self, power: Any) -> Any:
if isinstance(power, (int, float, complex, sympy.Basic)):
if isinstance(power, (numbers.Complex, sympy.Basic)):
return self
return NotImplemented

Expand Down Expand Up @@ -126,7 +127,7 @@ def _json_dict_(self) -> Dict[str, Any]:
def _mul_with_qubits(self, qubits: Tuple['cirq.Qid', ...], other):
if isinstance(other, raw_types.Operation):
return other
if isinstance(other, (complex, float, int)):
if isinstance(other, numbers.Complex):
from cirq.ops.pauli_string import PauliString

return PauliString(coefficient=other)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/linear_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

UnitPauliStringT = FrozenSet[Tuple[raw_types.Qid, pauli_gates.Pauli]]
PauliSumLike = Union[
int, float, complex, PauliString, 'PauliSum', pauli_string.SingleQubitPauliStringGateOperation
complex, PauliString, 'PauliSum', pauli_string.SingleQubitPauliStringGateOperation
]
document(
PauliSumLike,
Expand Down
29 changes: 13 additions & 16 deletions cirq-core/cirq/ops/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Optional,
overload,
Sequence,
SupportsComplex,
Tuple,
TYPE_CHECKING,
TypeVar,
Expand Down Expand Up @@ -271,9 +270,7 @@ def __mul__(self, other: 'cirq.Operation') -> 'cirq.PauliString[Union[TKey, cirq
pass

@overload
def __mul__(
self, other: Union[complex, int, float, numbers.Number]
) -> 'cirq.PauliString[TKey]':
def __mul__(self, other: complex) -> 'cirq.PauliString[TKey]':
pass

def __mul__(self, other):
Expand Down Expand Up @@ -308,10 +305,9 @@ def gate(self) -> 'cirq.DensePauliString':
)

def __rmul__(self, other) -> 'PauliString':
if isinstance(other, numbers.Number):
if isinstance(other, numbers.Complex):
return PauliString(
qubit_pauli_map=self._qubit_pauli_map,
coefficient=self._coefficient * complex(cast(SupportsComplex, other)),
qubit_pauli_map=self._qubit_pauli_map, coefficient=self._coefficient * other
)

if isinstance(other, raw_types.Operation) and isinstance(other.gate, identity.IdentityGate):
Expand All @@ -321,10 +317,9 @@ def __rmul__(self, other) -> 'PauliString':
return NotImplemented

def __truediv__(self, other):
if isinstance(other, numbers.Number):
if isinstance(other, numbers.Complex):
return PauliString(
qubit_pauli_map=self._qubit_pauli_map,
coefficient=self._coefficient / complex(cast(SupportsComplex, other)),
qubit_pauli_map=self._qubit_pauli_map, coefficient=self._coefficient / other
)
return NotImplemented

Expand Down Expand Up @@ -518,7 +513,7 @@ def _unitary_(self) -> Optional[np.ndarray]:
def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs'):
if not self._has_unitary_():
return None
assert isinstance(self.coefficient, complex)
assert isinstance(self.coefficient, numbers.Complex)
if self.coefficient != 1:
args.target_tensor *= self.coefficient
return protocols.apply_unitaries([self[q].on(q) for q in self.qubits], self.qubits, args)
Expand Down Expand Up @@ -792,9 +787,11 @@ def __pos__(self) -> 'PauliString':
return self

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""Override behavior of numpy's exp method."""
"""Override numpy behavior."""
if ufunc == np.exp and len(inputs) == 1 and inputs[0] is self:
return math.e**self
if ufunc == np.multiply and len(inputs) == 2 and inputs[1] is self:
return self * inputs[0]
return NotImplemented

def __pow__(self, power):
Expand Down Expand Up @@ -1174,14 +1171,14 @@ def _as_pauli_string(self) -> PauliString:
def __mul__(self, other):
if isinstance(other, SingleQubitPauliStringGateOperation):
return self._as_pauli_string() * other._as_pauli_string()
if isinstance(other, (PauliString, complex, float, int)):
if isinstance(other, (PauliString, numbers.Complex)):
return self._as_pauli_string() * other
if (as_pauli_string := _try_interpret_as_pauli_string(other)) is not None:
return self * as_pauli_string
return NotImplemented

def __rmul__(self, other):
if isinstance(other, (PauliString, complex, float, int)):
if isinstance(other, (PauliString, numbers.Complex)):
return other * self._as_pauli_string()
if (as_pauli_string := _try_interpret_as_pauli_string(other)) is not None:
return as_pauli_string * self
Expand Down Expand Up @@ -1430,8 +1427,8 @@ def _imul_helper(self, other: 'cirq.PAULI_STRING_LIKE', sign: int):
pauli_int = _pauli_like_to_pauli_int(qubit, pauli_gate_like)
phase_log_i += self._imul_atom_helper(cast(TKey, qubit), pauli_int, sign)
self.coefficient *= 1j ** (phase_log_i & 3)
elif isinstance(other, numbers.Number):
self.coefficient *= complex(cast(SupportsComplex, other))
elif isinstance(other, numbers.Complex):
self.coefficient *= other
elif isinstance(other, raw_types.Operation) and isinstance(
other.gate, identity.IdentityGate
):
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/pauli_string_phasor.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,14 @@ def _resolve_parameters_(
) -> 'PauliStringPhasorGate':
exponent_neg = resolver.value_of(self.exponent_neg, recursive)
exponent_pos = resolver.value_of(self.exponent_pos, recursive)
if isinstance(exponent_neg, (complex, numbers.Complex)):
if isinstance(exponent_neg, numbers.Complex):
if isinstance(exponent_neg, numbers.Real):
exponent_neg = float(exponent_neg)
else:
raise ValueError(
f'PauliStringPhasorGate does not support complex exponent {exponent_neg}'
)
if isinstance(exponent_pos, (complex, numbers.Complex)):
if isinstance(exponent_pos, numbers.Complex):
if isinstance(exponent_pos, numbers.Real):
exponent_pos = float(exponent_pos)
else:
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ def test_numpy_ufunc():
_ = np.exp(cirq.PauliString())
x = np.exp(1j * np.pi * cirq.PauliString())
assert x is not None
x = np.int64(2) * cirq.PauliString()
assert x == 2 * cirq.PauliString()


def test_map_qubits():
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ def _resolve_parameters_(
"""See `cirq.SupportsParameterization`."""
phase_exponent = resolver.value_of(self._phase_exponent, recursive)
exponent = resolver.value_of(self._exponent, recursive)
if isinstance(phase_exponent, (complex, numbers.Complex)):
if isinstance(phase_exponent, numbers.Complex):
if isinstance(phase_exponent, numbers.Real):
phase_exponent = float(phase_exponent)
else:
raise ValueError(f'PhasedXPowGate does not support complex value {phase_exponent}')
if isinstance(exponent, (complex, numbers.Complex)):
if isinstance(exponent, numbers.Complex):
if isinstance(exponent, numbers.Real):
exponent = float(exponent)
else:
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/phased_x_z_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,17 @@ def _resolve_parameters_(
z_exponent = resolver.value_of(self._z_exponent, recursive)
x_exponent = resolver.value_of(self._x_exponent, recursive)
axis_phase_exponent = resolver.value_of(self._axis_phase_exponent, recursive)
if isinstance(z_exponent, (complex, numbers.Complex)):
if isinstance(z_exponent, numbers.Complex):
if isinstance(z_exponent, numbers.Real):
z_exponent = float(z_exponent)
else:
raise ValueError(f'Complex exponent {z_exponent} not allowed in cirq.PhasedXZGate')
if isinstance(x_exponent, (complex, numbers.Complex)):
if isinstance(x_exponent, numbers.Complex):
if isinstance(x_exponent, numbers.Real):
x_exponent = float(x_exponent)
else:
raise ValueError(f'Complex exponent {x_exponent} not allowed in cirq.PhasedXZGate')
if isinstance(axis_phase_exponent, (complex, numbers.Complex)):
if isinstance(axis_phase_exponent, numbers.Complex):
if isinstance(axis_phase_exponent, numbers.Real):
axis_phase_exponent = float(axis_phase_exponent)
else:
Expand Down
6 changes: 2 additions & 4 deletions cirq-core/cirq/ops/projector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
import itertools
import math
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional

import numpy as np
from scipy.sparse import csr_matrix
Expand All @@ -21,9 +21,7 @@ def _check_qids_dimension(qids):
class ProjectorString:
"""Mapping of `cirq.Qid` to measurement values (with a coefficient) representing a projector."""

def __init__(
self, projector_dict: Dict[raw_types.Qid, int], coefficient: Union[int, float, complex] = 1
):
def __init__(self, projector_dict: Dict[raw_types.Qid, int], coefficient: complex = 1):
"""Constructor for ProjectorString
Args:
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ def __sub__(
def __neg__(self) -> 'cirq.LinearCombinationOfGates':
return self.wrap_in_linear_combination(coefficient=-1)

def __mul__(self, other: Union[complex, float, int]) -> 'cirq.LinearCombinationOfGates':
def __mul__(self, other: complex) -> 'cirq.LinearCombinationOfGates':
return self.wrap_in_linear_combination(coefficient=other)

def __rmul__(self, other: Union[complex, float, int]) -> 'cirq.LinearCombinationOfGates':
def __rmul__(self, other: complex) -> 'cirq.LinearCombinationOfGates':
return self.wrap_in_linear_combination(coefficient=other)

def __truediv__(self, other: Union[complex, float, int]) -> 'cirq.LinearCombinationOfGates':
def __truediv__(self, other: complex) -> 'cirq.LinearCombinationOfGates':
return self.wrap_in_linear_combination(coefficient=1 / other)

def __pow__(self, power):
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/qis/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Sequence[int],
# Explicit state vector or state tensor.
np.ndarray,
Sequence[Union[int, float, complex]],
Sequence[complex],
# Product state object
'cirq.ProductState',
]
Expand Down Expand Up @@ -341,7 +341,7 @@ def _infer_qid_shape_from_dimension(dim: int) -> Tuple[int, ...]:
Sequence[int],
# Explicit state vector or state tensor.
np.ndarray,
Sequence[Union[int, float, complex]],
Sequence[complex],
# Product state object
'cirq.ProductState',
# Quantum state object
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Sequence
import numpy as np

import cirq
Expand Down Expand Up @@ -101,7 +101,7 @@ def __repr__(self) -> str:
"""Return the CH form representation of the state."""
return f'StabilizerStateChForm(num_qubits={self.n!r})'

def inner_product_of_state_and_x(self, x: int) -> Union[float, complex]:
def inner_product_of_state_and_x(self, x: int) -> complex:
"""Returns the amplitude of x'th element of
the state vector, i.e. <x|psi>"""
if type(x) == int:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/study/flatten_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def value_of(
The unique symbol or value of the parameter as resolved by this
resolver.
"""
if isinstance(value, (int, float, complex, numbers.Complex)):
if isinstance(value, numbers.Complex):
return value
if isinstance(value, str):
value = sympy.Symbol(value)
Expand Down Expand Up @@ -380,7 +380,7 @@ def __repr__(self) -> str:


def _ensure_not_str(
param: Union[sympy.Expr, 'cirq.TParamValComplex', str]
param: Union[sympy.Expr, 'cirq.TParamValComplex', str],
) -> Union[sympy.Expr, 'cirq.TParamValComplex']:
if isinstance(param, str):
return sympy.Symbol(param)
Expand Down
Loading

0 comments on commit 72e1d20

Please sign in to comment.