diff --git a/yt/data_objects/construction_data_containers.py b/yt/data_objects/construction_data_containers.py index 4fae89568be..8a195050bbf 100644 --- a/yt/data_objects/construction_data_containers.py +++ b/yt/data_objects/construction_data_containers.py @@ -33,8 +33,9 @@ from yt.geometry import particle_deposit as particle_deposit from yt.geometry.coordinates.cartesian_coordinates import all_data from yt.loaders import load_uniform_grid +from yt.units._numpy_wrapper_functions import uconcatenate from yt.units.unit_object import Unit # type: ignore -from yt.units.yt_array import YTArray, uconcatenate # type: ignore +from yt.units.yt_array import YTArray from yt.utilities.exceptions import ( YTNoAPIKey, YTNotInsideNotebook, diff --git a/yt/data_objects/data_containers.py b/yt/data_objects/data_containers.py index 91b5f6bb187..b4a5318c7e5 100644 --- a/yt/data_objects/data_containers.py +++ b/yt/data_objects/data_containers.py @@ -13,7 +13,8 @@ from yt.fields.field_exceptions import NeedsGridType from yt.frontends.ytdata.utilities import save_as_dataset from yt.funcs import get_output_filename, is_sequence, iter_fields, mylog -from yt.units.yt_array import YTArray, YTQuantity, uconcatenate # type: ignore +from yt.units._numpy_wrapper_functions import uconcatenate +from yt.units.yt_array import YTArray, YTQuantity from yt.utilities.amr_kdtree.api import AMRKDTree from yt.utilities.exceptions import ( YTCouldNotGenerateField, diff --git a/yt/data_objects/selection_objects/ray.py b/yt/data_objects/selection_objects/ray.py index 5d24e7b1330..0faf411130e 100644 --- a/yt/data_objects/selection_objects/ray.py +++ b/yt/data_objects/selection_objects/ray.py @@ -1,5 +1,4 @@ import numpy as np -from unyt import udot, unorm from yt.data_objects.selection_objects.data_selection_objects import ( YTSelectionContainer, @@ -16,6 +15,7 @@ validate_sequence, ) from yt.units import YTArray, YTQuantity +from yt.units._numpy_wrapper_functions import udot, unorm from yt.utilities.lib.pixelization_routines import SPHKernelInterpolationTable from yt.utilities.logger import ytLogger as mylog diff --git a/yt/data_objects/tests/test_chunking.py b/yt/data_objects/tests/test_chunking.py index c0efe200e99..599d958e76a 100644 --- a/yt/data_objects/tests/test_chunking.py +++ b/yt/data_objects/tests/test_chunking.py @@ -1,5 +1,5 @@ from yt.testing import assert_equal, assert_true, fake_random_ds -from yt.units.yt_array import uconcatenate +from yt.units._numpy_wrapper_functions import uconcatenate def _get_dobjs(c): diff --git a/yt/data_objects/tests/test_compose.py b/yt/data_objects/tests/test_compose.py index 9e5cf8615ca..48192abfefa 100644 --- a/yt/data_objects/tests/test_compose.py +++ b/yt/data_objects/tests/test_compose.py @@ -1,7 +1,8 @@ import numpy as np from yt.testing import assert_array_equal, fake_amr_ds, fake_random_ds -from yt.units.yt_array import YTArray, uintersect1d +from yt.units._numpy_wrapper_functions import uintersect1d +from yt.units.yt_array import YTArray def setup(): diff --git a/yt/data_objects/tests/test_rays.py b/yt/data_objects/tests/test_rays.py index afff2904544..b5f847a482b 100644 --- a/yt/data_objects/tests/test_rays.py +++ b/yt/data_objects/tests/test_rays.py @@ -2,7 +2,7 @@ from yt import load from yt.testing import assert_equal, assert_rel_equal, fake_random_ds, requires_file -from yt.units.yt_array import uconcatenate +from yt.units._numpy_wrapper_functions import uconcatenate def test_ray(): diff --git a/yt/fields/particle_fields.py b/yt/fields/particle_fields.py index c1b6ed5c6da..6b9d8ed787b 100644 --- a/yt/fields/particle_fields.py +++ b/yt/fields/particle_fields.py @@ -1,7 +1,7 @@ import numpy as np from yt.fields.derived_field import ValidateParameter, ValidateSpatial -from yt.units.yt_array import uconcatenate, ucross # type: ignore +from yt.units._numpy_wrapper_functions import uconcatenate, ucross from yt.utilities.lib.misc_utilities import ( obtain_position_vector, obtain_relative_velocity_vector, diff --git a/yt/frontends/gadget/io.py b/yt/frontends/gadget/io.py index 3ceb988937f..a68e9b65497 100644 --- a/yt/frontends/gadget/io.py +++ b/yt/frontends/gadget/io.py @@ -6,7 +6,7 @@ import numpy as np from yt.frontends.sph.io import IOHandlerSPH -from yt.units.yt_array import uconcatenate # type: ignore +from yt.units._numpy_wrapper_functions import uconcatenate from yt.utilities.lib.particle_kdtree_tools import generate_smoothing_length from yt.utilities.logger import ytLogger as mylog from yt.utilities.on_demand_imports import _h5py as h5py diff --git a/yt/frontends/ytdata/data_structures.py b/yt/frontends/ytdata/data_structures.py index ab0f6638339..7e622c7394e 100644 --- a/yt/frontends/ytdata/data_structures.py +++ b/yt/frontends/ytdata/data_structures.py @@ -22,8 +22,9 @@ from yt.geometry.grid_geometry_handler import GridIndex from yt.geometry.particle_geometry_handler import ParticleIndex from yt.units import dimensions +from yt.units._numpy_wrapper_functions import uconcatenate from yt.units.unit_registry import UnitRegistry # type: ignore -from yt.units.yt_array import YTQuantity, uconcatenate # type: ignore +from yt.units.yt_array import YTQuantity from yt.utilities.exceptions import GenerationInProgress, YTFieldTypeNotFound from yt.utilities.logger import ytLogger as mylog from yt.utilities.on_demand_imports import _h5py as h5py diff --git a/yt/geometry/coordinates/cartesian_coordinates.py b/yt/geometry/coordinates/cartesian_coordinates.py index b782d5595c8..4568473423a 100644 --- a/yt/geometry/coordinates/cartesian_coordinates.py +++ b/yt/geometry/coordinates/cartesian_coordinates.py @@ -2,7 +2,8 @@ from yt.data_objects.index_subobjects.unstructured_mesh import SemiStructuredMesh from yt.funcs import mylog -from yt.units.yt_array import YTArray, uconcatenate, uvstack # type: ignore +from yt.units._numpy_wrapper_functions import uconcatenate, uvstack +from yt.units.yt_array import YTArray from yt.utilities.lib.pixelization_routines import ( interpolate_sph_grid_gather, normalization_2d_utility, diff --git a/yt/geometry/geometry_handler.py b/yt/geometry/geometry_handler.py index a1f80ec146e..37de3fa5f18 100644 --- a/yt/geometry/geometry_handler.py +++ b/yt/geometry/geometry_handler.py @@ -6,7 +6,8 @@ import numpy as np from yt.config import ytcfg -from yt.units.yt_array import YTArray, uconcatenate # type: ignore +from yt.units._numpy_wrapper_functions import uconcatenate +from yt.units.yt_array import YTArray from yt.utilities.exceptions import YTFieldNotFound from yt.utilities.io_handler import io_registry from yt.utilities.logger import ytLogger as mylog diff --git a/yt/units/__init__.py b/yt/units/__init__.py index 4d838a33ffc..4fb5bd475fc 100644 --- a/yt/units/__init__.py +++ b/yt/units/__init__.py @@ -1,17 +1,8 @@ from unyt.array import ( loadtxt, savetxt, - uconcatenate, - ucross, - udot, - uhstack, - uintersect1d, - unorm, unyt_array, unyt_quantity, - ustack, - uunion1d, - uvstack, ) from unyt.unit_object import Unit, define_unit # NOQA: F401 from unyt.unit_registry import UnitRegistry # NOQA: Ffg401 @@ -22,7 +13,17 @@ from yt.units.unit_symbols import * from yt.units.unit_symbols import _SymbolContainer from yt.utilities.exceptions import YTArrayTooLargeToDisplay - +from yt.units._numpy_wrapper_functions import ( + uconcatenate, + ucross, + udot, + uhstack, + uintersect1d, + unorm, + ustack, + uunion1d, + uvstack, +) YTArray = unyt_array YTQuantity = unyt_quantity diff --git a/yt/units/_numpy_wrapper_functions.py b/yt/units/_numpy_wrapper_functions.py new file mode 100644 index 00000000000..ff330234b0c --- /dev/null +++ b/yt/units/_numpy_wrapper_functions.py @@ -0,0 +1,206 @@ +# This module is not part of the public namespace `yt.units` +# It is home to wrapper functions that are directly copied from unyt 2.9.2 +# We vendor them as a transition step towards unyt 3.0 (in devlopment), +# where these wrapper functions are deprecated and are should be replaced with vanilla numpy API +# FUTURE: +# - require unyt>=3.0 +# - deprecate these functions in yt too + +from unyt import unyt_array, unyt_quantity +import numpy as np + + +def _validate_numpy_wrapper_units(v, arrs): + if not any(isinstance(a, unyt_array) for a in arrs): + return v + if not all(isinstance(a, unyt_array) for a in arrs): + raise RuntimeError("Not all of your arrays are unyt_arrays.") + a1 = arrs[0] + if not all(a.units == a1.units for a in arrs[1:]): + raise RuntimeError("Your arrays must have identical units.") + v.units = a1.units + return v + + +def uconcatenate(arrs, axis=0): + """Concatenate a sequence of arrays. + + This wrapper around numpy.concatenate preserves units. All input arrays + must have the same units. See the documentation of numpy.concatenate for + full details. + + Examples + -------- + >>> from unyt import cm + >>> A = [1, 2, 3]*cm + >>> B = [2, 3, 4]*cm + >>> uconcatenate((A, B)) + unyt_array([1, 2, 3, 2, 3, 4], 'cm') + + """ + v = np.concatenate(arrs, axis=axis) + v = _validate_numpy_wrapper_units(v, arrs) + return v + + +def ucross(arr1, arr2, registry=None, axisa=-1, axisb=-1, axisc=-1, axis=None): + """Applies the cross product to two YT arrays. + + This wrapper around numpy.cross preserves units. + See the documentation of numpy.cross for full + details. + """ + v = np.cross(arr1, arr2, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis) + units = arr1.units * arr2.units + arr = unyt_array(v, units, registry=registry) + return arr + + +def uintersect1d(arr1, arr2, assume_unique=False): + """Find the sorted unique elements of the two input arrays. + + A wrapper around numpy.intersect1d that preserves units. All input arrays + must have the same units. See the documentation of numpy.intersect1d for + full details. + + Examples + -------- + >>> from unyt import cm + >>> A = [1, 2, 3]*cm + >>> B = [2, 3, 4]*cm + >>> uintersect1d(A, B) + unyt_array([2, 3], 'cm') + + """ + v = np.intersect1d(arr1, arr2, assume_unique=assume_unique) + v = _validate_numpy_wrapper_units(v, [arr1, arr2]) + return v + + +def uunion1d(arr1, arr2): + """Find the union of two arrays. + + A wrapper around numpy.intersect1d that preserves units. All input arrays + must have the same units. See the documentation of numpy.intersect1d for + full details. + + Examples + -------- + >>> from unyt import cm + >>> A = [1, 2, 3]*cm + >>> B = [2, 3, 4]*cm + >>> uunion1d(A, B) + unyt_array([1, 2, 3, 4], 'cm') + + """ + v = np.union1d(arr1, arr2) + v = _validate_numpy_wrapper_units(v, [arr1, arr2]) + return v + + +def unorm(data, ord=None, axis=None, keepdims=False): + """Matrix or vector norm that preserves units + + This is a wrapper around np.linalg.norm that preserves units. See + the documentation for that function for descriptions of the keyword + arguments. + + Examples + -------- + >>> from unyt import km + >>> data = [1, 2, 3]*km + >>> print(unorm(data)) + 3.7416573867739413 km + """ + norm = np.linalg.norm(data, ord=ord, axis=axis, keepdims=keepdims) + if norm.shape == (): + return unyt_quantity(norm, data.units) + return unyt_array(norm, data.units) + + +def udot(op1, op2): + """Matrix or vector dot product that preserves units + + This is a wrapper around np.dot that preserves units. + + Examples + -------- + >>> from unyt import km, s + >>> a = np.eye(2)*km + >>> b = (np.ones((2, 2)) * 2)*s + >>> print(udot(a, b)) + [[2. 2.] + [2. 2.]] km*s + """ + dot = np.dot(op1.d, op2.d) + units = op1.units * op2.units + if dot.shape == (): + return unyt_quantity(dot, units) + return unyt_array(dot, units) + + +def uvstack(arrs): + """Stack arrays in sequence vertically (row wise) while preserving units + + This is a wrapper around np.vstack that preserves units. + + Examples + -------- + >>> from unyt import km + >>> a = [1, 2, 3]*km + >>> b = [2, 3, 4]*km + >>> print(uvstack([a, b])) + [[1 2 3] + [2 3 4]] km + """ + v = np.vstack(arrs) + v = _validate_numpy_wrapper_units(v, arrs) + return v + + +def uhstack(arrs): + """Stack arrays in sequence horizontally while preserving units + + This is a wrapper around np.hstack that preserves units. + + Examples + -------- + >>> from unyt import km + >>> a = [1, 2, 3]*km + >>> b = [2, 3, 4]*km + >>> print(uhstack([a, b])) + [1 2 3 2 3 4] km + >>> a = [[1],[2],[3]]*km + >>> b = [[2],[3],[4]]*km + >>> print(uhstack([a, b])) + [[1 2] + [2 3] + [3 4]] km + """ + v = np.hstack(arrs) + v = _validate_numpy_wrapper_units(v, arrs) + return v + + +def ustack(arrs, axis=0): + """Join a sequence of arrays along a new axis while preserving units + + The axis parameter specifies the index of the new axis in the + dimensions of the result. For example, if ``axis=0`` it will be the + first dimension and if ``axis=-1`` it will be the last dimension. + + This is a wrapper around np.stack that preserves units. See the + documentation for np.stack for full details. + + Examples + -------- + >>> from unyt import km + >>> a = [1, 2, 3]*km + >>> b = [2, 3, 4]*km + >>> print(ustack([a, b])) + [[1 2 3] + [2 3 4]] km + """ + v = np.stack(arrs, axis=axis) + v = _validate_numpy_wrapper_units(v, arrs) + return v diff --git a/yt/utilities/particle_generator.py b/yt/utilities/particle_generator.py index 18a2b24a8c7..c9d606e0c0a 100644 --- a/yt/utilities/particle_generator.py +++ b/yt/utilities/particle_generator.py @@ -1,7 +1,7 @@ import numpy as np from yt.funcs import get_pbar -from yt.units.yt_array import uconcatenate # type: ignore +from yt.units._numpy_wrapper_functions import uconcatenate from yt.utilities.lib.particle_mesh_operations import CICSample_3 diff --git a/yt/utilities/tests/test_particle_generator.py b/yt/utilities/tests/test_particle_generator.py index 9cfb40bb4c2..a992c66aa70 100644 --- a/yt/utilities/tests/test_particle_generator.py +++ b/yt/utilities/tests/test_particle_generator.py @@ -2,7 +2,7 @@ from yt.loaders import load_uniform_grid from yt.testing import assert_almost_equal, assert_equal -from yt.units.yt_array import uconcatenate +from yt.units._numpy_wrapper_functions import uconcatenate from yt.utilities.particle_generator import ( FromListParticleGenerator, LatticeParticleGenerator, diff --git a/yt/visualization/volume_rendering/lens.py b/yt/visualization/volume_rendering/lens.py index 55aa0cc71d2..633c57bc534 100644 --- a/yt/visualization/volume_rendering/lens.py +++ b/yt/visualization/volume_rendering/lens.py @@ -1,7 +1,7 @@ import numpy as np from yt.data_objects.image_array import ImageArray -from yt.units.yt_array import uhstack, unorm, uvstack # type: ignore +from yt.units._numpy_wrapper_functions import uhstack, unorm, uvstack from yt.utilities.lib.grid_traversal import arr_fisheye_vectors from yt.utilities.math_utils import get_rotation_matrix from yt.utilities.parallel_tools.parallel_analysis_interface import ( diff --git a/yt/visualization/volume_rendering/old_camera.py b/yt/visualization/volume_rendering/old_camera.py index e0bd2e8ad7e..680408066f1 100644 --- a/yt/visualization/volume_rendering/old_camera.py +++ b/yt/visualization/volume_rendering/old_camera.py @@ -649,7 +649,7 @@ def new_image(self): def get_sampler_args(self, image): rotp = np.concatenate( - [self.orienter.inv_mat.ravel("F"), self.back_center.ravel()] + [self.orienter.inv_mat.ravel("F"), self.back_center.ravel().ndview] ) args = ( np.atleast_3d(rotp), @@ -2125,7 +2125,7 @@ def initialize_source(self): def get_sampler_args(self, image): rotp = np.concatenate( - [self.orienter.inv_mat.ravel("F"), self.back_center.ravel()] + [self.orienter.inv_mat.ravel("F"), self.back_center.ravel().ndview] ) args = ( np.atleast_3d(rotp),